mirror of
https://github.com/cexll/myclaude.git
synced 2026-02-05 02:30:26 +08:00
Improve backend termination after message and extend timeout (#86)
* Improve backend termination after message and extend timeout
* fix: prevent premature backend termination and revert timeout
Critical fixes for executor.go termination logic:
1. Add onComplete callback to prevent premature termination
- Parser now distinguishes between "any message" (onMessage) and
"terminal event" (onComplete)
- Codex: triggers onComplete on thread.completed
- Claude: triggers onComplete on type:"result"
- Gemini: triggers onComplete on type:"result" + terminal status
2. Fix executor to wait for completion events
- Replace messageSeen termination trigger with completeSeen
- Only start postMessageTerminateDelay after terminal event
- Prevents killing backend before final answer in multi-message scenarios
3. Fix terminated flag synchronization
- Only set terminated=true if terminateCommandFn actually succeeds
- Prevents "marked as terminated but not actually terminated" state
4. Simplify timer cleanup logic
- Unified non-blocking drain on messageTimer.C
- Remove dependency on messageTimerCh nil state
5. Revert defaultTimeout from 24h to 2h
- 24h (86400s) → 2h (7200s) to avoid operational risks
- 12× timeout increase could cause resource exhaustion
- Users needing longer tasks can use CODEX_TIMEOUT env var
All tests pass. Resolves early termination bug from code review.
Co-authored-by: Codeagent (Codex)
Generated with SWE-Agent.ai
Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>
---------
Co-authored-by: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
@@ -16,6 +16,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const postMessageTerminateDelay = 1 * time.Second
|
||||||
|
|
||||||
// commandRunner abstracts exec.Cmd for testability
|
// commandRunner abstracts exec.Cmd for testability
|
||||||
type commandRunner interface {
|
type commandRunner interface {
|
||||||
Start() error
|
Start() error
|
||||||
@@ -729,6 +731,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
// Start parse goroutine BEFORE starting the command to avoid race condition
|
// Start parse goroutine BEFORE starting the command to avoid race condition
|
||||||
// where fast-completing commands close stdout before parser starts reading
|
// where fast-completing commands close stdout before parser starts reading
|
||||||
messageSeen := make(chan struct{}, 1)
|
messageSeen := make(chan struct{}, 1)
|
||||||
|
completeSeen := make(chan struct{}, 1)
|
||||||
parseCh := make(chan parseResult, 1)
|
parseCh := make(chan parseResult, 1)
|
||||||
go func() {
|
go func() {
|
||||||
msg, tid := parseJSONStreamInternal(stdoutReader, logWarnFn, logInfoFn, func() {
|
msg, tid := parseJSONStreamInternal(stdoutReader, logWarnFn, logInfoFn, func() {
|
||||||
@@ -736,6 +739,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
case messageSeen <- struct{}{}:
|
case messageSeen <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
}, func() {
|
||||||
|
select {
|
||||||
|
case completeSeen <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
})
|
})
|
||||||
parseCh <- parseResult{message: msg, threadID: tid}
|
parseCh <- parseResult{message: msg, threadID: tid}
|
||||||
}()
|
}()
|
||||||
@@ -773,17 +781,63 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
waitCh := make(chan error, 1)
|
waitCh := make(chan error, 1)
|
||||||
go func() { waitCh <- cmd.Wait() }()
|
go func() { waitCh <- cmd.Wait() }()
|
||||||
|
|
||||||
var waitErr error
|
var (
|
||||||
var forceKillTimer *forceKillTimer
|
waitErr error
|
||||||
var ctxCancelled bool
|
forceKillTimer *forceKillTimer
|
||||||
|
ctxCancelled bool
|
||||||
|
messageTimer *time.Timer
|
||||||
|
messageTimerCh <-chan time.Time
|
||||||
|
forcedAfterComplete bool
|
||||||
|
terminated bool
|
||||||
|
messageSeenObserved bool
|
||||||
|
completeSeenObserved bool
|
||||||
|
)
|
||||||
|
|
||||||
select {
|
waitLoop:
|
||||||
case waitErr = <-waitCh:
|
for {
|
||||||
case <-ctx.Done():
|
select {
|
||||||
ctxCancelled = true
|
case waitErr = <-waitCh:
|
||||||
logErrorFn(cancelReason(commandName, ctx))
|
break waitLoop
|
||||||
forceKillTimer = terminateCommandFn(cmd)
|
case <-ctx.Done():
|
||||||
waitErr = <-waitCh
|
ctxCancelled = true
|
||||||
|
logErrorFn(cancelReason(commandName, ctx))
|
||||||
|
if !terminated {
|
||||||
|
if timer := terminateCommandFn(cmd); timer != nil {
|
||||||
|
forceKillTimer = timer
|
||||||
|
terminated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
waitErr = <-waitCh
|
||||||
|
break waitLoop
|
||||||
|
case <-messageTimerCh:
|
||||||
|
forcedAfterComplete = true
|
||||||
|
messageTimerCh = nil
|
||||||
|
if !terminated {
|
||||||
|
logWarnFn(fmt.Sprintf("%s output parsed; terminating lingering backend", commandName))
|
||||||
|
if timer := terminateCommandFn(cmd); timer != nil {
|
||||||
|
forceKillTimer = timer
|
||||||
|
terminated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-completeSeen:
|
||||||
|
completeSeenObserved = true
|
||||||
|
if messageTimer != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messageTimer = time.NewTimer(postMessageTerminateDelay)
|
||||||
|
messageTimerCh = messageTimer.C
|
||||||
|
case <-messageSeen:
|
||||||
|
messageSeenObserved = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if messageTimer != nil {
|
||||||
|
if !messageTimer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-messageTimer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if forceKillTimer != nil {
|
if forceKillTimer != nil {
|
||||||
@@ -791,10 +845,14 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
}
|
}
|
||||||
|
|
||||||
var parsed parseResult
|
var parsed parseResult
|
||||||
if ctxCancelled {
|
switch {
|
||||||
|
case ctxCancelled:
|
||||||
closeWithReason(stdout, stdoutCloseReasonCtx)
|
closeWithReason(stdout, stdoutCloseReasonCtx)
|
||||||
parsed = <-parseCh
|
parsed = <-parseCh
|
||||||
} else {
|
case messageSeenObserved || completeSeenObserved:
|
||||||
|
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||||
|
parsed = <-parseCh
|
||||||
|
default:
|
||||||
drainTimer := time.NewTimer(stdoutDrainTimeout)
|
drainTimer := time.NewTimer(stdoutDrainTimeout)
|
||||||
defer drainTimer.Stop()
|
defer drainTimer.Stop()
|
||||||
|
|
||||||
@@ -802,6 +860,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
case parsed = <-parseCh:
|
case parsed = <-parseCh:
|
||||||
closeWithReason(stdout, stdoutCloseReasonWait)
|
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||||
case <-messageSeen:
|
case <-messageSeen:
|
||||||
|
messageSeenObserved = true
|
||||||
|
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||||
|
parsed = <-parseCh
|
||||||
|
case <-completeSeen:
|
||||||
|
completeSeenObserved = true
|
||||||
closeWithReason(stdout, stdoutCloseReasonWait)
|
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||||
parsed = <-parseCh
|
parsed = <-parseCh
|
||||||
case <-drainTimer.C:
|
case <-drainTimer.C:
|
||||||
@@ -822,17 +885,21 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if waitErr != nil {
|
if waitErr != nil {
|
||||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
if forcedAfterComplete && parsed.message != "" {
|
||||||
code := exitErr.ExitCode()
|
logWarnFn(fmt.Sprintf("%s terminated after delivering output", commandName))
|
||||||
logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code))
|
} else {
|
||||||
result.ExitCode = code
|
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||||
result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code))
|
code := exitErr.ExitCode()
|
||||||
|
logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code))
|
||||||
|
result.ExitCode = code
|
||||||
|
result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code))
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
logErrorFn(commandName + " error: " + waitErr.Error())
|
||||||
|
result.ExitCode = 1
|
||||||
|
result.Error = attachStderr(commandName + " error: " + waitErr.Error())
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
logErrorFn(commandName + " error: " + waitErr.Error())
|
|
||||||
result.ExitCode = 1
|
|
||||||
result.Error = attachStderr(commandName + " error: " + waitErr.Error())
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message := parsed.message
|
message := parsed.message
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
version = "5.2.5"
|
version = "5.2.6"
|
||||||
defaultWorkdir = "."
|
defaultWorkdir = "."
|
||||||
defaultTimeout = 7200 // seconds
|
defaultTimeout = 7200 // seconds (2 hours)
|
||||||
codexLogLineLimit = 1000
|
codexLogLineLimit = 1000
|
||||||
stdinSpecialChars = "\n\\\"'`$"
|
stdinSpecialChars = "\n\\\"'`$"
|
||||||
stderrCaptureLimit = 4 * 1024
|
stderrCaptureLimit = 4 * 1024
|
||||||
|
|||||||
@@ -879,6 +879,79 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) {
|
||||||
|
defer resetTestHooks()
|
||||||
|
forceKillDelay.Store(0)
|
||||||
|
|
||||||
|
fake := newFakeCmd(fakeCmdConfig{
|
||||||
|
StdoutPlan: []fakeStdoutEvent{
|
||||||
|
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}` + "\n"},
|
||||||
|
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
|
||||||
|
},
|
||||||
|
KeepStdoutOpen: true,
|
||||||
|
BlockWait: true,
|
||||||
|
ReleaseWaitOnSignal: true,
|
||||||
|
ReleaseWaitOnKill: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
|
||||||
|
return fake
|
||||||
|
}
|
||||||
|
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
|
||||||
|
codexCommand = "fake-cmd"
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
if result.ExitCode != 0 || result.Message != "done" {
|
||||||
|
t.Fatalf("unexpected result: %+v", result)
|
||||||
|
}
|
||||||
|
if duration > 2*time.Second {
|
||||||
|
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
|
||||||
|
}
|
||||||
|
if fake.process.SignalCount() == 0 {
|
||||||
|
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) {
|
||||||
|
defer resetTestHooks()
|
||||||
|
forceKillDelay.Store(0)
|
||||||
|
|
||||||
|
fake := newFakeCmd(fakeCmdConfig{
|
||||||
|
StdoutPlan: []fakeStdoutEvent{
|
||||||
|
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"intermediate"}}` + "\n"},
|
||||||
|
{Delay: 1100 * time.Millisecond, Data: `{"type":"item.completed","item":{"type":"agent_message","text":"final"}}` + "\n"},
|
||||||
|
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
|
||||||
|
},
|
||||||
|
KeepStdoutOpen: true,
|
||||||
|
BlockWait: true,
|
||||||
|
ReleaseWaitOnSignal: true,
|
||||||
|
ReleaseWaitOnKill: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
|
||||||
|
return fake
|
||||||
|
}
|
||||||
|
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
|
||||||
|
codexCommand = "fake-cmd"
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
if result.ExitCode != 0 || result.Message != "final" {
|
||||||
|
t.Fatalf("unexpected result: %+v", result)
|
||||||
|
}
|
||||||
|
if duration > 5*time.Second {
|
||||||
|
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
|
||||||
|
}
|
||||||
|
if fake.process.SignalCount() == 0 {
|
||||||
|
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBackendParseArgs_NewMode(t *testing.T) {
|
func TestBackendParseArgs_NewMode(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -1650,7 +1723,7 @@ func TestBackendParseJSONStream_GeminiEvents_OnMessageTriggeredOnStatus(t *testi
|
|||||||
var called int
|
var called int
|
||||||
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||||
called++
|
called++
|
||||||
})
|
}, nil)
|
||||||
|
|
||||||
if message != "Hi there" {
|
if message != "Hi there" {
|
||||||
t.Fatalf("message=%q, want %q", message, "Hi there")
|
t.Fatalf("message=%q, want %q", message, "Hi there")
|
||||||
@@ -1679,7 +1752,7 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
|
|||||||
var called int
|
var called int
|
||||||
message, threadID := parseJSONStreamInternal(strings.NewReader(`{"type":"item.completed","item":{"type":"agent_message","text":"hook"}}`), nil, nil, func() {
|
message, threadID := parseJSONStreamInternal(strings.NewReader(`{"type":"item.completed","item":{"type":"agent_message","text":"hook"}}`), nil, nil, func() {
|
||||||
called++
|
called++
|
||||||
})
|
}, nil)
|
||||||
if message != "hook" {
|
if message != "hook" {
|
||||||
t.Fatalf("message = %q, want hook", message)
|
t.Fatalf("message = %q, want hook", message)
|
||||||
}
|
}
|
||||||
@@ -1691,10 +1764,86 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackendParseJSONStream_OnComplete_CodexThreadCompleted(t *testing.T) {
|
||||||
|
input := `{"type":"item.completed","item":{"type":"agent_message","text":"first"}}` + "\n" +
|
||||||
|
`{"type":"item.completed","item":{"type":"agent_message","text":"second"}}` + "\n" +
|
||||||
|
`{"type":"thread.completed","thread_id":"t-1"}`
|
||||||
|
|
||||||
|
var onMessageCalls int
|
||||||
|
var onCompleteCalls int
|
||||||
|
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||||
|
onMessageCalls++
|
||||||
|
}, func() {
|
||||||
|
onCompleteCalls++
|
||||||
|
})
|
||||||
|
if message != "second" {
|
||||||
|
t.Fatalf("message = %q, want second", message)
|
||||||
|
}
|
||||||
|
if threadID != "t-1" {
|
||||||
|
t.Fatalf("threadID = %q, want t-1", threadID)
|
||||||
|
}
|
||||||
|
if onMessageCalls != 2 {
|
||||||
|
t.Fatalf("onMessage calls = %d, want 2", onMessageCalls)
|
||||||
|
}
|
||||||
|
if onCompleteCalls != 1 {
|
||||||
|
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendParseJSONStream_OnComplete_ClaudeResult(t *testing.T) {
|
||||||
|
input := `{"type":"message","subtype":"stream","session_id":"s-1"}` + "\n" +
|
||||||
|
`{"type":"result","result":"OK","session_id":"s-1"}`
|
||||||
|
|
||||||
|
var onMessageCalls int
|
||||||
|
var onCompleteCalls int
|
||||||
|
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||||
|
onMessageCalls++
|
||||||
|
}, func() {
|
||||||
|
onCompleteCalls++
|
||||||
|
})
|
||||||
|
if message != "OK" {
|
||||||
|
t.Fatalf("message = %q, want OK", message)
|
||||||
|
}
|
||||||
|
if threadID != "s-1" {
|
||||||
|
t.Fatalf("threadID = %q, want s-1", threadID)
|
||||||
|
}
|
||||||
|
if onMessageCalls != 1 {
|
||||||
|
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
|
||||||
|
}
|
||||||
|
if onCompleteCalls != 1 {
|
||||||
|
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendParseJSONStream_OnComplete_GeminiTerminalResultStatus(t *testing.T) {
|
||||||
|
input := `{"type":"message","role":"assistant","content":"Hi","delta":true,"session_id":"g-1"}` + "\n" +
|
||||||
|
`{"type":"result","status":"success","session_id":"g-1"}`
|
||||||
|
|
||||||
|
var onMessageCalls int
|
||||||
|
var onCompleteCalls int
|
||||||
|
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||||
|
onMessageCalls++
|
||||||
|
}, func() {
|
||||||
|
onCompleteCalls++
|
||||||
|
})
|
||||||
|
if message != "Hi" {
|
||||||
|
t.Fatalf("message = %q, want Hi", message)
|
||||||
|
}
|
||||||
|
if threadID != "g-1" {
|
||||||
|
t.Fatalf("threadID = %q, want g-1", threadID)
|
||||||
|
}
|
||||||
|
if onMessageCalls != 1 {
|
||||||
|
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
|
||||||
|
}
|
||||||
|
if onCompleteCalls != 1 {
|
||||||
|
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBackendParseJSONStream_ScannerError(t *testing.T) {
|
func TestBackendParseJSONStream_ScannerError(t *testing.T) {
|
||||||
var warnings []string
|
var warnings []string
|
||||||
warnFn := func(msg string) { warnings = append(warnings, msg) }
|
warnFn := func(msg string) { warnings = append(warnings, msg) }
|
||||||
message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil)
|
message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil, nil)
|
||||||
if message != "" || threadID != "" {
|
if message != "" || threadID != "" {
|
||||||
t.Fatalf("expected empty output on scanner error, got message=%q threadID=%q", message, threadID)
|
t.Fatalf("expected empty output on scanner error, got message=%q threadID=%q", message, threadID)
|
||||||
}
|
}
|
||||||
@@ -2756,7 +2905,7 @@ func TestVersionFlag(t *testing.T) {
|
|||||||
t.Errorf("exit = %d, want 0", code)
|
t.Errorf("exit = %d, want 0", code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
want := "codeagent-wrapper version 5.2.5\n"
|
want := "codeagent-wrapper version 5.2.6\n"
|
||||||
if output != want {
|
if output != want {
|
||||||
t.Fatalf("output = %q, want %q", output, want)
|
t.Fatalf("output = %q, want %q", output, want)
|
||||||
}
|
}
|
||||||
@@ -2770,7 +2919,7 @@ func TestVersionShortFlag(t *testing.T) {
|
|||||||
t.Errorf("exit = %d, want 0", code)
|
t.Errorf("exit = %d, want 0", code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
want := "codeagent-wrapper version 5.2.5\n"
|
want := "codeagent-wrapper version 5.2.6\n"
|
||||||
if output != want {
|
if output != want {
|
||||||
t.Fatalf("output = %q, want %q", output, want)
|
t.Fatalf("output = %q, want %q", output, want)
|
||||||
}
|
}
|
||||||
@@ -2784,7 +2933,7 @@ func TestVersionLegacyAlias(t *testing.T) {
|
|||||||
t.Errorf("exit = %d, want 0", code)
|
t.Errorf("exit = %d, want 0", code)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
want := "codex-wrapper version 5.2.5\n"
|
want := "codex-wrapper version 5.2.6\n"
|
||||||
if output != want {
|
if output != want {
|
||||||
t.Fatalf("output = %q, want %q", output, want)
|
t.Fatalf("output = %q, want %q", output, want)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) {
|
func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) {
|
||||||
return parseJSONStreamInternal(r, warnFn, infoFn, nil)
|
return parseJSONStreamInternal(r, warnFn, infoFn, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -95,7 +95,7 @@ type ItemContent struct {
|
|||||||
Text interface{} `json:"text"`
|
Text interface{} `json:"text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func()) (message, threadID string) {
|
func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) {
|
||||||
reader := bufio.NewReaderSize(r, jsonLineReaderSize)
|
reader := bufio.NewReaderSize(r, jsonLineReaderSize)
|
||||||
|
|
||||||
if warnFn == nil {
|
if warnFn == nil {
|
||||||
@@ -111,6 +111,12 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
notifyComplete := func() {
|
||||||
|
if onComplete != nil {
|
||||||
|
onComplete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
totalEvents := 0
|
totalEvents := 0
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -158,6 +164,9 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
isClaude := event.Subtype != "" || event.Result != ""
|
isClaude := event.Subtype != "" || event.Result != ""
|
||||||
|
if !isClaude && event.Type == "result" && event.SessionID != "" && event.Status == "" {
|
||||||
|
isClaude = true
|
||||||
|
}
|
||||||
isGemini := event.Role != "" || event.Delta != nil || event.Status != ""
|
isGemini := event.Role != "" || event.Delta != nil || event.Status != ""
|
||||||
|
|
||||||
// Handle Codex events
|
// Handle Codex events
|
||||||
@@ -178,6 +187,13 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
|||||||
threadID = event.ThreadID
|
threadID = event.ThreadID
|
||||||
infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID))
|
infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID))
|
||||||
|
|
||||||
|
case "thread.completed":
|
||||||
|
if event.ThreadID != "" && threadID == "" {
|
||||||
|
threadID = event.ThreadID
|
||||||
|
}
|
||||||
|
infoFn(fmt.Sprintf("thread.completed event thread_id=%s", event.ThreadID))
|
||||||
|
notifyComplete()
|
||||||
|
|
||||||
case "item.completed":
|
case "item.completed":
|
||||||
var itemType string
|
var itemType string
|
||||||
if len(event.Item) > 0 {
|
if len(event.Item) > 0 {
|
||||||
@@ -221,6 +237,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
|||||||
claudeMessage = event.Result
|
claudeMessage = event.Result
|
||||||
notifyMessage()
|
notifyMessage()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if event.Type == "result" {
|
||||||
|
notifyComplete()
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,6 +256,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
|||||||
|
|
||||||
if event.Status != "" {
|
if event.Status != "" {
|
||||||
notifyMessage()
|
notifyMessage()
|
||||||
|
|
||||||
|
if event.Type == "result" && (event.Status == "success" || event.Status == "error" || event.Status == "complete" || event.Status == "failed") {
|
||||||
|
notifyComplete()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delta := false
|
delta := false
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func TestParseJSONStream_SkipsOverlongLineAndContinues(t *testing.T) {
|
|||||||
var warns []string
|
var warns []string
|
||||||
warnFn := func(msg string) { warns = append(warns, msg) }
|
warnFn := func(msg string) { warns = append(warns, msg) }
|
||||||
|
|
||||||
gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil)
|
gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil, nil)
|
||||||
if gotMessage != "ok" {
|
if gotMessage != "ok" {
|
||||||
t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns)
|
t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user