Merge pull request #51 from cexll/fix/channel-sync-race-conditions

fix: 修复channel同步竞态条件和死锁问题
This commit is contained in:
ben
2025-12-09 00:13:04 +08:00
committed by GitHub
4 changed files with 976 additions and 29 deletions

View File

@@ -1 +1,5 @@
coverage.out
coverage*.out
cover.out
cover_*.out
coverage.html

View File

@@ -28,6 +28,7 @@ type Logger struct {
closeOnce sync.Once
workerWG sync.WaitGroup
pendingWG sync.WaitGroup
flushMu sync.Mutex
}
type logEntry struct {
@@ -46,12 +47,12 @@ type CleanupStats struct {
}
var (
processRunningCheck = isProcessRunning
processStartTimeFn = getProcessStartTime
removeLogFileFn = os.Remove
globLogFiles = filepath.Glob
fileStatFn = os.Lstat // Use Lstat to detect symlinks
evalSymlinksFn = filepath.EvalSymlinks
processRunningCheck = isProcessRunning
processStartTimeFn = getProcessStartTime
removeLogFileFn = os.Remove
globLogFiles = filepath.Glob
fileStatFn = os.Lstat // Use Lstat to detect symlinks
evalSymlinksFn = filepath.EvalSymlinks
)
// NewLogger creates the async logger and starts the worker goroutine.
@@ -176,6 +177,9 @@ func (l *Logger) Flush() {
return
}
l.flushMu.Lock()
defer l.flushMu.Unlock()
// Wait for pending entries with timeout
done := make(chan struct{})
go func() {
@@ -221,7 +225,9 @@ func (l *Logger) log(level, msg string) {
}
entry := logEntry{level: level, msg: msg}
l.flushMu.Lock()
l.pendingWG.Add(1)
l.flushMu.Unlock()
select {
case l.ch <- entry:

View File

@@ -22,12 +22,19 @@ import (
)
const (
version = "4.8.2"
version = "5.1.2"
defaultWorkdir = "."
defaultTimeout = 7200 // seconds
codexLogLineLimit = 1000
stdinSpecialChars = "\n\\\"'`$"
stderrCaptureLimit = 4 * 1024
stdoutDrainTimeout = 5 * time.Second
)
const (
stdoutCloseReasonWait = "wait-complete"
stdoutCloseReasonCtx = "context-cancelled"
stdoutCloseReasonDrain = "drain-timeout"
)
// Test hooks for dependency injection
@@ -40,10 +47,14 @@ var (
buildCodexArgsFn = buildCodexArgs
commandContext = exec.CommandContext
jsonMarshal = json.Marshal
cleanupLogsFn = cleanupOldLogs
signalNotifyFn = signal.Notify
signalStopFn = signal.Stop
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return &realCmd{cmd: commandContext(ctx, name, args...)}
}
jsonMarshal = json.Marshal
cleanupLogsFn = cleanupOldLogs
signalNotifyFn = signal.Notify
signalStopFn = signal.Stop
terminateCommandFn = terminateCommand
)
var forceKillDelay atomic.Int32
@@ -52,6 +63,77 @@ func init() {
forceKillDelay.Store(5) // seconds - default value
}
type commandRunner interface {
Start() error
Wait() error
StdoutPipe() (io.ReadCloser, error)
StdinPipe() (io.WriteCloser, error)
SetStderr(io.Writer)
Process() processHandle
}
type processHandle interface {
Pid() int
Kill() error
Signal(os.Signal) error
}
type realCmd struct {
cmd *exec.Cmd
}
func (r *realCmd) Start() error {
return r.cmd.Start()
}
func (r *realCmd) Wait() error {
return r.cmd.Wait()
}
func (r *realCmd) StdoutPipe() (io.ReadCloser, error) {
return r.cmd.StdoutPipe()
}
func (r *realCmd) StdinPipe() (io.WriteCloser, error) {
return r.cmd.StdinPipe()
}
func (r *realCmd) SetStderr(w io.Writer) {
r.cmd.Stderr = w
}
func (r *realCmd) Process() processHandle {
if r.cmd == nil || r.cmd.Process == nil {
return nil
}
return &realProcess{proc: r.cmd.Process}
}
type realProcess struct {
proc *os.Process
}
func (p *realProcess) Pid() int {
if p == nil || p.proc == nil {
return 0
}
return p.proc.Pid
}
func (p *realProcess) Kill() error {
if p == nil || p.proc == nil {
return nil
}
return p.proc.Kill()
}
func (p *realProcess) Signal(sig os.Signal) error {
if p == nil || p.proc == nil {
return nil
}
return p.proc.Signal(sig)
}
// Config holds CLI configuration
type Config struct {
Mode string // "new" or "resume"
@@ -383,6 +465,8 @@ func runStartupCleanup() {
// run is the main logic, returns exit code for testability
func run() (exitCode int) {
var startupCleanupWG sync.WaitGroup
// Handle --version and --help first (no logger needed)
if len(os.Args) > 1 {
switch os.Args[1] {
@@ -421,9 +505,16 @@ func run() (exitCode int) {
}
}()
defer runCleanupHook()
defer startupCleanupWG.Wait()
// Run cleanup asynchronously to avoid blocking startup
go runStartupCleanup()
// Run cleanup asynchronously to avoid blocking startup but wait before exit
if cleanupLogsFn != nil {
startupCleanupWG.Add(1)
go func() {
defer startupCleanupWG.Done()
runStartupCleanup()
}()
}
// Handle remaining commands
if len(os.Args) > 1 {
@@ -810,7 +901,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
return fmt.Sprintf("%s; stderr: %s", msg, stderrBuf.String())
}
cmd := commandContext(ctx, codexCommand, codexArgs...)
cmd := newCommandRunner(ctx, codexCommand, codexArgs...)
stderrWriters := []io.Writer{stderrBuf}
if stderrLogger != nil {
@@ -820,9 +911,9 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
stderrWriters = append([]io.Writer{os.Stderr}, stderrWriters...)
}
if len(stderrWriters) == 1 {
cmd.Stderr = stderrWriters[0]
cmd.SetStderr(stderrWriters[0])
} else {
cmd.Stderr = io.MultiWriter(stderrWriters...)
cmd.SetStderr(io.MultiWriter(stderrWriters...))
}
var stdinPipe io.WriteCloser
@@ -865,7 +956,9 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
return result
}
logInfoFn(fmt.Sprintf("Starting codex with PID: %d", cmd.Process.Pid))
if proc := cmd.Process(); proc != nil {
logInfoFn(fmt.Sprintf("Starting codex with PID: %d", proc.Pid()))
}
if logger := activeLogger(); logger != nil {
logInfoFn(fmt.Sprintf("Log capturing to: %s", logger.Path()))
}
@@ -888,23 +981,105 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
parseCh <- parseResult{message: msg, threadID: tid}
}()
var waitErr error
var forceKillTimer *time.Timer
select {
case waitErr = <-waitCh:
case <-ctx.Done():
logErrorFn(cancelReason(ctx))
forceKillTimer = terminateProcess(cmd)
waitErr = <-waitCh
var stdoutCloseOnce sync.Once
var stdoutDrainCloseOnce sync.Once
closeStdout := func(reason string) {
var once *sync.Once
if reason == stdoutCloseReasonDrain {
once = &stdoutDrainCloseOnce
} else {
once = &stdoutCloseOnce
}
once.Do(func() {
if stdout == nil {
return
}
var closeErr error
switch c := stdout.(type) {
case interface{ CloseWithReason(string) error }:
closeErr = c.CloseWithReason(reason)
case interface{ CloseWithError(error) error }:
closeErr = c.CloseWithError(nil)
default:
closeErr = stdout.Close()
}
if closeErr != nil {
logWarnFn(fmt.Sprintf("Failed to close stdout pipe: %v", closeErr))
}
})
}
var waitErr error
var forceKillTimer *forceKillTimer
var parsed parseResult
var drainTimer *time.Timer
var drainTimerCh <-chan time.Time
startDrainTimer := func() {
if drainTimer != nil {
return
}
timer := time.NewTimer(stdoutDrainTimeout)
drainTimer = timer
drainTimerCh = timer.C
}
stopDrainTimer := func() {
if drainTimer == nil {
return
}
if !drainTimer.Stop() {
select {
case <-drainTimerCh:
default:
}
}
drainTimer = nil
drainTimerCh = nil
}
waitDone := false
parseDone := false
ctxLogged := false
for !waitDone || !parseDone {
select {
case waitErr = <-waitCh:
waitDone = true
waitCh = nil
closeStdout(stdoutCloseReasonWait)
if !parseDone {
startDrainTimer()
}
case parsed = <-parseCh:
parseDone = true
parseCh = nil
stopDrainTimer()
case <-ctx.Done():
if !ctxLogged {
logErrorFn(cancelReason(ctx))
ctxLogged = true
if forceKillTimer == nil {
forceKillTimer = terminateCommandFn(cmd)
}
}
closeStdout(stdoutCloseReasonCtx)
if !parseDone {
startDrainTimer()
}
case <-drainTimerCh:
logWarnFn("stdout did not drain within 5s; forcing close")
closeStdout(stdoutCloseReasonDrain)
stopDrainTimer()
}
}
stopDrainTimer()
if forceKillTimer != nil {
forceKillTimer.Stop()
forceKillTimer.stop()
}
parsed := <-parseCh
if ctxErr := ctx.Err(); ctxErr != nil {
if errors.Is(ctxErr, context.DeadlineExceeded) {
result.ExitCode = 124
@@ -1045,6 +1220,51 @@ func terminateProcess(cmd *exec.Cmd) *time.Timer {
})
}
type forceKillTimer struct {
timer *time.Timer
done chan struct{}
stopped atomic.Bool
drained atomic.Bool
}
func (t *forceKillTimer) stop() {
if t == nil || t.timer == nil {
return
}
if !t.timer.Stop() {
<-t.done
t.drained.Store(true)
}
t.stopped.Store(true)
}
func terminateCommand(cmd commandRunner) *forceKillTimer {
if cmd == nil {
return nil
}
proc := cmd.Process()
if proc == nil {
return nil
}
if runtime.GOOS == "windows" {
_ = proc.Kill()
return nil
}
_ = proc.Signal(syscall.SIGTERM)
done := make(chan struct{}, 1)
timer := time.AfterFunc(time.Duration(forceKillDelay.Load())*time.Second, func() {
if p := cmd.Process(); p != nil {
_ = p.Kill()
}
done <- struct{}{}
})
return &forceKillTimer{timer: timer, done: done}
}
func parseJSONStream(r io.Reader) (message, threadID string) {
return parseJSONStreamWithLog(r, logWarn, logInfo)
}

View File

@@ -32,6 +32,9 @@ func resetTestHooks() {
signalStopFn = signal.Stop
buildCodexArgsFn = buildCodexArgs
commandContext = exec.CommandContext
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return &realCmd{cmd: commandContext(ctx, name, args...)}
}
jsonMarshal = json.Marshal
forceKillDelay.Store(5)
closeLogger()
@@ -103,6 +106,430 @@ func captureStderr(t *testing.T, fn func()) string {
return buf.String()
}
type ctxAwareReader struct {
reader io.ReadCloser
mu sync.Mutex
reason string
closed bool
}
func newCtxAwareReader(r io.ReadCloser) *ctxAwareReader {
return &ctxAwareReader{reader: r}
}
func (r *ctxAwareReader) Read(p []byte) (int, error) {
if r.reader == nil {
return 0, io.EOF
}
return r.reader.Read(p)
}
func (r *ctxAwareReader) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed || r.reader == nil {
r.closed = true
return nil
}
r.closed = true
return r.reader.Close()
}
func (r *ctxAwareReader) CloseWithReason(reason string) error {
r.mu.Lock()
if !r.closed {
r.reason = reason
}
r.mu.Unlock()
return r.Close()
}
func (r *ctxAwareReader) Reason() string {
r.mu.Lock()
defer r.mu.Unlock()
return r.reason
}
type drainBlockingStdout struct {
inner *ctxAwareReader
}
func newDrainBlockingStdout(inner *ctxAwareReader) *drainBlockingStdout {
return &drainBlockingStdout{inner: inner}
}
func (d *drainBlockingStdout) Read(p []byte) (int, error) {
return d.inner.Read(p)
}
func (d *drainBlockingStdout) Close() error {
return d.inner.Close()
}
func (d *drainBlockingStdout) CloseWithReason(reason string) error {
if reason != stdoutCloseReasonDrain {
return nil
}
return d.inner.CloseWithReason(reason)
}
type drainBlockingCmd struct {
inner *fakeCmd
injected atomic.Bool
}
func newDrainBlockingCmd(inner *fakeCmd) *drainBlockingCmd {
return &drainBlockingCmd{inner: inner}
}
func (d *drainBlockingCmd) Start() error {
return d.inner.Start()
}
func (d *drainBlockingCmd) Wait() error {
return d.inner.Wait()
}
func (d *drainBlockingCmd) StdoutPipe() (io.ReadCloser, error) {
stdout, err := d.inner.StdoutPipe()
if err != nil {
return nil, err
}
ctxReader, ok := stdout.(*ctxAwareReader)
if !ok {
return stdout, nil
}
d.injected.Store(true)
return newDrainBlockingStdout(ctxReader), nil
}
func (d *drainBlockingCmd) StdinPipe() (io.WriteCloser, error) {
return d.inner.StdinPipe()
}
func (d *drainBlockingCmd) SetStderr(w io.Writer) {
d.inner.SetStderr(w)
}
func (d *drainBlockingCmd) Process() processHandle {
return d.inner.Process()
}
type bufferWriteCloser struct {
buf bytes.Buffer
mu sync.Mutex
closed bool
}
func newBufferWriteCloser() *bufferWriteCloser {
return &bufferWriteCloser{}
}
func (b *bufferWriteCloser) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return 0, io.ErrClosedPipe
}
return b.buf.Write(p)
}
func (b *bufferWriteCloser) Close() error {
b.mu.Lock()
b.closed = true
b.mu.Unlock()
return nil
}
func (b *bufferWriteCloser) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.String()
}
type fakeProcess struct {
pid int
killed atomic.Bool
mu sync.Mutex
signals []os.Signal
signalCount atomic.Int32
killCount atomic.Int32
onSignal func(os.Signal)
onKill func()
}
func newFakeProcess(pid int) *fakeProcess {
if pid == 0 {
pid = 4242
}
return &fakeProcess{pid: pid}
}
func (p *fakeProcess) Pid() int {
return p.pid
}
func (p *fakeProcess) Kill() error {
p.killed.Store(true)
p.killCount.Add(1)
if p.onKill != nil {
p.onKill()
}
return nil
}
func (p *fakeProcess) Signal(sig os.Signal) error {
p.mu.Lock()
p.signals = append(p.signals, sig)
p.mu.Unlock()
p.signalCount.Add(1)
if p.onSignal != nil {
p.onSignal(sig)
}
return nil
}
func (p *fakeProcess) Signals() []os.Signal {
p.mu.Lock()
defer p.mu.Unlock()
cp := make([]os.Signal, len(p.signals))
copy(cp, p.signals)
return cp
}
func (p *fakeProcess) Killed() bool {
return p.killed.Load()
}
func (p *fakeProcess) SignalCount() int {
return int(p.signalCount.Load())
}
func (p *fakeProcess) KillCount() int {
return int(p.killCount.Load())
}
type fakeStdoutEvent struct {
Delay time.Duration
Data string
}
type fakeCmdConfig struct {
StdoutPlan []fakeStdoutEvent
WaitDelay time.Duration
WaitErr error
StartErr error
PID int
KeepStdoutOpen bool
BlockWait bool
ReleaseWaitOnKill bool
ReleaseWaitOnSignal bool
}
type fakeCmd struct {
mu sync.Mutex
stdout *ctxAwareReader
stdoutWriter *io.PipeWriter
stdoutPlan []fakeStdoutEvent
stdoutOnce sync.Once
stdoutClaim bool
keepStdoutOpen bool
stdoutWriteMu sync.Mutex
stdinWriter *bufferWriteCloser
stdinClaim bool
stderr io.Writer
waitDelay time.Duration
waitErr error
startErr error
waitOnce sync.Once
waitDone chan struct{}
waitResult error
waitReleaseCh chan struct{}
waitReleaseOnce sync.Once
waitBlocked bool
started bool
startCount atomic.Int32
waitCount atomic.Int32
stdoutPipeCount atomic.Int32
process *fakeProcess
}
func newFakeCmd(cfg fakeCmdConfig) *fakeCmd {
r, w := io.Pipe()
cmd := &fakeCmd{
stdout: newCtxAwareReader(r),
stdoutWriter: w,
stdoutPlan: append([]fakeStdoutEvent(nil), cfg.StdoutPlan...),
stdinWriter: newBufferWriteCloser(),
waitDelay: cfg.WaitDelay,
waitErr: cfg.WaitErr,
startErr: cfg.StartErr,
waitDone: make(chan struct{}),
keepStdoutOpen: cfg.KeepStdoutOpen,
process: newFakeProcess(cfg.PID),
}
if len(cmd.stdoutPlan) == 0 {
cmd.stdoutPlan = nil
}
if cfg.BlockWait {
cmd.waitBlocked = true
cmd.waitReleaseCh = make(chan struct{})
releaseOnSignal := cfg.ReleaseWaitOnSignal
releaseOnKill := cfg.ReleaseWaitOnKill
if !releaseOnSignal && !releaseOnKill {
releaseOnKill = true
}
cmd.process.onSignal = func(os.Signal) {
if releaseOnSignal {
cmd.releaseWait()
}
}
cmd.process.onKill = func() {
if releaseOnKill {
cmd.releaseWait()
}
}
}
return cmd
}
func (f *fakeCmd) Start() error {
f.mu.Lock()
if f.started {
f.mu.Unlock()
return errors.New("start already called")
}
f.started = true
f.mu.Unlock()
f.startCount.Add(1)
if f.startErr != nil {
f.waitOnce.Do(func() {
f.waitResult = f.startErr
close(f.waitDone)
})
return f.startErr
}
go f.runStdoutScript()
return nil
}
func (f *fakeCmd) Wait() error {
f.waitCount.Add(1)
f.waitOnce.Do(func() {
if f.waitBlocked && f.waitReleaseCh != nil {
<-f.waitReleaseCh
} else if f.waitDelay > 0 {
time.Sleep(f.waitDelay)
}
f.waitResult = f.waitErr
close(f.waitDone)
})
<-f.waitDone
return f.waitResult
}
func (f *fakeCmd) StdoutPipe() (io.ReadCloser, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.stdoutClaim {
return nil, errors.New("stdout pipe already claimed")
}
f.stdoutClaim = true
f.stdoutPipeCount.Add(1)
return f.stdout, nil
}
func (f *fakeCmd) StdinPipe() (io.WriteCloser, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.stdinClaim {
return nil, errors.New("stdin pipe already claimed")
}
f.stdinClaim = true
return f.stdinWriter, nil
}
func (f *fakeCmd) SetStderr(w io.Writer) {
f.stderr = w
}
func (f *fakeCmd) Process() processHandle {
if f == nil {
return nil
}
return f.process
}
func (f *fakeCmd) runStdoutScript() {
if len(f.stdoutPlan) == 0 {
if !f.keepStdoutOpen {
f.CloseStdout(nil)
}
return
}
for _, ev := range f.stdoutPlan {
if ev.Delay > 0 {
time.Sleep(ev.Delay)
}
f.WriteStdout(ev.Data)
}
if !f.keepStdoutOpen {
f.CloseStdout(nil)
}
}
func (f *fakeCmd) releaseWait() {
if f.waitReleaseCh == nil {
return
}
f.waitReleaseOnce.Do(func() {
close(f.waitReleaseCh)
})
}
func (f *fakeCmd) WriteStdout(data string) {
if data == "" {
return
}
f.stdoutWriteMu.Lock()
defer f.stdoutWriteMu.Unlock()
if f.stdoutWriter != nil {
_, _ = io.WriteString(f.stdoutWriter, data)
}
}
func (f *fakeCmd) CloseStdout(err error) {
f.stdoutOnce.Do(func() {
if f.stdoutWriter == nil {
return
}
if err != nil {
_ = f.stdoutWriter.CloseWithError(err)
return
}
_ = f.stdoutWriter.Close()
})
}
func (f *fakeCmd) StdinContents() string {
if f.stdinWriter == nil {
return ""
}
return f.stdinWriter.String()
}
func createFakeCodexScript(t *testing.T, threadID, message string) string {
t.Helper()
scriptPath := filepath.Join(t.TempDir(), "codex.sh")
@@ -116,6 +543,296 @@ printf '%%s\n' '{"type":"item.completed","item":{"type":"agent_message","text":"
return scriptPath
}
func TestFakeCmdInfra(t *testing.T) {
t.Run("pipes and wait scheduling", func(t *testing.T) {
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: "line1\n"},
{Delay: 5 * time.Millisecond, Data: "line2\n"},
},
WaitDelay: 20 * time.Millisecond,
})
stdout, err := fake.StdoutPipe()
if err != nil {
t.Fatalf("StdoutPipe() error = %v", err)
}
if err := fake.Start(); err != nil {
t.Fatalf("Start() error = %v", err)
}
scanner := bufio.NewScanner(stdout)
var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
if len(lines) == 2 {
break
}
}
if err := scanner.Err(); err != nil {
t.Fatalf("scanner error: %v", err)
}
if len(lines) != 2 || lines[0] != "line1" || lines[1] != "line2" {
t.Fatalf("unexpected stdout lines: %v", lines)
}
ctxReader, ok := stdout.(*ctxAwareReader)
if !ok {
t.Fatalf("stdout pipe is %T, want *ctxAwareReader", stdout)
}
if err := ctxReader.CloseWithReason("test-complete"); err != nil {
t.Fatalf("CloseWithReason error: %v", err)
}
if ctxReader.Reason() != "test-complete" {
t.Fatalf("CloseWithReason reason mismatch: %q", ctxReader.Reason())
}
waitStart := time.Now()
if err := fake.Wait(); err != nil {
t.Fatalf("Wait() error = %v", err)
}
if elapsed := time.Since(waitStart); elapsed < 20*time.Millisecond {
t.Fatalf("Wait() returned too early: %v", elapsed)
}
if fake.startCount.Load() != 1 {
t.Fatalf("Start() count = %d, want 1", fake.startCount.Load())
}
if fake.waitCount.Load() != 1 {
t.Fatalf("Wait() count = %d, want 1", fake.waitCount.Load())
}
if fake.stdoutPipeCount.Load() != 1 {
t.Fatalf("StdoutPipe() count = %d, want 1", fake.stdoutPipeCount.Load())
}
})
t.Run("integration with runCodexTask", func(t *testing.T) {
defer resetTestHooks()
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: `{"type":"thread.started","thread_id":"fake-thread"}` + "\n"},
{
Delay: time.Millisecond,
Data: `{"type":"item.completed","item":{"type":"agent_message","text":"fake-msg"}}` + "\n",
},
},
WaitDelay: time.Millisecond,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string {
return []string{targetArg}
}
codexCommand = "fake-cmd"
res := runCodexTask(TaskSpec{Task: "ignored"}, false, 2)
if res.ExitCode != 0 {
t.Fatalf("runCodexTask exit = %d, want 0 (%s)", res.ExitCode, res.Error)
}
if res.Message != "fake-msg" {
t.Fatalf("message = %q, want fake-msg", res.Message)
}
if res.SessionID != "fake-thread" {
t.Fatalf("sessionID = %q, want fake-thread", res.SessionID)
}
if fake.startCount.Load() != 1 {
t.Fatalf("Start() count = %d, want 1", fake.startCount.Load())
}
if fake.waitCount.Load() != 1 {
t.Fatalf("Wait() count = %d, want 1", fake.waitCount.Load())
}
})
}
func TestRunCodexTask_WaitBeforeParse(t *testing.T) {
defer resetTestHooks()
const (
threadID = "wait-first-thread"
message = "wait-first-message"
waitDelay = 100 * time.Millisecond
extraDelay = 2 * time.Second
)
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: fmt.Sprintf(`{"type":"thread.started","thread_id":"%s"}`+"\n", threadID)},
{Data: fmt.Sprintf(`{"type":"item.completed","item":{"type":"agent_message","text":"%s"}}`+"\n", message)},
{Delay: extraDelay},
},
WaitDelay: waitDelay,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string {
return []string{targetArg}
}
codexCommand = "fake-cmd"
start := time.Now()
result := runCodexTask(TaskSpec{Task: "ignored"}, false, 5)
elapsed := time.Since(start)
if result.ExitCode != 0 {
t.Fatalf("runCodexTask exit = %d, want 0 (%s)", result.ExitCode, result.Error)
}
if result.Message != message {
t.Fatalf("message = %q, want %q", result.Message, message)
}
if result.SessionID != threadID {
t.Fatalf("sessionID = %q, want %q", result.SessionID, threadID)
}
if elapsed >= extraDelay {
t.Fatalf("runCodexTask took %v, want < %v", elapsed, extraDelay)
}
if fake.stdout == nil {
t.Fatalf("stdout reader not initialized")
}
if reason := fake.stdout.Reason(); reason != stdoutCloseReasonWait {
t.Fatalf("stdout close reason = %q, want %q", reason, stdoutCloseReasonWait)
}
}
func TestRunCodexTask_ParseStall(t *testing.T) {
defer resetTestHooks()
const threadID = "stall-thread"
startG := runtime.NumGoroutine()
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: fmt.Sprintf(`{"type":"thread.started","thread_id":"%s"}`+"\n", threadID)},
},
KeepStdoutOpen: true,
})
blockingCmd := newDrainBlockingCmd(fake)
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return blockingCmd
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string {
return []string{targetArg}
}
codexCommand = "fake-cmd"
start := time.Now()
result := runCodexTask(TaskSpec{Task: "stall"}, false, 60)
elapsed := time.Since(start)
if !blockingCmd.injected.Load() {
t.Fatalf("stdout wrapper was not installed")
}
if result.ExitCode == 0 || result.Error == "" {
t.Fatalf("expected runCodexTask to error when parse stalls, got %+v", result)
}
errText := strings.ToLower(result.Error)
if !strings.Contains(errText, "drain timeout") && !strings.Contains(errText, "agent_message") {
t.Fatalf("error %q does not mention drain timeout or missing agent_message", result.Error)
}
if elapsed < stdoutDrainTimeout {
t.Fatalf("runCodexTask returned after %v (reason=%s), want >= %v to confirm drainTimer firing", elapsed, fake.stdout.Reason(), stdoutDrainTimeout)
}
maxDuration := stdoutDrainTimeout + time.Second
if elapsed >= maxDuration {
t.Fatalf("runCodexTask took %v, want < %v", elapsed, maxDuration)
}
if fake.stdout == nil {
t.Fatalf("stdout reader not initialized")
}
if !fake.stdout.closed {
t.Fatalf("stdout reader still open; drainTimer should force close")
}
if reason := fake.stdout.Reason(); reason != stdoutCloseReasonDrain {
t.Fatalf("stdout close reason = %q, want %q", reason, stdoutCloseReasonDrain)
}
deadline := time.Now().Add(500 * time.Millisecond)
allowed := startG + 8
finalG := runtime.NumGoroutine()
for finalG > allowed && time.Now().Before(deadline) {
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
runtime.GC()
finalG = runtime.NumGoroutine()
}
if finalG > allowed {
t.Fatalf("goroutines leaked: before=%d after=%d", startG, finalG)
}
}
func TestRunCodexTask_ContextTimeout(t *testing.T) {
defer resetTestHooks()
forceKillDelay.Store(0)
fake := newFakeCmd(fakeCmdConfig{
KeepStdoutOpen: true,
BlockWait: true,
ReleaseWaitOnKill: true,
ReleaseWaitOnSignal: false,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string {
return []string{targetArg}
}
codexCommand = "fake-cmd"
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
var capturedTimer *forceKillTimer
terminateCommandFn = func(cmd commandRunner) *forceKillTimer {
timer := terminateCommand(cmd)
capturedTimer = timer
return timer
}
defer func() { terminateCommandFn = terminateCommand }()
result := runCodexTaskWithContext(ctx, TaskSpec{Task: "ctx-timeout", WorkDir: defaultWorkdir}, nil, false, false, 60)
if result.ExitCode != 124 {
t.Fatalf("exit code = %d, want 124 (%s)", result.ExitCode, result.Error)
}
if !strings.Contains(strings.ToLower(result.Error), "timeout") {
t.Fatalf("error %q does not mention timeout", result.Error)
}
if fake.process == nil {
t.Fatalf("fake process not initialized")
}
if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got 0")
}
if fake.process.KillCount() == 0 {
t.Fatalf("expected Kill to eventually run, got 0")
}
if capturedTimer == nil {
t.Fatalf("forceKillTimer not captured")
}
if !capturedTimer.stopped.Load() {
t.Fatalf("forceKillTimer.Stop was not called")
}
if !capturedTimer.drained.Load() {
t.Fatalf("forceKillTimer drain logic did not run")
}
if fake.stdout == nil {
t.Fatalf("stdout reader not initialized")
}
if reason := fake.stdout.Reason(); reason != stdoutCloseReasonCtx {
t.Fatalf("stdout close reason = %q, want %q", reason, stdoutCloseReasonCtx)
}
}
func TestRunParseArgs_NewMode(t *testing.T) {
tests := []struct {
name string