Improve backend termination after message and extend timeout (#86)

* Improve backend termination after message and extend timeout

* fix: prevent premature backend termination and revert timeout

Critical fixes for executor.go termination logic:

1. Add onComplete callback to prevent premature termination
   - Parser now distinguishes between "any message" (onMessage) and
     "terminal event" (onComplete)
   - Codex: triggers onComplete on thread.completed
   - Claude: triggers onComplete on type:"result"
   - Gemini: triggers onComplete on type:"result" + terminal status

2. Fix executor to wait for completion events
   - Replace messageSeen termination trigger with completeSeen
   - Only start postMessageTerminateDelay after terminal event
   - Prevents killing backend before final answer in multi-message scenarios

3. Fix terminated flag synchronization
   - Only set terminated=true if terminateCommandFn actually succeeds
   - Prevents "marked as terminated but not actually terminated" state

4. Simplify timer cleanup logic
   - Unified non-blocking drain on messageTimer.C
   - Remove dependency on messageTimerCh nil state

5. Revert defaultTimeout from 24h to 2h
   - 24h (86400s) → 2h (7200s) to avoid operational risks
   - 12× timeout increase could cause resource exhaustion
   - Users needing longer tasks can use CODEX_TIMEOUT env var

All tests pass. Resolves early termination bug from code review.

Co-authored-by: Codeagent (Codex)

Generated with SWE-Agent.ai

Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>

---------

Co-authored-by: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
ben
2025-12-21 15:55:01 +08:00
committed by GitHub
parent 4e2df6a80e
commit 0f359b048f
5 changed files with 272 additions and 32 deletions

View File

@@ -16,6 +16,8 @@ import (
"time" "time"
) )
const postMessageTerminateDelay = 1 * time.Second
// commandRunner abstracts exec.Cmd for testability // commandRunner abstracts exec.Cmd for testability
type commandRunner interface { type commandRunner interface {
Start() error Start() error
@@ -729,6 +731,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
// Start parse goroutine BEFORE starting the command to avoid race condition // Start parse goroutine BEFORE starting the command to avoid race condition
// where fast-completing commands close stdout before parser starts reading // where fast-completing commands close stdout before parser starts reading
messageSeen := make(chan struct{}, 1) messageSeen := make(chan struct{}, 1)
completeSeen := make(chan struct{}, 1)
parseCh := make(chan parseResult, 1) parseCh := make(chan parseResult, 1)
go func() { go func() {
msg, tid := parseJSONStreamInternal(stdoutReader, logWarnFn, logInfoFn, func() { msg, tid := parseJSONStreamInternal(stdoutReader, logWarnFn, logInfoFn, func() {
@@ -736,6 +739,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
case messageSeen <- struct{}{}: case messageSeen <- struct{}{}:
default: default:
} }
}, func() {
select {
case completeSeen <- struct{}{}:
default:
}
}) })
parseCh <- parseResult{message: msg, threadID: tid} parseCh <- parseResult{message: msg, threadID: tid}
}() }()
@@ -773,17 +781,63 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
waitCh := make(chan error, 1) waitCh := make(chan error, 1)
go func() { waitCh <- cmd.Wait() }() go func() { waitCh <- cmd.Wait() }()
var waitErr error var (
var forceKillTimer *forceKillTimer waitErr error
var ctxCancelled bool forceKillTimer *forceKillTimer
ctxCancelled bool
messageTimer *time.Timer
messageTimerCh <-chan time.Time
forcedAfterComplete bool
terminated bool
messageSeenObserved bool
completeSeenObserved bool
)
select { waitLoop:
case waitErr = <-waitCh: for {
case <-ctx.Done(): select {
ctxCancelled = true case waitErr = <-waitCh:
logErrorFn(cancelReason(commandName, ctx)) break waitLoop
forceKillTimer = terminateCommandFn(cmd) case <-ctx.Done():
waitErr = <-waitCh ctxCancelled = true
logErrorFn(cancelReason(commandName, ctx))
if !terminated {
if timer := terminateCommandFn(cmd); timer != nil {
forceKillTimer = timer
terminated = true
}
}
waitErr = <-waitCh
break waitLoop
case <-messageTimerCh:
forcedAfterComplete = true
messageTimerCh = nil
if !terminated {
logWarnFn(fmt.Sprintf("%s output parsed; terminating lingering backend", commandName))
if timer := terminateCommandFn(cmd); timer != nil {
forceKillTimer = timer
terminated = true
}
}
case <-completeSeen:
completeSeenObserved = true
if messageTimer != nil {
continue
}
messageTimer = time.NewTimer(postMessageTerminateDelay)
messageTimerCh = messageTimer.C
case <-messageSeen:
messageSeenObserved = true
}
}
if messageTimer != nil {
if !messageTimer.Stop() {
select {
case <-messageTimer.C:
default:
}
}
} }
if forceKillTimer != nil { if forceKillTimer != nil {
@@ -791,10 +845,14 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
} }
var parsed parseResult var parsed parseResult
if ctxCancelled { switch {
case ctxCancelled:
closeWithReason(stdout, stdoutCloseReasonCtx) closeWithReason(stdout, stdoutCloseReasonCtx)
parsed = <-parseCh parsed = <-parseCh
} else { case messageSeenObserved || completeSeenObserved:
closeWithReason(stdout, stdoutCloseReasonWait)
parsed = <-parseCh
default:
drainTimer := time.NewTimer(stdoutDrainTimeout) drainTimer := time.NewTimer(stdoutDrainTimeout)
defer drainTimer.Stop() defer drainTimer.Stop()
@@ -802,6 +860,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
case parsed = <-parseCh: case parsed = <-parseCh:
closeWithReason(stdout, stdoutCloseReasonWait) closeWithReason(stdout, stdoutCloseReasonWait)
case <-messageSeen: case <-messageSeen:
messageSeenObserved = true
closeWithReason(stdout, stdoutCloseReasonWait)
parsed = <-parseCh
case <-completeSeen:
completeSeenObserved = true
closeWithReason(stdout, stdoutCloseReasonWait) closeWithReason(stdout, stdoutCloseReasonWait)
parsed = <-parseCh parsed = <-parseCh
case <-drainTimer.C: case <-drainTimer.C:
@@ -822,17 +885,21 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
} }
if waitErr != nil { if waitErr != nil {
if exitErr, ok := waitErr.(*exec.ExitError); ok { if forcedAfterComplete && parsed.message != "" {
code := exitErr.ExitCode() logWarnFn(fmt.Sprintf("%s terminated after delivering output", commandName))
logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code)) } else {
result.ExitCode = code if exitErr, ok := waitErr.(*exec.ExitError); ok {
result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code)) code := exitErr.ExitCode()
logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code))
result.ExitCode = code
result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code))
return result
}
logErrorFn(commandName + " error: " + waitErr.Error())
result.ExitCode = 1
result.Error = attachStderr(commandName + " error: " + waitErr.Error())
return result return result
} }
logErrorFn(commandName + " error: " + waitErr.Error())
result.ExitCode = 1
result.Error = attachStderr(commandName + " error: " + waitErr.Error())
return result
} }
message := parsed.message message := parsed.message

