fix: 增强日志清理的安全性和可靠性

必须修复的问题:
1. PID重用防护 - 添加进程启动时间检查,对比文件修改时间避免误删活动进程的日志
   - Unix: 通过 /proc/<pid>/stat 读取进程启动时间
   - Windows: 使用 GetProcessTimes API 获取创建时间
   - 7天策略: 无法获取进程启动时间时,超过7天的日志视为孤儿

2. 符号链接攻击防护 - 新增安全检查避免删除恶意符号链接
   - 使用 os.Lstat 检测符号链接
   - 使用 filepath.EvalSymlinks 解析真实路径
   - 确保所有文件在 TempDir 内(防止路径遍历)

强烈建议的改进:
3. 异步启动清理 - 通过 goroutine 运行清理避免阻塞主流程启动

4. NotExist错误语义修正 - 文件已被其他进程删除时计入 Kept 而非 Deleted
   - 更准确反映实际清理行为
   - 避免并发清理时的统计误导

5. Windows兼容性验证 - 完善Windows平台的进程时间获取

测试覆盖:
- 更新所有测试以适配新的安全检查逻辑
- 添加 stubProcessStartTime 支持PID重用测试
- 修复 setTempDirEnv 解析符号链接避免安全检查失败
- 所有测试通过(codex-wrapper: ok 6.183s)

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
swe-agent[bot]
2025-12-08 10:53:52 +08:00
parent 9452b77307
commit da257b860b
6 changed files with 317 additions and 46 deletions

View File

