mirror of
https://github.com/cexll/myclaude.git
synced 2026-02-04 02:20:42 +08:00
Improve backend termination after message and extend timeout (#86)
* Improve backend termination after message and extend timeout
* fix: prevent premature backend termination and revert timeout
Critical fixes for executor.go termination logic:
1. Add onComplete callback to prevent premature termination
- Parser now distinguishes between "any message" (onMessage) and
"terminal event" (onComplete)
- Codex: triggers onComplete on thread.completed
- Claude: triggers onComplete on type:"result"
- Gemini: triggers onComplete on type:"result" + terminal status
2. Fix executor to wait for completion events
- Replace messageSeen termination trigger with completeSeen
- Only start postMessageTerminateDelay after terminal event
- Prevents killing backend before final answer in multi-message scenarios
3. Fix terminated flag synchronization
- Only set terminated=true if terminateCommandFn actually succeeds
- Prevents "marked as terminated but not actually terminated" state
4. Simplify timer cleanup logic
- Unified non-blocking drain on messageTimer.C
- Remove dependency on messageTimerCh nil state
5. Revert defaultTimeout from 24h to 2h
- 24h (86400s) → 2h (7200s) to avoid operational risks
- 12× timeout increase could cause resource exhaustion
- Users needing longer tasks can use CODEX_TIMEOUT env var
All tests pass. Resolves early termination bug from code review.
Co-authored-by: Codeagent (Codex)
Generated with SWE-Agent.ai
Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>
---------
Co-authored-by: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
@@ -16,6 +16,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const postMessageTerminateDelay = 1 * time.Second
|
||||
|
||||
// commandRunner abstracts exec.Cmd for testability
|
||||
type commandRunner interface {
|
||||
Start() error
|
||||
@@ -729,6 +731,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
||||
// Start parse goroutine BEFORE starting the command to avoid race condition
|
||||
// where fast-completing commands close stdout before parser starts reading
|
||||
messageSeen := make(chan struct{}, 1)
|
||||
completeSeen := make(chan struct{}, 1)
|
||||
parseCh := make(chan parseResult, 1)
|
||||
go func() {
|
||||
msg, tid := parseJSONStreamInternal(stdoutReader, logWarnFn, logInfoFn, func() {
|
||||
@@ -736,6 +739,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
||||
case messageSeen <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}, func() {
|
||||
select {
|
||||
case completeSeen <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
parseCh <- parseResult{message: msg, threadID: tid}
|
||||
}()
|
||||
@@ -773,17 +781,63 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
||||
waitCh := make(chan error, 1)
|
||||
go func() { waitCh <- cmd.Wait() }()
|
||||
|
||||
var waitErr error
|
||||
var forceKillTimer *forceKillTimer
|
||||
var ctxCancelled bool
|
||||
var (
|
||||
waitErr error
|
||||
forceKillTimer *forceKillTimer
|
||||
ctxCancelled bool
|
||||
messageTimer *time.Timer
|
||||
messageTimerCh <-chan time.Time
|
||||
forcedAfterComplete bool
|
||||
terminated bool
|
||||
messageSeenObserved bool
|
||||
completeSeenObserved bool
|
||||
)
|
||||
|
||||
select {
|
||||
case waitErr = <-waitCh:
|
||||
case <-ctx.Done():
|
||||
ctxCancelled = true
|
||||
logErrorFn(cancelReason(commandName, ctx))
|
||||
forceKillTimer = terminateCommandFn(cmd)
|
||||
waitErr = <-waitCh
|
||||
waitLoop:
|
||||
for {
|
||||
select {
|
||||
case waitErr = <-waitCh:
|
||||
break waitLoop
|
||||
case <-ctx.Done():
|
||||
ctxCancelled = true
|
||||
logErrorFn(cancelReason(commandName, ctx))
|
||||
if !terminated {
|
||||
if timer := terminateCommandFn(cmd); timer != nil {
|
||||
forceKillTimer = timer
|
||||
terminated = true
|
||||
}
|
||||
}
|
||||
waitErr = <-waitCh
|
||||
break waitLoop
|
||||
case <-messageTimerCh:
|
||||
forcedAfterComplete = true
|
||||
messageTimerCh = nil
|
||||
if !terminated {
|
||||
logWarnFn(fmt.Sprintf("%s output parsed; terminating lingering backend", commandName))
|
||||
if timer := terminateCommandFn(cmd); timer != nil {
|
||||
forceKillTimer = timer
|
||||
terminated = true
|
||||
}
|
||||
}
|
||||
case <-completeSeen:
|
||||
completeSeenObserved = true
|
||||
if messageTimer != nil {
|
||||
continue
|
||||
}
|
||||
messageTimer = time.NewTimer(postMessageTerminateDelay)
|
||||
messageTimerCh = messageTimer.C
|
||||
case <-messageSeen:
|
||||
messageSeenObserved = true
|
||||
}
|
||||
}
|
||||
|
||||
if messageTimer != nil {
|
||||
if !messageTimer.Stop() {
|
||||
select {
|
||||
case <-messageTimer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if forceKillTimer != nil {
|
||||
@@ -791,10 +845,14 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
||||
}
|
||||
|
||||
var parsed parseResult
|
||||
if ctxCancelled {
|
||||
switch {
|
||||
case ctxCancelled:
|
||||
closeWithReason(stdout, stdoutCloseReasonCtx)
|
||||
parsed = <-parseCh
|
||||
} else {
|
||||
case messageSeenObserved || completeSeenObserved:
|
||||
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||
parsed = <-parseCh
|
||||
default:
|
||||
drainTimer := time.NewTimer(stdoutDrainTimeout)
|
||||
defer drainTimer.Stop()
|
||||
|
||||
@@ -802,6 +860,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
||||
case parsed = <-parseCh:
|
||||
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||
case <-messageSeen:
|
||||
messageSeenObserved = true
|
||||
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||
parsed = <-parseCh
|
||||
case <-completeSeen:
|
||||
completeSeenObserved = true
|
||||
closeWithReason(stdout, stdoutCloseReasonWait)
|
||||
parsed = <-parseCh
|
||||
case <-drainTimer.C:
|
||||
@@ -822,17 +885,21 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
code := exitErr.ExitCode()
|
||||
logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code))
|
||||
result.ExitCode = code
|
||||
result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code))
|
||||
if forcedAfterComplete && parsed.message != "" {
|
||||
logWarnFn(fmt.Sprintf("%s terminated after delivering output", commandName))
|
||||
} else {
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
code := exitErr.ExitCode()
|
||||
logErrorFn(fmt.Sprintf("%s exited with status %d", commandName, code))
|
||||
result.ExitCode = code
|
||||
result.Error = attachStderr(fmt.Sprintf("%s exited with status %d", commandName, code))
|
||||
return result
|
||||
}
|
||||
logErrorFn(commandName + " error: " + waitErr.Error())
|
||||
result.ExitCode = 1
|
||||
result.Error = attachStderr(commandName + " error: " + waitErr.Error())
|
||||
return result
|
||||
}
|
||||
logErrorFn(commandName + " error: " + waitErr.Error())
|
||||
result.ExitCode = 1
|
||||
result.Error = attachStderr(commandName + " error: " + waitErr.Error())
|
||||
return result
|
||||
}
|
||||
|
||||
message := parsed.message
|
||||
|
||||
@@ -14,9 +14,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
version = "5.2.5"
|
||||
version = "5.2.6"
|
||||
defaultWorkdir = "."
|
||||
defaultTimeout = 7200 // seconds
|
||||
defaultTimeout = 7200 // seconds (2 hours)
|
||||
codexLogLineLimit = 1000
|
||||
stdinSpecialChars = "\n\\\"'`$"
|
||||
stderrCaptureLimit = 4 * 1024
|
||||
|
||||
@@ -879,6 +879,79 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) {
|
||||
defer resetTestHooks()
|
||||
forceKillDelay.Store(0)
|
||||
|
||||
fake := newFakeCmd(fakeCmdConfig{
|
||||
StdoutPlan: []fakeStdoutEvent{
|
||||
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}` + "\n"},
|
||||
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
|
||||
},
|
||||
KeepStdoutOpen: true,
|
||||
BlockWait: true,
|
||||
ReleaseWaitOnSignal: true,
|
||||
ReleaseWaitOnKill: true,
|
||||
})
|
||||
|
||||
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
|
||||
return fake
|
||||
}
|
||||
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
|
||||
codexCommand = "fake-cmd"
|
||||
|
||||
start := time.Now()
|
||||
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
|
||||
duration := time.Since(start)
|
||||
|
||||
if result.ExitCode != 0 || result.Message != "done" {
|
||||
t.Fatalf("unexpected result: %+v", result)
|
||||
}
|
||||
if duration > 2*time.Second {
|
||||
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
|
||||
}
|
||||
if fake.process.SignalCount() == 0 {
|
||||
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) {
|
||||
defer resetTestHooks()
|
||||
forceKillDelay.Store(0)
|
||||
|
||||
fake := newFakeCmd(fakeCmdConfig{
|
||||
StdoutPlan: []fakeStdoutEvent{
|
||||
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"intermediate"}}` + "\n"},
|
||||
{Delay: 1100 * time.Millisecond, Data: `{"type":"item.completed","item":{"type":"agent_message","text":"final"}}` + "\n"},
|
||||
{Data: `{"type":"thread.completed","thread_id":"tid"}` + "\n"},
|
||||
},
|
||||
KeepStdoutOpen: true,
|
||||
BlockWait: true,
|
||||
ReleaseWaitOnSignal: true,
|
||||
ReleaseWaitOnKill: true,
|
||||
})
|
||||
|
||||
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
|
||||
return fake
|
||||
}
|
||||
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
|
||||
codexCommand = "fake-cmd"
|
||||
|
||||
start := time.Now()
|
||||
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
|
||||
duration := time.Since(start)
|
||||
|
||||
if result.ExitCode != 0 || result.Message != "final" {
|
||||
t.Fatalf("unexpected result: %+v", result)
|
||||
}
|
||||
if duration > 5*time.Second {
|
||||
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
|
||||
}
|
||||
if fake.process.SignalCount() == 0 {
|
||||
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParseArgs_NewMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1650,7 +1723,7 @@ func TestBackendParseJSONStream_GeminiEvents_OnMessageTriggeredOnStatus(t *testi
|
||||
var called int
|
||||
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||
called++
|
||||
})
|
||||
}, nil)
|
||||
|
||||
if message != "Hi there" {
|
||||
t.Fatalf("message=%q, want %q", message, "Hi there")
|
||||
@@ -1679,7 +1752,7 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
|
||||
var called int
|
||||
message, threadID := parseJSONStreamInternal(strings.NewReader(`{"type":"item.completed","item":{"type":"agent_message","text":"hook"}}`), nil, nil, func() {
|
||||
called++
|
||||
})
|
||||
}, nil)
|
||||
if message != "hook" {
|
||||
t.Fatalf("message = %q, want hook", message)
|
||||
}
|
||||
@@ -1691,10 +1764,86 @@ func TestBackendParseJSONStream_OnMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParseJSONStream_OnComplete_CodexThreadCompleted(t *testing.T) {
|
||||
input := `{"type":"item.completed","item":{"type":"agent_message","text":"first"}}` + "\n" +
|
||||
`{"type":"item.completed","item":{"type":"agent_message","text":"second"}}` + "\n" +
|
||||
`{"type":"thread.completed","thread_id":"t-1"}`
|
||||
|
||||
var onMessageCalls int
|
||||
var onCompleteCalls int
|
||||
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||
onMessageCalls++
|
||||
}, func() {
|
||||
onCompleteCalls++
|
||||
})
|
||||
if message != "second" {
|
||||
t.Fatalf("message = %q, want second", message)
|
||||
}
|
||||
if threadID != "t-1" {
|
||||
t.Fatalf("threadID = %q, want t-1", threadID)
|
||||
}
|
||||
if onMessageCalls != 2 {
|
||||
t.Fatalf("onMessage calls = %d, want 2", onMessageCalls)
|
||||
}
|
||||
if onCompleteCalls != 1 {
|
||||
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParseJSONStream_OnComplete_ClaudeResult(t *testing.T) {
|
||||
input := `{"type":"message","subtype":"stream","session_id":"s-1"}` + "\n" +
|
||||
`{"type":"result","result":"OK","session_id":"s-1"}`
|
||||
|
||||
var onMessageCalls int
|
||||
var onCompleteCalls int
|
||||
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||
onMessageCalls++
|
||||
}, func() {
|
||||
onCompleteCalls++
|
||||
})
|
||||
if message != "OK" {
|
||||
t.Fatalf("message = %q, want OK", message)
|
||||
}
|
||||
if threadID != "s-1" {
|
||||
t.Fatalf("threadID = %q, want s-1", threadID)
|
||||
}
|
||||
if onMessageCalls != 1 {
|
||||
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
|
||||
}
|
||||
if onCompleteCalls != 1 {
|
||||
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParseJSONStream_OnComplete_GeminiTerminalResultStatus(t *testing.T) {
|
||||
input := `{"type":"message","role":"assistant","content":"Hi","delta":true,"session_id":"g-1"}` + "\n" +
|
||||
`{"type":"result","status":"success","session_id":"g-1"}`
|
||||
|
||||
var onMessageCalls int
|
||||
var onCompleteCalls int
|
||||
message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, nil, func() {
|
||||
onMessageCalls++
|
||||
}, func() {
|
||||
onCompleteCalls++
|
||||
})
|
||||
if message != "Hi" {
|
||||
t.Fatalf("message = %q, want Hi", message)
|
||||
}
|
||||
if threadID != "g-1" {
|
||||
t.Fatalf("threadID = %q, want g-1", threadID)
|
||||
}
|
||||
if onMessageCalls != 1 {
|
||||
t.Fatalf("onMessage calls = %d, want 1", onMessageCalls)
|
||||
}
|
||||
if onCompleteCalls != 1 {
|
||||
t.Fatalf("onComplete calls = %d, want 1", onCompleteCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParseJSONStream_ScannerError(t *testing.T) {
|
||||
var warnings []string
|
||||
warnFn := func(msg string) { warnings = append(warnings, msg) }
|
||||
message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil)
|
||||
message, threadID := parseJSONStreamInternal(errReader{err: errors.New("scan-fail")}, warnFn, nil, nil, nil)
|
||||
if message != "" || threadID != "" {
|
||||
t.Fatalf("expected empty output on scanner error, got message=%q threadID=%q", message, threadID)
|
||||
}
|
||||
@@ -2756,7 +2905,7 @@ func TestVersionFlag(t *testing.T) {
|
||||
t.Errorf("exit = %d, want 0", code)
|
||||
}
|
||||
})
|
||||
want := "codeagent-wrapper version 5.2.5\n"
|
||||
want := "codeagent-wrapper version 5.2.6\n"
|
||||
if output != want {
|
||||
t.Fatalf("output = %q, want %q", output, want)
|
||||
}
|
||||
@@ -2770,7 +2919,7 @@ func TestVersionShortFlag(t *testing.T) {
|
||||
t.Errorf("exit = %d, want 0", code)
|
||||
}
|
||||
})
|
||||
want := "codeagent-wrapper version 5.2.5\n"
|
||||
want := "codeagent-wrapper version 5.2.6\n"
|
||||
if output != want {
|
||||
t.Fatalf("output = %q, want %q", output, want)
|
||||
}
|
||||
@@ -2784,7 +2933,7 @@ func TestVersionLegacyAlias(t *testing.T) {
|
||||
t.Errorf("exit = %d, want 0", code)
|
||||
}
|
||||
})
|
||||
want := "codex-wrapper version 5.2.5\n"
|
||||
want := "codex-wrapper version 5.2.6\n"
|
||||
if output != want {
|
||||
t.Fatalf("output = %q, want %q", output, want)
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI
|
||||
}
|
||||
|
||||
func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) {
|
||||
return parseJSONStreamInternal(r, warnFn, infoFn, nil)
|
||||
return parseJSONStreamInternal(r, warnFn, infoFn, nil, nil)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -95,7 +95,7 @@ type ItemContent struct {
|
||||
Text interface{} `json:"text"`
|
||||
}
|
||||
|
||||
func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func()) (message, threadID string) {
|
||||
func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) {
|
||||
reader := bufio.NewReaderSize(r, jsonLineReaderSize)
|
||||
|
||||
if warnFn == nil {
|
||||
@@ -111,6 +111,12 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
||||
}
|
||||
}
|
||||
|
||||
notifyComplete := func() {
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
}
|
||||
}
|
||||
|
||||
totalEvents := 0
|
||||
|
||||
var (
|
||||
@@ -158,6 +164,9 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
||||
}
|
||||
}
|
||||
isClaude := event.Subtype != "" || event.Result != ""
|
||||
if !isClaude && event.Type == "result" && event.SessionID != "" && event.Status == "" {
|
||||
isClaude = true
|
||||
}
|
||||
isGemini := event.Role != "" || event.Delta != nil || event.Status != ""
|
||||
|
||||
// Handle Codex events
|
||||
@@ -178,6 +187,13 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
||||
threadID = event.ThreadID
|
||||
infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID))
|
||||
|
||||
case "thread.completed":
|
||||
if event.ThreadID != "" && threadID == "" {
|
||||
threadID = event.ThreadID
|
||||
}
|
||||
infoFn(fmt.Sprintf("thread.completed event thread_id=%s", event.ThreadID))
|
||||
notifyComplete()
|
||||
|
||||
case "item.completed":
|
||||
var itemType string
|
||||
if len(event.Item) > 0 {
|
||||
@@ -221,6 +237,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
||||
claudeMessage = event.Result
|
||||
notifyMessage()
|
||||
}
|
||||
|
||||
if event.Type == "result" {
|
||||
notifyComplete()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -236,6 +256,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
|
||||
|
||||
if event.Status != "" {
|
||||
notifyMessage()
|
||||
|
||||
if event.Type == "result" && (event.Status == "success" || event.Status == "error" || event.Status == "complete" || event.Status == "failed") {
|
||||
notifyComplete()
|
||||
}
|
||||
}
|
||||
|
||||
delta := false
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestParseJSONStream_SkipsOverlongLineAndContinues(t *testing.T) {
|
||||
var warns []string
|
||||
warnFn := func(msg string) { warns = append(warns, msg) }
|
||||
|
||||
gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil)
|
||||
gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil, nil)
|
||||
if gotMessage != "ok" {
|
||||
t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user