optimize codex-wrapper

This commit is contained in:
cexll
2025-12-02 15:49:36 +08:00
parent 8a8771076d
commit d51a2f12f8
5 changed files with 557 additions and 37 deletions

View File

@@ -0,0 +1,39 @@
package main
import (
"testing"
)
// BenchmarkLoggerWrite 测试日志写入性能
func BenchmarkLoggerWrite(b *testing.B) {
logger, err := NewLogger()
if err != nil {
b.Fatal(err)
}
defer logger.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Info("benchmark log message")
}
b.StopTimer()
logger.Flush()
}
// BenchmarkLoggerConcurrentWrite 测试并发日志写入性能
func BenchmarkLoggerConcurrentWrite(b *testing.B) {
logger, err := NewLogger()
if err != nil {
b.Fatal(err)
}
defer logger.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Info("concurrent benchmark log message")
}
})
b.StopTimer()
logger.Flush()
}

View File

@@ -0,0 +1,321 @@
package main
import (
"bufio"
"fmt"
"os"
"regexp"
"strings"
"sync"
"testing"
"time"
)
// TestConcurrentStressLogger 高并发压力测试
func TestConcurrentStressLogger(t *testing.T) {
if testing.Short() {
t.Skip("skipping stress test in short mode")
}
logger, err := NewLoggerWithSuffix("stress")
if err != nil {
t.Fatal(err)
}
defer logger.Close()
t.Logf("Log file: %s", logger.Path())
const (
numGoroutines = 100 // 并发协程数
logsPerRoutine = 1000 // 每个协程写入日志数
totalExpected = numGoroutines * logsPerRoutine
)
var wg sync.WaitGroup
start := time.Now()
// 启动并发写入
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < logsPerRoutine; j++ {
logger.Info(fmt.Sprintf("goroutine-%d-msg-%d", id, j))
}
}(i)
}
wg.Wait()
logger.Flush()
elapsed := time.Since(start)
// 读取日志文件验证
data, err := os.ReadFile(logger.Path())
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
actualCount := len(lines)
t.Logf("Concurrent stress test results:")
t.Logf(" Goroutines: %d", numGoroutines)
t.Logf(" Logs per goroutine: %d", logsPerRoutine)
t.Logf(" Total expected: %d", totalExpected)
t.Logf(" Total actual: %d", actualCount)
t.Logf(" Duration: %v", elapsed)
t.Logf(" Throughput: %.2f logs/sec", float64(totalExpected)/elapsed.Seconds())
// 验证日志数量
if actualCount < totalExpected/10 {
t.Errorf("too many logs lost: got %d, want at least %d (10%% of %d)",
actualCount, totalExpected/10, totalExpected)
}
t.Logf("Successfully wrote %d/%d logs (%.1f%%)",
actualCount, totalExpected, float64(actualCount)/float64(totalExpected)*100)
// 验证日志格式
formatRE := regexp.MustCompile(`^\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3}\] \[PID:\d+\] INFO: goroutine-`)
for i, line := range lines[:min(10, len(lines))] {
if !formatRE.MatchString(line) {
t.Errorf("line %d has invalid format: %s", i, line)
}
}
}
// TestConcurrentBurstLogger 突发流量测试
func TestConcurrentBurstLogger(t *testing.T) {
if testing.Short() {
t.Skip("skipping burst test in short mode")
}
logger, err := NewLoggerWithSuffix("burst")
if err != nil {
t.Fatal(err)
}
defer logger.Close()
t.Logf("Log file: %s", logger.Path())
const (
numBursts = 10
goroutinesPerBurst = 50
logsPerGoroutine = 100
)
totalLogs := 0
start := time.Now()
// 模拟突发流量
for burst := 0; burst < numBursts; burst++ {
var wg sync.WaitGroup
for i := 0; i < goroutinesPerBurst; i++ {
wg.Add(1)
totalLogs += logsPerGoroutine
go func(b, g int) {
defer wg.Done()
for j := 0; j < logsPerGoroutine; j++ {
logger.Info(fmt.Sprintf("burst-%d-goroutine-%d-msg-%d", b, g, j))
}
}(burst, i)
}
wg.Wait()
time.Sleep(10 * time.Millisecond) // 突发间隔
}
logger.Flush()
elapsed := time.Since(start)
// 验证
data, err := os.ReadFile(logger.Path())
if err != nil {
t.Fatalf("failed to read log file: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
actualCount := len(lines)
t.Logf("Burst test results:")
t.Logf(" Total bursts: %d", numBursts)
t.Logf(" Goroutines per burst: %d", goroutinesPerBurst)
t.Logf(" Expected logs: %d", totalLogs)
t.Logf(" Actual logs: %d", actualCount)
t.Logf(" Duration: %v", elapsed)
t.Logf(" Throughput: %.2f logs/sec", float64(totalLogs)/elapsed.Seconds())
if actualCount < totalLogs/10 {
t.Errorf("too many logs lost: got %d, want at least %d (10%% of %d)", actualCount, totalLogs/10, totalLogs)
}
t.Logf("Successfully wrote %d/%d logs (%.1f%%)",
actualCount, totalLogs, float64(actualCount)/float64(totalLogs)*100)
}
// TestLoggerChannelCapacity 测试 channel 容量极限
func TestLoggerChannelCapacity(t *testing.T) {
logger, err := NewLoggerWithSuffix("capacity")
if err != nil {
t.Fatal(err)
}
defer logger.Close()
const rapidLogs = 2000 // 超过 channel 容量 (1000)
start := time.Now()
for i := 0; i < rapidLogs; i++ {
logger.Info(fmt.Sprintf("rapid-log-%d", i))
}
sendDuration := time.Since(start)
logger.Flush()
flushDuration := time.Since(start) - sendDuration
t.Logf("Channel capacity test:")
t.Logf(" Logs sent: %d", rapidLogs)
t.Logf(" Send duration: %v", sendDuration)
t.Logf(" Flush duration: %v", flushDuration)
// 验证仍有合理比例的日志写入(非阻塞模式允许部分丢失)
data, err := os.ReadFile(logger.Path())
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
actualCount := len(lines)
if actualCount < rapidLogs/10 {
t.Errorf("too many logs lost: got %d, want at least %d (10%% of %d)", actualCount, rapidLogs/10, rapidLogs)
}
t.Logf("Logs persisted: %d/%d (%.1f%%)", actualCount, rapidLogs, float64(actualCount)/float64(rapidLogs)*100)
}
// TestLoggerMemoryUsage 内存使用测试
func TestLoggerMemoryUsage(t *testing.T) {
logger, err := NewLoggerWithSuffix("memory")
if err != nil {
t.Fatal(err)
}
defer logger.Close()
const numLogs = 20000
longMessage := strings.Repeat("x", 500) // 500 字节长消息
start := time.Now()
for i := 0; i < numLogs; i++ {
logger.Info(fmt.Sprintf("log-%d-%s", i, longMessage))
}
logger.Flush()
elapsed := time.Since(start)
// 检查文件大小
info, err := os.Stat(logger.Path())
if err != nil {
t.Fatal(err)
}
expectedTotalSize := int64(numLogs * 500) // 理论最小总字节数
expectedMinSize := expectedTotalSize / 10 // 接受最多 90% 丢失
actualSize := info.Size()
t.Logf("Memory/disk usage test:")
t.Logf(" Logs written: %d", numLogs)
t.Logf(" Message size: 500 bytes")
t.Logf(" File size: %.2f MB", float64(actualSize)/1024/1024)
t.Logf(" Duration: %v", elapsed)
t.Logf(" Write speed: %.2f MB/s", float64(actualSize)/1024/1024/elapsed.Seconds())
t.Logf(" Persistence ratio: %.1f%%", float64(actualSize)/float64(expectedTotalSize)*100)
if actualSize < expectedMinSize {
t.Errorf("file size too small: got %d bytes, expected at least %d", actualSize, expectedMinSize)
}
}
// TestLoggerFlushTimeout 测试 Flush 超时机制
func TestLoggerFlushTimeout(t *testing.T) {
logger, err := NewLoggerWithSuffix("flush")
if err != nil {
t.Fatal(err)
}
defer logger.Close()
// 写入一些日志
for i := 0; i < 100; i++ {
logger.Info(fmt.Sprintf("test-log-%d", i))
}
// 测试 Flush 应该在合理时间内完成
start := time.Now()
logger.Flush()
duration := time.Since(start)
t.Logf("Flush duration: %v", duration)
if duration > 6*time.Second {
t.Errorf("Flush took too long: %v (expected < 6s)", duration)
}
}
// TestLoggerOrderPreservation 测试日志顺序保持
func TestLoggerOrderPreservation(t *testing.T) {
logger, err := NewLoggerWithSuffix("order")
if err != nil {
t.Fatal(err)
}
defer logger.Close()
const numGoroutines = 10
const logsPerRoutine = 100
var wg sync.WaitGroup
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < logsPerRoutine; j++ {
logger.Info(fmt.Sprintf("G%d-SEQ%04d", id, j))
}
}(i)
}
wg.Wait()
logger.Flush()
// 读取并验证每个 goroutine 的日志顺序
data, err := os.ReadFile(logger.Path())
if err != nil {
t.Fatal(err)
}
scanner := bufio.NewScanner(strings.NewReader(string(data)))
sequences := make(map[int][]int) // goroutine ID -> sequence numbers
for scanner.Scan() {
line := scanner.Text()
var gid, seq int
parts := strings.SplitN(line, " INFO: ", 2)
if len(parts) != 2 {
t.Errorf("invalid log format: %s", line)
continue
}
if _, err := fmt.Sscanf(parts[1], "G%d-SEQ%d", &gid, &seq); err == nil {
sequences[gid] = append(sequences[gid], seq)
} else {
t.Errorf("failed to parse sequence from line: %s", line)
}
}
// 验证每个 goroutine 内部顺序
for gid, seqs := range sequences {
for i := 0; i < len(seqs)-1; i++ {
if seqs[i] >= seqs[i+1] {
t.Errorf("Goroutine %d: out of order at index %d: %d >= %d",
gid, i, seqs[i], seqs[i+1])
}
}
if len(seqs) != logsPerRoutine {
t.Errorf("Goroutine %d: missing logs, got %d, want %d",
gid, len(seqs), logsPerRoutine)
}
}
t.Logf("Order preservation test: all %d goroutines maintained sequence order", len(sequences))
}

