Improve backend termination after message and extend timeout (#86)

* Improve backend termination after message and extend timeout

* fix: prevent premature backend termination and revert timeout

Critical fixes for executor.go termination logic:

1. Add onComplete callback to prevent premature termination
   - Parser now distinguishes between "any message" (onMessage) and
     "terminal event" (onComplete)
   - Codex: triggers onComplete on thread.completed
   - Claude: triggers onComplete on type:"result"
   - Gemini: triggers onComplete on type:"result" + terminal status

2. Fix executor to wait for completion events
   - Replace messageSeen termination trigger with completeSeen
   - Only start postMessageTerminateDelay after terminal event
   - Prevents killing backend before final answer in multi-message scenarios

3. Fix terminated flag synchronization
   - Only set terminated=true if terminateCommandFn actually succeeds
   - Prevents "marked as terminated but not actually terminated" state

4. Simplify timer cleanup logic
   - Unified non-blocking drain on messageTimer.C
   - Remove dependency on messageTimerCh nil state

5. Revert defaultTimeout from 24h to 2h
   - 24h (86400s) → 2h (7200s) to avoid operational risks
   - 12× timeout increase could cause resource exhaustion
   - Users needing longer tasks can use CODEX_TIMEOUT env var

All tests pass. Resolves early termination bug from code review.

Co-authored-by: Codeagent (Codex)

Generated with SWE-Agent.ai

Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>

---------

Co-authored-by: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
ben
2025-12-21 15:55:01 +08:00
committed by GitHub
parent 4e2df6a80e
commit 0f359b048f
5 changed files with 272 additions and 32 deletions

View File

@@ -16,6 +16,8 @@ import (
"time"
)
const postMessageTerminateDelay = 1 * time.Second
// commandRunner abstracts exec.Cmd for testability
type commandRunner interface {
Start() error
@@ -729,6 +731,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
// Start parse goroutine BEFORE starting the command to avoid race condition
// where fast-completing commands close stdout before parser starts reading
messageSeen := make(chan struct{}, 1)
completeSeen := make(chan struct{}, 1)
parseCh := make(chan parseResult, 1)
go func() {
msg, tid := parseJSONStreamInternal(stdoutReader, logWarnFn, logInfoFn, func() {
@@ -736,6 +739,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
case messageSeen <- struct{}{}:
default:
}
}, func() {
select {
case completeSeen <- struct{}{}:
default:
}
})
parseCh <- parseResult{message: msg, threadID: tid}
}()
@@ -773,17 +781,63 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
waitCh := make(chan error, 1)
go func() { waitCh <- cmd.Wait() }()
var waitErr error
var forceKillTimer *forceKillTimer
var ctxCancelled bool
var (
waitErr error
forceKillTimer *forceKillTimer
ctxCancelled bool
messageTimer *time.Timer
messageTimerCh <-chan time.Time
forcedAfterComplete bool
terminated bool
messageSeenObserved bool
completeSeenObserved bool
)
waitLoop:
for {
select {
case waitErr = <-waitCh:
break waitLoop
case <-ctx.Done():
ctxCancelled = true
logErrorFn(cancelReason(commandName, ctx))
forceKillTimer = terminateCommandFn(cmd)
if !terminated {
if timer := terminateCommandFn(cmd); timer != nil {
forceKillTimer = timer
terminated = true
}
}
waitErr = <-waitCh
break waitLoop
case <-messageTimerCh:
forcedAfterComplete = true
messageTimerCh = nil
if !terminated {
logWarnFn(fmt.Sprintf("%s output parsed; terminating lingering backend", commandName))
if timer := terminateCommandFn(cmd); timer != nil {
forceKillTimer = timer
terminated = true
}
}
case <-completeSeen:
completeSeenObserved = true
if messageTimer != nil {
continue
}
messageTimer = time.NewTimer(postMessageTerminateDelay)
messageTimerCh = messageTimer.C
case <-messageSeen:
messageSeenObserved = true
}
}
if messageTimer != nil {
if !messageTimer.Stop() {
select {
case <-messageTimer.C:
default:
}
}
}
if forceKillTimer != nil {
@@ -791,10 +845,14 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
}
var parsed parseResult
if ctxCancelled {
switch {
case ctxCancelled:
closeWithReason(stdout, stdoutCloseReasonCtx)
parsed = <-parseCh
} else {
case messageSeenObserved || completeSeenObserved:
closeWithReason(stdout, stdoutCloseReasonWait)
parsed = <-parseCh
default:
drainTimer := time.NewTimer(stdoutDrainTimeout)
defer drainTimer.Stop()
@@ -802,6 +860,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
case parsed = <-parseCh:
closeWithReason(stdout, stdoutCloseReasonWait)
case <-messageSeen:
messageSeenObserved = true
closeWithReason(stdout, stdoutCloseReasonWait)
parsed = <-parseCh
case <-completeSeen:
completeSeenObserved = true
closeWithReason(stdout, stdoutCloseReasonWait)
parsed = <-parseCh
case <-drainTimer.C:
@@ -822,6 +885,9 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
}
if waitErr != nil {
if forcedAfterComplete && parsed.message != "" {
logWarnFn(fmt.Sprintf("%s terminated after delivering output", commandName))
} else {
if exitErr, ok := waitErr.(*exec.ExitError); ok {
code := exitErr.ExitCode()
logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code))
@@ -834,6 +900,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
result.Error = attachStderr(commandName + " error: " + waitErr.Error())
return result
}
}
message := parsed.message
threadID := parsed.threadID

View File

@@ -14,9 +14,9 @@ import (
)
const (
version = "5.2.5"
version = "5.2.6"
defaultWorkdir = "."
defaultTimeout = 7200 // seconds
defaultTimeout = 7200 // seconds (2 hours)
codexLogLineLimit = 1000
stdinSpecialChars = "\n\\\"'`$"
stderrCaptureLimit = 4 * 1024

View File

@@ -879,6 +879,79 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) {
}
}
func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) {
defer resetTestHooks()
forceKillDelay.Store(0)
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}` + "\n"},
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
},
KeepStdoutOpen: true,
BlockWait: true,
ReleaseWaitOnSignal: true,
ReleaseWaitOnKill: true,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
codexCommand = "fake-cmd"
start := time.Now()
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
duration := time.Since(start)
if result.ExitCode != 0 || result.Message != "done" {
t.Fatalf("unexpected result: %+v", result)
}
if duration > 2*time.Second {
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
}
if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
}
}
func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) {
defer resetTestHooks()
forceKillDelay.Store(0)
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"intermediate"}}` + "\n"},
{Delay: 1100 * time.Millisecond, Data: `{"type":"item.completed","item":{"type":"agent_message","text":"final"}}` + "\n"},
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
},
KeepStdoutOpen: true,
BlockWait: true,
ReleaseWaitOnSignal: true,
ReleaseWaitOnKill: true,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
codexCommand = "fake-cmd"
start := time.Now()
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
duration := time.Since(start)
if result.ExitCode != 0 || result.Message != "final" {
t.Fatalf("unexpected result: %+v", result)
}
if duration > 5*time.Second {
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
}
if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
}
}
func TestBackendParseArgs_NewMode(t *testing.T) {
tests := []struct {
name string
@@ -1650,7 +1723,7 @@ func TestBackendParseJSONStream_GeminiEvents_OnMessageTriggeredOnStatus(t *testi
var called int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
called++
})
}, nil)
if message != "Hi there" {
t.Fatalf("message=%q, want %q", message, "Hi there")
@@ -1679,7 +1752,7 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
var called int
message, threadID := parseJSONStreamInternal(strings.NewReader(`{"type":"item.completed","item":{"type":"agent_message","text":"hook"}}`), nil, nil, func() {
called++
})
}, nil)
if message != "hook" {
t.Fatalf("message = %q, want hook", message)
}
@@ -1691,10 +1764,86 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
}
}
func TestBackendParseJSONStream_OnComplete_CodexThreadCompleted(t *testing.T) {
input := `{"type":"item.completed","item":{"type":"agent_message","text":"first"}}` + "\n" +
`{"type":"item.completed","item":{"type":"agent_message","text":"second"}}` + "\n" +
`{"type":"thread.completed","thread_id":"t-1"}`
var onMessageCalls int
var onCompleteCalls int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
onMessageCalls++
}, func() {
onCompleteCalls++
})
if message != "second" {
t.Fatalf("message = %q, want second", message)
}
if threadID != "t-1" {
t.Fatalf("threadID = %q, want t-1", threadID)
}
if onMessageCalls != 2 {
t.Fatalf("onMessage calls = %d, want 2", onMessageCalls)
}
if onCompleteCalls != 1 {
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
}
}
func TestBackendParseJSONStream_OnComplete_ClaudeResult(t *testing.T) {
input := `{"type":"message","subtype":"stream","session_id":"s-1"}` + "\n" +
`{"type":"result","result":"OK","session_id":"s-1"}`
var onMessageCalls int
var onCompleteCalls int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
onMessageCalls++
}, func() {
onCompleteCalls++
})
if message != "OK" {
t.Fatalf("message = %q, want OK", message)
}
if threadID != "s-1" {
t.Fatalf("threadID = %q, want s-1", threadID)
}
if onMessageCalls != 1 {
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
}
if onCompleteCalls != 1 {
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
}
}
func TestBackendParseJSONStream_OnComplete_GeminiTerminalResultStatus(t *testing.T) {
input := `{"type":"message","role":"assistant","content":"Hi","delta":true,"session_id":"g-1"}` + "\n" +
`{"type":"result","status":"success","session_id":"g-1"}`
var onMessageCalls int
var onCompleteCalls int
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
onMessageCalls++
}, func() {
onCompleteCalls++
})
if message != "Hi" {
t.Fatalf("message = %q, want Hi", message)
}
if threadID != "g-1" {
t.Fatalf("threadID = %q, want g-1", threadID)
}
if onMessageCalls != 1 {
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
}
if onCompleteCalls != 1 {
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
}
}
func TestBackendParseJSONStream_ScannerError(t *testing.T) {
var warnings []string
warnFn := func(msg string) { warnings = append(warnings, msg) }
message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil)
message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil, nil)
if message != "" || threadID != "" {
t.Fatalf("expected empty output on scanner error, got message=%q threadID=%q", message, threadID)
}
@@ -2756,7 +2905,7 @@ func TestVersionFlag(t *testing.T) {
t.Errorf("exit = %d, want 0", code)
}
})
want := "codeagent-wrapper version 5.2.5\n"
want := "codeagent-wrapper version 5.2.6\n"
if output != want {
t.Fatalf("output = %q, want %q", output, want)
}
@@ -2770,7 +2919,7 @@ func TestVersionShortFlag(t *testing.T) {
t.Errorf("exit = %d, want 0", code)
}
})
want := "codeagent-wrapper version 5.2.5\n"
want := "codeagent-wrapper version 5.2.6\n"
if output != want {
t.Fatalf("output = %q, want %q", output, want)
}
@@ -2784,7 +2933,7 @@ func TestVersionLegacyAlias(t *testing.T) {
t.Errorf("exit = %d, want 0", code)
}
})
want := "codex-wrapper version 5.2.5\n"
want := "codex-wrapper version 5.2.6\n"
if output != want {
t.Fatalf("output = %q, want %q", output, want)
}