@@ -46,9 +46,12 @@ type CleanupStats struct {
}
var (
processRunningCheck = isProcessRunning
removeLogFileFn = os.Remove
globLogFiles = filepath.Glob
processRunningCheck = isProcessRunning
processStartTimeFn = getProcessStartTime
removeLogFileFn = os.Remove
globLogFiles = filepath.Glob
fileStatFn = os.Lstat // Use Lstat to detect symlinks
evalSymlinksFn = filepath.EvalSymlinks
)
// NewLogger creates the async logger and starts the worker goroutine.
@@ -263,6 +266,9 @@ func (l *Logger) run() {
// cleanupOldLogs scans os.TempDir() for codex-wrapper-*.log files and removes those
// whose owning process is no longer running (i.e., orphaned logs).
// It includes safety checks for:
// - PID reuse: Compares file modification time with process start time
// - Symlink attacks: Ensures files are within TempDir and not symlinks
func cleanupOldLogs() (CleanupStats, error) {
var stats CleanupStats
tempDir := os.TempDir()
@@ -279,30 +285,66 @@ func cleanupOldLogs() (CleanupStats, error) {
for _, path := range matches {
stats.Scanned++
filename := filepath.Base(path)
// Security check: Verify file is not a symlink and is within tempDir
if shouldSkipFile, reason := isUnsafeFile(path, tempDir); shouldSkipFile {
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename)
if reason != "" {
logWarn(fmt.Sprintf("cleanupOldLogs: skipping %s: %s", filename, reason))
}
continue
}
pid, ok := parsePIDFromLog(path)
if !ok {
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename)
continue
}
if processRunningCheck(pid) {
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename)
continue
}
if err := removeLogFileFn(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
stats.Deleted++
stats.DeletedFiles = append(stats.DeletedFiles, filename)
// Check if process is running
if !processRunningCheck(pid) {
// Process not running, safe to delete
if err := removeLogFileFn(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
// File already deleted by another process, don't count as success
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename+" (already deleted)")
continue
}
stats.Errors++
logWarn(fmt.Sprintf("cleanupOldLogs: failed to remove %s: %v", filename, err))
removeErr = errors.Join(removeErr, fmt.Errorf("failed to remove %s: %w", filename, err))
continue
}
stats.Errors++
logWarn(fmt.Sprintf("cleanupOldLogs: failed to remove %s: %v", filename, err))
removeErr = errors.Join(removeErr, fmt.Errorf("failed to remove %s: %w", filename, err))
stats.Deleted++
stats.DeletedFiles = append(stats.DeletedFiles, filename)
continue
}
stats.Deleted++
stats.DeletedFiles = append(stats.DeletedFiles, filename)
// Process is running, check for PID reuse
if isPIDReused(path, pid) {
// PID was reused, the log file is orphaned
if err := removeLogFileFn(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename+" (already deleted)")
continue
}
stats.Errors++
logWarn(fmt.Sprintf("cleanupOldLogs: failed to remove %s (PID reused): %v", filename, err))
removeErr = errors.Join(removeErr, fmt.Errorf("failed to remove %s: %w", filename, err))
continue
}
stats.Deleted++
stats.DeletedFiles = append(stats.DeletedFiles, filename)
continue
}
// Process is running and owns this log file
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename)
}
if removeErr != nil {
@@ -312,6 +354,72 @@ func cleanupOldLogs() (CleanupStats, error) {
return stats, nil
}
// isUnsafeFile checks if a file is unsafe to delete (symlink or outside tempDir).
// Returns (true, reason) if the file should be skipped.
func isUnsafeFile(path string, tempDir string) (bool, string) {
// Check if file is a symlink
info, err := fileStatFn(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return true, "" // File disappeared, skip silently
}
return true, fmt.Sprintf("stat failed: %v", err)
}
// Check if it's a symlink
if info.Mode()&os.ModeSymlink != 0 {
return true, "refusing to delete symlink"
}
// Resolve any path traversal and verify it's within tempDir
resolvedPath, err := evalSymlinksFn(path)
if err != nil {
return true, fmt.Sprintf("path resolution failed: %v", err)
}
// Get absolute path of tempDir
absTempDir, err := filepath.Abs(tempDir)
if err != nil {
return true, fmt.Sprintf("tempDir resolution failed: %v", err)
}
// Ensure resolved path is within tempDir
relPath, err := filepath.Rel(absTempDir, resolvedPath)
if err != nil || strings.HasPrefix(relPath, "..") {
return true, "file is outside tempDir"
}
return false, ""
}
// isPIDReused checks if a PID has been reused by comparing file modification time
// with process start time. Returns true if the log file was created by a different
// process that previously had the same PID.
func isPIDReused(logPath string, pid int) bool {
// Get file modification time (when log was last written)
info, err := fileStatFn(logPath)
if err != nil {
// If we can't stat the file, be conservative and keep it
return false
}
fileModTime := info.ModTime()
// Get process start time
procStartTime := processStartTimeFn(pid)
if procStartTime.IsZero() {
// Can't determine process start time
// Check if file is very old (>7 days), likely from a dead process
if time.Since(fileModTime) > 7*24*time.Hour {
return true // File is old enough to be from a different process
}
return false // Be conservative for recent files
}
// If the log file was modified before the process started, PID was reused
// Add a small buffer (1 second) to account for clock skew and file system timing
return fileModTime.Add(1 * time.Second).Before(procStartTime)
}
func parsePIDFromLog(path string) (int, bool) {
name := filepath.Base(path)
if !strings.HasPrefix(name, "codex-wrapper-") || !strings.HasSuffix(name, ".log") {

View File

@@ -201,8 +201,7 @@ func TestRunTerminateProcessNil(t *testing.T) {
}
func TestRunCleanupOldLogsRemovesOrphans(t *testing.T) {
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
orphan1 := createTempLog(t, tempDir, "codex-wrapper-111.log")
orphan2 := createTempLog(t, tempDir, "codex-wrapper-222-suffix.log")
@@ -215,6 +214,15 @@ func TestRunCleanupOldLogsRemovesOrphans(t *testing.T) {
return runningPIDs[pid]
})
// Stub process start time to be in the past so files won't be considered as PID reused
stubProcessStartTime(t, func(pid int) time.Time {
if runningPIDs[pid] {
// Return a time before file creation
return time.Now().Add(-1 * time.Hour)
}
return time.Time{}
})
stats, err := cleanupOldLogs()
if err != nil {
t.Fatalf("cleanupOldLogs() unexpected error: %v", err)
@@ -243,8 +251,7 @@ func TestRunCleanupOldLogsRemovesOrphans(t *testing.T) {
}
func TestRunCleanupOldLogsHandlesInvalidNamesAndErrors(t *testing.T) {
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
invalid := []string{
"codex-wrapper-.log",
@@ -263,6 +270,10 @@ func TestRunCleanupOldLogsHandlesInvalidNamesAndErrors(t *testing.T) {
return false
})
stubProcessStartTime(t, func(pid int) time.Time {
return time.Time{} // Return zero time for processes not running
})
removeErr := errors.New("remove failure")
callCount := 0
stubRemoveLogFile(t, func(path string) error {
@@ -302,6 +313,9 @@ func TestRunCleanupOldLogsHandlesGlobFailures(t *testing.T) {
t.Fatalf("process check should not run when glob fails")
return false
})
stubProcessStartTime(t, func(int) time.Time {
return time.Time{}
})
globErr := errors.New("glob failure")
stubGlobLogFiles(t, func(pattern string) ([]string, error) {
@@ -321,13 +335,15 @@ func TestRunCleanupOldLogsHandlesGlobFailures(t *testing.T) {
}
func TestRunCleanupOldLogsEmptyDirectoryStats(t *testing.T) {
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
setTempDirEnv(t, t.TempDir())
stubProcessRunning(t, func(int) bool {
t.Fatalf("process check should not run for empty directory")
return false
})
stubProcessStartTime(t, func(int) time.Time {
return time.Time{}
})
stats, err := cleanupOldLogs()
if err != nil {
@@ -339,8 +355,7 @@ func TestRunCleanupOldLogsEmptyDirectoryStats(t *testing.T) {
}
func TestRunCleanupOldLogsHandlesTempDirPermissionErrors(t *testing.T) {
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
paths := []string{
createTempLog(t, tempDir, "codex-wrapper-6100.log"),
@@ -348,6 +363,7 @@ func TestRunCleanupOldLogsHandlesTempDirPermissionErrors(t *testing.T) {
}
stubProcessRunning(t, func(int) bool { return false })
stubProcessStartTime(t, func(int) time.Time { return time.Time{} })
var attempts int
stubRemoveLogFile(t, func(path string) error {
@@ -379,13 +395,13 @@ func TestRunCleanupOldLogsHandlesTempDirPermissionErrors(t *testing.T) {
}
func TestRunCleanupOldLogsHandlesPermissionDeniedFile(t *testing.T) {
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
protected := createTempLog(t, tempDir, "codex-wrapper-6200.log")
deletable := createTempLog(t, tempDir, "codex-wrapper-6201.log")
stubProcessRunning(t, func(int) bool { return false })
stubProcessStartTime(t, func(int) time.Time { return time.Time{} })
stubRemoveLogFile(t, func(path string) error {
if path == protected {
@@ -416,19 +432,20 @@ func TestRunCleanupOldLogsHandlesPermissionDeniedFile(t *testing.T) {
}
func TestRunCleanupOldLogsPerformanceBound(t *testing.T) {
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
const fileCount = 400
fakePaths := make([]string, fileCount)
for i := 0; i < fileCount; i++ {
fakePaths[i] = filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", 10000+i))
name := fmt.Sprintf("codex-wrapper-%d.log", 10000+i)
fakePaths[i] = createTempLog(t, tempDir, name)
}
stubGlobLogFiles(t, func(pattern string) ([]string, error) {
return fakePaths, nil
})
stubProcessRunning(t, func(int) bool { return false })
stubProcessStartTime(t, func(int) time.Time { return time.Time{} })
var removed int
stubRemoveLogFile(t, func(path string) error {
@@ -468,8 +485,7 @@ func TestRunLoggerCoverageSuite(t *testing.T) {
}
func TestRunCleanupOldLogsKeepsCurrentProcessLog(t *testing.T) {
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
currentPID := os.Getpid()
currentLog := createTempLog(t, tempDir, fmt.Sprintf("codex-wrapper-%d.log", currentPID))
@@ -480,6 +496,12 @@ func TestRunCleanupOldLogsKeepsCurrentProcessLog(t *testing.T) {
}
return true
})
stubProcessStartTime(t, func(pid int) time.Time {
if pid == currentPID {
return time.Now().Add(-1 * time.Hour)
}
return time.Time{}
})
stats, err := cleanupOldLogs()
if err != nil {
@@ -580,11 +602,16 @@ func createTempLog(t *testing.T, dir, name string) string {
return path
}
func setTempDirEnv(t *testing.T, dir string) {
func setTempDirEnv(t *testing.T, dir string) string {
t.Helper()
t.Setenv("TMPDIR", dir)
t.Setenv("TEMP", dir)
t.Setenv("TMP", dir)
resolved := dir
if eval, err := filepath.EvalSymlinks(dir); err == nil {
resolved = eval
}
t.Setenv("TMPDIR", resolved)
t.Setenv("TEMP", resolved)
t.Setenv("TMP", resolved)
return resolved
}
func stubProcessRunning(t *testing.T, fn func(int) bool) {
@@ -596,6 +623,15 @@ func stubProcessRunning(t *testing.T, fn func(int) bool) {
})
}
func stubProcessStartTime(t *testing.T, fn func(int) time.Time) {
t.Helper()
original := processStartTimeFn
processStartTimeFn = fn
t.Cleanup(func() {
processStartTimeFn = original
})
}
func stubRemoveLogFile(t *testing.T, fn func(string) error) {
t.Helper()
original := removeLogFileFn

View File

@@ -422,7 +422,8 @@ func run() (exitCode int) {
}()
defer runCleanupHook()
runStartupCleanup()
// Run cleanup asynchronously to avoid blocking startup
go runStartupCleanup()
// Handle remaining commands
if len(os.Args) > 1 {

View File

@@ -403,8 +403,7 @@ func TestRunConcurrentSpeedupBenchmark(t *testing.T) {
func TestRunStartupCleanupRemovesOrphansEndToEnd(t *testing.T) {
defer resetTestHooks()
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
orphanA := createTempLog(t, tempDir, "codex-wrapper-5001.log")
orphanB := createTempLog(t, tempDir, "codex-wrapper-5002-extra.log")
@@ -416,6 +415,12 @@ func TestRunStartupCleanupRemovesOrphansEndToEnd(t *testing.T) {
stubProcessRunning(t, func(pid int) bool {
return pid == runningPID || pid == os.Getpid()
})
stubProcessStartTime(t, func(pid int) time.Time {
if pid == runningPID || pid == os.Getpid() {
return time.Now().Add(-1 * time.Hour)
}
return time.Time{}
})
codexCommand = createFakeCodexScript(t, "tid-startup", "ok")
stdinReader = strings.NewReader("")
@@ -442,8 +447,7 @@ func TestRunStartupCleanupRemovesOrphansEndToEnd(t *testing.T) {
func TestRunStartupCleanupConcurrentWrappers(t *testing.T) {
defer resetTestHooks()
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
const totalLogs = 40
for i := 0; i < totalLogs; i++ {
@@ -453,6 +457,7 @@ func TestRunStartupCleanupConcurrentWrappers(t *testing.T) {
stubProcessRunning(t, func(pid int) bool {
return false
})
stubProcessStartTime(t, func(int) time.Time { return time.Time{} })
var wg sync.WaitGroup
const instances = 5
@@ -482,8 +487,7 @@ func TestRunStartupCleanupConcurrentWrappers(t *testing.T) {
func TestRunCleanupFlagEndToEnd_Success(t *testing.T) {
defer resetTestHooks()
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
staleA := createTempLog(t, tempDir, "codex-wrapper-2100.log")
staleB := createTempLog(t, tempDir, "codex-wrapper-2200-extra.log")
@@ -492,6 +496,12 @@ func TestRunCleanupFlagEndToEnd_Success(t *testing.T) {
stubProcessRunning(t, func(pid int) bool {
return pid == 2300 || pid == os.Getpid()
})
stubProcessStartTime(t, func(pid int) time.Time {
if pid == 2300 || pid == os.Getpid() {
return time.Now().Add(-1 * time.Hour)
}
return time.Time{}
})
os.Args = []string{"codex-wrapper", "--cleanup"}
@@ -544,8 +554,7 @@ func TestRunCleanupFlagEndToEnd_Success(t *testing.T) {
func TestRunCleanupFlagEndToEnd_FailureDoesNotAffectStartup(t *testing.T) {
defer resetTestHooks()
tempDir := t.TempDir()
setTempDirEnv(t, tempDir)
tempDir := setTempDirEnv(t, t.TempDir())
calls := 0
cleanupLogsFn = func() (CleanupStats, error) {

View File

@@ -5,11 +5,16 @@ package main
import (
"errors"
"fmt"
"os"
"strconv"
"strings"
"syscall"
"time"
)
var findProcess = os.FindProcess
var readFileFn = os.ReadFile
// isProcessRunning returns true if a process with the given pid is running on Unix-like systems.
func isProcessRunning(pid int) bool {
@@ -28,3 +33,72 @@ func isProcessRunning(pid int) bool {
}
return true
}
// getProcessStartTime returns the start time of a process on Unix-like systems.
// Returns zero time if the start time cannot be determined.
func getProcessStartTime(pid int) time.Time {
if pid <= 0 {
return time.Time{}
}
// Read /proc/<pid>/stat to get process start time
statPath := fmt.Sprintf("/proc/%d/stat", pid)
data, err := readFileFn(statPath)
if err != nil {
return time.Time{}
}
// Parse stat file: fields are space-separated, but comm (field 2) can contain spaces
// Find the last ')' to skip comm field safely
content := string(data)
lastParen := strings.LastIndex(content, ")")
if lastParen == -1 {
return time.Time{}
}
fields := strings.Fields(content[lastParen+1:])
if len(fields) < 20 {
return time.Time{}
}
// Field 22 (index 19 after comm) is starttime in clock ticks since boot
startTicks, err := strconv.ParseUint(fields[19], 10, 64)
if err != nil {
return time.Time{}
}
// Get system boot time
bootTime := getBootTime()
if bootTime.IsZero() {
return time.Time{}
}
// Convert ticks to duration (typically 100 ticks/sec on most systems)
ticksPerSec := uint64(100) // sysconf(_SC_CLK_TCK), typically 100
startTime := bootTime.Add(time.Duration(startTicks/ticksPerSec) * time.Second)
return startTime
}
// getBootTime returns the system boot time by reading /proc/stat.
func getBootTime() time.Time {
data, err := readFileFn("/proc/stat")
if err != nil {
return time.Time{}
}
lines := strings.Split(string(data), "\n")
for _, line := range lines {
if strings.HasPrefix(line, "btime ") {
fields := strings.Fields(line)
if len(fields) >= 2 {
bootSec, err := strconv.ParseInt(fields[1], 10, 64)
if err == nil {
return time.Unix(bootSec, 0)
}
}
}
}
return time.Time{}
}

View File

@@ -7,6 +7,8 @@ import (
"errors"
"os"
"syscall"
"time"
"unsafe"
)
const (
@@ -14,7 +16,12 @@ const (
stillActive = 259 // STILL_ACTIVE exit code
)
var findProcess = os.FindProcess
var (
findProcess = os.FindProcess
kernel32 = syscall.NewLazyDLL("kernel32.dll")
getProcessTimes = kernel32.NewProc("GetProcessTimes")
fileTimeToUnixFn = fileTimeToUnix
)
// isProcessRunning returns true if a process with the given pid is running on Windows.
func isProcessRunning(pid int) bool {
@@ -42,3 +49,39 @@ func isProcessRunning(pid int) bool {
return exitCode == stillActive
}
// getProcessStartTime returns the start time of a process on Windows.
// Returns zero time if the start time cannot be determined.
func getProcessStartTime(pid int) time.Time {
if pid <= 0 {
return time.Time{}
}
handle, err := syscall.OpenProcess(processQueryLimitedInformation, false, uint32(pid))
if err != nil {
return time.Time{}
}
defer syscall.CloseHandle(handle)
var creationTime, exitTime, kernelTime, userTime syscall.Filetime
ret, _, _ := getProcessTimes.Call(
uintptr(handle),
uintptr(unsafe.Pointer(&creationTime)),
uintptr(unsafe.Pointer(&exitTime)),
uintptr(unsafe.Pointer(&kernelTime)),
uintptr(unsafe.Pointer(&userTime)),
)
if ret == 0 {
return time.Time{}
}
return fileTimeToUnixFn(creationTime)
}
// fileTimeToUnix converts Windows FILETIME to Unix time.
func fileTimeToUnix(ft syscall.Filetime) time.Time {
// FILETIME is 100-nanosecond intervals since January 1, 1601 UTC
nsec := ft.Nanoseconds()
return time.Unix(0, nsec)
}