View File

@@ -1,11 +1,14 @@
package main package main
import ( import (
"bufio"
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
) )
// Logger writes log messages asynchronously to a temp file. // Logger writes log messages asynchronously to a temp file.
@@ -14,12 +17,15 @@ import (
type Logger struct { type Logger struct {
path string path string
file *os.File file *os.File
writer *bufio.Writer
ch chan logEntry ch chan logEntry
flushReq chan struct{}
done chan struct{} done chan struct{}
closed atomic.Bool closed atomic.Bool
closeOnce sync.Once closeOnce sync.Once
workerWG sync.WaitGroup workerWG sync.WaitGroup
pendingWG sync.WaitGroup pendingWG sync.WaitGroup
flushMu sync.Mutex
} }
type logEntry struct { type logEntry struct {
@@ -30,7 +36,19 @@ type logEntry struct {
// NewLogger creates the async logger and starts the worker goroutine. // NewLogger creates the async logger and starts the worker goroutine.
// The log file is created under os.TempDir() using the required naming scheme. // The log file is created under os.TempDir() using the required naming scheme.
func NewLogger() (*Logger, error) { func NewLogger() (*Logger, error) {
path := filepath.Join(os.TempDir(), fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) return NewLoggerWithSuffix("")
}
// NewLoggerWithSuffix creates a logger with an optional suffix in the filename.
// Useful for tests that need isolated log files within the same process.
func NewLoggerWithSuffix(suffix string) (*Logger, error) {
filename := fmt.Sprintf("codex-wrapper-%d", os.Getpid())
if suffix != "" {
filename += "-" + suffix
}
filename += ".log"
path := filepath.Join(os.TempDir(), filename)
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil { if err != nil {
@@ -38,10 +56,12 @@ func NewLogger() (*Logger, error) {
} }
l := &Logger{ l := &Logger{
path: path, path: path,
file: f, file: f,
ch: make(chan logEntry, 100), writer: bufio.NewWriterSize(f, 4096),
done: make(chan struct{}), ch: make(chan logEntry, 1000),
flushReq: make(chan struct{}, 1),
done: make(chan struct{}),
} }
l.workerWG.Add(1) l.workerWG.Add(1)
@@ -73,6 +93,7 @@ func (l *Logger) Error(msg string) { l.log("ERROR", msg) }
// Close stops the worker and syncs the log file. // Close stops the worker and syncs the log file.
// The log file is NOT removed, allowing inspection after program exit. // The log file is NOT removed, allowing inspection after program exit.
// It is safe to call multiple times. // It is safe to call multiple times.
// Returns after a 5-second timeout if worker doesn't stop gracefully.
func (l *Logger) Close() error { func (l *Logger) Close() error {
if l == nil { if l == nil {
return nil return nil
@@ -85,9 +106,26 @@ func (l *Logger) Close() error {
close(l.done) close(l.done)
close(l.ch) close(l.ch)
l.workerWG.Wait() // Wait for worker with timeout
workerDone := make(chan struct{})
go func() {
l.workerWG.Wait()
close(workerDone)
}()
if err := l.file.Sync(); err != nil { select {
case <-workerDone:
// Worker stopped gracefully
case <-time.After(5 * time.Second):
// Worker timeout - proceed with cleanup anyway
closeErr = fmt.Errorf("logger worker timeout during close")
}
if err := l.writer.Flush(); err != nil && closeErr == nil {
closeErr = err
}
if err := l.file.Sync(); err != nil && closeErr == nil {
closeErr = err closeErr = err
} }
@@ -102,12 +140,61 @@ func (l *Logger) Close() error {
return closeErr return closeErr
} }
// RemoveLogFile removes the log file. Should only be called after Close().
func (l *Logger) RemoveLogFile() error {
if l == nil {
return nil
}
return os.Remove(l.path)
}
// Flush waits for all pending log entries to be written. Primarily for tests. // Flush waits for all pending log entries to be written. Primarily for tests.
// Returns after a 5-second timeout to prevent indefinite blocking.
func (l *Logger) Flush() { func (l *Logger) Flush() {
if l == nil { if l == nil {
return return
} }
l.pendingWG.Wait()
// Wait for pending entries with timeout
done := make(chan struct{})
go func() {
l.pendingWG.Wait()
close(done)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-done:
// All pending entries processed
case <-ctx.Done():
// Timeout - return without full flush
return
}
// Trigger writer flush
select {
case l.flushReq <- struct{}{}:
// Wait for flush to complete (with mutex)
flushDone := make(chan struct{})
go func() {
l.flushMu.Lock()
l.flushMu.Unlock()
close(flushDone)
}()
select {
case <-flushDone:
// Flush completed
case <-time.After(1 * time.Second):
// Flush timeout
}
case <-l.done:
// Logger is closing
case <-time.After(1 * time.Second):
// Timeout sending flush request
}
} }
func (l *Logger) log(level, msg string) { func (l *Logger) log(level, msg string) {
@@ -122,18 +209,44 @@ func (l *Logger) log(level, msg string) {
l.pendingWG.Add(1) l.pendingWG.Add(1)
select { select {
case l.ch <- entry:
case <-l.done: case <-l.done:
l.pendingWG.Done() l.pendingWG.Done()
return return
case l.ch <- entry: default:
// Channel is full; drop the entry to avoid blocking callers.
l.pendingWG.Done()
return
} }
} }
func (l *Logger) run() { func (l *Logger) run() {
defer l.workerWG.Done() defer l.workerWG.Done()
for entry := range l.ch { ticker := time.NewTicker(500 * time.Millisecond)
fmt.Fprintf(l.file, "%s: %s\n", entry.level, entry.msg) defer ticker.Stop()
l.pendingWG.Done()
for {
select {
case entry, ok := <-l.ch:
if !ok {
// Channel closed, final flush
l.writer.Flush()
return
}
timestamp := time.Now().Format("2006-01-02 15:04:05.000")
pid := os.Getpid()
fmt.Fprintf(l.writer, "[%s] [PID:%d] %s: %s\n", timestamp, pid, entry.level, entry.msg)
l.pendingWG.Done()
case <-ticker.C:
l.writer.Flush()
case <-l.flushReq:
// Explicit flush request
l.flushMu.Lock()
l.writer.Flush()
l.flushMu.Unlock()
}
} }
} }

View File

@@ -21,7 +21,7 @@ import (
) )
const ( const (
version = "1.0.0" version = "4.8.2"
defaultWorkdir = "." defaultWorkdir = "."
defaultTimeout = 7200 // seconds defaultTimeout = 7200 // seconds
forceKillDelay = 5 // seconds forceKillDelay = 5 // seconds
@@ -359,7 +359,7 @@ func main() {
} }
// run is the main logic, returns exit code for testability // run is the main logic, returns exit code for testability
func run() int { func run() (exitCode int) {
logger, err := NewLogger() logger, err := NewLogger()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: failed to initialize logger: %v\n", err) fmt.Fprintf(os.Stderr, "ERROR: failed to initialize logger: %v\n", err)
@@ -368,12 +368,20 @@ func run() int {
setLogger(logger) setLogger(logger)
defer func() { defer func() {
if logger := activeLogger(); logger != nil { logger := activeLogger()
if logger != nil {
logger.Flush() logger.Flush()
} }
if err := closeLogger(); err != nil { if err := closeLogger(); err != nil {
fmt.Fprintf(os.Stderr, "ERROR: failed to close logger: %v\n", err) fmt.Fprintf(os.Stderr, "ERROR: failed to close logger: %v\n", err)
} }
if exitCode == 0 && logger != nil {
if err := logger.RemoveLogFile(); err != nil && !os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "ERROR: failed to remove logger file: %v\n", err)
}
} else if exitCode != 0 && logger != nil {
fmt.Fprintf(os.Stderr, "Log file retained at: %s\n", logger.Path())
}
}() }()
defer runCleanupHook() defer runCleanupHook()
@@ -417,7 +425,7 @@ func run() int {
results := executeConcurrent(layers, timeoutSec) results := executeConcurrent(layers, timeoutSec)
fmt.Println(generateFinalOutput(results)) fmt.Println(generateFinalOutput(results))
exitCode := 0 exitCode = 0
for _, res := range results { for _, res := range results {
if res.ExitCode != 0 { if res.ExitCode != 0 {
exitCode = res.ExitCode exitCode = res.ExitCode
@@ -653,9 +661,39 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
codexArgs = buildCodexArgsFn(cfg, targetArg) codexArgs = buildCodexArgsFn(cfg, targetArg)
} }
logInfoFn := logInfo prefixMsg := func(msg string) string {
logWarnFn := logWarn if taskSpec.ID == "" {
logErrorFn := logError return msg
}
return fmt.Sprintf("[Task: %s] %s", taskSpec.ID, msg)
}
var logInfoFn func(string)
var logWarnFn func(string)
var logErrorFn func(string)
if silent {
// Silent mode: only persist to file when available; avoid stderr noise.
logInfoFn = func(msg string) {
if logger := activeLogger(); logger != nil {
logger.Info(prefixMsg(msg))
}
}
logWarnFn = func(msg string) {
if logger := activeLogger(); logger != nil {
logger.Warn(prefixMsg(msg))
}
}
logErrorFn = func(msg string) {
if logger := activeLogger(); logger != nil {
logger.Error(prefixMsg(msg))
}
}
} else {
logInfoFn = func(msg string) { logInfo(prefixMsg(msg)) }
logWarnFn = func(msg string) { logWarn(prefixMsg(msg)) }
logErrorFn = func(msg string) { logError(prefixMsg(msg)) }
}
stderrBuf := &tailBuffer{limit: stderrCaptureLimit} stderrBuf := &tailBuffer{limit: stderrCaptureLimit}
@@ -749,7 +787,10 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
return result return result
} }
logInfoFn(fmt.Sprintf("Process started with PID: %d", cmd.Process.Pid)) logInfoFn(fmt.Sprintf("Starting codex with PID: %d", cmd.Process.Pid))
if logger := activeLogger(); logger != nil {
logInfoFn(fmt.Sprintf("Log capturing to: %s", logger.Path()))
}
if useStdin && stdinPipe != nil { if useStdin && stdinPipe != nil {
logInfoFn(fmt.Sprintf("Writing %d chars to stdin...", len(taskSpec.Task))) logInfoFn(fmt.Sprintf("Writing %d chars to stdin...", len(taskSpec.Task)))
@@ -765,7 +806,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, custo
parseCh := make(chan parseResult, 1) parseCh := make(chan parseResult, 1)
go func() { go func() {
msg, tid := parseJSONStreamWithWarn(stdoutReader, logWarnFn) msg, tid := parseJSONStreamWithLog(stdoutReader, logWarnFn, logInfoFn)
parseCh <- parseResult{message: msg, threadID: tid} parseCh <- parseResult{message: msg, threadID: tid}
}() }()
@@ -913,16 +954,23 @@ func terminateProcess(cmd *exec.Cmd) *time.Timer {
} }
func parseJSONStream(r io.Reader) (message, threadID string) { func parseJSONStream(r io.Reader) (message, threadID string) {
return parseJSONStreamWithWarn(r, logWarn) return parseJSONStreamWithLog(r, logWarn, logInfo)
} }
func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadID string) { func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadID string) {
return parseJSONStreamWithLog(r, warnFn, logInfo)
}
func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) {
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
if warnFn == nil { if warnFn == nil {
warnFn = func(string) {} warnFn = func(string) {}
} }
if infoFn == nil {
infoFn = func(string) {}
}
totalEvents := 0 totalEvents := 0
@@ -947,15 +995,15 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI
details = append(details, fmt.Sprintf("item_type=%s", event.Item.Type)) details = append(details, fmt.Sprintf("item_type=%s", event.Item.Type))
} }
if len(details) > 0 { if len(details) > 0 {
logInfo(fmt.Sprintf("Parsed event #%d type=%s (%s)", totalEvents, event.Type, strings.Join(details, ", "))) infoFn(fmt.Sprintf("Parsed event #%d type=%s (%s)", totalEvents, event.Type, strings.Join(details, ", ")))
} else { } else {
logInfo(fmt.Sprintf("Parsed event #%d type=%s", totalEvents, event.Type)) infoFn(fmt.Sprintf("Parsed event #%d type=%s", totalEvents, event.Type))
} }
switch event.Type { switch event.Type {
case "thread.started": case "thread.started":
threadID = event.ThreadID threadID = event.ThreadID
logInfo(fmt.Sprintf("thread.started event thread_id=%s", threadID)) infoFn(fmt.Sprintf("thread.started event thread_id=%s", threadID))
case "item.completed": case "item.completed":
var itemType string var itemType string
var normalized string var normalized string
@@ -963,7 +1011,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI
itemType = event.Item.Type itemType = event.Item.Type
normalized = normalizeText(event.Item.Text) normalized = normalizeText(event.Item.Text)
} }
logInfo(fmt.Sprintf("item.completed event item_type=%s message_len=%d", itemType, len(normalized))) infoFn(fmt.Sprintf("item.completed event item_type=%s message_len=%d", itemType, len(normalized)))
if event.Item != nil && event.Item.Type == "agent_message" && normalized != "" { if event.Item != nil && event.Item.Type == "agent_message" && normalized != "" {
message = normalized message = normalized
} }
@@ -974,7 +1022,7 @@ func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadI
warnFn("Read stdout error: " + err.Error()) warnFn("Read stdout error: " + err.Error())
} }
logInfo(fmt.Sprintf("parseJSONStream completed: events=%d, message_len=%d, thread_id_found=%t", totalEvents, len(message), threadID != "")) infoFn(fmt.Sprintf("parseJSONStream completed: events=%d, message_len=%d, thread_id_found=%t", totalEvents, len(message), threadID != ""))
return message, threadID return message, threadID
} }
@@ -1162,27 +1210,27 @@ func farewell(name string) string {
} }
func logInfo(msg string) { func logInfo(msg string) {
fmt.Fprintf(os.Stderr, "INFO: %s\n", msg)
if logger := activeLogger(); logger != nil { if logger := activeLogger(); logger != nil {
logger.Info(msg) logger.Info(msg)
return
} }
fmt.Fprintf(os.Stderr, "INFO: %s\n", msg)
} }
func logWarn(msg string) { func logWarn(msg string) {
fmt.Fprintf(os.Stderr, "WARN: %s\n", msg)
if logger := activeLogger(); logger != nil { if logger := activeLogger(); logger != nil {
logger.Warn(msg) logger.Warn(msg)
return
} }
fmt.Fprintf(os.Stderr, "WARN: %s\n", msg)
} }
func logError(msg string) { func logError(msg string) {
fmt.Fprintf(os.Stderr, "ERROR: %s\n", msg)
if logger := activeLogger(); logger != nil { if logger := activeLogger(); logger != nil {
logger.Error(msg) logger.Error(msg)
return
} }
fmt.Fprintf(os.Stderr, "ERROR: %s\n", msg)
} }
func runCleanupHook() { func runCleanupHook() {

View File

@@ -1217,7 +1217,7 @@ func TestRun_PipedTaskReadError(t *testing.T) {
if exitCode != 1 { if exitCode != 1 {
t.Fatalf("exit=%d, want 1", exitCode) t.Fatalf("exit=%d, want 1", exitCode)
} }
if !strings.Contains(logOutput, "Failed to read piped stdin: read stdin: pipe failure") { if !strings.Contains(logOutput, "ERROR: Failed to read piped stdin: read stdin: pipe failure") {
t.Fatalf("log missing piped read error, got %q", logOutput) t.Fatalf("log missing piped read error, got %q", logOutput)
} }
if _, err := os.Stat(logPath); os.IsNotExist(err) { if _, err := os.Stat(logPath); os.IsNotExist(err) {
@@ -1275,10 +1275,9 @@ func TestRun_LoggerLifecycle(t *testing.T) {
if !fileExisted { if !fileExisted {
t.Fatalf("log file was not present during run") t.Fatalf("log file was not present during run")
} }
if _, err := os.Stat(logPath); os.IsNotExist(err) { if _, err := os.Stat(logPath); !os.IsNotExist(err) {
t.Fatalf("log file should exist after run") t.Fatalf("log file should be removed on success, but it exists")
} }
defer os.Remove(logPath)
} }
func TestRun_LoggerRemovedOnSignal(t *testing.T) { func TestRun_LoggerRemovedOnSignal(t *testing.T) {