Files
myclaude/codeagent-wrapper/logger.go
ben fe5508228f fix: 修复多 backend 并行日志 PID 混乱并移除包装格式 (#74) (#76)
* fix(logger): 修复多 backend 并行日志 PID 混乱并移除包装格式

**问题:**
- logger.go:288 使用 os.Getpid() 导致并行任务日志 PID 混乱
- 日志文件添加时间戳/PID/级别前缀包装,应输出 backend 原始内容

**修复:**
1. Logger 结构体添加 pid 字段,创建时捕获 PID
2. 日志写入使用固定 l.pid 替代 os.Getpid()
3. 移除日志输出格式包装,直接写入原始消息
4. 添加内存缓存 ERROR/WARN 条目,ExtractRecentErrors 从缓存读取
5. 优化 executor.go context 初始化顺序,避免重复创建 logger

**测试:**
- 所有测试通过(23.7s)
- 更新相关测试用例匹配新格式

Closes #74

* fix(logger): 增强并发日志隔离和 task ID 清理

## 核心修复

### 1. Task ID Sanitization (logger.go)
- 新增 sanitizeLogSuffix(): 清理非法字符 (/, \, :, 等)
- 新增 fallbackLogSuffix(): 为空/非法 ID 生成唯一后备名
- 新增 isSafeLogRune(): 仅允许 [A-Za-z0-9._-]
- 路径穿越防护: ../../../etc/passwd → etc-passwd-{hash}.log
- 超长 ID 处理: 截断到 64 字符 + hash 确保唯一性
- 自动创建 TMPDIR (MkdirAll)

### 2. 共享日志标识 (executor.go)
- 新增 taskLoggerHandle 结构: 封装 logger、路径、共享标志
- 新增 newTaskLoggerHandle(): 统一处理 logger 创建和回退
- printTaskStart(): 显示 "Log (shared)" 标识
- generateFinalOutput(): 在 summary 中标记共享日志
- 并发失败时明确标识所有任务使用共享主日志

### 3. 内部标志 (config.go)
- TaskResult.sharedLog: 非导出字段,标识共享日志状态

### 4. Race Detector 修复 (logger.go:209-219)
- Close() 在关闭 channel 前先等待 pendingWG
- 消除 Logger.Close() 与 Logger.log() 之间的竞态条件

## 测试覆盖

### 新增测试 (logger_suffix_test.go)
- TestLoggerWithSuffixSanitizesUnsafeSuffix: 非法字符清理
- TestLoggerWithSuffixReturnsErrorWhenTempDirNotWritable: 只读目录处理

### 新增测试 (executor_concurrent_test.go)
- TestConcurrentTaskLoggerFailure: 多任务失败时共享日志标识
- TestSanitizeTaskID: 并发场景下 task ID 清理验证

## 验证结果

 所有单元测试通过
 Race detector 无竞态 (65.4s)
 路径穿越攻击防护
 并发日志完全隔离
 边界情况正确处理

Resolves: PR #76 review feedback
Co-Authored-By: Codex Review <codex@anthropic.ai>

Generated with swe-agent-bot

Co-Authored-By: swe-agent-bot <agent@swe-agent.ai>

* fix(logger): 修复关键 bug 并优化日志系统 (v5.2.5)

修复 P0 级别问题:
- sanitizeLogSuffix 的 trim 碰撞(防止多 task 日志文件名冲突)
- ExtractRecentErrors 边界检查(防止 slice 越界)
- Logger.Close 阻塞风险(新增可配置超时机制)

代码质量改进:
- 删除无用字段 Logger.pid 和 logEntry.level
- 优化 sharedLog 标记绑定到最终 LogPath
- 移除日志前缀,直接输出 backend 原始内容

测试覆盖增强:
- 新增 4 个测试用例(碰撞防护、边界检查、缓存上限、shared 判定)
- 优化测试注释和逻辑

版本更新:5.2.4 → 5.2.5

Generated with swe-agent-bot

Co-Authored-By: swe-agent-bot <agent@swe-agent.ai>

---------

Co-authored-by: swe-agent-bot <agent@swe-agent.ai>
2025-12-17 10:33:38 +08:00

663 lines
16 KiB
Go

package main
import (
"bufio"
"context"
"errors"
"fmt"
"hash/crc32"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
// Logger writes log messages asynchronously to a temp file.
// It is intentionally minimal: a buffered channel + single worker goroutine
// to avoid contention while keeping ordering guarantees.
type Logger struct {
path string
file *os.File
writer *bufio.Writer
ch chan logEntry
flushReq chan chan struct{}
done chan struct{}
closed atomic.Bool
closeOnce sync.Once
workerWG sync.WaitGroup
pendingWG sync.WaitGroup
flushMu sync.Mutex
workerErr error
errorEntries []string // Cache of recent ERROR/WARN entries
errorMu sync.Mutex
}
type logEntry struct {
msg string
isError bool // true for ERROR or WARN levels
}
// CleanupStats captures the outcome of a cleanupOldLogs run.
type CleanupStats struct {
Scanned int
Deleted int
Kept int
Errors int
DeletedFiles []string
KeptFiles []string
}
var (
processRunningCheck = isProcessRunning
processStartTimeFn = getProcessStartTime
removeLogFileFn = os.Remove
globLogFiles = filepath.Glob
fileStatFn = os.Lstat // Use Lstat to detect symlinks
evalSymlinksFn = filepath.EvalSymlinks
)
const maxLogSuffixLen = 64
var logSuffixCounter atomic.Uint64
// NewLogger creates the async logger and starts the worker goroutine.
// The log file is created under os.TempDir() using the required naming scheme.
func NewLogger() (*Logger, error) {
return NewLoggerWithSuffix("")
}
// NewLoggerWithSuffix creates a logger with an optional suffix in the filename.
// Useful for tests that need isolated log files within the same process.
func NewLoggerWithSuffix(suffix string) (*Logger, error) {
pid := os.Getpid()
filename := fmt.Sprintf("%s-%d", primaryLogPrefix(), pid)
var safeSuffix string
if suffix != "" {
safeSuffix = sanitizeLogSuffix(suffix)
}
if safeSuffix != "" {
filename += "-" + safeSuffix
}
filename += ".log"
path := filepath.Clean(filepath.Join(os.TempDir(), filename))
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return nil, err
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600)
if err != nil {
return nil, err
}
l := &Logger{
path: path,
file: f,
writer: bufio.NewWriterSize(f, 4096),
ch: make(chan logEntry, 1000),
flushReq: make(chan chan struct{}, 1),
done: make(chan struct{}),
}
l.workerWG.Add(1)
go l.run()
return l, nil
}
func sanitizeLogSuffix(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return fallbackLogSuffix()
}
var b strings.Builder
changed := false
for _, r := range trimmed {
if isSafeLogRune(r) {
b.WriteRune(r)
} else {
changed = true
b.WriteByte('-')
}
if b.Len() >= maxLogSuffixLen {
changed = true
break
}
}
sanitized := strings.Trim(b.String(), "-.")
if sanitized != b.String() {
changed = true // Mark if trim removed any characters
}
if sanitized == "" {
return fallbackLogSuffix()
}
if changed || len(sanitized) > maxLogSuffixLen {
hash := crc32.ChecksumIEEE([]byte(trimmed))
hashStr := fmt.Sprintf("%x", hash)
maxPrefix := maxLogSuffixLen - len(hashStr) - 1
if maxPrefix < 1 {
maxPrefix = 1
}
if len(sanitized) > maxPrefix {
sanitized = sanitized[:maxPrefix]
}
sanitized = fmt.Sprintf("%s-%s", sanitized, hashStr)
}
return sanitized
}
func fallbackLogSuffix() string {
next := logSuffixCounter.Add(1)
return fmt.Sprintf("task-%d", next)
}
func isSafeLogRune(r rune) bool {
switch {
case r >= 'a' && r <= 'z':
return true
case r >= 'A' && r <= 'Z':
return true
case r >= '0' && r <= '9':
return true
case r == '-', r == '_', r == '.':
return true
default:
return false
}
}
// Path returns the underlying log file path (useful for tests/inspection).
func (l *Logger) Path() string {
if l == nil {
return ""
}
return l.path
}
// Info logs at INFO level.
func (l *Logger) Info(msg string) { l.log("INFO", msg) }
// Warn logs at WARN level.
func (l *Logger) Warn(msg string) { l.log("WARN", msg) }
// Debug logs at DEBUG level.
func (l *Logger) Debug(msg string) { l.log("DEBUG", msg) }
// Error logs at ERROR level.
func (l *Logger) Error(msg string) { l.log("ERROR", msg) }
// Close signals the worker to flush and close the log file.
// The log file is NOT removed, allowing inspection after program exit.
// It is safe to call multiple times.
// Waits up to CODEAGENT_LOGGER_CLOSE_TIMEOUT_MS (default: 5000) for shutdown; set to 0 to wait indefinitely.
// Returns an error if shutdown doesn't complete within the timeout.
func (l *Logger) Close() error {
if l == nil {
return nil
}
var closeErr error
l.closeOnce.Do(func() {
l.closed.Store(true)
close(l.done)
timeout := loggerCloseTimeout()
workerDone := make(chan struct{})
go func() {
l.workerWG.Wait()
close(workerDone)
}()
if timeout > 0 {
select {
case <-workerDone:
// Worker stopped gracefully
case <-time.After(timeout):
closeErr = fmt.Errorf("logger worker timeout during close")
return
}
} else {
<-workerDone
}
if l.workerErr != nil && closeErr == nil {
closeErr = l.workerErr
}
})
return closeErr
}
func loggerCloseTimeout() time.Duration {
const defaultTimeout = 5 * time.Second
raw := strings.TrimSpace(os.Getenv("CODEAGENT_LOGGER_CLOSE_TIMEOUT_MS"))
if raw == "" {
return defaultTimeout
}
ms, err := strconv.Atoi(raw)
if err != nil {
return defaultTimeout
}
if ms <= 0 {
return 0
}
return time.Duration(ms) * time.Millisecond
}
// RemoveLogFile removes the log file. Should only be called after Close().
func (l *Logger) RemoveLogFile() error {
if l == nil {
return nil
}
return os.Remove(l.path)
}
// ExtractRecentErrors returns the most recent ERROR and WARN entries from memory cache.
// Returns up to maxEntries entries in chronological order.
func (l *Logger) ExtractRecentErrors(maxEntries int) []string {
if l == nil || maxEntries <= 0 {
return nil
}
l.errorMu.Lock()
defer l.errorMu.Unlock()
if len(l.errorEntries) == 0 {
return nil
}
// Return last N entries
start := 0
if len(l.errorEntries) > maxEntries {
start = len(l.errorEntries) - maxEntries
}
result := make([]string, len(l.errorEntries)-start)
copy(result, l.errorEntries[start:])
return result
}
// Flush waits for all pending log entries to be written. Primarily for tests.
// Returns after a 5-second timeout to prevent indefinite blocking.
func (l *Logger) Flush() {
if l == nil {
return
}
l.flushMu.Lock()
defer l.flushMu.Unlock()
// Wait for pending entries with timeout
done := make(chan struct{})
go func() {
l.pendingWG.Wait()
close(done)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-done:
// All pending entries processed
case <-ctx.Done():
// Timeout - return without full flush
return
}
// Trigger writer flush
flushDone := make(chan struct{})
select {
case l.flushReq <- flushDone:
// Wait for flush to complete
select {
case <-flushDone:
// Flush completed
case <-time.After(1 * time.Second):
// Flush timeout
}
case <-l.done:
// Logger is closing
case <-time.After(1 * time.Second):
// Timeout sending flush request
}
}
func (l *Logger) log(level, msg string) {
if l == nil {
return
}
if l.closed.Load() {
return
}
isError := level == "WARN" || level == "ERROR"
entry := logEntry{msg: msg, isError: isError}
l.flushMu.Lock()
l.pendingWG.Add(1)
l.flushMu.Unlock()
select {
case l.ch <- entry:
// Successfully sent to channel
case <-l.done:
// Logger is closing, drop this entry
l.pendingWG.Done()
return
}
}
func (l *Logger) run() {
defer l.workerWG.Done()
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
writeEntry := func(entry logEntry) {
fmt.Fprintf(l.writer, "%s\n", entry.msg)
// Cache error/warn entries in memory for fast extraction
if entry.isError {
l.errorMu.Lock()
l.errorEntries = append(l.errorEntries, entry.msg)
if len(l.errorEntries) > 100 { // Keep last 100
l.errorEntries = l.errorEntries[1:]
}
l.errorMu.Unlock()
}
l.pendingWG.Done()
}
finalize := func() {
if err := l.writer.Flush(); err != nil && l.workerErr == nil {
l.workerErr = err
}
if err := l.file.Sync(); err != nil && l.workerErr == nil {
l.workerErr = err
}
if err := l.file.Close(); err != nil && l.workerErr == nil {
l.workerErr = err
}
}
for {
select {
case entry, ok := <-l.ch:
if !ok {
finalize()
return
}
writeEntry(entry)
case <-ticker.C:
_ = l.writer.Flush()
case flushDone := <-l.flushReq:
// Explicit flush request - flush writer and sync to disk
_ = l.writer.Flush()
_ = l.file.Sync()
close(flushDone)
case <-l.done:
for {
select {
case entry, ok := <-l.ch:
if !ok {
finalize()
return
}
writeEntry(entry)
default:
finalize()
return
}
}
}
}
}
// cleanupOldLogs scans os.TempDir() for wrapper log files and removes those
// whose owning process is no longer running (i.e., orphaned logs).
// It includes safety checks for:
// - PID reuse: Compares file modification time with process start time
// - Symlink attacks: Ensures files are within TempDir and not symlinks
func cleanupOldLogs() (CleanupStats, error) {
var stats CleanupStats
tempDir := os.TempDir()
prefixes := logPrefixes()
if len(prefixes) == 0 {
prefixes = []string{defaultWrapperName}
}
seen := make(map[string]struct{})
var matches []string
for _, prefix := range prefixes {
pattern := filepath.Join(tempDir, fmt.Sprintf("%s-*.log", prefix))
found, err := globLogFiles(pattern)
if err != nil {
logWarn(fmt.Sprintf("cleanupOldLogs: failed to list logs: %v", err))
return stats, fmt.Errorf("cleanupOldLogs: %w", err)
}
for _, path := range found {
if _, ok := seen[path]; ok {
continue
}
seen[path] = struct{}{}
matches = append(matches, path)
}
}
var removeErr error
for _, path := range matches {
stats.Scanned++
filename := filepath.Base(path)
// Security check: Verify file is not a symlink and is within tempDir
if shouldSkipFile, reason := isUnsafeFile(path, tempDir); shouldSkipFile {
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename)
if reason != "" {
logWarn(fmt.Sprintf("cleanupOldLogs: skipping %s: %s", filename, reason))
}
continue
}
pid, ok := parsePIDFromLog(path)
if !ok {
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename)
continue
}
// Check if process is running
if !processRunningCheck(pid) {
// Process not running, safe to delete
if err := removeLogFileFn(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
// File already deleted by another process, don't count as success
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename+" (already deleted)")
continue
}
stats.Errors++
logWarn(fmt.Sprintf("cleanupOldLogs: failed to remove %s: %v", filename, err))
removeErr = errors.Join(removeErr, fmt.Errorf("failed to remove %s: %w", filename, err))
continue
}
stats.Deleted++
stats.DeletedFiles = append(stats.DeletedFiles, filename)
continue
}
// Process is running, check for PID reuse
if isPIDReused(path, pid) {
// PID was reused, the log file is orphaned
if err := removeLogFileFn(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename+" (already deleted)")
continue
}
stats.Errors++
logWarn(fmt.Sprintf("cleanupOldLogs: failed to remove %s (PID reused): %v", filename, err))
removeErr = errors.Join(removeErr, fmt.Errorf("failed to remove %s: %w", filename, err))
continue
}
stats.Deleted++
stats.DeletedFiles = append(stats.DeletedFiles, filename)
continue
}
// Process is running and owns this log file
stats.Kept++
stats.KeptFiles = append(stats.KeptFiles, filename)
}
if removeErr != nil {
return stats, fmt.Errorf("cleanupOldLogs: %w", removeErr)
}
return stats, nil
}
// isUnsafeFile checks if a file is unsafe to delete (symlink or outside tempDir).
// Returns (true, reason) if the file should be skipped.
func isUnsafeFile(path string, tempDir string) (bool, string) {
// Check if file is a symlink
info, err := fileStatFn(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return true, "" // File disappeared, skip silently
}
return true, fmt.Sprintf("stat failed: %v", err)
}
// Check if it's a symlink
if info.Mode()&os.ModeSymlink != 0 {
return true, "refusing to delete symlink"
}
// Resolve any path traversal and verify it's within tempDir
resolvedPath, err := evalSymlinksFn(path)
if err != nil {
return true, fmt.Sprintf("path resolution failed: %v", err)
}
// Get absolute path of tempDir
absTempDir, err := filepath.Abs(tempDir)
if err != nil {
return true, fmt.Sprintf("tempDir resolution failed: %v", err)
}
// Ensure resolved path is within tempDir
relPath, err := filepath.Rel(absTempDir, resolvedPath)
if err != nil || strings.HasPrefix(relPath, "..") {
return true, "file is outside tempDir"
}
return false, ""
}
// isPIDReused checks if a PID has been reused by comparing file modification time
// with process start time. Returns true if the log file was created by a different
// process that previously had the same PID.
func isPIDReused(logPath string, pid int) bool {
// Get file modification time (when log was last written)
info, err := fileStatFn(logPath)
if err != nil {
// If we can't stat the file, be conservative and keep it
return false
}
fileModTime := info.ModTime()
// Get process start time
procStartTime := processStartTimeFn(pid)
if procStartTime.IsZero() {
// Can't determine process start time
// Check if file is very old (>7 days), likely from a dead process
if time.Since(fileModTime) > 7*24*time.Hour {
return true // File is old enough to be from a different process
}
return false // Be conservative for recent files
}
// If the log file was modified before the process started, PID was reused
// Add a small buffer (1 second) to account for clock skew and file system timing
return fileModTime.Add(1 * time.Second).Before(procStartTime)
}
func parsePIDFromLog(path string) (int, bool) {
name := filepath.Base(path)
prefixes := logPrefixes()
if len(prefixes) == 0 {
prefixes = []string{defaultWrapperName}
}
for _, prefix := range prefixes {
prefixWithDash := fmt.Sprintf("%s-", prefix)
if !strings.HasPrefix(name, prefixWithDash) || !strings.HasSuffix(name, ".log") {
continue
}
core := strings.TrimSuffix(strings.TrimPrefix(name, prefixWithDash), ".log")
if core == "" {
continue
}
pidPart := core
if idx := strings.IndexRune(core, '-'); idx != -1 {
pidPart = core[:idx]
}
if pidPart == "" {
continue
}
pid, err := strconv.Atoi(pidPart)
if err != nil || pid <= 0 {
continue
}
return pid, true
}
return 0, false
}
func logConcurrencyPlanning(limit, total int) {
logger := activeLogger()
if logger == nil {
return
}
logger.Info(fmt.Sprintf("parallel: worker_limit=%s total_tasks=%d", renderWorkerLimit(limit), total))
}
func logConcurrencyState(event, taskID string, active, limit int) {
logger := activeLogger()
if logger == nil {
return
}
logger.Debug(fmt.Sprintf("parallel: %s task=%s active=%d limit=%s", event, taskID, active, renderWorkerLimit(limit)))
}
func renderWorkerLimit(limit int) string {
if limit <= 0 {
return "unbounded"
}
return strconv.Itoa(limit)
}