View File

@@ -14,9 +14,9 @@ import (
) )
const ( const (
version = "5.2.5" version = "5.2.6"
defaultWorkdir = "." defaultWorkdir = "."
defaultTimeout = 7200 // seconds defaultTimeout = 7200 // seconds (2 hours)
codexLogLineLimit = 1000 codexLogLineLimit = 1000
stdinSpecialChars = "\n\\\"'`$" stdinSpecialChars = "\n\\\"'`$"
stderrCaptureLimit = 4 * 1024 stderrCaptureLimit = 4 * 1024

View File

@@ -879,6 +879,79 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) {
} }
} }
func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) {
defer resetTestHooks()
forceKillDelay.Store(0)
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}` + "\n"},
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
},
KeepStdoutOpen: true,
BlockWait: true,
ReleaseWaitOnSignal: true,
ReleaseWaitOnKill: true,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
codexCommand = "fake-cmd"
start := time.Now()
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
duration := time.Since(start)
if result.ExitCode != 0 || result.Message != "done" {
t.Fatalf("unexpected result: %+v", result)
}
if duration > 2*time.Second {
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
}
if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
}
}
func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) {
defer resetTestHooks()
forceKillDelay.Store(0)
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"intermediate"}}` + "\n"},
{Delay: 1100 * time.Millisecond, Data: `{"type":"item.completed","item":{"type":"agent_message","text":"final"}}` + "\n"},
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
},
KeepStdoutOpen: true,
BlockWait: true,
ReleaseWaitOnSignal: true,
ReleaseWaitOnKill: true,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
codexCommand = "fake-cmd"
start := time.Now()
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
duration := time.Since(start)
if result.ExitCode != 0 || result.Message != "final" {
t.Fatalf("unexpected result: %+v", result)
}
if duration > 5*time.Second {
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
}
if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
}
}
func TestBackendParseArgs_NewMode(t *testing.T) { func TestBackendParseArgs_NewMode(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -1650,7 +1723,7 @@ func TestBackendParseJSONStream_GeminiEvents_OnMessageTriggeredOnStatus(t *testi
var called int var called int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() { message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
called++ called++
}) }, nil)
if message != "Hi there" { if message != "Hi there" {
t.Fatalf("message=%q, want %q", message, "Hi there") t.Fatalf("message=%q, want %q", message, "Hi there")
@@ -1679,7 +1752,7 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
var called int var called int
message, threadID := parseJSONStreamInternal(strings.NewReader(`{"type":"item.completed","item":{"type":"agent_message","text":"hook"}}`), nil, nil, func() { message, threadID := parseJSONStreamInternal(strings.NewReader(`{"type":"item.completed","item":{"type":"agent_message","text":"hook"}}`), nil, nil, func() {
called++ called++
}) }, nil)
if message != "hook" { if message != "hook" {
t.Fatalf("message = %q, want hook", message) t.Fatalf("message = %q, want hook", message)
} }
@@ -1691,10 +1764,86 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
} }
} }
func TestBackendParseJSONStream_OnComplete_CodexThreadCompleted(t *testing.T) {
input := `{"type":"item.completed","item":{"type":"agent_message","text":"first"}}` + "\n" +
`{"type":"item.completed","item":{"type":"agent_message","text":"second"}}` + "\n" +
`{"type":"thread.completed","thread_id":"t-1"}`
var onMessageCalls int
var onCompleteCalls int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
onMessageCalls++
}, func() {
onCompleteCalls++
})
if message != "second" {
t.Fatalf("message = %q, want second", message)
}
if threadID != "t-1" {
t.Fatalf("threadID = %q, want t-1", threadID)
}
if onMessageCalls != 2 {
t.Fatalf("onMessage calls = %d, want 2", onMessageCalls)
}
if onCompleteCalls != 1 {
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
}
}
func TestBackendParseJSONStream_OnComplete_ClaudeResult(t *testing.T) {
input := `{"type":"message","subtype":"stream","session_id":"s-1"}` + "\n" +
`{"type":"result","result":"OK","session_id":"s-1"}`
var onMessageCalls int
var onCompleteCalls int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
onMessageCalls++
}, func() {
onCompleteCalls++
})
if message != "OK" {
t.Fatalf("message = %q, want OK", message)
}
if threadID != "s-1" {
t.Fatalf("threadID = %q, want s-1", threadID)
}
if onMessageCalls != 1 {
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
}
if onCompleteCalls != 1 {
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
}
}
func TestBackendParseJSONStream_OnComplete_GeminiTerminalResultStatus(t *testing.T) {
input := `{"type":"message","role":"assistant","content":"Hi","delta":true,"session_id":"g-1"}` + "\n" +
`{"type":"result","status":"success","session_id":"g-1"}`
var onMessageCalls int
var onCompleteCalls int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
onMessageCalls++
}, func() {
onCompleteCalls++
})
if message != "Hi" {
t.Fatalf("message = %q, want Hi", message)
}
if threadID != "g-1" {
t.Fatalf("threadID = %q, want g-1", threadID)
}
if onMessageCalls != 1 {
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
}
if onCompleteCalls != 1 {
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
}
}
func TestBackendParseJSONStream_ScannerError(t *testing.T) { func TestBackendParseJSONStream_ScannerError(t *testing.T) {
var warnings []string var warnings []string
warnFn := func(msg string) { warnings = append(warnings, msg) } warnFn := func(msg string) { warnings = append(warnings, msg) }
message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil) message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil, nil)
if message != "" || threadID != "" { if message != "" || threadID != "" {
t.Fatalf("expected empty output on scanner error, got message=%q threadID=%q", message, threadID) t.Fatalf("expected empty output on scanner error, got message=%q threadID=%q", message, threadID)
} }
@@ -2756,7 +2905,7 @@ func TestVersionFlag(t *testing.T) {
t.Errorf("exit = %d, want 0", code) t.Errorf("exit = %d, want 0", code)
} }
}) })
want := "codeagent-wrapper version 5.2.5\n" want := "codeagent-wrapper version 5.2.6\n"
if output != want { if output != want {
t.Fatalf("output = %q, want %q", output, want) t.Fatalf("output = %q, want %q", output, want)
} }
@@ -2770,7 +2919,7 @@ func TestVersionShortFlag(t *testing.T) {
t.Errorf("exit = %d, want 0", code) t.Errorf("exit = %d, want 0", code)
} }
}) })
want := "codeagent-wrapper version 5.2.5\n" want := "codeagent-wrapper version 5.2.6\n"
if output != want { if output != want {
t.Fatalf("output = %q, want %q", output, want) t.Fatalf("output = %q, want %q", output, want)
} }
@@ -2784,7 +2933,7 @@ func TestVersionLegacyAlias(t *testing.T) {
t.Errorf("exit = %d, want 0", code) t.Errorf("exit = %d, want 0", code)
} }
}) })
want := "codex-wrapper version 5.2.5\n" want := "codex-wrapper version 5.2.6\n"
if output != want { if output != want {
t.Fatalf("output = %q, want %q", output, want) t.Fatalf("output = %q, want %q", output, want)
} }

