Compare commits

..

1 Commits

Author SHA1 Message Date
makoMako
40e2d00d35 修复 Windows 后端退出:taskkill 结束进程树 + turn.completed 支持 (#108)
* fix(executor): handle turn.completed and terminate process tree on Windows

* fix: 修复代码审查发现的安全和资源泄漏问题

修复内容:
1. Windows 测试 taskkill 副作用:fake process 在 Windows 上返回 Pid()==0,避免真实执行 taskkill
2. taskkill PATH 劫持风险:使用 SystemRoot 环境变量构建绝对路径
3. stdinPipe 资源泄漏:在 StdoutPipe() 和 Start() 失败路径关闭 stdinPipe
4. stderr drain 并发语义:移除 500ms 超时,确保 drain 完成后再访问共享缓冲

测试验证:
- go test ./... -race 通过
- TestRunCodexTask_ForcesStopAfterTurnCompleted 通过
- TestExecutorSignalAndTermination 通过

Generated with SWE-Agent.ai

Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>

---------

Co-authored-by: cexll <evanxian9@gmail.com>
Co-authored-by: SWE-Agent.ai <noreply@swe-agent.ai>
2026-01-08 10:33:09 +08:00
7 changed files with 315 additions and 23 deletions

View File

@@ -23,6 +23,7 @@ type commandRunner interface {
Start() error
Wait() error
StdoutPipe() (io.ReadCloser, error)
StderrPipe() (io.ReadCloser, error)
StdinPipe() (io.WriteCloser, error)
SetStderr(io.Writer)
SetDir(string)
@@ -63,6 +64,13 @@ func (r *realCmd) StdoutPipe() (io.ReadCloser, error) {
return r.cmd.StdoutPipe()
}
func (r *realCmd) StderrPipe() (io.ReadCloser, error) {
if r.cmd == nil {
return nil, errors.New("command is nil")
}
return r.cmd.StderrPipe()
}
func (r *realCmd) StdinPipe() (io.WriteCloser, error) {
if r.cmd == nil {
return nil, errors.New("command is nil")
@@ -951,33 +959,40 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
if cfg.Backend == "gemini" {
stderrFilter = newFilteringWriter(os.Stderr, geminiNoisePatterns)
stderrOut = stderrFilter
defer stderrFilter.Flush()
}
stderrWriters = append([]io.Writer{stderrOut}, stderrWriters...)
}
if len(stderrWriters) == 1 {
cmd.SetStderr(stderrWriters[0])
} else {
cmd.SetStderr(io.MultiWriter(stderrWriters...))
stderr, err := cmd.StderrPipe()
if err != nil {
logErrorFn("Failed to create stderr pipe: " + err.Error())
result.ExitCode = 1
result.Error = attachStderr("failed to create stderr pipe: " + err.Error())
return result
}
var stdinPipe io.WriteCloser
var err error
if useStdin {
stdinPipe, err = cmd.StdinPipe()
if err != nil {
logErrorFn("Failed to create stdin pipe: " + err.Error())
result.ExitCode = 1
result.Error = attachStderr("failed to create stdin pipe: " + err.Error())
closeWithReason(stderr, "stdin-pipe-failed")
return result
}
}
stderrDone := make(chan error, 1)
stdout, err := cmd.StdoutPipe()
if err != nil {
logErrorFn("Failed to create stdout pipe: " + err.Error())
result.ExitCode = 1
result.Error = attachStderr("failed to create stdout pipe: " + err.Error())
closeWithReason(stderr, "stdout-pipe-failed")
if stdinPipe != nil {
_ = stdinPipe.Close()
}
return result
}
@@ -1013,6 +1028,11 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
logInfoFn(fmt.Sprintf("Starting %s with args: %s %s...", commandName, commandName, strings.Join(codexArgs[:min(5, len(codexArgs))], " ")))
if err := cmd.Start(); err != nil {
closeWithReason(stdout, "start-failed")
closeWithReason(stderr, "start-failed")
if stdinPipe != nil {
_ = stdinPipe.Close()
}
if strings.Contains(err.Error(), "executable file not found") {
msg := fmt.Sprintf("%s command not found in PATH", commandName)
logErrorFn(msg)
@@ -1031,6 +1051,15 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
logInfoFn(fmt.Sprintf("Log capturing to: %s", logger.Path()))
}
// Start stderr drain AFTER we know the command started, but BEFORE cmd.Wait can close the pipe.
go func() {
_, copyErr := io.Copy(io.MultiWriter(stderrWriters...), stderr)
if stderrFilter != nil {
stderrFilter.Flush()
}
stderrDone <- copyErr
}()
if useStdin && stdinPipe != nil {
logInfoFn(fmt.Sprintf("Writing %d chars to stdin...", len(taskSpec.Task)))
go func(data string) {
@@ -1081,6 +1110,11 @@ waitLoop:
terminated = true
}
}
// Close pipes to unblock stream readers, then wait for process exit.
closeWithReason(stdout, "terminate")
closeWithReason(stderr, "terminate")
waitErr = <-waitCh
break waitLoop
case <-completeSeen:
completeSeenObserved = true
if messageTimer != nil {
@@ -1135,6 +1169,12 @@ waitLoop:
}
}
closeWithReason(stderr, stdoutCloseReasonWait)
// Wait for stderr drain so stderrBuf / stderrLogger are not accessed concurrently.
// Important: cmd.Wait can block on internal stderr copying if cmd.Stderr is a non-file writer.
// We use StderrPipe and drain ourselves to avoid that deadlock class (common when children inherit pipes).
<-stderrDone
if ctxErr := ctx.Err(); ctxErr != nil {
if errors.Is(ctxErr, context.DeadlineExceeded) {
result.ExitCode = 124
@@ -1209,7 +1249,7 @@ func forwardSignals(ctx context.Context, cmd commandRunner, logErrorFn func(stri
case sig := <-sigCh:
logErrorFn(fmt.Sprintf("Received signal: %v", sig))
if proc := cmd.Process(); proc != nil {
_ = proc.Signal(syscall.SIGTERM)
_ = sendTermSignal(proc)
time.AfterFunc(time.Duration(forceKillDelay.Load())*time.Second, func() {
if p := cmd.Process(); p != nil {
_ = p.Kill()
@@ -1279,7 +1319,7 @@ func terminateCommand(cmd commandRunner) *forceKillTimer {
return nil
}
_ = proc.Signal(syscall.SIGTERM)
_ = sendTermSignal(proc)
done := make(chan struct{}, 1)
timer := time.AfterFunc(time.Duration(forceKillDelay.Load())*time.Second, func() {
@@ -1301,7 +1341,7 @@ func terminateProcess(cmd commandRunner) *time.Timer {
return nil
}
_ = proc.Signal(syscall.SIGTERM)
_ = sendTermSignal(proc)
return time.AfterFunc(time.Duration(forceKillDelay.Load())*time.Second, func() {
if p := cmd.Process(); p != nil {

View File

@@ -10,6 +10,7 @@ import (
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
@@ -32,7 +33,12 @@ type execFakeProcess struct {
mu sync.Mutex
}
func (p *execFakeProcess) Pid() int { return p.pid }
func (p *execFakeProcess) Pid() int {
if runtime.GOOS == "windows" {
return 0
}
return p.pid
}
func (p *execFakeProcess) Kill() error {
p.killed.Add(1)
return nil
@@ -84,6 +90,7 @@ func (rc *reasonReadCloser) record(reason string) {
type execFakeRunner struct {
stdout io.ReadCloser
stderr io.ReadCloser
process processHandle
stdin io.WriteCloser
dir string
@@ -92,6 +99,7 @@ type execFakeRunner struct {
waitDelay time.Duration
startErr error
stdoutErr error
stderrErr error
stdinErr error
allowNilProcess bool
started atomic.Bool
@@ -119,6 +127,15 @@ func (f *execFakeRunner) StdoutPipe() (io.ReadCloser, error) {
}
return f.stdout, nil
}
func (f *execFakeRunner) StderrPipe() (io.ReadCloser, error) {
if f.stderrErr != nil {
return nil, f.stderrErr
}
if f.stderr == nil {
f.stderr = io.NopCloser(strings.NewReader(""))
}
return f.stderr, nil
}
func (f *execFakeRunner) StdinPipe() (io.WriteCloser, error) {
if f.stdinErr != nil {
return nil, f.stdinErr
@@ -163,6 +180,9 @@ func TestExecutorHelperCoverage(t *testing.T) {
if _, err := rc.StdoutPipe(); err == nil {
t.Fatalf("expected error for nil command")
}
if _, err := rc.StderrPipe(); err == nil {
t.Fatalf("expected error for nil command")
}
if _, err := rc.StdinPipe(); err == nil {
t.Fatalf("expected error for nil command")
}
@@ -182,11 +202,14 @@ func TestExecutorHelperCoverage(t *testing.T) {
if err != nil {
t.Fatalf("StdoutPipe error: %v", err)
}
stderrPipe, err := rcProc.StderrPipe()
if err != nil {
t.Fatalf("StderrPipe error: %v", err)
}
stdinPipe, err := rcProc.StdinPipe()
if err != nil {
t.Fatalf("StdinPipe error: %v", err)
}
rcProc.SetStderr(io.Discard)
if err := rcProc.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
@@ -200,6 +223,7 @@ func TestExecutorHelperCoverage(t *testing.T) {
_ = procHandle.Kill()
_ = rcProc.Wait()
_, _ = io.ReadAll(stdoutPipe)
_, _ = io.ReadAll(stderrPipe)
rp := &realProcess{}
if rp.Pid() != 0 {
@@ -1250,7 +1274,7 @@ func TestExecutorSignalAndTermination(t *testing.T) {
proc.mu.Lock()
signalled := len(proc.signals)
proc.mu.Unlock()
if signalled == 0 {
if runtime.GOOS != "windows" && signalled == 0 {
t.Fatalf("process did not receive signal")
}
if proc.killed.Load() == 0 {

View File

@@ -243,6 +243,10 @@ func (d *drainBlockingCmd) StdoutPipe() (io.ReadCloser, error) {
return newDrainBlockingStdout(ctxReader), nil
}
func (d *drainBlockingCmd) StderrPipe() (io.ReadCloser, error) {
return d.inner.StderrPipe()
}
func (d *drainBlockingCmd) StdinPipe() (io.WriteCloser, error) {
return d.inner.StdinPipe()
}
@@ -314,6 +318,9 @@ func newFakeProcess(pid int) *fakeProcess {
}
func (p *fakeProcess) Pid() int {
if runtime.GOOS == "windows" {
return 0
}
return p.pid
}
@@ -389,7 +396,10 @@ type fakeCmd struct {
stdinWriter *bufferWriteCloser
stdinClaim bool
stderr io.Writer
stderr *ctxAwareReader
stderrWriter *io.PipeWriter
stderrOnce sync.Once
stderrClaim bool
env map[string]string
@@ -415,6 +425,7 @@ type fakeCmd struct {
func newFakeCmd(cfg fakeCmdConfig) *fakeCmd {
r, w := io.Pipe()
stderrR, stderrW := io.Pipe()
cmd := &fakeCmd{
stdout: newCtxAwareReader(r),
stdoutWriter: w,
@@ -425,6 +436,8 @@ func newFakeCmd(cfg fakeCmdConfig) *fakeCmd {
startErr: cfg.StartErr,
waitDone: make(chan struct{}),
keepStdoutOpen: cfg.KeepStdoutOpen,
stderr: newCtxAwareReader(stderrR),
stderrWriter: stderrW,
process: newFakeProcess(cfg.PID),
}
if len(cmd.stdoutPlan) == 0 {
@@ -501,6 +514,16 @@ func (f *fakeCmd) StdoutPipe() (io.ReadCloser, error) {
return f.stdout, nil
}
func (f *fakeCmd) StderrPipe() (io.ReadCloser, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.stderrClaim {
return nil, errors.New("stderr pipe already claimed")
}
f.stderrClaim = true
return f.stderr, nil
}
func (f *fakeCmd) StdinPipe() (io.WriteCloser, error) {
f.mu.Lock()
defer f.mu.Unlock()
@@ -512,7 +535,7 @@ func (f *fakeCmd) StdinPipe() (io.WriteCloser, error) {
}
func (f *fakeCmd) SetStderr(w io.Writer) {
f.stderr = w
_ = w
}
func (f *fakeCmd) SetDir(string) {}
@@ -542,6 +565,7 @@ func (f *fakeCmd) runStdoutScript() {
if len(f.stdoutPlan) == 0 {
if !f.keepStdoutOpen {
f.CloseStdout(nil)
f.CloseStderr(nil)
}
return
}
@@ -553,6 +577,7 @@ func (f *fakeCmd) runStdoutScript() {
}
if !f.keepStdoutOpen {
f.CloseStdout(nil)
f.CloseStderr(nil)
}
}
@@ -589,6 +614,19 @@ func (f *fakeCmd) CloseStdout(err error) {
})
}
func (f *fakeCmd) CloseStderr(err error) {
f.stderrOnce.Do(func() {
if f.stderrWriter == nil {
return
}
if err != nil {
_ = f.stderrWriter.CloseWithError(err)
return
}
_ = f.stderrWriter.Close()
})
}
func (f *fakeCmd) StdinContents() string {
if f.stdinWriter == nil {
return ""
@@ -876,11 +914,17 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) {
if fake.process == nil {
t.Fatalf("fake process not initialized")
}
if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got 0")
}
if fake.process.KillCount() == 0 {
t.Fatalf("expected Kill to eventually run, got 0")
if runtime.GOOS == "windows" {
if fake.process.KillCount() == 0 {
t.Fatalf("expected Kill to be called, got 0")
}
} else {
if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got 0")
}
if fake.process.KillCount() == 0 {
t.Fatalf("expected Kill to eventually run, got 0")
}
}
if capturedTimer == nil {
t.Fatalf("forceKillTimer not captured")
@@ -930,7 +974,51 @@ func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) {
if duration > 2*time.Second {
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
}
if fake.process.SignalCount() == 0 {
if runtime.GOOS == "windows" {
if fake.process.KillCount() == 0 {
t.Fatalf("expected Kill to be called, got 0")
}
} else if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
}
}
func TestRunCodexTask_ForcesStopAfterTurnCompleted(t *testing.T) {
defer resetTestHooks()
forceKillDelay.Store(0)
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}` + "\n"},
{Data: `{"type":"turn.completed"}` + "\n"},
},
KeepStdoutOpen: true,
BlockWait: true,
ReleaseWaitOnSignal: true,
ReleaseWaitOnKill: true,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
codexCommand = "fake-cmd"
start := time.Now()
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
duration := time.Since(start)
if result.ExitCode != 0 || result.Message != "done" {
t.Fatalf("unexpected result: %+v", result)
}
if duration > 2*time.Second {
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
}
if runtime.GOOS == "windows" {
if fake.process.KillCount() == 0 {
t.Fatalf("expected Kill to be called, got 0")
}
} else if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
}
}
@@ -967,7 +1055,11 @@ func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) {
if duration > 5*time.Second {
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
}
if fake.process.SignalCount() == 0 {
if runtime.GOOS == "windows" {
if fake.process.KillCount() == 0 {
t.Fatalf("expected Kill to be called, got 0")
}
} else if fake.process.SignalCount() == 0 {
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
}
}
@@ -2720,6 +2812,10 @@ func TestRunCodexTask_Timeout(t *testing.T) {
}
func TestRunCodexTask_SignalHandling(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("signal-based test is not supported on Windows")
}
defer resetTestHooks()
codexCommand = "sleep"
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{"5"} }
@@ -2728,7 +2824,9 @@ func TestRunCodexTask_SignalHandling(t *testing.T) {
go func() { resultCh <- runCodexTask(TaskSpec{Task: "ignored"}, false, 5) }()
time.Sleep(200 * time.Millisecond)
syscall.Kill(os.Getpid(), syscall.SIGTERM)
if proc, err := os.FindProcess(os.Getpid()); err == nil && proc != nil {
_ = proc.Signal(syscall.SIGTERM)
}
res := <-resultCh
signal.Reset(syscall.SIGINT, syscall.SIGTERM)
@@ -3984,6 +4082,10 @@ func TestRun_LoggerLifecycle(t *testing.T) {
}
func TestRun_LoggerRemovedOnSignal(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("signal-based test is not supported on Windows")
}
// Skip in CI due to unreliable signal delivery in containerized environments
if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" {
t.Skip("Skipping signal test in CI environment")
@@ -4025,7 +4127,9 @@ printf '%s\n' '{"type":"item.completed","item":{"type":"agent_message","text":"l
time.Sleep(10 * time.Millisecond)
}
_ = syscall.Kill(os.Getpid(), syscall.SIGINT)
if proc, err := os.FindProcess(os.Getpid()); err == nil && proc != nil {
_ = proc.Signal(syscall.SIGINT)
}
var exitCode int
select {

View File

@@ -163,6 +163,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
isCodex = true
}
}
// Codex-specific event types without thread_id or item
if !isCodex && (event.Type == "turn.started" || event.Type == "turn.completed") {
isCodex = true
}
isClaude := event.Subtype != "" || event.Result != ""
if !isClaude && event.Type == "result" && event.SessionID != "" && event.Status == "" {
isClaude = true
@@ -194,6 +198,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
infoFn(fmt.Sprintf("thread.completed event thread_id=%s", event.ThreadID))
notifyComplete()
case "turn.completed":
infoFn("turn.completed event")
notifyComplete()
case "item.completed":
var itemType string
if len(event.Item) > 0 {

View File

@@ -0,0 +1,64 @@
//go:build windows
// +build windows
package main
import (
"os"
"testing"
"time"
)
func TestIsProcessRunning(t *testing.T) {
t.Run("boundary values", func(t *testing.T) {
if isProcessRunning(0) {
t.Fatalf("expected pid 0 to be reported as not running")
}
if isProcessRunning(-1) {
t.Fatalf("expected pid -1 to be reported as not running")
}
})
t.Run("current process", func(t *testing.T) {
if !isProcessRunning(os.Getpid()) {
t.Fatalf("expected current process (pid=%d) to be running", os.Getpid())
}
})
t.Run("fake pid", func(t *testing.T) {
const nonexistentPID = 1 << 30
if isProcessRunning(nonexistentPID) {
t.Fatalf("expected pid %d to be reported as not running", nonexistentPID)
}
})
}
func TestGetProcessStartTimeReadsProcStat(t *testing.T) {
start := getProcessStartTime(os.Getpid())
if start.IsZero() {
t.Fatalf("expected non-zero start time for current process")
}
if start.After(time.Now().Add(5 * time.Second)) {
t.Fatalf("start time is unexpectedly in the future: %v", start)
}
}
func TestGetProcessStartTimeInvalidData(t *testing.T) {
if !getProcessStartTime(0).IsZero() {
t.Fatalf("expected zero time for pid 0")
}
if !getProcessStartTime(-1).IsZero() {
t.Fatalf("expected zero time for negative pid")
}
if !getProcessStartTime(1 << 30).IsZero() {
t.Fatalf("expected zero time for non-existent pid")
}
}
func TestGetBootTimeParsesBtime(t *testing.T) {
t.Skip("getBootTime is only implemented on Unix-like systems")
}
func TestGetBootTimeInvalidData(t *testing.T) {
t.Skip("getBootTime is only implemented on Unix-like systems")
}

View File

@@ -0,0 +1,16 @@
//go:build unix || darwin || linux
// +build unix darwin linux
package main
import (
"syscall"
)
// sendTermSignal sends SIGTERM for graceful shutdown on Unix.
func sendTermSignal(proc processHandle) error {
if proc == nil {
return nil
}
return proc.Signal(syscall.SIGTERM)
}

View File

@@ -0,0 +1,36 @@
//go:build windows
// +build windows
package main
import (
"io"
"os"
"os/exec"
"path/filepath"
"strconv"
)
// sendTermSignal on Windows directly kills the process.
// SIGTERM is not supported on Windows.
func sendTermSignal(proc processHandle) error {
if proc == nil {
return nil
}
pid := proc.Pid()
if pid > 0 {
// Kill the whole process tree to avoid leaving inheriting child processes around.
// This also helps prevent exec.Cmd.Wait() from blocking on stderr/stdout pipes held open by children.
taskkill := "taskkill"
if root := os.Getenv("SystemRoot"); root != "" {
taskkill = filepath.Join(root, "System32", "taskkill.exe")
}
cmd := exec.Command(taskkill, "/PID", strconv.Itoa(pid), "/T", "/F")
cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
if err := cmd.Run(); err == nil {
return nil
}
}
return proc.Kill()
}