View File

@@ -50,7 +50,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI
}
func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) {
return parseJSONStreamInternal(r, warnFn, infoFn, nil)
return parseJSONStreamInternal(r, warnFn, infoFn, nil, nil)
}
const (
@@ -95,7 +95,7 @@ type ItemContent struct {
Text interface{} `json:"text"`
}
func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func()) (message, threadID string) {
func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) {
reader := bufio.NewReaderSize(r, jsonLineReaderSize)
if warnFn == nil {
@@ -111,6 +111,12 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
}
}
notifyComplete := func() {
if onComplete != nil {
onComplete()
}
}
totalEvents := 0
var (
@@ -158,6 +164,9 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
}
}
isClaude := event.Subtype != "" || event.Result != ""
if !isClaude && event.Type == "result" && event.SessionID != "" && event.Status == "" {
isClaude = true
}
isGemini := event.Role != "" || event.Delta != nil || event.Status != ""
// Handle Codex events
@@ -178,6 +187,13 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
threadID = event.ThreadID
infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID))
case "thread.completed":
if event.ThreadID != "" && threadID == "" {
threadID = event.ThreadID
}
infoFn(fmt.Sprintf("thread.completed event thread_id=%s", event.ThreadID))
notifyComplete()
case "item.completed":
var itemType string
if len(event.Item) > 0 {
@@ -221,6 +237,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
claudeMessage = event.Result
notifyMessage()
}
if event.Type == "result" {
notifyComplete()
}
continue
}
@@ -236,6 +256,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
if event.Status != "" {
notifyMessage()
if event.Type == "result" && (event.Status == "success" || event.Status == "error" || event.Status == "complete" || event.Status == "failed") {
notifyComplete()
}
}
delta := false

View File

@@ -18,7 +18,7 @@ func TestParseJSONStream_SkipsOverlongLineAndContinues(t *testing.T) {
var warns []string
warnFn := func(msg string) { warns = append(warns, msg) }
gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil)
gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil, nil)
if gotMessage != "ok" {
t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns)
}