View File

@@ -50,7 +50,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI
} }
func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) { func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) {
return parseJSONStreamInternal(r, warnFn, infoFn, nil) return parseJSONStreamInternal(r, warnFn, infoFn, nil, nil)
} }
const ( const (
@@ -95,7 +95,7 @@ type ItemContent struct {
Text interface{} `json:"text"` Text interface{} `json:"text"`
} }
func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func()) (message, threadID string) { func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) {
reader := bufio.NewReaderSize(r, jsonLineReaderSize) reader := bufio.NewReaderSize(r, jsonLineReaderSize)
if warnFn == nil { if warnFn == nil {
@@ -111,6 +111,12 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
} }
} }
notifyComplete := func() {
if onComplete != nil {
onComplete()
}
}
totalEvents := 0 totalEvents := 0
var ( var (
@@ -158,6 +164,9 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
} }
} }
isClaude := event.Subtype != "" || event.Result != "" isClaude := event.Subtype != "" || event.Result != ""
if !isClaude && event.Type == "result" && event.SessionID != "" && event.Status == "" {
isClaude = true
}
isGemini := event.Role != "" || event.Delta != nil || event.Status != "" isGemini := event.Role != "" || event.Delta != nil || event.Status != ""
// Handle Codex events // Handle Codex events
@@ -178,6 +187,13 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
threadID = event.ThreadID threadID = event.ThreadID
infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID)) infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID))
case "thread.completed":
if event.ThreadID != "" && threadID == "" {
threadID = event.ThreadID
}
infoFn(fmt.Sprintf("thread.completed event thread_id=%s", event.ThreadID))
notifyComplete()
case "item.completed": case "item.completed":
var itemType string var itemType string
if len(event.Item) > 0 { if len(event.Item) > 0 {
@@ -221,6 +237,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
claudeMessage = event.Result claudeMessage = event.Result
notifyMessage() notifyMessage()
} }
if event.Type == "result" {
notifyComplete()
}
continue continue
} }
@@ -236,6 +256,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
if event.Status != "" { if event.Status != "" {
notifyMessage() notifyMessage()
if event.Type == "result" && (event.Status == "success" || event.Status == "error" || event.Status == "complete" || event.Status == "failed") {
notifyComplete()
}
} }
delta := false delta := false

View File

@@ -18,7 +18,7 @@ func TestParseJSONStream_SkipsOverlongLineAndContinues(t *testing.T) {
var warns []string var warns []string
warnFn := func(msg string) { warns = append(warns, msg) } warnFn := func(msg string) { warns = append(warns, msg) }
gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil) gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil, nil)
if gotMessage != "ok" { if gotMessage != "ok" {
t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns) t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns)
} }