diff --git a/codeagent-wrapper/executor.go b/codeagent-wrapper/executor.go index c15b068..0762f3b 100644 --- a/codeagent-wrapper/executor.go +++ b/codeagent-wrapper/executor.go @@ -16,6 +16,8 @@ import ( "time" ) +const postMessageTerminateDelay = 1 * time.Second + // commandRunner abstracts exec.Cmd for testability type commandRunner interface { Start() error @@ -729,6 +731,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe // Start parse goroutine BEFORE starting the command to avoid race condition // where fast-completing commands close stdout before parser starts reading messageSeen := make(chan struct{}, 1) + completeSeen := make(chan struct{}, 1) parseCh := make(chan parseResult, 1) go func() { msg, tid := parseJSONStreamInternal(stdoutReader, logWarnFn, logInfoFn, func() { @@ -736,6 +739,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe case messageSeen <- struct{}{}: default: } + }, func() { + select { + case completeSeen <- struct{}{}: + default: + } }) parseCh <- parseResult{message: msg, threadID: tid} }() @@ -773,17 +781,63 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe waitCh := make(chan error, 1) go func() { waitCh <- cmd.Wait() }() - var waitErr error - var forceKillTimer *forceKillTimer - var ctxCancelled bool + var ( + waitErr error + forceKillTimer *forceKillTimer + ctxCancelled bool + messageTimer *time.Timer + messageTimerCh <-chan time.Time + forcedAfterComplete bool + terminated bool + messageSeenObserved bool + completeSeenObserved bool + ) - select { - case waitErr = <-waitCh: - case <-ctx.Done(): - ctxCancelled = true - logErrorFn(cancelReason(commandName, ctx)) - forceKillTimer = terminateCommandFn(cmd) - waitErr = <-waitCh +waitLoop: + for { + select { + case waitErr = <-waitCh: + break waitLoop + case <-ctx.Done(): + ctxCancelled = true + logErrorFn(cancelReason(commandName, ctx)) + if !terminated { + if timer := terminateCommandFn(cmd); timer != nil { + forceKillTimer = timer + terminated = true + } + } + waitErr = <-waitCh + break waitLoop + case <-messageTimerCh: + forcedAfterComplete = true + messageTimerCh = nil + if !terminated { + logWarnFn(fmt.Sprintf("%s output parsed; terminating lingering backend", commandName)) + if timer := terminateCommandFn(cmd); timer != nil { + forceKillTimer = timer + terminated = true + } + } + case <-completeSeen: + completeSeenObserved = true + if messageTimer != nil { + continue + } + messageTimer = time.NewTimer(postMessageTerminateDelay) + messageTimerCh = messageTimer.C + case <-messageSeen: + messageSeenObserved = true + } + } + + if messageTimer != nil { + if !messageTimer.Stop() { + select { + case <-messageTimer.C: + default: + } + } } if forceKillTimer != nil { @@ -791,10 +845,14 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe } var parsed parseResult - if ctxCancelled { + switch { + case ctxCancelled: closeWithReason(stdout, stdoutCloseReasonCtx) parsed = <-parseCh - } else { + case messageSeenObserved || completeSeenObserved: + closeWithReason(stdout, stdoutCloseReasonWait) + parsed = <-parseCh + default: drainTimer := time.NewTimer(stdoutDrainTimeout) defer drainTimer.Stop() @@ -802,6 +860,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe case parsed = <-parseCh: closeWithReason(stdout, stdoutCloseReasonWait) case <-messageSeen: + messageSeenObserved = true + closeWithReason(stdout, stdoutCloseReasonWait) + parsed = <-parseCh + case <-completeSeen: + completeSeenObserved = true closeWithReason(stdout, stdoutCloseReasonWait) parsed = <-parseCh case <-drainTimer.C: @@ -822,17 +885,21 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe } if waitErr != nil { - if exitErr, ok := waitErr.(*exec.ExitError); ok { - code := exitErr.ExitCode() - logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code)) - result.ExitCode = code - result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code)) + if forcedAfterComplete && parsed.message != "" { + logWarnFn(fmt.Sprintf("%s terminated after delivering output", commandName)) + } else { + if exitErr, ok := waitErr.(*exec.ExitError); ok { + code := exitErr.ExitCode() + logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code)) + result.ExitCode = code + result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code)) + return result + } + logErrorFn(commandName + " error: " + waitErr.Error()) + result.ExitCode = 1 + result.Error = attachStderr(commandName + " error: " + waitErr.Error()) return result } - logErrorFn(commandName + " error: " + waitErr.Error()) - result.ExitCode = 1 - result.Error = attachStderr(commandName + " error: " + waitErr.Error()) - return result } message := parsed.message diff --git a/codeagent-wrapper/main.go b/codeagent-wrapper/main.go index f923a3e..33719a6 100644 --- a/codeagent-wrapper/main.go +++ b/codeagent-wrapper/main.go @@ -14,9 +14,9 @@ import ( ) const ( - version = "5.2.5" + version = "5.2.6" defaultWorkdir = "." - defaultTimeout = 7200 // seconds + defaultTimeout = 7200 // seconds (2 hours) codexLogLineLimit = 1000 stdinSpecialChars = "\n\\\"'`$" stderrCaptureLimit = 4 * 1024 diff --git a/codeagent-wrapper/main_test.go b/codeagent-wrapper/main_test.go index e5fc37b..337d9f1 100644 --- a/codeagent-wrapper/main_test.go +++ b/codeagent-wrapper/main_test.go @@ -879,6 +879,79 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) { } } +func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) { + defer resetTestHooks() + forceKillDelay.Store(0) + + fake := newFakeCmd(fakeCmdConfig{ + StdoutPlan: []fakeStdoutEvent{ + {Data: `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}` + "\n"}, + {Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"}, + }, + KeepStdoutOpen: true, + BlockWait: true, + ReleaseWaitOnSignal: true, + ReleaseWaitOnKill: true, + }) + + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return fake + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } + codexCommand = "fake-cmd" + + start := time.Now() + result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60) + duration := time.Since(start) + + if result.ExitCode != 0 || result.Message != "done" { + t.Fatalf("unexpected result: %+v", result) + } + if duration > 2*time.Second { + t.Fatalf("runCodexTaskWithContext took too long: %v", duration) + } + if fake.process.SignalCount() == 0 { + t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount()) + } +} + +func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) { + defer resetTestHooks() + forceKillDelay.Store(0) + + fake := newFakeCmd(fakeCmdConfig{ + StdoutPlan: []fakeStdoutEvent{ + {Data: `{"type":"item.completed","item":{"type":"agent_message","text":"intermediate"}}` + "\n"}, + {Delay: 1100 * time.Millisecond, Data: `{"type":"item.completed","item":{"type":"agent_message","text":"final"}}` + "\n"}, + {Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"}, + }, + KeepStdoutOpen: true, + BlockWait: true, + ReleaseWaitOnSignal: true, + ReleaseWaitOnKill: true, + }) + + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return fake + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } + codexCommand = "fake-cmd" + + start := time.Now() + result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60) + duration := time.Since(start) + + if result.ExitCode != 0 || result.Message != "final" { + t.Fatalf("unexpected result: %+v", result) + } + if duration > 5*time.Second { + t.Fatalf("runCodexTaskWithContext took too long: %v", duration) + } + if fake.process.SignalCount() == 0 { + t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount()) + } +} + func TestBackendParseArgs_NewMode(t *testing.T) { tests := []struct { name string @@ -1650,7 +1723,7 @@ func TestBackendParseJSONStream_GeminiEvents_OnMessageTriggeredOnStatus(t *testi var called int message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() { called++ - }) + }, nil) if message != "Hi there" { t.Fatalf("message=%q, want %q", message, "Hi there") @@ -1679,7 +1752,7 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) { var called int message, threadID := parseJSONStreamInternal(strings.NewReader(`{"type":"item.completed","item":{"type":"agent_message","text":"hook"}}`), nil, nil, func() { called++ - }) + }, nil) if message != "hook" { t.Fatalf("message = %q, want hook", message) } @@ -1691,10 +1764,86 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) { } } +func TestBackendParseJSONStream_OnComplete_CodexThreadCompleted(t *testing.T) { + input := `{"type":"item.completed","item":{"type":"agent_message","text":"first"}}` + "\n" + + `{"type":"item.completed","item":{"type":"agent_message","text":"second"}}` + "\n" + + `{"type":"thread.completed","thread_id":"t-1"}` + + var onMessageCalls int + var onCompleteCalls int + message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() { + onMessageCalls++ + }, func() { + onCompleteCalls++ + }) + if message != "second" { + t.Fatalf("message = %q, want second", message) + } + if threadID != "t-1" { + t.Fatalf("threadID = %q, want t-1", threadID) + } + if onMessageCalls != 2 { + t.Fatalf("onMessage calls = %d, want 2", onMessageCalls) + } + if onCompleteCalls != 1 { + t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls) + } +} + +func TestBackendParseJSONStream_OnComplete_ClaudeResult(t *testing.T) { + input := `{"type":"message","subtype":"stream","session_id":"s-1"}` + "\n" + + `{"type":"result","result":"OK","session_id":"s-1"}` + + var onMessageCalls int + var onCompleteCalls int + message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() { + onMessageCalls++ + }, func() { + onCompleteCalls++ + }) + if message != "OK" { + t.Fatalf("message = %q, want OK", message) + } + if threadID != "s-1" { + t.Fatalf("threadID = %q, want s-1", threadID) + } + if onMessageCalls != 1 { + t.Fatalf("onMessage calls = %d, want 1", onMessageCalls) + } + if onCompleteCalls != 1 { + t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls) + } +} + +func TestBackendParseJSONStream_OnComplete_GeminiTerminalResultStatus(t *testing.T) { + input := `{"type":"message","role":"assistant","content":"Hi","delta":true,"session_id":"g-1"}` + "\n" + + `{"type":"result","status":"success","session_id":"g-1"}` + + var onMessageCalls int + var onCompleteCalls int + message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() { + onMessageCalls++ + }, func() { + onCompleteCalls++ + }) + if message != "Hi" { + t.Fatalf("message = %q, want Hi", message) + } + if threadID != "g-1" { + t.Fatalf("threadID = %q, want g-1", threadID) + } + if onMessageCalls != 1 { + t.Fatalf("onMessage calls = %d, want 1", onMessageCalls) + } + if onCompleteCalls != 1 { + t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls) + } +} + func TestBackendParseJSONStream_ScannerError(t *testing.T) { var warnings []string warnFn := func(msg string) { warnings = append(warnings, msg) } - message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil) + message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil, nil) if message != "" || threadID != "" { t.Fatalf("expected empty output on scanner error, got message=%q threadID=%q", message, threadID) } @@ -2756,7 +2905,7 @@ func TestVersionFlag(t *testing.T) { t.Errorf("exit = %d, want 0", code) } }) - want := "codeagent-wrapper version 5.2.5\n" + want := "codeagent-wrapper version 5.2.6\n" if output != want { t.Fatalf("output = %q, want %q", output, want) } @@ -2770,7 +2919,7 @@ func TestVersionShortFlag(t *testing.T) { t.Errorf("exit = %d, want 0", code) } }) - want := "codeagent-wrapper version 5.2.5\n" + want := "codeagent-wrapper version 5.2.6\n" if output != want { t.Fatalf("output = %q, want %q", output, want) } @@ -2784,7 +2933,7 @@ func TestVersionLegacyAlias(t *testing.T) { t.Errorf("exit = %d, want 0", code) } }) - want := "codex-wrapper version 5.2.5\n" + want := "codex-wrapper version 5.2.6\n" if output != want { t.Fatalf("output = %q, want %q", output, want) } diff --git a/codeagent-wrapper/parser.go b/codeagent-wrapper/parser.go index ecf27e6..0718d21 100644 --- a/codeagent-wrapper/parser.go +++ b/codeagent-wrapper/parser.go @@ -50,7 +50,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI } func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) { - return parseJSONStreamInternal(r, warnFn, infoFn, nil) + return parseJSONStreamInternal(r, warnFn, infoFn, nil, nil) } const ( @@ -95,7 +95,7 @@ type ItemContent struct { Text interface{} `json:"text"` } -func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func()) (message, threadID string) { +func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) { reader := bufio.NewReaderSize(r, jsonLineReaderSize) if warnFn == nil { @@ -111,6 +111,12 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin } } + notifyComplete := func() { + if onComplete != nil { + onComplete() + } + } + totalEvents := 0 var ( @@ -158,6 +164,9 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin } } isClaude := event.Subtype != "" || event.Result != "" + if !isClaude && event.Type == "result" && event.SessionID != "" && event.Status == "" { + isClaude = true + } isGemini := event.Role != "" || event.Delta != nil || event.Status != "" // Handle Codex events @@ -178,6 +187,13 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin threadID = event.ThreadID infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID)) + case "thread.completed": + if event.ThreadID != "" && threadID == "" { + threadID = event.ThreadID + } + infoFn(fmt.Sprintf("thread.completed event thread_id=%s", event.ThreadID)) + notifyComplete() + case "item.completed": var itemType string if len(event.Item) > 0 { @@ -221,6 +237,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin claudeMessage = event.Result notifyMessage() } + + if event.Type == "result" { + notifyComplete() + } continue } @@ -236,6 +256,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin if event.Status != "" { notifyMessage() + + if event.Type == "result" && (event.Status == "success" || event.Status == "error" || event.Status == "complete" || event.Status == "failed") { + notifyComplete() + } } delta := false diff --git a/codeagent-wrapper/parser_token_too_long_test.go b/codeagent-wrapper/parser_token_too_long_test.go index ed91cd2..662e443 100644 --- a/codeagent-wrapper/parser_token_too_long_test.go +++ b/codeagent-wrapper/parser_token_too_long_test.go @@ -18,7 +18,7 @@ func TestParseJSONStream_SkipsOverlongLineAndContinues(t *testing.T) { var warns []string warnFn := func(msg string) { warns = append(warns, msg) } - gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil) + gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil, nil) if gotMessage != "ok" { t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns) }