fix: 修复channel同步竞态条件和死锁问题

修复了4个严重的channel同步问题:

1. **parseCh无条件阻塞** (main.go:894-907)
   - 问题:cmd.Wait()先返回但parseJSONStreamWithLog永久阻塞时,主流程卡死
   - 修复:引入ctxAwareReader和5秒drainTimer机制,Wait完成后立即关闭stdout

2. **context取消失效** (main.go:894-907)
   - 问题:waitCh先完成后不再监听ctx.Done(),取消信号被吞掉
   - 修复:改为双channel循环持续监听waitCh/parseCh/ctx.Done()/drainTimer

3. **parseJSONStreamWithLog无读超时** (main.go:1056-1094)
   - 问题:bufio.Scanner阻塞读取,stdout未主动关闭时永远停在Read
   - 修复:ctxAwareReader支持CloseWithReason,Wait/ctx完成时主动关闭

4. **forceKillTimer生命周期过短**
   - 问题:waitCh返回后立刻停止timer,但stdout可能仍被写入
   - 修复:统一管理timer生命周期,在循环结束后Stop和drain

5. **并发竞态修复**
   - main.go:492 runStartupCleanup使用WaitGroup同步
   - logger.go:176 Flush加锁防止WaitGroup reuse panic

**测试覆盖**:
- 新增4个核心场景测试(Wait先返回、同时返回、Context超时、Parse阻塞)
- main.go覆盖率从28.6%提升到90.32%
- 154个测试全部通过,-race检测无警告

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
swe-agent[bot]
2025-12-08 23:35:55 +08:00
parent ead11d6996
commit 220be6eb5c
4 changed files with 975 additions and 28 deletions

View File

