mirror of
https://github.com/cexll/myclaude.git
synced 2026-02-12 03:27:47 +08:00
修复 Windows 后端退出:taskkill 结束进程树 + turn.completed 支持 (#108)
* fix(executor): handle turn.completed and terminate process tree on Windows * fix: 修复代码审查发现的安全和资源泄漏问题 修复内容: 1. Windows 测试 taskkill 副作用:fake process 在 Windows 上返回 Pid()==0,避免真实执行 taskkill 2. taskkill PATH 劫持风险:使用 SystemRoot 环境变量构建绝对路径 3. stdinPipe 资源泄漏:在 StdoutPipe() 和 Start() 失败路径关闭 stdinPipe 4. stderr drain 并发语义:移除 500ms 超时,确保 drain 完成后再访问共享缓冲 测试验证: - go test ./... -race 通过 - TestRunCodexTask_ForcesStopAfterTurnCompleted 通过 - TestExecutorSignalAndTermination 通过 Generated with SWE-Agent.ai Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai> --------- Co-authored-by: cexll <evanxian9@gmail.com> Co-authored-by: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
@@ -243,6 +243,10 @@ func (d *drainBlockingCmd) StdoutPipe() (io.ReadCloser, error) {
|
||||
return newDrainBlockingStdout(ctxReader), nil
|
||||
}
|
||||
|
||||
func (d *drainBlockingCmd) StderrPipe() (io.ReadCloser, error) {
|
||||
return d.inner.StderrPipe()
|
||||
}
|
||||
|
||||
func (d *drainBlockingCmd) StdinPipe() (io.WriteCloser, error) {
|
||||
return d.inner.StdinPipe()
|
||||
}
|
||||
@@ -314,6 +318,9 @@ func newFakeProcess(pid int) *fakeProcess {
|
||||
}
|
||||
|
||||
func (p *fakeProcess) Pid() int {
|
||||
if runtime.GOOS == "windows" {
|
||||
return 0
|
||||
}
|
||||
return p.pid
|
||||
}
|
||||
|
||||
@@ -389,7 +396,10 @@ type fakeCmd struct {
|
||||
stdinWriter *bufferWriteCloser
|
||||
stdinClaim bool
|
||||
|
||||
stderr io.Writer
|
||||
stderr *ctxAwareReader
|
||||
stderrWriter *io.PipeWriter
|
||||
stderrOnce sync.Once
|
||||
stderrClaim bool
|
||||
|
||||
env map[string]string
|
||||
|
||||
@@ -415,6 +425,7 @@ type fakeCmd struct {
|
||||
|
||||
func newFakeCmd(cfg fakeCmdConfig) *fakeCmd {
|
||||
r, w := io.Pipe()
|
||||
stderrR, stderrW := io.Pipe()
|
||||
cmd := &fakeCmd{
|
||||
stdout: newCtxAwareReader(r),
|
||||
stdoutWriter: w,
|
||||
@@ -425,6 +436,8 @@ func newFakeCmd(cfg fakeCmdConfig) *fakeCmd {
|
||||
startErr: cfg.StartErr,
|
||||
waitDone: make(chan struct{}),
|
||||
keepStdoutOpen: cfg.KeepStdoutOpen,
|
||||
stderr: newCtxAwareReader(stderrR),
|
||||
stderrWriter: stderrW,
|
||||
process: newFakeProcess(cfg.PID),
|
||||
}
|
||||
if len(cmd.stdoutPlan) == 0 {
|
||||
@@ -501,6 +514,16 @@ func (f *fakeCmd) StdoutPipe() (io.ReadCloser, error) {
|
||||
return f.stdout, nil
|
||||
}
|
||||
|
||||
func (f *fakeCmd) StderrPipe() (io.ReadCloser, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.stderrClaim {
|
||||
return nil, errors.New("stderr pipe already claimed")
|
||||
}
|
||||
f.stderrClaim = true
|
||||
return f.stderr, nil
|
||||
}
|
||||
|
||||
func (f *fakeCmd) StdinPipe() (io.WriteCloser, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
@@ -512,7 +535,7 @@ func (f *fakeCmd) StdinPipe() (io.WriteCloser, error) {
|
||||
}
|
||||
|
||||
func (f *fakeCmd) SetStderr(w io.Writer) {
|
||||
f.stderr = w
|
||||
_ = w
|
||||
}
|
||||
|
||||
func (f *fakeCmd) SetDir(string) {}
|
||||
@@ -542,6 +565,7 @@ func (f *fakeCmd) runStdoutScript() {
|
||||
if len(f.stdoutPlan) == 0 {
|
||||
if !f.keepStdoutOpen {
|
||||
f.CloseStdout(nil)
|
||||
f.CloseStderr(nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -553,6 +577,7 @@ func (f *fakeCmd) runStdoutScript() {
|
||||
}
|
||||
if !f.keepStdoutOpen {
|
||||
f.CloseStdout(nil)
|
||||
f.CloseStderr(nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -589,6 +614,19 @@ func (f *fakeCmd) CloseStdout(err error) {
|
||||
})
|
||||
}
|
||||
|
||||
func (f *fakeCmd) CloseStderr(err error) {
|
||||
f.stderrOnce.Do(func() {
|
||||
if f.stderrWriter == nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
_ = f.stderrWriter.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
_ = f.stderrWriter.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func (f *fakeCmd) StdinContents() string {
|
||||
if f.stdinWriter == nil {
|
||||
return ""
|
||||
@@ -876,11 +914,17 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) {
|
||||
if fake.process == nil {
|
||||
t.Fatalf("fake process not initialized")
|
||||
}
|
||||
if fake.process.SignalCount() == 0 {
|
||||
t.Fatalf("expected SIGTERM to be sent, got 0")
|
||||
}
|
||||
if fake.process.KillCount() == 0 {
|
||||
t.Fatalf("expected Kill to eventually run, got 0")
|
||||
if runtime.GOOS == "windows" {
|
||||
if fake.process.KillCount() == 0 {
|
||||
t.Fatalf("expected Kill to be called, got 0")
|
||||
}
|
||||
} else {
|
||||
if fake.process.SignalCount() == 0 {
|
||||
t.Fatalf("expected SIGTERM to be sent, got 0")
|
||||
}
|
||||
if fake.process.KillCount() == 0 {
|
||||
t.Fatalf("expected Kill to eventually run, got 0")
|
||||
}
|
||||
}
|
||||
if capturedTimer == nil {
|
||||
t.Fatalf("forceKillTimer not captured")
|
||||
@@ -930,7 +974,51 @@ func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) {
|
||||
if duration > 2*time.Second {
|
||||
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
|
||||
}
|
||||
if fake.process.SignalCount() == 0 {
|
||||
if runtime.GOOS == "windows" {
|
||||
if fake.process.KillCount() == 0 {
|
||||
t.Fatalf("expected Kill to be called, got 0")
|
||||
}
|
||||
} else if fake.process.SignalCount() == 0 {
|
||||
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCodexTask_ForcesStopAfterTurnCompleted(t *testing.T) {
|
||||
defer resetTestHooks()
|
||||
forceKillDelay.Store(0)
|
||||
|
||||
fake := newFakeCmd(fakeCmdConfig{
|
||||
StdoutPlan: []fakeStdoutEvent{
|
||||
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}` + "\n"},
|
||||
{Data: `{"type":"turn.completed"}` + "\n"},
|
||||
},
|
||||
KeepStdoutOpen: true,
|
||||
BlockWait: true,
|
||||
ReleaseWaitOnSignal: true,
|
||||
ReleaseWaitOnKill: true,
|
||||
})
|
||||
|
||||
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
|
||||
return fake
|
||||
}
|
||||
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} }
|
||||
codexCommand = "fake-cmd"
|
||||
|
||||
start := time.Now()
|
||||
result := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "done", WorkDir: defaultWorkdir}, nil, nil, false, false, 60)
|
||||
duration := time.Since(start)
|
||||
|
||||
if result.ExitCode != 0 || result.Message != "done" {
|
||||
t.Fatalf("unexpected result: %+v", result)
|
||||
}
|
||||
if duration > 2*time.Second {
|
||||
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
if fake.process.KillCount() == 0 {
|
||||
t.Fatalf("expected Kill to be called, got 0")
|
||||
}
|
||||
} else if fake.process.SignalCount() == 0 {
|
||||
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
|
||||
}
|
||||
}
|
||||
@@ -967,7 +1055,11 @@ func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) {
|
||||
if duration > 5*time.Second {
|
||||
t.Fatalf("runCodexTaskWithContext took too long: %v", duration)
|
||||
}
|
||||
if fake.process.SignalCount() == 0 {
|
||||
if runtime.GOOS == "windows" {
|
||||
if fake.process.KillCount() == 0 {
|
||||
t.Fatalf("expected Kill to be called, got 0")
|
||||
}
|
||||
} else if fake.process.SignalCount() == 0 {
|
||||
t.Fatalf("expected SIGTERM to be sent, got %d", fake.process.SignalCount())
|
||||
}
|
||||
}
|
||||
@@ -2720,6 +2812,10 @@ func TestRunCodexTask_Timeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRunCodexTask_SignalHandling(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("signal-based test is not supported on Windows")
|
||||
}
|
||||
|
||||
defer resetTestHooks()
|
||||
codexCommand = "sleep"
|
||||
buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{"5"} }
|
||||
@@ -2728,7 +2824,9 @@ func TestRunCodexTask_SignalHandling(t *testing.T) {
|
||||
go func() { resultCh <- runCodexTask(TaskSpec{Task: "ignored"}, false, 5) }()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGTERM)
|
||||
if proc, err := os.FindProcess(os.Getpid()); err == nil && proc != nil {
|
||||
_ = proc.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
res := <-resultCh
|
||||
signal.Reset(syscall.SIGINT, syscall.SIGTERM)
|
||||
@@ -3984,6 +4082,10 @@ func TestRun_LoggerLifecycle(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRun_LoggerRemovedOnSignal(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("signal-based test is not supported on Windows")
|
||||
}
|
||||
|
||||
// Skip in CI due to unreliable signal delivery in containerized environments
|
||||
if os.Getenv("CI") != "" || os.Getenv("GITHUB_ACTIONS") != "" {
|
||||
t.Skip("Skipping signal test in CI environment")
|
||||
@@ -4025,7 +4127,9 @@ printf '%s\n' '{"type":"item.completed","item":{"type":"agent_message","text":"l
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
_ = syscall.Kill(os.Getpid(), syscall.SIGINT)
|
||||
if proc, err := os.FindProcess(os.Getpid()); err == nil && proc != nil {
|
||||
_ = proc.Signal(syscall.SIGINT)
|
||||
}
|
||||
|
||||
var exitCode int
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user