From d51a2f12f8909c25e6335c4459d18a56b37057b5 Mon Sep 17 00:00:00 2001 From: cexll Date: Tue, 2 Dec 2025 15:49:36 +0800 Subject: [PATCH] optimize codex-wrapper --- codex-wrapper/bench_test.go | 39 +++ codex-wrapper/concurrent_stress_test.go | 321 ++++++++++++++++++++++++ codex-wrapper/logger.go | 137 +++++++++- codex-wrapper/main.go | 90 +++++-- codex-wrapper/main_test.go | 7 +- 5 files changed, 557 insertions(+), 37 deletions(-) create mode 100644 codex-wrapper/bench_test.go create mode 100644 codex-wrapper/concurrent_stress_test.go diff --git a/codex-wrapper/bench_test.go b/codex-wrapper/bench_test.go new file mode 100644 index 0000000..2a99861 --- /dev/null +++ b/codex-wrapper/bench_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "testing" +) + +// BenchmarkLoggerWrite 测试日志写入性能 +func BenchmarkLoggerWrite(b *testing.B) { + logger, err := NewLogger() + if err != nil { + b.Fatal(err) + } + defer logger.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Info("benchmark log message") + } + b.StopTimer() + logger.Flush() +} + +// BenchmarkLoggerConcurrentWrite 测试并发日志写入性能 +func BenchmarkLoggerConcurrentWrite(b *testing.B) { + logger, err := NewLogger() + if err != nil { + b.Fatal(err) + } + defer logger.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + logger.Info("concurrent benchmark log message") + } + }) + b.StopTimer() + logger.Flush() +} diff --git a/codex-wrapper/concurrent_stress_test.go b/codex-wrapper/concurrent_stress_test.go new file mode 100644 index 0000000..ac31137 --- /dev/null +++ b/codex-wrapper/concurrent_stress_test.go @@ -0,0 +1,321 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "regexp" + "strings" + "sync" + "testing" + "time" +) + +// TestConcurrentStressLogger 高并发压力测试 +func TestConcurrentStressLogger(t *testing.T) { + if testing.Short() { + t.Skip("skipping stress test in short mode") + } + + logger, err := NewLoggerWithSuffix("stress") + if err != nil { + t.Fatal(err) + } + defer logger.Close() + + t.Logf("Log file: %s", logger.Path()) + + const ( + numGoroutines = 100 // 并发协程数 + logsPerRoutine = 1000 // 每个协程写入日志数 + totalExpected = numGoroutines * logsPerRoutine + ) + + var wg sync.WaitGroup + start := time.Now() + + // 启动并发写入 + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < logsPerRoutine; j++ { + logger.Info(fmt.Sprintf("goroutine-%d-msg-%d", id, j)) + } + }(i) + } + + wg.Wait() + logger.Flush() + elapsed := time.Since(start) + + // 读取日志文件验证 + data, err := os.ReadFile(logger.Path()) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + actualCount := len(lines) + + t.Logf("Concurrent stress test results:") + t.Logf(" Goroutines: %d", numGoroutines) + t.Logf(" Logs per goroutine: %d", logsPerRoutine) + t.Logf(" Total expected: %d", totalExpected) + t.Logf(" Total actual: %d", actualCount) + t.Logf(" Duration: %v", elapsed) + t.Logf(" Throughput: %.2f logs/sec", float64(totalExpected)/elapsed.Seconds()) + + // 验证日志数量 + if actualCount < totalExpected/10 { + t.Errorf("too many logs lost: got %d, want at least %d (10%% of %d)", + actualCount, totalExpected/10, totalExpected) + } + t.Logf("Successfully wrote %d/%d logs (%.1f%%)", + actualCount, totalExpected, float64(actualCount)/float64(totalExpected)*100) + + // 验证日志格式 + formatRE := regexp.MustCompile(`^\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}\] \[PID:\d+\] INFO: goroutine-`) + for i, line := range lines[:min(10, len(lines))] { + if !formatRE.MatchString(line) { + t.Errorf("line %d has invalid format: %s", i, line) + } + } +} + +// TestConcurrentBurstLogger 突发流量测试 +func TestConcurrentBurstLogger(t *testing.T) { + if testing.Short() { + t.Skip("skipping burst test in short mode") + } + + logger, err := NewLoggerWithSuffix("burst") + if err != nil { + t.Fatal(err) + } + defer logger.Close() + + t.Logf("Log file: %s", logger.Path()) + + const ( + numBursts = 10 + goroutinesPerBurst = 50 + logsPerGoroutine = 100 + ) + + totalLogs := 0 + start := time.Now() + + // 模拟突发流量 + for burst := 0; burst < numBursts; burst++ { + var wg sync.WaitGroup + for i := 0; i < goroutinesPerBurst; i++ { + wg.Add(1) + totalLogs += logsPerGoroutine + go func(b, g int) { + defer wg.Done() + for j := 0; j < logsPerGoroutine; j++ { + logger.Info(fmt.Sprintf("burst-%d-goroutine-%d-msg-%d", b, g, j)) + } + }(burst, i) + } + wg.Wait() + time.Sleep(10 * time.Millisecond) // 突发间隔 + } + + logger.Flush() + elapsed := time.Since(start) + + // 验证 + data, err := os.ReadFile(logger.Path()) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + actualCount := len(lines) + + t.Logf("Burst test results:") + t.Logf(" Total bursts: %d", numBursts) + t.Logf(" Goroutines per burst: %d", goroutinesPerBurst) + t.Logf(" Expected logs: %d", totalLogs) + t.Logf(" Actual logs: %d", actualCount) + t.Logf(" Duration: %v", elapsed) + t.Logf(" Throughput: %.2f logs/sec", float64(totalLogs)/elapsed.Seconds()) + + if actualCount < totalLogs/10 { + t.Errorf("too many logs lost: got %d, want at least %d (10%% of %d)", actualCount, totalLogs/10, totalLogs) + } + t.Logf("Successfully wrote %d/%d logs (%.1f%%)", + actualCount, totalLogs, float64(actualCount)/float64(totalLogs)*100) +} + +// TestLoggerChannelCapacity 测试 channel 容量极限 +func TestLoggerChannelCapacity(t *testing.T) { + logger, err := NewLoggerWithSuffix("capacity") + if err != nil { + t.Fatal(err) + } + defer logger.Close() + + const rapidLogs = 2000 // 超过 channel 容量 (1000) + + start := time.Now() + for i := 0; i < rapidLogs; i++ { + logger.Info(fmt.Sprintf("rapid-log-%d", i)) + } + sendDuration := time.Since(start) + + logger.Flush() + flushDuration := time.Since(start) - sendDuration + + t.Logf("Channel capacity test:") + t.Logf(" Logs sent: %d", rapidLogs) + t.Logf(" Send duration: %v", sendDuration) + t.Logf(" Flush duration: %v", flushDuration) + + // 验证仍有合理比例的日志写入(非阻塞模式允许部分丢失) + data, err := os.ReadFile(logger.Path()) + if err != nil { + t.Fatal(err) + } + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + actualCount := len(lines) + + if actualCount < rapidLogs/10 { + t.Errorf("too many logs lost: got %d, want at least %d (10%% of %d)", actualCount, rapidLogs/10, rapidLogs) + } + t.Logf("Logs persisted: %d/%d (%.1f%%)", actualCount, rapidLogs, float64(actualCount)/float64(rapidLogs)*100) +} + +// TestLoggerMemoryUsage 内存使用测试 +func TestLoggerMemoryUsage(t *testing.T) { + logger, err := NewLoggerWithSuffix("memory") + if err != nil { + t.Fatal(err) + } + defer logger.Close() + + const numLogs = 20000 + longMessage := strings.Repeat("x", 500) // 500 字节长消息 + + start := time.Now() + for i := 0; i < numLogs; i++ { + logger.Info(fmt.Sprintf("log-%d-%s", i, longMessage)) + } + logger.Flush() + elapsed := time.Since(start) + + // 检查文件大小 + info, err := os.Stat(logger.Path()) + if err != nil { + t.Fatal(err) + } + + expectedTotalSize := int64(numLogs * 500) // 理论最小总字节数 + expectedMinSize := expectedTotalSize / 10 // 接受最多 90% 丢失 + actualSize := info.Size() + + t.Logf("Memory/disk usage test:") + t.Logf(" Logs written: %d", numLogs) + t.Logf(" Message size: 500 bytes") + t.Logf(" File size: %.2f MB", float64(actualSize)/1024/1024) + t.Logf(" Duration: %v", elapsed) + t.Logf(" Write speed: %.2f MB/s", float64(actualSize)/1024/1024/elapsed.Seconds()) + t.Logf(" Persistence ratio: %.1f%%", float64(actualSize)/float64(expectedTotalSize)*100) + + if actualSize < expectedMinSize { + t.Errorf("file size too small: got %d bytes, expected at least %d", actualSize, expectedMinSize) + } +} + +// TestLoggerFlushTimeout 测试 Flush 超时机制 +func TestLoggerFlushTimeout(t *testing.T) { + logger, err := NewLoggerWithSuffix("flush") + if err != nil { + t.Fatal(err) + } + defer logger.Close() + + // 写入一些日志 + for i := 0; i < 100; i++ { + logger.Info(fmt.Sprintf("test-log-%d", i)) + } + + // 测试 Flush 应该在合理时间内完成 + start := time.Now() + logger.Flush() + duration := time.Since(start) + + t.Logf("Flush duration: %v", duration) + + if duration > 6*time.Second { + t.Errorf("Flush took too long: %v (expected < 6s)", duration) + } +} + +// TestLoggerOrderPreservation 测试日志顺序保持 +func TestLoggerOrderPreservation(t *testing.T) { + logger, err := NewLoggerWithSuffix("order") + if err != nil { + t.Fatal(err) + } + defer logger.Close() + + const numGoroutines = 10 + const logsPerRoutine = 100 + + var wg sync.WaitGroup + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < logsPerRoutine; j++ { + logger.Info(fmt.Sprintf("G%d-SEQ%04d", id, j)) + } + }(i) + } + + wg.Wait() + logger.Flush() + + // 读取并验证每个 goroutine 的日志顺序 + data, err := os.ReadFile(logger.Path()) + if err != nil { + t.Fatal(err) + } + + scanner := bufio.NewScanner(strings.NewReader(string(data))) + sequences := make(map[int][]int) // goroutine ID -> sequence numbers + + for scanner.Scan() { + line := scanner.Text() + var gid, seq int + parts := strings.SplitN(line, " INFO: ", 2) + if len(parts) != 2 { + t.Errorf("invalid log format: %s", line) + continue + } + if _, err := fmt.Sscanf(parts[1], "G%d-SEQ%d", &gid, &seq); err == nil { + sequences[gid] = append(sequences[gid], seq) + } else { + t.Errorf("failed to parse sequence from line: %s", line) + } + } + + // 验证每个 goroutine 内部顺序 + for gid, seqs := range sequences { + for i := 0; i < len(seqs)-1; i++ { + if seqs[i] >= seqs[i+1] { + t.Errorf("Goroutine %d: out of order at index %d: %d >= %d", + gid, i, seqs[i], seqs[i+1]) + } + } + if len(seqs) != logsPerRoutine { + t.Errorf("Goroutine %d: missing logs, got %d, want %d", + gid, len(seqs), logsPerRoutine) + } + } + + t.Logf("Order preservation test: all %d goroutines maintained sequence order", len(sequences)) +} diff --git a/codex-wrapper/logger.go b/codex-wrapper/logger.go index caad4ee..e54385d 100644 --- a/codex-wrapper/logger.go +++ b/codex-wrapper/logger.go @@ -1,11 +1,14 @@ package main import ( + "bufio" + "context" "fmt" "os" "path/filepath" "sync" "sync/atomic" + "time" ) // Logger writes log messages asynchronously to a temp file. @@ -14,12 +17,15 @@ import ( type Logger struct { path string file *os.File + writer *bufio.Writer ch chan logEntry + flushReq chan struct{} done chan struct{} closed atomic.Bool closeOnce sync.Once workerWG sync.WaitGroup pendingWG sync.WaitGroup + flushMu sync.Mutex } type logEntry struct { @@ -30,7 +36,19 @@ type logEntry struct { // NewLogger creates the async logger and starts the worker goroutine. // The log file is created under os.TempDir() using the required naming scheme. func NewLogger() (*Logger, error) { - path := filepath.Join(os.TempDir(), fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + return NewLoggerWithSuffix("") +} + +// NewLoggerWithSuffix creates a logger with an optional suffix in the filename. +// Useful for tests that need isolated log files within the same process. +func NewLoggerWithSuffix(suffix string) (*Logger, error) { + filename := fmt.Sprintf("codex-wrapper-%d", os.Getpid()) + if suffix != "" { + filename += "-" + suffix + } + filename += ".log" + + path := filepath.Join(os.TempDir(), filename) f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { @@ -38,10 +56,12 @@ func NewLogger() (*Logger, error) { } l := &Logger{ - path: path, - file: f, - ch: make(chan logEntry, 100), - done: make(chan struct{}), + path: path, + file: f, + writer: bufio.NewWriterSize(f, 4096), + ch: make(chan logEntry, 1000), + flushReq: make(chan struct{}, 1), + done: make(chan struct{}), } l.workerWG.Add(1) @@ -73,6 +93,7 @@ func (l *Logger) Error(msg string) { l.log("ERROR", msg) } // Close stops the worker and syncs the log file. // The log file is NOT removed, allowing inspection after program exit. // It is safe to call multiple times. +// Returns after a 5-second timeout if worker doesn't stop gracefully. func (l *Logger) Close() error { if l == nil { return nil @@ -85,9 +106,26 @@ func (l *Logger) Close() error { close(l.done) close(l.ch) - l.workerWG.Wait() + // Wait for worker with timeout + workerDone := make(chan struct{}) + go func() { + l.workerWG.Wait() + close(workerDone) + }() - if err := l.file.Sync(); err != nil { + select { + case <-workerDone: + // Worker stopped gracefully + case <-time.After(5 * time.Second): + // Worker timeout - proceed with cleanup anyway + closeErr = fmt.Errorf("logger worker timeout during close") + } + + if err := l.writer.Flush(); err != nil && closeErr == nil { + closeErr = err + } + + if err := l.file.Sync(); err != nil && closeErr == nil { closeErr = err } @@ -102,12 +140,61 @@ func (l *Logger) Close() error { return closeErr } +// RemoveLogFile removes the log file. Should only be called after Close(). +func (l *Logger) RemoveLogFile() error { + if l == nil { + return nil + } + return os.Remove(l.path) +} + // Flush waits for all pending log entries to be written. Primarily for tests. +// Returns after a 5-second timeout to prevent indefinite blocking. func (l *Logger) Flush() { if l == nil { return } - l.pendingWG.Wait() + + // Wait for pending entries with timeout + done := make(chan struct{}) + go func() { + l.pendingWG.Wait() + close(done) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + select { + case <-done: + // All pending entries processed + case <-ctx.Done(): + // Timeout - return without full flush + return + } + + // Trigger writer flush + select { + case l.flushReq <- struct{}{}: + // Wait for flush to complete (with mutex) + flushDone := make(chan struct{}) + go func() { + l.flushMu.Lock() + l.flushMu.Unlock() + close(flushDone) + }() + + select { + case <-flushDone: + // Flush completed + case <-time.After(1 * time.Second): + // Flush timeout + } + case <-l.done: + // Logger is closing + case <-time.After(1 * time.Second): + // Timeout sending flush request + } } func (l *Logger) log(level, msg string) { @@ -122,18 +209,44 @@ func (l *Logger) log(level, msg string) { l.pendingWG.Add(1) select { + case l.ch <- entry: case <-l.done: l.pendingWG.Done() return - case l.ch <- entry: + default: + // Channel is full; drop the entry to avoid blocking callers. + l.pendingWG.Done() + return } } func (l *Logger) run() { defer l.workerWG.Done() - for entry := range l.ch { - fmt.Fprintf(l.file, "%s: %s\n", entry.level, entry.msg) - l.pendingWG.Done() + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case entry, ok := <-l.ch: + if !ok { + // Channel closed, final flush + l.writer.Flush() + return + } + timestamp := time.Now().Format("2006-01-02 15:04:05.000") + pid := os.Getpid() + fmt.Fprintf(l.writer, "[%s] [PID:%d] %s: %s\n", timestamp, pid, entry.level, entry.msg) + l.pendingWG.Done() + + case <-ticker.C: + l.writer.Flush() + + case <-l.flushReq: + // Explicit flush request + l.flushMu.Lock() + l.writer.Flush() + l.flushMu.Unlock() + } } } diff --git a/codex-wrapper/main.go b/codex-wrapper/main.go index 7adca2b..476cedf 100644 --- a/codex-wrapper/main.go +++ b/codex-wrapper/main.go @@ -21,7 +21,7 @@ import ( ) const ( - version = "1.0.0" + version = "4.8.2" defaultWorkdir = "." defaultTimeout = 7200 // seconds forceKillDelay = 5 // seconds @@ -359,7 +359,7 @@ func main() { } // run is the main logic, returns exit code for testability -func run() int { +func run() (exitCode int) { logger, err := NewLogger() if err != nil { fmt.Fprintf(os.Stderr, "ERROR: failed to initialize logger: %v\n", err) @@ -368,12 +368,20 @@ func run() int { setLogger(logger) defer func() { - if logger := activeLogger(); logger != nil { + logger := activeLogger() + if logger != nil { logger.Flush() } if err := closeLogger(); err != nil { fmt.Fprintf(os.Stderr, "ERROR: failed to close logger: %v\n", err) } + if exitCode == 0 && logger != nil { + if err := logger.RemoveLogFile(); err != nil && !os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "ERROR: failed to remove logger file: %v\n", err) + } + } else if exitCode != 0 && logger != nil { + fmt.Fprintf(os.Stderr, "Log file retained at: %s\n", logger.Path()) + } }() defer runCleanupHook() @@ -417,7 +425,7 @@ func run() int { results := executeConcurrent(layers, timeoutSec) fmt.Println(generateFinalOutput(results)) - exitCode := 0 + exitCode = 0 for _, res := range results { if res.ExitCode != 0 { exitCode = res.ExitCode @@ -653,9 +661,39 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo codexArgs = buildCodexArgsFn(cfg, targetArg) } - logInfoFn := logInfo - logWarnFn := logWarn - logErrorFn := logError + prefixMsg := func(msg string) string { + if taskSpec.ID == "" { + return msg + } + return fmt.Sprintf("[Task: %s] %s", taskSpec.ID, msg) + } + + var logInfoFn func(string) + var logWarnFn func(string) + var logErrorFn func(string) + + if silent { + // Silent mode: only persist to file when available; avoid stderr noise. + logInfoFn = func(msg string) { + if logger := activeLogger(); logger != nil { + logger.Info(prefixMsg(msg)) + } + } + logWarnFn = func(msg string) { + if logger := activeLogger(); logger != nil { + logger.Warn(prefixMsg(msg)) + } + } + logErrorFn = func(msg string) { + if logger := activeLogger(); logger != nil { + logger.Error(prefixMsg(msg)) + } + } + } else { + logInfoFn = func(msg string) { logInfo(prefixMsg(msg)) } + logWarnFn = func(msg string) { logWarn(prefixMsg(msg)) } + logErrorFn = func(msg string) { logError(prefixMsg(msg)) } + } stderrBuf := &tailBuffer{limit: stderrCaptureLimit} @@ -749,7 +787,10 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo return result } - logInfoFn(fmt.Sprintf("Process started with PID: %d", cmd.Process.Pid)) + logInfoFn(fmt.Sprintf("Starting codex with PID: %d", cmd.Process.Pid)) + if logger := activeLogger(); logger != nil { + logInfoFn(fmt.Sprintf("Log capturing to: %s", logger.Path())) + } if useStdin && stdinPipe != nil { logInfoFn(fmt.Sprintf("Writing %d chars to stdin...", len(taskSpec.Task))) @@ -765,7 +806,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo parseCh := make(chan parseResult, 1) go func() { - msg, tid := parseJSONStreamWithWarn(stdoutReader, logWarnFn) + msg, tid := parseJSONStreamWithLog(stdoutReader, logWarnFn, logInfoFn) parseCh <- parseResult{message: msg, threadID: tid} }() @@ -913,16 +954,23 @@ func terminateProcess(cmd *exec.Cmd) *time.Timer { } func parseJSONStream(r io.Reader) (message, threadID string) { - return parseJSONStreamWithWarn(r, logWarn) + return parseJSONStreamWithLog(r, logWarn, logInfo) } func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadID string) { + return parseJSONStreamWithLog(r, warnFn, logInfo) +} + +func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) { scanner := bufio.NewScanner(r) scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) if warnFn == nil { warnFn = func(string) {} } + if infoFn == nil { + infoFn = func(string) {} + } totalEvents := 0 @@ -947,15 +995,15 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI details = append(details, fmt.Sprintf("item_type=%s", event.Item.Type)) } if len(details) > 0 { - logInfo(fmt.Sprintf("Parsed event #%d type=%s (%s)", totalEvents, event.Type, strings.Join(details, ", "))) + infoFn(fmt.Sprintf("Parsed event #%d type=%s (%s)", totalEvents, event.Type, strings.Join(details, ", "))) } else { - logInfo(fmt.Sprintf("Parsed event #%d type=%s", totalEvents, event.Type)) + infoFn(fmt.Sprintf("Parsed event #%d type=%s", totalEvents, event.Type)) } switch event.Type { case "thread.started": threadID = event.ThreadID - logInfo(fmt.Sprintf("thread.started event thread_id=%s", threadID)) + infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID)) case "item.completed": var itemType string var normalized string @@ -963,7 +1011,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI itemType = event.Item.Type normalized = normalizeText(event.Item.Text) } - logInfo(fmt.Sprintf("item.completed event item_type=%s message_len=%d", itemType, len(normalized))) + infoFn(fmt.Sprintf("item.completed event item_type=%s message_len=%d", itemType, len(normalized))) if event.Item != nil && event.Item.Type == "agent_message" && normalized != "" { message = normalized } @@ -974,7 +1022,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI warnFn("Read stdout error: " + err.Error()) } - logInfo(fmt.Sprintf("parseJSONStream completed: events=%d, message_len=%d, thread_id_found=%t", totalEvents, len(message), threadID != "")) + infoFn(fmt.Sprintf("parseJSONStream completed: events=%d, message_len=%d, thread_id_found=%t", totalEvents, len(message), threadID != "")) return message, threadID } @@ -1162,27 +1210,27 @@ func farewell(name string) string { } func logInfo(msg string) { + fmt.Fprintf(os.Stderr, "INFO: %s\n", msg) + if logger := activeLogger(); logger != nil { logger.Info(msg) - return } - fmt.Fprintf(os.Stderr, "INFO: %s\n", msg) } func logWarn(msg string) { + fmt.Fprintf(os.Stderr, "WARN: %s\n", msg) + if logger := activeLogger(); logger != nil { logger.Warn(msg) - return } - fmt.Fprintf(os.Stderr, "WARN: %s\n", msg) } func logError(msg string) { + fmt.Fprintf(os.Stderr, "ERROR: %s\n", msg) + if logger := activeLogger(); logger != nil { logger.Error(msg) - return } - fmt.Fprintf(os.Stderr, "ERROR: %s\n", msg) } func runCleanupHook() { diff --git a/codex-wrapper/main_test.go b/codex-wrapper/main_test.go index 1f742b4..e948d07 100644 --- a/codex-wrapper/main_test.go +++ b/codex-wrapper/main_test.go @@ -1217,7 +1217,7 @@ func TestRun_PipedTaskReadError(t *testing.T) { if exitCode != 1 { t.Fatalf("exit=%d, want 1", exitCode) } - if !strings.Contains(logOutput, "Failed to read piped stdin: read stdin: pipe failure") { + if !strings.Contains(logOutput, "ERROR: Failed to read piped stdin: read stdin: pipe failure") { t.Fatalf("log missing piped read error, got %q", logOutput) } if _, err := os.Stat(logPath); os.IsNotExist(err) { @@ -1275,10 +1275,9 @@ func TestRun_LoggerLifecycle(t *testing.T) { if !fileExisted { t.Fatalf("log file was not present during run") } - if _, err := os.Stat(logPath); os.IsNotExist(err) { - t.Fatalf("log file should exist after run") + if _, err := os.Stat(logPath); !os.IsNotExist(err) { + t.Fatalf("log file should be removed on success, but it exists") } - defer os.Remove(logPath) } func TestRun_LoggerRemovedOnSignal(t *testing.T) {