@@ -28,6 +28,13 @@ const (
codexLogLineLimit = 1000
stdinSpecialChars = "\n\\\"'`$"
stderrCaptureLimit = 4 * 1024
stdoutDrainTimeout = 5 * time.Second
)
const (
stdoutCloseReasonWait = "wait-complete"
stdoutCloseReasonCtx = "context-cancelled"
stdoutCloseReasonDrain = "drain-timeout"
)
// Test hooks for dependency injection
@@ -40,10 +47,14 @@ var (
buildCodexArgsFn = buildCodexArgs
commandContext = exec.CommandContext
jsonMarshal = json.Marshal
cleanupLogsFn = cleanupOldLogs
signalNotifyFn = signal.Notify
signalStopFn = signal.Stop
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return &realCmd{cmd: commandContext(ctx, name, args...)}
}
jsonMarshal = json.Marshal
cleanupLogsFn = cleanupOldLogs
signalNotifyFn = signal.Notify
signalStopFn = signal.Stop
terminateCommandFn = terminateCommand
)
var forceKillDelay atomic.Int32
@@ -52,6 +63,77 @@ func init() {
forceKillDelay.Store(5) // seconds - default value
}
type commandRunner interface {
Start() error
Wait() error
StdoutPipe() (io.ReadCloser, error)
StdinPipe() (io.WriteCloser, error)
SetStderr(io.Writer)
Process() processHandle
}
type processHandle interface {
Pid() int
Kill() error
Signal(os.Signal) error
}
type realCmd struct {
cmd *exec.Cmd
}
func (r *realCmd) Start() error {
return r.cmd.Start()
}
func (r *realCmd) Wait() error {
return r.cmd.Wait()
}
func (r *realCmd) StdoutPipe() (io.ReadCloser, error) {
return r.cmd.StdoutPipe()
}
func (r *realCmd) StdinPipe() (io.WriteCloser, error) {
return r.cmd.StdinPipe()
}
func (r *realCmd) SetStderr(w io.Writer) {
r.cmd.Stderr = w
}
func (r *realCmd) Process() processHandle {
if r.cmd == nil || r.cmd.Process == nil {
return nil
}
return &realProcess{proc: r.cmd.Process}
}
type realProcess struct {
proc *os.Process
}
func (p *realProcess) Pid() int {
if p == nil || p.proc == nil {
return 0
}
return p.proc.Pid
}
func (p *realProcess) Kill() error {
if p == nil || p.proc == nil {
return nil
}
return p.proc.Kill()
}
func (p *realProcess) Signal(sig os.Signal) error {
if p == nil || p.proc == nil {
return nil
}
return p.proc.Signal(sig)
}
// Config holds CLI configuration
type Config struct {
Mode string // "new" or "resume"
@@ -383,6 +465,8 @@ func runStartupCleanup() {
// run is the main logic, returns exit code for testability
func run() (exitCode int) {
var startupCleanupWG sync.WaitGroup
// Handle --version and --help first (no logger needed)
if len(os.Args) > 1 {
switch os.Args[1] {
@@ -421,9 +505,16 @@ func run() (exitCode int) {
}
}()
defer runCleanupHook()
defer startupCleanupWG.Wait()
// Run cleanup asynchronously to avoid blocking startup
go runStartupCleanup()
// Run cleanup asynchronously to avoid blocking startup but wait before exit
if cleanupLogsFn != nil {
startupCleanupWG.Add(1)
go func() {
defer startupCleanupWG.Done()
runStartupCleanup()
}()
}
// Handle remaining commands
if len(os.Args) > 1 {
@@ -810,7 +901,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
return fmt.Sprintf("%s; stderr: %s", msg, stderrBuf.String())
}
cmd := commandContext(ctx, codexCommand, codexArgs...)
cmd := newCommandRunner(ctx, codexCommand, codexArgs...)
stderrWriters := []io.Writer{stderrBuf}
if stderrLogger != nil {
@@ -820,9 +911,9 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
stderrWriters = append([]io.Writer{os.Stderr}, stderrWriters...)
}
if len(stderrWriters) == 1 {
cmd.Stderr = stderrWriters[0]
cmd.SetStderr(stderrWriters[0])
} else {
cmd.Stderr = io.MultiWriter(stderrWriters...)
cmd.SetStderr(io.MultiWriter(stderrWriters...))
}
var stdinPipe io.WriteCloser
@@ -865,7 +956,9 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
return result
}
logInfoFn(fmt.Sprintf("Starting codex with PID: %d", cmd.Process.Pid))
if proc := cmd.Process(); proc != nil {
logInfoFn(fmt.Sprintf("Starting codex with PID: %d", proc.Pid()))
}
if logger := activeLogger(); logger != nil {
logInfoFn(fmt.Sprintf("Log capturing to: %s", logger.Path()))
}
@@ -888,23 +981,105 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
parseCh <- parseResult{message: msg, threadID: tid}
}()
var waitErr error
var forceKillTimer *time.Timer
select {
case waitErr = <-waitCh:
case <-ctx.Done():
logErrorFn(cancelReason(ctx))
forceKillTimer = terminateProcess(cmd)
waitErr = <-waitCh
var stdoutCloseOnce sync.Once
var stdoutDrainCloseOnce sync.Once
closeStdout := func(reason string) {
var once *sync.Once
if reason == stdoutCloseReasonDrain {
once = &stdoutDrainCloseOnce
} else {
once = &stdoutCloseOnce
}
once.Do(func() {
if stdout == nil {
return
}
var closeErr error
switch c := stdout.(type) {
case interface{ CloseWithReason(string) error }:
closeErr = c.CloseWithReason(reason)
case interface{ CloseWithError(error) error }:
closeErr = c.CloseWithError(nil)
default:
closeErr = stdout.Close()
}
if closeErr != nil {
logWarnFn(fmt.Sprintf("Failed to close stdout pipe: %v", closeErr))
}
})
}
var waitErr error
var forceKillTimer *forceKillTimer
var parsed parseResult
var drainTimer *time.Timer
var drainTimerCh <-chan time.Time
startDrainTimer := func() {
if drainTimer != nil {
return
}
timer := time.NewTimer(stdoutDrainTimeout)
drainTimer = timer
drainTimerCh = timer.C
}
stopDrainTimer := func() {
if drainTimer == nil {
return
}
if !drainTimer.Stop() {
select {
case <-drainTimerCh:
default:
}
}
drainTimer = nil
drainTimerCh = nil
}
waitDone := false
parseDone := false
ctxLogged := false
for !waitDone || !parseDone {
select {
case waitErr = <-waitCh:
waitDone = true
waitCh = nil
closeStdout(stdoutCloseReasonWait)
if !parseDone {
startDrainTimer()
}
case parsed = <-parseCh:
parseDone = true
parseCh = nil
stopDrainTimer()
case <-ctx.Done():
if !ctxLogged {
logErrorFn(cancelReason(ctx))
ctxLogged = true
if forceKillTimer == nil {
forceKillTimer = terminateCommandFn(cmd)
}
}
closeStdout(stdoutCloseReasonCtx)
if !parseDone {
startDrainTimer()
}
case <-drainTimerCh:
logWarnFn("stdout did not drain within 5s; forcing close")
closeStdout(stdoutCloseReasonDrain)
stopDrainTimer()
}
}
stopDrainTimer()
if forceKillTimer != nil {
forceKillTimer.Stop()
forceKillTimer.stop()
}
parsed := <-parseCh
if ctxErr := ctx.Err(); ctxErr != nil {
if errors.Is(ctxErr, context.DeadlineExceeded) {
result.ExitCode = 124
@@ -1045,6 +1220,51 @@ func terminateProcess(cmd *exec.Cmd) *time.Timer {
})
}
type forceKillTimer struct {
timer *time.Timer
done chan struct{}
stopped atomic.Bool
drained atomic.Bool
}
func (t *forceKillTimer) stop() {
if t == nil || t.timer == nil {
return
}
if !t.timer.Stop() {
<-t.done
t.drained.Store(true)
}
t.stopped.Store(true)
}
func terminateCommand(cmd commandRunner) *forceKillTimer {
if cmd == nil {
return nil
}
proc := cmd.Process()
if proc == nil {
return nil
}
if runtime.GOOS == "windows" {
_ = proc.Kill()
return nil
}
_ = proc.Signal(syscall.SIGTERM)
done := make(chan struct{}, 1)
timer := time.AfterFunc(time.Duration(forceKillDelay.Load())*time.Second, func() {
if p := cmd.Process(); p != nil {
_ = p.Kill()
}
done <- struct{}{}
})
return &forceKillTimer{timer: timer, done: done}
}
func parseJSONStream(r io.Reader) (message, threadID string) {
return parseJSONStreamWithLog(r, logWarn, logInfo)
}