diff --git a/codeagent-wrapper/.github/workflows/ci.yml b/codeagent-wrapper/.github/workflows/ci.yml new file mode 100644 index 0000000..cc59c52 --- /dev/null +++ b/codeagent-wrapper/.github/workflows/ci.yml @@ -0,0 +1,39 @@ +name: CI + +on: + push: + branches: [main, master] + pull_request: + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + go-version: ["1.21", "1.22"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + cache: true + - name: Test + run: make test + - name: Build + run: make build + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "1.22" + cache: true + - name: Lint + run: make lint + diff --git a/codeagent-wrapper/.gitignore b/codeagent-wrapper/.gitignore index f2cbd14..b47e524 100644 --- a/codeagent-wrapper/.gitignore +++ b/codeagent-wrapper/.gitignore @@ -1,4 +1,7 @@ # Build artifacts +bin/ +codeagent +codeagent.exe codeagent-wrapper codeagent-wrapper.exe *.test @@ -9,3 +12,12 @@ coverage*.out cover.out cover_*.out coverage.html + +# Logs +*.log + +# Temp files +*.tmp +*.swp +*~ +.DS_Store diff --git a/codeagent-wrapper/Makefile b/codeagent-wrapper/Makefile new file mode 100644 index 0000000..e74e27c --- /dev/null +++ b/codeagent-wrapper/Makefile @@ -0,0 +1,38 @@ +GO ?= go + +BINARY ?= codeagent +CMD_PKG := ./cmd/codeagent + +TOOLS_BIN := $(CURDIR)/bin +TOOLCHAIN ?= go1.22.0 +GOLANGCI_LINT_VERSION := v1.56.2 +STATICCHECK_VERSION := v0.4.7 + +GOLANGCI_LINT := $(TOOLS_BIN)/golangci-lint +STATICCHECK := $(TOOLS_BIN)/staticcheck + +.PHONY: build test lint clean install + +build: + $(GO) build -o $(BINARY) $(CMD_PKG) + +test: + $(GO) test ./... + +$(GOLANGCI_LINT): + @mkdir -p $(TOOLS_BIN) + GOTOOLCHAIN=$(TOOLCHAIN) GOBIN=$(TOOLS_BIN) $(GO) install github.com/golangci/golangci-lint/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION) + +$(STATICCHECK): + @mkdir -p $(TOOLS_BIN) + GOTOOLCHAIN=$(TOOLCHAIN) GOBIN=$(TOOLS_BIN) $(GO) install honnef.co/go/tools/cmd/staticcheck@$(STATICCHECK_VERSION) + +lint: $(GOLANGCI_LINT) $(STATICCHECK) + GOTOOLCHAIN=$(TOOLCHAIN) $(GOLANGCI_LINT) run ./... + GOTOOLCHAIN=$(TOOLCHAIN) $(STATICCHECK) ./... + +clean: + @python3 -c 'import glob, os; paths=["codeagent","codeagent.exe","codeagent-wrapper","codeagent-wrapper.exe","coverage.out","cover.out","coverage.html"]; paths += glob.glob("coverage*.out") + glob.glob("cover_*.out") + glob.glob("*.test"); [os.remove(p) for p in paths if os.path.exists(p)]' + +install: + $(GO) install $(CMD_PKG) diff --git a/codeagent-wrapper/README.md b/codeagent-wrapper/README.md new file mode 100644 index 0000000..f5ed350 --- /dev/null +++ b/codeagent-wrapper/README.md @@ -0,0 +1,151 @@ +# codeagent-wrapper + +`codeagent-wrapper` 是一个用 Go 编写的“多后端 AI 代码代理”命令行包装器:用统一的 CLI 入口封装不同的 AI 工具后端(Codex / Claude / Gemini / Opencode),并提供一致的参数、配置与会话恢复体验。 + +入口:`cmd/codeagent/main.go`(生成二进制名:`codeagent`)。 + +## 功能特性 + +- 多后端支持:`codex` / `claude` / `gemini` / `opencode` +- 统一命令行:`codeagent [flags] ` / `codeagent resume [workdir]` +- 自动 stdin:遇到换行/特殊字符/超长任务自动走 stdin,避免 shell quoting 地狱;也可显式使用 `-` +- 配置合并:支持配置文件与 `CODEAGENT_*` 环境变量(viper) +- Agent 预设:从 `~/.codeagent/models.json` 读取 backend/model/prompt 等预设 +- 并行执行:`--parallel` 从 stdin 读取多任务配置,支持依赖拓扑并发执行 +- 日志清理:`codeagent cleanup` 清理旧日志(日志写入系统临时目录) + +## 安装 + +要求:Go 1.21+。 + +在仓库根目录执行: + +```bash +go install ./cmd/codeagent +``` + +安装后确认: + +```bash +codeagent version +``` + +## 使用示例 + +最简单用法(默认后端:`codex`): + +```bash +codeagent "分析 internal/app/cli.go 的入口逻辑,给出改进建议" +``` + +指定后端: + +```bash +codeagent --backend claude "解释 internal/executor/parallel_config.go 的并行配置格式" +``` + +指定工作目录(第 2 个位置参数): + +```bash +codeagent "在当前 repo 下搜索潜在数据竞争" . +``` + +显式从 stdin 读取 task(使用 `-`): + +```bash +cat task.txt | codeagent - +``` + +恢复会话: + +```bash +codeagent resume "继续上次任务" +``` + +并行模式(从 stdin 读取任务配置;禁止位置参数): + +```bash +codeagent --parallel <<'EOF' +---TASK--- +id: t1 +workdir: . +backend: codex +---CONTENT--- +列出本项目的主要模块以及它们的职责。 +---TASK--- +id: t2 +dependencies: t1 +backend: claude +---CONTENT--- +基于 t1 的结论,提出重构风险点与建议。 +EOF +``` + +## 配置说明 + +### 配置文件 + +默认查找路径(当 `--config` 为空时): + +- `$HOME/.codeagent/config.(yaml|yml|json|toml|...)` + +示例(YAML): + +```yaml +backend: codex +model: gpt-4.1 +skip-permissions: false +``` + +也可以通过 `--config /path/to/config.yaml` 显式指定。 + +### 环境变量(`CODEAGENT_*`) + +通过 viper 读取并自动映射 `-` 为 `_`,常用项: + +- `CODEAGENT_BACKEND`(`codex|claude|gemini|opencode`) +- `CODEAGENT_MODEL` +- `CODEAGENT_AGENT` +- `CODEAGENT_PROMPT_FILE` +- `CODEAGENT_REASONING_EFFORT` +- `CODEAGENT_SKIP_PERMISSIONS` +- `CODEAGENT_FULL_OUTPUT`(并行模式 legacy 输出) +- `CODEAGENT_MAX_PARALLEL_WORKERS`(0 表示不限制,上限 100) + +### Agent 预设(`~/.codeagent/models.json`) + +可在 `~/.codeagent/models.json` 定义 agent → backend/model/prompt 等映射,用 `--agent ` 选择: + +```json +{ + "default_backend": "opencode", + "default_model": "opencode/grok-code", + "agents": { + "develop": { + "backend": "codex", + "model": "gpt-4.1", + "prompt_file": "~/.codeagent/prompts/develop.md", + "description": "Code development" + } + } +} +``` + +## 支持的后端 + +该项目本身不内置模型能力,依赖你本机安装并可在 `PATH` 中找到对应 CLI: + +- `codex`:执行 `codex e ...`(默认会添加 `--dangerously-bypass-approvals-and-sandbox`;如需关闭请设置 `CODEX_BYPASS_SANDBOX=false`) +- `claude`:执行 `claude -p ... --output-format stream-json`(默认会跳过权限提示;如需开启请设置 `CODEAGENT_SKIP_PERMISSIONS=false`) +- `gemini`:执行 `gemini ... -o stream-json`(可从 `~/.gemini/.env` 加载环境变量) +- `opencode`:执行 `opencode run --format json` + +## 开发 + +```bash +make build +make test +make lint +make clean +``` + diff --git a/codeagent-wrapper/agent_config.go b/codeagent-wrapper/agent_config.go deleted file mode 100644 index 72f9b57..0000000 --- a/codeagent-wrapper/agent_config.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" -) - -type AgentModelConfig struct { - Backend string `json:"backend"` - Model string `json:"model"` - PromptFile string `json:"prompt_file,omitempty"` - Description string `json:"description,omitempty"` - Yolo bool `json:"yolo,omitempty"` - Reasoning string `json:"reasoning,omitempty"` -} - -type ModelsConfig struct { - DefaultBackend string `json:"default_backend"` - DefaultModel string `json:"default_model"` - Agents map[string]AgentModelConfig `json:"agents"` -} - -var defaultModelsConfig = ModelsConfig{ - DefaultBackend: "opencode", - DefaultModel: "opencode/grok-code", - Agents: map[string]AgentModelConfig{ - "oracle": {Backend: "claude", Model: "claude-opus-4-5-20251101", PromptFile: "~/.claude/skills/omo/references/oracle.md", Description: "Technical advisor"}, - "librarian": {Backend: "claude", Model: "claude-sonnet-4-5-20250929", PromptFile: "~/.claude/skills/omo/references/librarian.md", Description: "Researcher"}, - "explore": {Backend: "opencode", Model: "opencode/grok-code", PromptFile: "~/.claude/skills/omo/references/explore.md", Description: "Code search"}, - "develop": {Backend: "codex", Model: "", PromptFile: "~/.claude/skills/omo/references/develop.md", Description: "Code development"}, - "frontend-ui-ux-engineer": {Backend: "gemini", Model: "", PromptFile: "~/.claude/skills/omo/references/frontend-ui-ux-engineer.md", Description: "Frontend engineer"}, - "document-writer": {Backend: "gemini", Model: "", PromptFile: "~/.claude/skills/omo/references/document-writer.md", Description: "Documentation"}, - }, - } - -func loadModelsConfig() *ModelsConfig { - home, err := os.UserHomeDir() - if err != nil { - logWarn(fmt.Sprintf("Failed to resolve home directory for models config: %v; using defaults", err)) - return &defaultModelsConfig - } - - configPath := filepath.Join(home, ".codeagent", "models.json") - data, err := os.ReadFile(configPath) - if err != nil { - if !os.IsNotExist(err) { - logWarn(fmt.Sprintf("Failed to read models config %s: %v; using defaults", configPath, err)) - } - return &defaultModelsConfig - } - - var cfg ModelsConfig - if err := json.Unmarshal(data, &cfg); err != nil { - logWarn(fmt.Sprintf("Failed to parse models config %s: %v; using defaults", configPath, err)) - return &defaultModelsConfig - } - - // Merge with defaults - for name, agent := range defaultModelsConfig.Agents { - if _, exists := cfg.Agents[name]; !exists { - if cfg.Agents == nil { - cfg.Agents = make(map[string]AgentModelConfig) - } - cfg.Agents[name] = agent - } - } - - return &cfg -} - -func resolveAgentConfig(agentName string) (backend, model, promptFile, reasoning string, yolo bool) { - cfg := loadModelsConfig() - if agent, ok := cfg.Agents[agentName]; ok { - return agent.Backend, agent.Model, agent.PromptFile, agent.Reasoning, agent.Yolo - } - return cfg.DefaultBackend, cfg.DefaultModel, "", "", false -} diff --git a/codeagent-wrapper/backend.go b/codeagent-wrapper/backend.go deleted file mode 100644 index 5a73410..0000000 --- a/codeagent-wrapper/backend.go +++ /dev/null @@ -1,240 +0,0 @@ -package main - -import ( - "encoding/json" - "os" - "path/filepath" - "strings" -) - -// Backend defines the contract for invoking different AI CLI backends. -// Each backend is responsible for supplying the executable command and -// building the argument list based on the wrapper config. -type Backend interface { - Name() string - BuildArgs(cfg *Config, targetArg string) []string - Command() string -} - -type CodexBackend struct{} - -func (CodexBackend) Name() string { return "codex" } -func (CodexBackend) Command() string { - return "codex" -} -func (CodexBackend) BuildArgs(cfg *Config, targetArg string) []string { - return buildCodexArgs(cfg, targetArg) -} - -type ClaudeBackend struct{} - -func (ClaudeBackend) Name() string { return "claude" } -func (ClaudeBackend) Command() string { - return "claude" -} -func (ClaudeBackend) BuildArgs(cfg *Config, targetArg string) []string { - return buildClaudeArgs(cfg, targetArg) -} - -const maxClaudeSettingsBytes = 1 << 20 // 1MB - -type minimalClaudeSettings struct { - Env map[string]string - Model string -} - -// loadMinimalClaudeSettings 从 ~/.claude/settings.json 只提取安全的最小子集: -// - env: 只接受字符串类型的值 -// - model: 只接受字符串类型的值 -// 文件缺失/解析失败/超限都返回空。 -func loadMinimalClaudeSettings() minimalClaudeSettings { - home, err := os.UserHomeDir() - if err != nil || home == "" { - return minimalClaudeSettings{} - } - - settingPath := filepath.Join(home, ".claude", "settings.json") - info, err := os.Stat(settingPath) - if err != nil || info.Size() > maxClaudeSettingsBytes { - return minimalClaudeSettings{} - } - - data, err := os.ReadFile(settingPath) - if err != nil { - return minimalClaudeSettings{} - } - - var cfg struct { - Env map[string]any `json:"env"` - Model any `json:"model"` - } - if err := json.Unmarshal(data, &cfg); err != nil { - return minimalClaudeSettings{} - } - - out := minimalClaudeSettings{} - - if model, ok := cfg.Model.(string); ok { - out.Model = strings.TrimSpace(model) - } - - if len(cfg.Env) == 0 { - return out - } - - env := make(map[string]string, len(cfg.Env)) - for k, v := range cfg.Env { - s, ok := v.(string) - if !ok { - continue - } - env[k] = s - } - if len(env) == 0 { - return out - } - out.Env = env - return out -} - -// loadMinimalEnvSettings is kept for backwards tests; prefer loadMinimalClaudeSettings. -func loadMinimalEnvSettings() map[string]string { - settings := loadMinimalClaudeSettings() - if len(settings.Env) == 0 { - return nil - } - return settings.Env -} - -// loadGeminiEnv loads environment variables from ~/.gemini/.env -// Supports GEMINI_API_KEY, GEMINI_MODEL, GOOGLE_GEMINI_BASE_URL -// Also sets GEMINI_API_KEY_AUTH_MECHANISM=bearer for third-party API compatibility -func loadGeminiEnv() map[string]string { - home, err := os.UserHomeDir() - if err != nil || home == "" { - return nil - } - - envPath := filepath.Join(home, ".gemini", ".env") - data, err := os.ReadFile(envPath) - if err != nil { - return nil - } - - env := make(map[string]string) - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - idx := strings.IndexByte(line, '=') - if idx <= 0 { - continue - } - key := strings.TrimSpace(line[:idx]) - value := strings.TrimSpace(line[idx+1:]) - if key != "" && value != "" { - env[key] = value - } - } - - // Set bearer auth mechanism for third-party API compatibility - if _, ok := env["GEMINI_API_KEY"]; ok { - if _, hasAuth := env["GEMINI_API_KEY_AUTH_MECHANISM"]; !hasAuth { - env["GEMINI_API_KEY_AUTH_MECHANISM"] = "bearer" - } - } - - if len(env) == 0 { - return nil - } - return env -} - -func buildClaudeArgs(cfg *Config, targetArg string) []string { - if cfg == nil { - return nil - } - args := []string{"-p"} - // Default to skip permissions unless CODEAGENT_SKIP_PERMISSIONS=false - if cfg.SkipPermissions || cfg.Yolo || envFlagDefaultTrue("CODEAGENT_SKIP_PERMISSIONS") { - args = append(args, "--dangerously-skip-permissions") - } - - // Prevent infinite recursion: disable all setting sources (user, project, local) - // This ensures a clean execution environment without CLAUDE.md or skills that would trigger codeagent - args = append(args, "--setting-sources", "") - - if model := strings.TrimSpace(cfg.Model); model != "" { - args = append(args, "--model", model) - } - - if cfg.Mode == "resume" { - if cfg.SessionID != "" { - // Claude CLI uses -r for resume. - args = append(args, "-r", cfg.SessionID) - } - } - // Note: claude CLI doesn't support -C flag; workdir set via cmd.Dir - - args = append(args, "--output-format", "stream-json", "--verbose", targetArg) - - return args -} - -type GeminiBackend struct{} - -func (GeminiBackend) Name() string { return "gemini" } -func (GeminiBackend) Command() string { - return "gemini" -} -func (GeminiBackend) BuildArgs(cfg *Config, targetArg string) []string { - return buildGeminiArgs(cfg, targetArg) -} - -type OpencodeBackend struct{} - -func (OpencodeBackend) Name() string { return "opencode" } -func (OpencodeBackend) Command() string { return "opencode" } -func (OpencodeBackend) BuildArgs(cfg *Config, targetArg string) []string { - args := []string{"run"} - if model := strings.TrimSpace(cfg.Model); model != "" { - args = append(args, "-m", model) - } - if cfg.Mode == "resume" && cfg.SessionID != "" { - args = append(args, "-s", cfg.SessionID) - } - args = append(args, "--format", "json") - if targetArg != "-" { - args = append(args, targetArg) - } - return args -} - -func buildGeminiArgs(cfg *Config, targetArg string) []string { - if cfg == nil { - return nil - } - args := []string{"-o", "stream-json", "-y"} - - if model := strings.TrimSpace(cfg.Model); model != "" { - args = append(args, "-m", model) - } - - if cfg.Mode == "resume" { - if cfg.SessionID != "" { - args = append(args, "-r", cfg.SessionID) - } - } - // Note: gemini CLI doesn't support -C flag; workdir set via cmd.Dir - - // Use positional argument instead of deprecated -p flag - // For stdin mode ("-"), use -p to read from stdin - if targetArg == "-" { - args = append(args, "-p", targetArg) - } else { - args = append(args, targetArg) - } - - return args -} diff --git a/codeagent-wrapper/bench_test.go b/codeagent-wrapper/bench_test.go deleted file mode 100644 index 2a99861..0000000 --- a/codeagent-wrapper/bench_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package main - -import ( - "testing" -) - -// BenchmarkLoggerWrite 测试日志写入性能 -func BenchmarkLoggerWrite(b *testing.B) { - logger, err := NewLogger() - if err != nil { - b.Fatal(err) - } - defer logger.Close() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - logger.Info("benchmark log message") - } - b.StopTimer() - logger.Flush() -} - -// BenchmarkLoggerConcurrentWrite 测试并发日志写入性能 -func BenchmarkLoggerConcurrentWrite(b *testing.B) { - logger, err := NewLogger() - if err != nil { - b.Fatal(err) - } - defer logger.Close() - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - logger.Info("concurrent benchmark log message") - } - }) - b.StopTimer() - logger.Flush() -} diff --git a/codeagent-wrapper/config.go b/codeagent-wrapper/config.go deleted file mode 100644 index cfeb639..0000000 --- a/codeagent-wrapper/config.go +++ /dev/null @@ -1,473 +0,0 @@ -package main - -import ( - "bytes" - "context" - "fmt" - "os" - "strconv" - "strings" -) - -// Config holds CLI configuration -type Config struct { - Mode string // "new" or "resume" - Task string - SessionID string - WorkDir string - Model string - ReasoningEffort string - ExplicitStdin bool - Timeout int - Backend string - Agent string - PromptFile string - PromptFileExplicit bool - SkipPermissions bool - Yolo bool - MaxParallelWorkers int -} - -// ParallelConfig defines the JSON schema for parallel execution -type ParallelConfig struct { - Tasks []TaskSpec `json:"tasks"` - GlobalBackend string `json:"backend,omitempty"` -} - -// TaskSpec describes an individual task entry in the parallel config -type TaskSpec struct { - ID string `json:"id"` - Task string `json:"task"` - WorkDir string `json:"workdir,omitempty"` - Dependencies []string `json:"dependencies,omitempty"` - SessionID string `json:"session_id,omitempty"` - Backend string `json:"backend,omitempty"` - Model string `json:"model,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` - Agent string `json:"agent,omitempty"` - PromptFile string `json:"prompt_file,omitempty"` - SkipPermissions bool `json:"skip_permissions,omitempty"` - Mode string `json:"-"` - UseStdin bool `json:"-"` - Context context.Context `json:"-"` -} - -// TaskResult captures the execution outcome of a task -type TaskResult struct { - TaskID string `json:"task_id"` - ExitCode int `json:"exit_code"` - Message string `json:"message"` - SessionID string `json:"session_id"` - Error string `json:"error"` - LogPath string `json:"log_path"` - // Structured report fields - Coverage string `json:"coverage,omitempty"` // extracted coverage percentage (e.g., "92%") - CoverageNum float64 `json:"coverage_num,omitempty"` // numeric coverage for comparison - CoverageTarget float64 `json:"coverage_target,omitempty"` // target coverage (default 90) - FilesChanged []string `json:"files_changed,omitempty"` // list of changed files - KeyOutput string `json:"key_output,omitempty"` // brief summary of what was done - TestsPassed int `json:"tests_passed,omitempty"` // number of tests passed - TestsFailed int `json:"tests_failed,omitempty"` // number of tests failed - sharedLog bool -} - -var backendRegistry = map[string]Backend{ - "codex": CodexBackend{}, - "claude": ClaudeBackend{}, - "gemini": GeminiBackend{}, - "opencode": OpencodeBackend{}, -} - -func selectBackend(name string) (Backend, error) { - key := strings.ToLower(strings.TrimSpace(name)) - if key == "" { - key = defaultBackendName - } - if backend, ok := backendRegistry[key]; ok { - return backend, nil - } - return nil, fmt.Errorf("unsupported backend %q", name) -} - -func envFlagEnabled(key string) bool { - val, ok := os.LookupEnv(key) - if !ok { - return false - } - val = strings.TrimSpace(strings.ToLower(val)) - switch val { - case "", "0", "false", "no", "off": - return false - default: - return true - } -} - -func parseBoolFlag(val string, defaultValue bool) bool { - val = strings.TrimSpace(strings.ToLower(val)) - switch val { - case "1", "true", "yes", "on": - return true - case "0", "false", "no", "off": - return false - default: - return defaultValue - } -} - -// envFlagDefaultTrue returns true unless the env var is explicitly set to false/0/no/off. -func envFlagDefaultTrue(key string) bool { - val, ok := os.LookupEnv(key) - if !ok { - return true - } - return parseBoolFlag(val, true) -} - -func validateAgentName(name string) error { - if strings.TrimSpace(name) == "" { - return fmt.Errorf("agent name is empty") - } - for _, r := range name { - switch { - case r >= 'a' && r <= 'z': - case r >= 'A' && r <= 'Z': - case r >= '0' && r <= '9': - case r == '-', r == '_': - default: - return fmt.Errorf("agent name %q contains invalid character %q", name, r) - } - } - return nil -} - -func parseParallelConfig(data []byte) (*ParallelConfig, error) { - trimmed := bytes.TrimSpace(data) - if len(trimmed) == 0 { - return nil, fmt.Errorf("parallel config is empty") - } - - tasks := strings.Split(string(trimmed), "---TASK---") - var cfg ParallelConfig - seen := make(map[string]struct{}) - - taskIndex := 0 - for _, taskBlock := range tasks { - taskBlock = strings.TrimSpace(taskBlock) - if taskBlock == "" { - continue - } - taskIndex++ - - parts := strings.SplitN(taskBlock, "---CONTENT---", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("task block #%d missing ---CONTENT--- separator", taskIndex) - } - - meta := strings.TrimSpace(parts[0]) - content := strings.TrimSpace(parts[1]) - - task := TaskSpec{WorkDir: defaultWorkdir} - agentSpecified := false - for _, line := range strings.Split(meta, "\n") { - line = strings.TrimSpace(line) - if line == "" { - continue - } - kv := strings.SplitN(line, ":", 2) - if len(kv) != 2 { - continue - } - key := strings.TrimSpace(kv[0]) - value := strings.TrimSpace(kv[1]) - - switch key { - case "id": - task.ID = value - case "workdir": - // Validate workdir: "-" is not a valid directory - if value == "-" { - return nil, fmt.Errorf("task block #%d has invalid workdir: '-' is not a valid directory path", taskIndex) - } - task.WorkDir = value - case "session_id": - task.SessionID = value - task.Mode = "resume" - case "backend": - task.Backend = value - case "model": - task.Model = value - case "reasoning_effort": - task.ReasoningEffort = value - case "agent": - agentSpecified = true - task.Agent = value - case "skip_permissions", "skip-permissions": - if value == "" { - task.SkipPermissions = true - continue - } - task.SkipPermissions = parseBoolFlag(value, false) - case "dependencies": - for _, dep := range strings.Split(value, ",") { - dep = strings.TrimSpace(dep) - if dep != "" { - task.Dependencies = append(task.Dependencies, dep) - } - } - } - } - - if task.Mode == "" { - task.Mode = "new" - } - - if agentSpecified { - if strings.TrimSpace(task.Agent) == "" { - return nil, fmt.Errorf("task block #%d has empty agent field", taskIndex) - } - if err := validateAgentName(task.Agent); err != nil { - return nil, fmt.Errorf("task block #%d invalid agent name: %w", taskIndex, err) - } - backend, model, promptFile, reasoning, _ := resolveAgentConfig(task.Agent) - if task.Backend == "" { - task.Backend = backend - } - if task.Model == "" { - task.Model = model - } - if task.ReasoningEffort == "" { - task.ReasoningEffort = reasoning - } - task.PromptFile = promptFile - } - - if task.ID == "" { - return nil, fmt.Errorf("task block #%d missing id field", taskIndex) - } - if content == "" { - return nil, fmt.Errorf("task block #%d (%q) missing content", taskIndex, task.ID) - } - if task.Mode == "resume" && strings.TrimSpace(task.SessionID) == "" { - return nil, fmt.Errorf("task block #%d (%q) has empty session_id", taskIndex, task.ID) - } - if _, exists := seen[task.ID]; exists { - return nil, fmt.Errorf("task block #%d has duplicate id: %s", taskIndex, task.ID) - } - - task.Task = content - cfg.Tasks = append(cfg.Tasks, task) - seen[task.ID] = struct{}{} - } - - if len(cfg.Tasks) == 0 { - return nil, fmt.Errorf("no tasks found") - } - - return &cfg, nil -} - -func parseArgs() (*Config, error) { - args := os.Args[1:] - if len(args) == 0 { - return nil, fmt.Errorf("task required") - } - - backendName := defaultBackendName - model := "" - reasoningEffort := "" - agentName := "" - promptFile := "" - promptFileExplicit := false - yolo := false - skipPermissions := envFlagEnabled("CODEAGENT_SKIP_PERMISSIONS") - filtered := make([]string, 0, len(args)) - for i := 0; i < len(args); i++ { - arg := args[i] - switch { - case arg == "--agent": - if i+1 >= len(args) { - return nil, fmt.Errorf("--agent flag requires a value") - } - value := strings.TrimSpace(args[i+1]) - if value == "" { - return nil, fmt.Errorf("--agent flag requires a value") - } - if err := validateAgentName(value); err != nil { - return nil, fmt.Errorf("--agent flag invalid value: %w", err) - } - resolvedBackend, resolvedModel, resolvedPromptFile, resolvedReasoning, resolvedYolo := resolveAgentConfig(value) - backendName = resolvedBackend - model = resolvedModel - if !promptFileExplicit { - promptFile = resolvedPromptFile - } - if reasoningEffort == "" { - reasoningEffort = resolvedReasoning - } - yolo = resolvedYolo - agentName = value - i++ - continue - case strings.HasPrefix(arg, "--agent="): - value := strings.TrimSpace(strings.TrimPrefix(arg, "--agent=")) - if value == "" { - return nil, fmt.Errorf("--agent flag requires a value") - } - if err := validateAgentName(value); err != nil { - return nil, fmt.Errorf("--agent flag invalid value: %w", err) - } - resolvedBackend, resolvedModel, resolvedPromptFile, resolvedReasoning, resolvedYolo := resolveAgentConfig(value) - backendName = resolvedBackend - model = resolvedModel - if !promptFileExplicit { - promptFile = resolvedPromptFile - } - if reasoningEffort == "" { - reasoningEffort = resolvedReasoning - } - yolo = resolvedYolo - agentName = value - continue - case arg == "--prompt-file": - if i+1 >= len(args) { - return nil, fmt.Errorf("--prompt-file flag requires a value") - } - value := strings.TrimSpace(args[i+1]) - if value == "" { - return nil, fmt.Errorf("--prompt-file flag requires a value") - } - promptFile = value - promptFileExplicit = true - i++ - continue - case strings.HasPrefix(arg, "--prompt-file="): - value := strings.TrimSpace(strings.TrimPrefix(arg, "--prompt-file=")) - if value == "" { - return nil, fmt.Errorf("--prompt-file flag requires a value") - } - promptFile = value - promptFileExplicit = true - continue - case arg == "--backend": - if i+1 >= len(args) { - return nil, fmt.Errorf("--backend flag requires a value") - } - backendName = args[i+1] - i++ - continue - case strings.HasPrefix(arg, "--backend="): - value := strings.TrimPrefix(arg, "--backend=") - if value == "" { - return nil, fmt.Errorf("--backend flag requires a value") - } - backendName = value - continue - case arg == "--skip-permissions", arg == "--dangerously-skip-permissions": - skipPermissions = true - continue - case arg == "--model": - if i+1 >= len(args) { - return nil, fmt.Errorf("--model flag requires a value") - } - model = args[i+1] - i++ - continue - case strings.HasPrefix(arg, "--model="): - value := strings.TrimPrefix(arg, "--model=") - if value == "" { - return nil, fmt.Errorf("--model flag requires a value") - } - model = value - continue - case arg == "--reasoning-effort": - if i+1 >= len(args) { - return nil, fmt.Errorf("--reasoning-effort flag requires a value") - } - value := strings.TrimSpace(args[i+1]) - if value == "" { - return nil, fmt.Errorf("--reasoning-effort flag requires a value") - } - reasoningEffort = value - i++ - continue - case strings.HasPrefix(arg, "--reasoning-effort="): - value := strings.TrimSpace(strings.TrimPrefix(arg, "--reasoning-effort=")) - if value == "" { - return nil, fmt.Errorf("--reasoning-effort flag requires a value") - } - reasoningEffort = value - continue - case strings.HasPrefix(arg, "--skip-permissions="): - skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--skip-permissions="), skipPermissions) - continue - case strings.HasPrefix(arg, "--dangerously-skip-permissions="): - skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--dangerously-skip-permissions="), skipPermissions) - continue - } - filtered = append(filtered, arg) - } - - if len(filtered) == 0 { - return nil, fmt.Errorf("task required") - } - args = filtered - - cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, Agent: agentName, PromptFile: promptFile, PromptFileExplicit: promptFileExplicit, SkipPermissions: skipPermissions, Yolo: yolo, Model: strings.TrimSpace(model), ReasoningEffort: strings.TrimSpace(reasoningEffort)} - cfg.MaxParallelWorkers = resolveMaxParallelWorkers() - - if args[0] == "resume" { - if len(args) < 3 { - return nil, fmt.Errorf("resume mode requires: resume ") - } - cfg.Mode = "resume" - cfg.SessionID = strings.TrimSpace(args[1]) - if cfg.SessionID == "" { - return nil, fmt.Errorf("resume mode requires non-empty session_id") - } - cfg.Task = args[2] - cfg.ExplicitStdin = (args[2] == "-") - if len(args) > 3 { - // Validate workdir: "-" is not a valid directory - if args[3] == "-" { - return nil, fmt.Errorf("invalid workdir: '-' is not a valid directory path") - } - cfg.WorkDir = args[3] - } - } else { - cfg.Mode = "new" - cfg.Task = args[0] - cfg.ExplicitStdin = (args[0] == "-") - if len(args) > 1 { - // Validate workdir: "-" is not a valid directory - if args[1] == "-" { - return nil, fmt.Errorf("invalid workdir: '-' is not a valid directory path") - } - cfg.WorkDir = args[1] - } - } - - return cfg, nil -} - -const maxParallelWorkersLimit = 100 - -func resolveMaxParallelWorkers() int { - raw := strings.TrimSpace(os.Getenv("CODEAGENT_MAX_PARALLEL_WORKERS")) - if raw == "" { - return 0 - } - - value, err := strconv.Atoi(raw) - if err != nil || value < 0 { - logWarn(fmt.Sprintf("Invalid CODEAGENT_MAX_PARALLEL_WORKERS=%q, falling back to unlimited", raw)) - return 0 - } - - if value > maxParallelWorkersLimit { - logWarn(fmt.Sprintf("CODEAGENT_MAX_PARALLEL_WORKERS=%d exceeds limit, capping at %d", value, maxParallelWorkersLimit)) - return maxParallelWorkersLimit - } - - return value -} diff --git a/codeagent-wrapper/go.mod b/codeagent-wrapper/go.mod index ae7aa47..629db73 100644 --- a/codeagent-wrapper/go.mod +++ b/codeagent-wrapper/go.mod @@ -1,3 +1,43 @@ module codeagent-wrapper go 1.21 + +require ( + github.com/goccy/go-json v0.10.5 + github.com/rs/zerolog v1.34.0 + github.com/shirou/gopsutil/v3 v3.24.5 + github.com/spf13/cobra v1.8.1 + github.com/spf13/pflag v1.0.5 + github.com/spf13/viper v1.19.0 +) + +require ( + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/tklauser/go-sysconf v0.3.12 // indirect + github.com/tklauser/numcpus v0.6.1 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/codeagent-wrapper/go.sum b/codeagent-wrapper/go.sum new file mode 100644 index 0000000..a5cf2cd --- /dev/null +++ b/codeagent-wrapper/go.sum @@ -0,0 +1,117 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= +github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= +github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= +github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= +github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= +github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= +github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= +github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/codeagent-wrapper/agent_validation_test.go b/codeagent-wrapper/internal/app/agent_validation_test.go similarity index 89% rename from codeagent-wrapper/agent_validation_test.go rename to codeagent-wrapper/internal/app/agent_validation_test.go index 5459fd8..5e00b66 100644 --- a/codeagent-wrapper/agent_validation_test.go +++ b/codeagent-wrapper/internal/app/agent_validation_test.go @@ -1,4 +1,4 @@ -package main +package wrapper import ( "context" @@ -6,6 +6,9 @@ import ( "path/filepath" "testing" "time" + + config "codeagent-wrapper/internal/config" + executor "codeagent-wrapper/internal/executor" ) func TestValidateAgentName(t *testing.T) { @@ -28,7 +31,7 @@ func TestValidateAgentName(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateAgentName(tt.input) + err := config.ValidateAgentName(tt.input) if (err != nil) != tt.wantErr { t.Fatalf("validateAgentName(%q) err=%v, wantErr=%v", tt.input, err, tt.wantErr) } @@ -59,6 +62,8 @@ func TestParseParallelConfig_ResolvesAgentPromptFile(t *testing.T) { home := t.TempDir() t.Setenv("HOME", home) t.Setenv("USERPROFILE", home) + t.Cleanup(config.ResetModelsConfigCacheForTest) + config.ResetModelsConfigCacheForTest() configDir := filepath.Join(home, ".codeagent") if err := os.MkdirAll(configDir, 0o755); err != nil { @@ -117,10 +122,8 @@ func TestDefaultRunCodexTaskFn_AppliesAgentPromptFile(t *testing.T) { WaitDelay: 2 * time.Millisecond, }) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return fake - } - selectBackendFn = func(name string) (Backend, error) { + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return fake }) + _ = executor.SetSelectBackendFn(func(name string) (Backend, error) { return testBackend{ name: name, command: "fake-cmd", @@ -128,7 +131,7 @@ func TestDefaultRunCodexTaskFn_AppliesAgentPromptFile(t *testing.T) { return []string{targetArg} }, }, nil - } + }) res := defaultRunCodexTaskFn(TaskSpec{ ID: "t", diff --git a/codeagent-wrapper/internal/app/app.go b/codeagent-wrapper/internal/app/app.go new file mode 100644 index 0000000..6963b0d --- /dev/null +++ b/codeagent-wrapper/internal/app/app.go @@ -0,0 +1,278 @@ +package wrapper + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + version = "6.0.0-alpha1" + defaultWorkdir = "." + defaultTimeout = 7200 // seconds (2 hours) + defaultCoverageTarget = 90.0 + codexLogLineLimit = 1000 + stdinSpecialChars = "\n\\\"'`$" + stderrCaptureLimit = 4 * 1024 + defaultBackendName = "codex" + defaultCodexCommand = "codex" + + // stdout close reasons + stdoutCloseReasonWait = "wait-done" + stdoutCloseReasonDrain = "drain-timeout" + stdoutCloseReasonCtx = "context-cancel" + stdoutDrainTimeout = 100 * time.Millisecond +) + +// Test hooks for dependency injection +var ( + stdinReader io.Reader = os.Stdin + isTerminalFn = defaultIsTerminal + codexCommand = defaultCodexCommand + cleanupHook func() + startupCleanupAsync = true + + buildCodexArgsFn = buildCodexArgs + selectBackendFn = selectBackend + cleanupLogsFn = cleanupOldLogs + defaultBuildArgsFn = buildCodexArgs + runTaskFn = runCodexTask + exitFn = os.Exit +) + +func runStartupCleanup() { + if cleanupLogsFn == nil { + return + } + defer func() { + if r := recover(); r != nil { + logWarn(fmt.Sprintf("cleanupOldLogs panic: %v", r)) + } + }() + if _, err := cleanupLogsFn(); err != nil { + logWarn(fmt.Sprintf("cleanupOldLogs error: %v", err)) + } +} + +func scheduleStartupCleanup() { + if !startupCleanupAsync { + runStartupCleanup() + return + } + if cleanupLogsFn == nil { + return + } + fn := cleanupLogsFn + go func() { + defer func() { + if r := recover(); r != nil { + logWarn(fmt.Sprintf("cleanupOldLogs panic: %v", r)) + } + }() + if _, err := fn(); err != nil { + logWarn(fmt.Sprintf("cleanupOldLogs error: %v", err)) + } + }() +} + +func runCleanupMode() int { + if cleanupLogsFn == nil { + fmt.Fprintln(os.Stderr, "Cleanup failed: log cleanup function not configured") + return 1 + } + + stats, err := cleanupLogsFn() + if err != nil { + fmt.Fprintf(os.Stderr, "Cleanup failed: %v\n", err) + return 1 + } + + fmt.Println("Cleanup completed") + fmt.Printf("Files scanned: %d\n", stats.Scanned) + fmt.Printf("Files deleted: %d\n", stats.Deleted) + if len(stats.DeletedFiles) > 0 { + for _, f := range stats.DeletedFiles { + fmt.Printf(" - %s\n", f) + } + } + fmt.Printf("Files kept: %d\n", stats.Kept) + if len(stats.KeptFiles) > 0 { + for _, f := range stats.KeptFiles { + fmt.Printf(" - %s\n", f) + } + } + if stats.Errors > 0 { + fmt.Printf("Deletion errors: %d\n", stats.Errors) + } + return 0 +} + +func readAgentPromptFile(path string, allowOutsideClaudeDir bool) (string, error) { + raw := strings.TrimSpace(path) + if raw == "" { + return "", nil + } + + expanded := raw + if raw == "~" || strings.HasPrefix(raw, "~/") || strings.HasPrefix(raw, "~\\") { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + if raw == "~" { + expanded = home + } else { + expanded = home + raw[1:] + } + } + + absPath, err := filepath.Abs(expanded) + if err != nil { + return "", err + } + absPath = filepath.Clean(absPath) + + home, err := os.UserHomeDir() + if err != nil { + if !allowOutsideClaudeDir { + return "", err + } + logWarn(fmt.Sprintf("Failed to resolve home directory for prompt file validation: %v; proceeding without restriction", err)) + } else { + allowedDirs := []string{ + filepath.Clean(filepath.Join(home, ".claude")), + filepath.Clean(filepath.Join(home, ".codeagent", "agents")), + } + for i := range allowedDirs { + allowedAbs, err := filepath.Abs(allowedDirs[i]) + if err == nil { + allowedDirs[i] = filepath.Clean(allowedAbs) + } + } + + isWithinDir := func(path, dir string) bool { + rel, err := filepath.Rel(dir, path) + if err != nil { + return false + } + rel = filepath.Clean(rel) + if rel == "." { + return true + } + if rel == ".." { + return false + } + prefix := ".." + string(os.PathSeparator) + return !strings.HasPrefix(rel, prefix) + } + + if !allowOutsideClaudeDir { + withinAllowed := false + for _, dir := range allowedDirs { + if isWithinDir(absPath, dir) { + withinAllowed = true + break + } + } + if !withinAllowed { + logWarn(fmt.Sprintf("Refusing to read prompt file outside allowed dirs (%s): %s", strings.Join(allowedDirs, ", "), absPath)) + return "", fmt.Errorf("prompt file must be under ~/.claude or ~/.codeagent/agents") + } + + resolvedPath, errPath := filepath.EvalSymlinks(absPath) + if errPath == nil { + resolvedPath = filepath.Clean(resolvedPath) + resolvedAllowed := make([]string, 0, len(allowedDirs)) + for _, dir := range allowedDirs { + resolvedBase, errBase := filepath.EvalSymlinks(dir) + if errBase != nil { + continue + } + resolvedAllowed = append(resolvedAllowed, filepath.Clean(resolvedBase)) + } + if len(resolvedAllowed) > 0 { + withinResolved := false + for _, dir := range resolvedAllowed { + if isWithinDir(resolvedPath, dir) { + withinResolved = true + break + } + } + if !withinResolved { + logWarn(fmt.Sprintf("Refusing to read prompt file outside allowed dirs (%s) (resolved): %s", strings.Join(resolvedAllowed, ", "), resolvedPath)) + return "", fmt.Errorf("prompt file must be under ~/.claude or ~/.codeagent/agents") + } + } + } + } else { + withinAllowed := false + for _, dir := range allowedDirs { + if isWithinDir(absPath, dir) { + withinAllowed = true + break + } + } + if !withinAllowed { + logWarn(fmt.Sprintf("Reading prompt file outside allowed dirs (%s): %s", strings.Join(allowedDirs, ", "), absPath)) + } + } + } + + data, err := os.ReadFile(absPath) + if err != nil { + return "", err + } + return strings.TrimRight(string(data), "\r\n"), nil +} + +func wrapTaskWithAgentPrompt(prompt string, task string) string { + return "\n" + prompt + "\n\n\n" + task +} + +func runCleanupHook() { + if logger := activeLogger(); logger != nil { + logger.Flush() + } + if cleanupHook != nil { + cleanupHook() + } +} + +func printHelp() { + name := currentWrapperName() + help := fmt.Sprintf(`%[1]s - Go wrapper for AI CLI backends + +Usage: + %[1]s "task" [workdir] + %[1]s --backend claude "task" [workdir] + %[1]s --prompt-file /path/to/prompt.md "task" [workdir] + %[1]s - [workdir] Read task from stdin + %[1]s resume "task" [workdir] + %[1]s resume - [workdir] + %[1]s --parallel Run tasks in parallel (config from stdin) + %[1]s --parallel --full-output Run tasks in parallel with full output (legacy) + %[1]s --version + %[1]s --help + +Parallel mode examples: + %[1]s --parallel < tasks.txt + echo '...' | %[1]s --parallel + %[1]s --parallel --full-output < tasks.txt + %[1]s --parallel <<'EOF' + +Environment Variables: + CODEX_TIMEOUT Timeout in milliseconds (default: 7200000) + CODEAGENT_ASCII_MODE Use ASCII symbols instead of Unicode (PASS/WARN/FAIL) + +Exit Codes: + 0 Success + 1 General error (missing args, no output) + 124 Timeout + 127 backend command not found + 130 Interrupted (Ctrl+C) + * Passthrough from backend process`, name) + fmt.Println(help) +} diff --git a/codeagent-wrapper/internal/app/backend.go b/codeagent-wrapper/internal/app/backend.go new file mode 100644 index 0000000..cefbd7b --- /dev/null +++ b/codeagent-wrapper/internal/app/backend.go @@ -0,0 +1,9 @@ +package wrapper + +import backend "codeagent-wrapper/internal/backend" + +type Backend = backend.Backend +type CodexBackend = backend.CodexBackend +type ClaudeBackend = backend.ClaudeBackend +type GeminiBackend = backend.GeminiBackend +type OpencodeBackend = backend.OpencodeBackend diff --git a/codeagent-wrapper/internal/app/backend_init.go b/codeagent-wrapper/internal/app/backend_init.go new file mode 100644 index 0000000..57a7c6d --- /dev/null +++ b/codeagent-wrapper/internal/app/backend_init.go @@ -0,0 +1,7 @@ +package wrapper + +import backend "codeagent-wrapper/internal/backend" + +func init() { + backend.SetLogFuncs(logWarn, logError) +} diff --git a/codeagent-wrapper/internal/app/backend_registry.go b/codeagent-wrapper/internal/app/backend_registry.go new file mode 100644 index 0000000..26af36b --- /dev/null +++ b/codeagent-wrapper/internal/app/backend_registry.go @@ -0,0 +1,5 @@ +package wrapper + +import backend "codeagent-wrapper/internal/backend" + +func selectBackend(name string) (Backend, error) { return backend.Select(name) } diff --git a/codeagent-wrapper/internal/app/bench_test.go b/codeagent-wrapper/internal/app/bench_test.go new file mode 100644 index 0000000..589e22f --- /dev/null +++ b/codeagent-wrapper/internal/app/bench_test.go @@ -0,0 +1,103 @@ +package wrapper + +import ( + "bytes" + "os" + "testing" + + config "codeagent-wrapper/internal/config" +) + +var ( + benchCmdSink any + benchConfigSink *Config + benchMessageSink string + benchThreadIDSink string +) + +// BenchmarkStartup_NewRootCommand measures CLI startup overhead (command+flags construction). +func BenchmarkStartup_NewRootCommand(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + benchCmdSink = newRootCommand() + } +} + +// BenchmarkConfigParse_ParseArgs measures config parsing from argv/env (steady-state). +func BenchmarkConfigParse_ParseArgs(b *testing.B) { + home := b.TempDir() + b.Setenv("HOME", home) + b.Setenv("USERPROFILE", home) + + config.ResetModelsConfigCacheForTest() + b.Cleanup(config.ResetModelsConfigCacheForTest) + + origArgs := os.Args + os.Args = []string{"codeagent-wrapper", "--agent", "develop", "task"} + b.Cleanup(func() { os.Args = origArgs }) + + if _, err := parseArgs(); err != nil { + b.Fatalf("warmup parseArgs() error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + cfg, err := parseArgs() + if err != nil { + b.Fatalf("parseArgs() error: %v", err) + } + benchConfigSink = cfg + } +} + +// BenchmarkJSONParse_ParseJSONStreamInternal measures line-delimited JSON stream parsing. +func BenchmarkJSONParse_ParseJSONStreamInternal(b *testing.B) { + stream := []byte( + `{"type":"thread.started","thread_id":"t"}` + "\n" + + `{"type":"item.completed","item":{"type":"agent_message","text":"hello"}}` + "\n" + + `{"type":"thread.completed","thread_id":"t"}` + "\n", + ) + b.SetBytes(int64(len(stream))) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + message, threadID := parseJSONStreamInternal(bytes.NewReader(stream), nil, nil, nil, nil) + benchMessageSink = message + benchThreadIDSink = threadID + } +} + +// BenchmarkLoggerWrite 测试日志写入性能 +func BenchmarkLoggerWrite(b *testing.B) { + logger, err := NewLogger() + if err != nil { + b.Fatal(err) + } + defer logger.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Info("benchmark log message") + } + b.StopTimer() + logger.Flush() +} + +// BenchmarkLoggerConcurrentWrite 测试并发日志写入性能 +func BenchmarkLoggerConcurrentWrite(b *testing.B) { + logger, err := NewLogger() + if err != nil { + b.Fatal(err) + } + defer logger.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + logger.Info("concurrent benchmark log message") + } + }) + b.StopTimer() + logger.Flush() +} diff --git a/codeagent-wrapper/internal/app/cli.go b/codeagent-wrapper/internal/app/cli.go new file mode 100644 index 0000000..6ce7bf8 --- /dev/null +++ b/codeagent-wrapper/internal/app/cli.go @@ -0,0 +1,657 @@ +package wrapper + +import ( + "errors" + "fmt" + "io" + "os" + "reflect" + "strings" + + config "codeagent-wrapper/internal/config" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/spf13/viper" +) + +type exitError struct { + code int +} + +func (e exitError) Error() string { + return fmt.Sprintf("exit %d", e.code) +} + +type cliOptions struct { + Backend string + Model string + ReasoningEffort string + Agent string + PromptFile string + SkipPermissions bool + + Parallel bool + FullOutput bool + + Cleanup bool + Version bool + ConfigFile string +} + +func Main() { + Run() +} + +// Run is the program entrypoint for cmd/codeagent/main.go. +func Run() { + exitFn(run()) +} + +func run() int { + cmd := newRootCommand() + cmd.SetArgs(os.Args[1:]) + if err := cmd.Execute(); err != nil { + var ee exitError + if errors.As(err, &ee) { + return ee.code + } + return 1 + } + return 0 +} + +func newRootCommand() *cobra.Command { + name := currentWrapperName() + opts := &cliOptions{} + + cmd := &cobra.Command{ + Use: fmt.Sprintf("%s [flags] |resume [workdir]", name), + Short: "Go wrapper for AI CLI backends", + SilenceErrors: true, + SilenceUsage: true, + Args: cobra.ArbitraryArgs, + RunE: func(cmd *cobra.Command, args []string) error { + if opts.Version { + fmt.Printf("%s version %s\n", name, version) + return nil + } + if opts.Cleanup { + code := runCleanupMode() + if code == 0 { + return nil + } + return exitError{code: code} + } + + exitCode := runWithLoggerAndCleanup(func() int { + v, err := config.NewViper(opts.ConfigFile) + if err != nil { + logError(err.Error()) + return 1 + } + + if opts.Parallel { + return runParallelMode(cmd, args, opts, v, name) + } + + logInfo("Script started") + + cfg, err := buildSingleConfig(cmd, args, os.Args[1:], opts, v) + if err != nil { + logError(err.Error()) + return 1 + } + logInfo(fmt.Sprintf("Parsed args: mode=%s, task_len=%d, backend=%s", cfg.Mode, len(cfg.Task), cfg.Backend)) + return runSingleMode(cfg, name) + }) + + if exitCode == 0 { + return nil + } + return exitError{code: exitCode} + }, + } + cmd.CompletionOptions.DisableDefaultCmd = true + + addRootFlags(cmd.Flags(), opts) + cmd.AddCommand(newVersionCommand(name), newCleanupCommand()) + + return cmd +} + +func addRootFlags(fs *pflag.FlagSet, opts *cliOptions) { + fs.StringVar(&opts.ConfigFile, "config", "", "Config file path (default: $HOME/.codeagent/config.*)") + fs.BoolVarP(&opts.Version, "version", "v", false, "Print version and exit") + fs.BoolVar(&opts.Cleanup, "cleanup", false, "Clean up old logs and exit") + + fs.BoolVar(&opts.Parallel, "parallel", false, "Run tasks in parallel (config from stdin)") + fs.BoolVar(&opts.FullOutput, "full-output", false, "Parallel mode: include full task output (legacy)") + + fs.StringVar(&opts.Backend, "backend", defaultBackendName, "Backend to use (codex, claude, gemini, opencode)") + fs.StringVar(&opts.Model, "model", "", "Model override") + fs.StringVar(&opts.ReasoningEffort, "reasoning-effort", "", "Reasoning effort (backend-specific)") + fs.StringVar(&opts.Agent, "agent", "", "Agent preset name (from ~/.codeagent/models.json)") + fs.StringVar(&opts.PromptFile, "prompt-file", "", "Prompt file path") + + fs.BoolVar(&opts.SkipPermissions, "skip-permissions", false, "Skip permissions prompts (also via CODEAGENT_SKIP_PERMISSIONS)") + fs.BoolVar(&opts.SkipPermissions, "dangerously-skip-permissions", false, "Alias for --skip-permissions") +} + +func newVersionCommand(name string) *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Print version and exit", + SilenceErrors: true, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Printf("%s version %s\n", name, version) + return nil + }, + } +} + +func newCleanupCommand() *cobra.Command { + return &cobra.Command{ + Use: "cleanup", + Short: "Clean up old logs and exit", + SilenceErrors: true, + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + code := runCleanupMode() + if code == 0 { + return nil + } + return exitError{code: code} + }, + } +} + +func runWithLoggerAndCleanup(fn func() int) (exitCode int) { + logger, err := NewLogger() + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: failed to initialize logger: %v\n", err) + return 1 + } + setLogger(logger) + + defer func() { + logger := activeLogger() + if logger != nil { + logger.Flush() + } + if err := closeLogger(); err != nil { + fmt.Fprintf(os.Stderr, "ERROR: failed to close logger: %v\n", err) + } + if logger == nil { + return + } + + if exitCode != 0 { + if entries := logger.ExtractRecentErrors(10); len(entries) > 0 { + fmt.Fprintln(os.Stderr, "\n=== Recent Errors ===") + for _, entry := range entries { + fmt.Fprintln(os.Stderr, entry) + } + fmt.Fprintf(os.Stderr, "Log file: %s (deleted)\n", logger.Path()) + } + } + _ = logger.RemoveLogFile() + }() + defer runCleanupHook() + + // Clean up stale logs from previous runs. + scheduleStartupCleanup() + + return fn() +} + +func parseArgs() (*Config, error) { + opts := &cliOptions{} + cmd := &cobra.Command{SilenceErrors: true, SilenceUsage: true, Args: cobra.ArbitraryArgs} + addRootFlags(cmd.Flags(), opts) + + rawArgv := os.Args[1:] + if err := cmd.ParseFlags(rawArgv); err != nil { + return nil, err + } + args := cmd.Flags().Args() + + v, err := config.NewViper(opts.ConfigFile) + if err != nil { + return nil, err + } + + return buildSingleConfig(cmd, args, rawArgv, opts, v) +} + +func buildSingleConfig(cmd *cobra.Command, args []string, rawArgv []string, opts *cliOptions, v *viper.Viper) (*Config, error) { + backendName := defaultBackendName + model := "" + reasoningEffort := "" + agentName := "" + promptFile := "" + promptFileExplicit := false + yolo := false + + if cmd.Flags().Changed("agent") { + agentName = strings.TrimSpace(opts.Agent) + if agentName == "" { + return nil, fmt.Errorf("--agent flag requires a value") + } + if err := config.ValidateAgentName(agentName); err != nil { + return nil, fmt.Errorf("--agent flag invalid value: %w", err) + } + } else { + agentName = strings.TrimSpace(v.GetString("agent")) + if agentName != "" { + if err := config.ValidateAgentName(agentName); err != nil { + return nil, fmt.Errorf("--agent flag invalid value: %w", err) + } + } + } + + var resolvedBackend, resolvedModel, resolvedPromptFile, resolvedReasoning string + if agentName != "" { + var resolvedYolo bool + resolvedBackend, resolvedModel, resolvedPromptFile, resolvedReasoning, _, _, resolvedYolo = config.ResolveAgentConfig(agentName) + yolo = resolvedYolo + } + + if cmd.Flags().Changed("prompt-file") { + promptFile = strings.TrimSpace(opts.PromptFile) + if promptFile == "" { + return nil, fmt.Errorf("--prompt-file flag requires a value") + } + promptFileExplicit = true + } else if val := strings.TrimSpace(v.GetString("prompt-file")); val != "" { + promptFile = val + promptFileExplicit = true + } else { + promptFile = resolvedPromptFile + } + + agentFlagChanged := cmd.Flags().Changed("agent") + backendFlagChanged := cmd.Flags().Changed("backend") + if backendFlagChanged { + backendName = strings.TrimSpace(opts.Backend) + if backendName == "" { + return nil, fmt.Errorf("--backend flag requires a value") + } + } + + switch { + case agentFlagChanged && backendFlagChanged && lastFlagIndex(rawArgv, "agent") > lastFlagIndex(rawArgv, "backend"): + backendName = resolvedBackend + case !backendFlagChanged && agentName != "": + backendName = resolvedBackend + case !backendFlagChanged: + if val := strings.TrimSpace(v.GetString("backend")); val != "" { + backendName = val + } + } + + modelFlagChanged := cmd.Flags().Changed("model") + if modelFlagChanged { + model = strings.TrimSpace(opts.Model) + if model == "" { + return nil, fmt.Errorf("--model flag requires a value") + } + } + + switch { + case agentFlagChanged && modelFlagChanged && lastFlagIndex(rawArgv, "agent") > lastFlagIndex(rawArgv, "model"): + model = strings.TrimSpace(resolvedModel) + case !modelFlagChanged && agentName != "": + model = strings.TrimSpace(resolvedModel) + case !modelFlagChanged: + model = strings.TrimSpace(v.GetString("model")) + } + + if cmd.Flags().Changed("reasoning-effort") { + reasoningEffort = strings.TrimSpace(opts.ReasoningEffort) + if reasoningEffort == "" { + return nil, fmt.Errorf("--reasoning-effort flag requires a value") + } + } else if val := strings.TrimSpace(v.GetString("reasoning-effort")); val != "" { + reasoningEffort = val + } else if agentName != "" { + reasoningEffort = strings.TrimSpace(resolvedReasoning) + } + + skipChanged := cmd.Flags().Changed("skip-permissions") || cmd.Flags().Changed("dangerously-skip-permissions") + skipPermissions := false + if skipChanged { + skipPermissions = opts.SkipPermissions + } else { + skipPermissions = v.GetBool("skip-permissions") + } + + if len(args) == 0 { + return nil, fmt.Errorf("task required") + } + + cfg := &Config{ + WorkDir: defaultWorkdir, + Backend: backendName, + Agent: agentName, + PromptFile: promptFile, + PromptFileExplicit: promptFileExplicit, + SkipPermissions: skipPermissions, + Yolo: yolo, + Model: model, + ReasoningEffort: reasoningEffort, + MaxParallelWorkers: config.ResolveMaxParallelWorkers(), + } + + if args[0] == "resume" { + if len(args) < 3 { + return nil, fmt.Errorf("resume mode requires: resume ") + } + cfg.Mode = "resume" + cfg.SessionID = strings.TrimSpace(args[1]) + if cfg.SessionID == "" { + return nil, fmt.Errorf("resume mode requires non-empty session_id") + } + cfg.Task = args[2] + cfg.ExplicitStdin = (args[2] == "-") + if len(args) > 3 { + if args[3] == "-" { + return nil, fmt.Errorf("invalid workdir: '-' is not a valid directory path") + } + cfg.WorkDir = args[3] + } + } else { + cfg.Mode = "new" + cfg.Task = args[0] + cfg.ExplicitStdin = (args[0] == "-") + if len(args) > 1 { + if args[1] == "-" { + return nil, fmt.Errorf("invalid workdir: '-' is not a valid directory path") + } + cfg.WorkDir = args[1] + } + } + + return cfg, nil +} + +func lastFlagIndex(argv []string, name string) int { + if len(argv) == 0 { + return -1 + } + name = strings.TrimSpace(name) + if name == "" { + return -1 + } + + needle := "--" + name + prefix := needle + "=" + last := -1 + for i, arg := range argv { + if arg == needle || strings.HasPrefix(arg, prefix) { + last = i + } + } + return last +} + +func runParallelMode(cmd *cobra.Command, args []string, opts *cliOptions, v *viper.Viper, name string) int { + if len(args) > 0 { + fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; no positional arguments are allowed.") + fmt.Fprintln(os.Stderr, "Usage examples:") + fmt.Fprintf(os.Stderr, " %s --parallel < tasks.txt\n", name) + fmt.Fprintf(os.Stderr, " echo '...' | %s --parallel\n", name) + fmt.Fprintf(os.Stderr, " %s --parallel <<'EOF'\n", name) + fmt.Fprintf(os.Stderr, " %s --parallel --full-output <<'EOF' # include full task output\n", name) + return 1 + } + + if cmd.Flags().Changed("agent") || cmd.Flags().Changed("prompt-file") || cmd.Flags().Changed("reasoning-effort") { + fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend, --model, --full-output and --skip-permissions are allowed.") + return 1 + } + + backendName := defaultBackendName + if cmd.Flags().Changed("backend") { + backendName = strings.TrimSpace(opts.Backend) + if backendName == "" { + fmt.Fprintln(os.Stderr, "ERROR: --backend flag requires a value") + return 1 + } + } else if val := strings.TrimSpace(v.GetString("backend")); val != "" { + backendName = val + } + + model := "" + if cmd.Flags().Changed("model") { + model = strings.TrimSpace(opts.Model) + if model == "" { + fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value") + return 1 + } + } else { + model = strings.TrimSpace(v.GetString("model")) + } + + fullOutput := opts.FullOutput + if !cmd.Flags().Changed("full-output") && v.IsSet("full-output") { + fullOutput = v.GetBool("full-output") + } + + skipChanged := cmd.Flags().Changed("skip-permissions") || cmd.Flags().Changed("dangerously-skip-permissions") + skipPermissions := false + if skipChanged { + skipPermissions = opts.SkipPermissions + } else { + skipPermissions = v.GetBool("skip-permissions") + } + + backend, err := selectBackendFn(backendName) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) + return 1 + } + backendName = backend.Name() + + data, err := io.ReadAll(stdinReader) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: failed to read stdin: %v\n", err) + return 1 + } + + cfg, err := parseParallelConfig(data) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) + return 1 + } + + cfg.GlobalBackend = backendName + model = strings.TrimSpace(model) + for i := range cfg.Tasks { + if strings.TrimSpace(cfg.Tasks[i].Backend) == "" { + cfg.Tasks[i].Backend = backendName + } + if strings.TrimSpace(cfg.Tasks[i].Model) == "" && model != "" { + cfg.Tasks[i].Model = model + } + cfg.Tasks[i].SkipPermissions = cfg.Tasks[i].SkipPermissions || skipPermissions + } + + timeoutSec := resolveTimeout() + layers, err := topologicalSort(cfg.Tasks) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) + return 1 + } + + results := executeConcurrent(layers, timeoutSec) + + for i := range results { + results[i].CoverageTarget = defaultCoverageTarget + if results[i].Message == "" { + continue + } + + lines := strings.Split(results[i].Message, "\n") + results[i].Coverage = extractCoverageFromLines(lines) + results[i].CoverageNum = extractCoverageNum(results[i].Coverage) + results[i].FilesChanged = extractFilesChangedFromLines(lines) + results[i].TestsPassed, results[i].TestsFailed = extractTestResultsFromLines(lines) + results[i].KeyOutput = extractKeyOutputFromLines(lines, 150) + } + + fmt.Println(generateFinalOutputWithMode(results, !fullOutput)) + + exitCode := 0 + for _, res := range results { + if res.ExitCode != 0 { + exitCode = res.ExitCode + } + } + return exitCode +} + +func runSingleMode(cfg *Config, name string) int { + backend, err := selectBackendFn(cfg.Backend) + if err != nil { + logError(err.Error()) + return 1 + } + cfg.Backend = backend.Name() + + cmdInjected := codexCommand != defaultCodexCommand + argsInjected := buildCodexArgsFn != nil && reflect.ValueOf(buildCodexArgsFn).Pointer() != reflect.ValueOf(defaultBuildArgsFn).Pointer() + + if backend.Name() != defaultBackendName || !cmdInjected { + codexCommand = backend.Command() + } + if backend.Name() != defaultBackendName || !argsInjected { + buildCodexArgsFn = backend.BuildArgs + } + logInfo(fmt.Sprintf("Selected backend: %s", backend.Name())) + + timeoutSec := resolveTimeout() + logInfo(fmt.Sprintf("Timeout: %ds", timeoutSec)) + cfg.Timeout = timeoutSec + + var taskText string + var piped bool + + if cfg.ExplicitStdin { + logInfo("Explicit stdin mode: reading task from stdin") + data, err := io.ReadAll(stdinReader) + if err != nil { + logError("Failed to read stdin: " + err.Error()) + return 1 + } + taskText = string(data) + if taskText == "" { + logError("Explicit stdin mode requires task input from stdin") + return 1 + } + piped = !isTerminal() + } else { + pipedTask, err := readPipedTask() + if err != nil { + logError("Failed to read piped stdin: " + err.Error()) + return 1 + } + piped = pipedTask != "" + if piped { + taskText = pipedTask + } else { + taskText = cfg.Task + } + } + + if strings.TrimSpace(cfg.PromptFile) != "" { + prompt, err := readAgentPromptFile(cfg.PromptFile, cfg.PromptFileExplicit) + if err != nil { + logError("Failed to read prompt file: " + err.Error()) + return 1 + } + taskText = wrapTaskWithAgentPrompt(prompt, taskText) + } + + useStdin := cfg.ExplicitStdin || shouldUseStdin(taskText, piped) + + targetArg := taskText + if useStdin { + targetArg = "-" + } + codexArgs := buildCodexArgsFn(cfg, targetArg) + + logger := activeLogger() + if logger == nil { + fmt.Fprintln(os.Stderr, "ERROR: logger is not initialized") + return 1 + } + + fmt.Fprintf(os.Stderr, "[%s]\n", name) + fmt.Fprintf(os.Stderr, " Backend: %s\n", cfg.Backend) + fmt.Fprintf(os.Stderr, " Command: %s %s\n", codexCommand, strings.Join(codexArgs, " ")) + fmt.Fprintf(os.Stderr, " PID: %d\n", os.Getpid()) + fmt.Fprintf(os.Stderr, " Log: %s\n", logger.Path()) + + if useStdin { + var reasons []string + if piped { + reasons = append(reasons, "piped input") + } + if cfg.ExplicitStdin { + reasons = append(reasons, "explicit \"-\"") + } + if strings.Contains(taskText, "\n") { + reasons = append(reasons, "newline") + } + if strings.Contains(taskText, "\\") { + reasons = append(reasons, "backslash") + } + if strings.Contains(taskText, "\"") { + reasons = append(reasons, "double-quote") + } + if strings.Contains(taskText, "'") { + reasons = append(reasons, "single-quote") + } + if strings.Contains(taskText, "`") { + reasons = append(reasons, "backtick") + } + if strings.Contains(taskText, "$") { + reasons = append(reasons, "dollar") + } + if len(taskText) > 800 { + reasons = append(reasons, "length>800") + } + if len(reasons) > 0 { + logWarn(fmt.Sprintf("Using stdin mode for task due to: %s", strings.Join(reasons, ", "))) + } + } + + logInfo(fmt.Sprintf("%s running...", cfg.Backend)) + + taskSpec := TaskSpec{ + Task: taskText, + WorkDir: cfg.WorkDir, + Mode: cfg.Mode, + SessionID: cfg.SessionID, + Model: cfg.Model, + ReasoningEffort: cfg.ReasoningEffort, + Agent: cfg.Agent, + SkipPermissions: cfg.SkipPermissions, + UseStdin: useStdin, + } + + result := runTaskFn(taskSpec, false, cfg.Timeout) + + if result.ExitCode != 0 { + return result.ExitCode + } + + fmt.Println(result.Message) + if result.SessionID != "" { + fmt.Printf("\n---\nSESSION_ID: %s\n", result.SessionID) + } + + return 0 +} diff --git a/codeagent-wrapper/concurrent_stress_test.go b/codeagent-wrapper/internal/app/concurrent_stress_test.go similarity index 97% rename from codeagent-wrapper/concurrent_stress_test.go rename to codeagent-wrapper/internal/app/concurrent_stress_test.go index 822289a..bca7412 100644 --- a/codeagent-wrapper/concurrent_stress_test.go +++ b/codeagent-wrapper/internal/app/concurrent_stress_test.go @@ -1,4 +1,4 @@ -package main +package wrapper import ( "bufio" @@ -11,9 +11,20 @@ import ( "sync/atomic" "testing" "time" + + "github.com/goccy/go-json" ) func stripTimestampPrefix(line string) string { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "{") { + var evt struct { + Message string `json:"message"` + } + if err := json.Unmarshal([]byte(line), &evt); err == nil && evt.Message != "" { + return evt.Message + } + } if !strings.HasPrefix(line, "[") { return line } diff --git a/codeagent-wrapper/internal/app/config_alias.go b/codeagent-wrapper/internal/app/config_alias.go new file mode 100644 index 0000000..e975398 --- /dev/null +++ b/codeagent-wrapper/internal/app/config_alias.go @@ -0,0 +1,7 @@ +package wrapper + +import config "codeagent-wrapper/internal/config" + +// Keep the existing Config name throughout the codebase, but source the +// implementation from internal/config. +type Config = config.Config diff --git a/codeagent-wrapper/internal/app/executor_alias.go b/codeagent-wrapper/internal/app/executor_alias.go new file mode 100644 index 0000000..4ec6422 --- /dev/null +++ b/codeagent-wrapper/internal/app/executor_alias.go @@ -0,0 +1,54 @@ +package wrapper + +import ( + "context" + + backend "codeagent-wrapper/internal/backend" + config "codeagent-wrapper/internal/config" + executor "codeagent-wrapper/internal/executor" +) + +// defaultRunCodexTaskFn is the default implementation of runCodexTaskFn (exposed for test reset). +func defaultRunCodexTaskFn(task TaskSpec, timeout int) TaskResult { + return executor.DefaultRunCodexTaskFn(task, timeout) +} + +var runCodexTaskFn = defaultRunCodexTaskFn + +func topologicalSort(tasks []TaskSpec) ([][]TaskSpec, error) { + return executor.TopologicalSort(tasks) +} + +func executeConcurrent(layers [][]TaskSpec, timeout int) []TaskResult { + maxWorkers := config.ResolveMaxParallelWorkers() + return executeConcurrentWithContext(context.Background(), layers, timeout, maxWorkers) +} + +func executeConcurrentWithContext(parentCtx context.Context, layers [][]TaskSpec, timeout int, maxWorkers int) []TaskResult { + return executor.ExecuteConcurrentWithContext(parentCtx, layers, timeout, maxWorkers, runCodexTaskFn) +} + +func generateFinalOutput(results []TaskResult) string { + return executor.GenerateFinalOutput(results) +} + +func generateFinalOutputWithMode(results []TaskResult, summaryOnly bool) string { + return executor.GenerateFinalOutputWithMode(results, summaryOnly) +} + +func buildCodexArgs(cfg *Config, targetArg string) []string { + return backend.BuildCodexArgs(cfg, targetArg) +} + +func runCodexTask(taskSpec TaskSpec, silent bool, timeoutSec int) TaskResult { + return runCodexTaskWithContext(context.Background(), taskSpec, nil, nil, false, silent, timeoutSec) +} + +func runCodexProcess(parentCtx context.Context, codexArgs []string, taskText string, useStdin bool, timeoutSec int) (message, threadID string, exitCode int) { + res := runCodexTaskWithContext(parentCtx, TaskSpec{Task: taskText, WorkDir: defaultWorkdir, Mode: "new", UseStdin: useStdin}, nil, codexArgs, true, false, timeoutSec) + return res.Message, res.SessionID, res.ExitCode +} + +func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backend Backend, customArgs []string, useCustomArgs bool, silent bool, timeoutSec int) TaskResult { + return executor.RunCodexTaskWithContext(parentCtx, taskSpec, backend, codexCommand, buildCodexArgsFn, customArgs, useCustomArgs, silent, timeoutSec) +} diff --git a/codeagent-wrapper/executor_concurrent_test.go b/codeagent-wrapper/internal/app/executor_concurrent_test.go similarity index 67% rename from codeagent-wrapper/executor_concurrent_test.go rename to codeagent-wrapper/internal/app/executor_concurrent_test.go index e568d8c..db0fc74 100644 --- a/codeagent-wrapper/executor_concurrent_test.go +++ b/codeagent-wrapper/internal/app/executor_concurrent_test.go @@ -1,4 +1,4 @@ -package main +package wrapper import ( "bufio" @@ -15,9 +15,10 @@ import ( "strings" "sync" "sync/atomic" - "syscall" "testing" "time" + + executor "codeagent-wrapper/internal/executor" ) var executorTestTaskCounter atomic.Int64 @@ -91,7 +92,7 @@ func (rc *reasonReadCloser) record(reason string) { type execFakeRunner struct { stdout io.ReadCloser stderr io.ReadCloser - process processHandle + process executor.ProcessHandle stdin io.WriteCloser dir string env map[string]string @@ -158,7 +159,7 @@ func (f *execFakeRunner) SetEnv(env map[string]string) { f.env[k] = v } } -func (f *execFakeRunner) Process() processHandle { +func (f *execFakeRunner) Process() executor.ProcessHandle { if f.process != nil { return f.process } @@ -168,225 +169,15 @@ func (f *execFakeRunner) Process() processHandle { return &execFakeProcess{pid: 1} } -func TestExecutorHelperCoverage(t *testing.T) { - t.Run("realCmdAndProcess", func(t *testing.T) { - rc := &realCmd{} - if err := rc.Start(); err == nil { - t.Fatalf("expected error for nil command") - } - if err := rc.Wait(); err == nil { - t.Fatalf("expected error for nil command") - } - if _, err := rc.StdoutPipe(); err == nil { - t.Fatalf("expected error for nil command") - } - if _, err := rc.StderrPipe(); err == nil { - t.Fatalf("expected error for nil command") - } - if _, err := rc.StdinPipe(); err == nil { - t.Fatalf("expected error for nil command") - } - rc.SetStderr(io.Discard) - if rc.Process() != nil { - t.Fatalf("expected nil process") - } - rcWithCmd := &realCmd{cmd: &exec.Cmd{}} - rcWithCmd.SetStderr(io.Discard) - rcWithCmd.SetDir("/tmp") - if rcWithCmd.cmd.Dir != "/tmp" { - t.Fatalf("expected SetDir to set cmd.Dir, got %q", rcWithCmd.cmd.Dir) - } - echoCmd := exec.Command("echo", "ok") - rcProc := &realCmd{cmd: echoCmd} - stdoutPipe, err := rcProc.StdoutPipe() - if err != nil { - t.Fatalf("StdoutPipe error: %v", err) - } - stderrPipe, err := rcProc.StderrPipe() - if err != nil { - t.Fatalf("StderrPipe error: %v", err) - } - stdinPipe, err := rcProc.StdinPipe() - if err != nil { - t.Fatalf("StdinPipe error: %v", err) - } - if err := rcProc.Start(); err != nil { - t.Fatalf("Start failed: %v", err) - } - _, _ = stdinPipe.Write([]byte{}) - _ = stdinPipe.Close() - procHandle := rcProc.Process() - if procHandle == nil { - t.Fatalf("expected process handle") - } - _ = procHandle.Signal(syscall.SIGTERM) - _ = procHandle.Kill() - _ = rcProc.Wait() - _, _ = io.ReadAll(stdoutPipe) - _, _ = io.ReadAll(stderrPipe) - - rp := &realProcess{} - if rp.Pid() != 0 { - t.Fatalf("nil process should have pid 0") - } - if rp.Kill() != nil { - t.Fatalf("nil process Kill should be nil") - } - if rp.Signal(syscall.SIGTERM) != nil { - t.Fatalf("nil process Signal should be nil") - } - rpLive := &realProcess{proc: &os.Process{Pid: 99}} - if rpLive.Pid() != 99 { - t.Fatalf("expected pid 99, got %d", rpLive.Pid()) - } - _ = rpLive.Kill() - _ = rpLive.Signal(syscall.SIGTERM) - }) - - t.Run("topologicalSortAndSkip", func(t *testing.T) { - layers, err := topologicalSort([]TaskSpec{{ID: "root"}, {ID: "child", Dependencies: []string{"root"}}}) - if err != nil || len(layers) != 2 { - t.Fatalf("unexpected topological sort result: layers=%d err=%v", len(layers), err) - } - if _, err := topologicalSort([]TaskSpec{{ID: "cycle", Dependencies: []string{"cycle"}}}); err == nil { - t.Fatalf("expected cycle detection error") - } - - failed := map[string]TaskResult{"root": {ExitCode: 1}} - if skip, _ := shouldSkipTask(TaskSpec{ID: "child", Dependencies: []string{"root"}}, failed); !skip { - t.Fatalf("should skip when dependency failed") - } - if skip, _ := shouldSkipTask(TaskSpec{ID: "leaf"}, failed); skip { - t.Fatalf("should not skip task without dependencies") - } - if skip, _ := shouldSkipTask(TaskSpec{ID: "child-ok", Dependencies: []string{"root"}}, map[string]TaskResult{}); skip { - t.Fatalf("should not skip when dependencies succeeded") - } - }) - - t.Run("cancelledTaskResult", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - res := cancelledTaskResult("t1", ctx) - if res.ExitCode != 130 { - t.Fatalf("expected cancel exit code, got %d", res.ExitCode) - } - - timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), 0) - defer timeoutCancel() - res = cancelledTaskResult("t2", timeoutCtx) - if res.ExitCode != 124 { - t.Fatalf("expected timeout exit code, got %d", res.ExitCode) - } - }) - - t.Run("generateFinalOutputAndArgs", func(t *testing.T) { - const key = "CODEX_BYPASS_SANDBOX" - t.Setenv(key, "false") - - out := generateFinalOutput([]TaskResult{ - {TaskID: "ok", ExitCode: 0}, - {TaskID: "fail", ExitCode: 1, Error: "boom"}, - }) - if !strings.Contains(out, "ok") || !strings.Contains(out, "fail") { - t.Fatalf("unexpected summary output: %s", out) - } - // Test summary mode (default) - should have new format with ### headers - out = generateFinalOutput([]TaskResult{{TaskID: "rich", ExitCode: 0, SessionID: "sess", LogPath: "/tmp/log", Message: "hello"}}) - if !strings.Contains(out, "### rich") { - t.Fatalf("summary output missing task header: %s", out) - } - // Test full output mode - should have Session and Message - out = generateFinalOutputWithMode([]TaskResult{{TaskID: "rich", ExitCode: 0, SessionID: "sess", LogPath: "/tmp/log", Message: "hello"}}, false) - if !strings.Contains(out, "Session: sess") || !strings.Contains(out, "Log: /tmp/log") || !strings.Contains(out, "hello") { - t.Fatalf("full output missing fields: %s", out) - } - - args := buildCodexArgs(&Config{Mode: "new", WorkDir: "/tmp"}, "task") - if !slices.Equal(args, []string{"e", "--skip-git-repo-check", "-C", "/tmp", "--json", "task"}) { - t.Fatalf("unexpected codex args: %+v", args) - } - args = buildCodexArgs(&Config{Mode: "resume", SessionID: "sess"}, "target") - if !slices.Equal(args, []string{"e", "--skip-git-repo-check", "--json", "resume", "sess", "target"}) { - t.Fatalf("unexpected resume args: %+v", args) - } - }) - - t.Run("generateFinalOutputASCIIMode", func(t *testing.T) { - t.Setenv("CODEAGENT_ASCII_MODE", "true") - - results := []TaskResult{ - {TaskID: "ok", ExitCode: 0, Coverage: "92%", CoverageNum: 92, CoverageTarget: 90, KeyOutput: "done"}, - {TaskID: "warn", ExitCode: 0, Coverage: "80%", CoverageNum: 80, CoverageTarget: 90, KeyOutput: "did"}, - {TaskID: "bad", ExitCode: 2, Error: "boom"}, - } - out := generateFinalOutput(results) - - for _, sym := range []string{"PASS", "WARN", "FAIL"} { - if !strings.Contains(out, sym) { - t.Fatalf("ASCII mode should include %q, got: %s", sym, out) - } - } - for _, sym := range []string{"✓", "⚠️", "✗"} { - if strings.Contains(out, sym) { - t.Fatalf("ASCII mode should not include %q, got: %s", sym, out) - } - } - }) - - t.Run("generateFinalOutputUnicodeMode", func(t *testing.T) { - t.Setenv("CODEAGENT_ASCII_MODE", "false") - - results := []TaskResult{ - {TaskID: "ok", ExitCode: 0, Coverage: "92%", CoverageNum: 92, CoverageTarget: 90, KeyOutput: "done"}, - {TaskID: "warn", ExitCode: 0, Coverage: "80%", CoverageNum: 80, CoverageTarget: 90, KeyOutput: "did"}, - {TaskID: "bad", ExitCode: 2, Error: "boom"}, - } - out := generateFinalOutput(results) - - for _, sym := range []string{"✓", "⚠️", "✗"} { - if !strings.Contains(out, sym) { - t.Fatalf("Unicode mode should include %q, got: %s", sym, out) - } - } - }) - - t.Run("executeConcurrentWrapper", func(t *testing.T) { - orig := runCodexTaskFn - defer func() { runCodexTaskFn = orig }() - runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { - return TaskResult{TaskID: task.ID, ExitCode: 0, Message: "done"} - } - t.Setenv("CODEAGENT_MAX_PARALLEL_WORKERS", "1") - - results := executeConcurrent([][]TaskSpec{{{ID: "wrap"}}}, 1) - if len(results) != 1 || results[0].TaskID != "wrap" { - t.Fatalf("unexpected wrapper results: %+v", results) - } - - unbounded := executeConcurrentWithContext(context.Background(), [][]TaskSpec{{{ID: "unbounded"}}}, 1, 0) - if len(unbounded) != 1 || unbounded[0].ExitCode != 0 { - t.Fatalf("unexpected unbounded result: %+v", unbounded) - } - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - cancelled := executeConcurrentWithContext(ctx, [][]TaskSpec{{{ID: "cancel"}}}, 1, 1) - if cancelled[0].ExitCode == 0 { - t.Fatalf("expected cancelled result, got %+v", cancelled[0]) - } - }) -} - func TestExecutorRunCodexTaskWithContext(t *testing.T) { - origRunner := newCommandRunner - defer func() { newCommandRunner = origRunner }() + defer resetTestHooks() t.Run("resumeMissingSessionID", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { t.Fatalf("unexpected command execution for invalid resume config") return nil - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: ".", Mode: "resume"}, nil, nil, false, false, 1) if res.ExitCode == 0 || !strings.Contains(res.Error, "session_id") { @@ -396,13 +187,14 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { t.Run("success", func(t *testing.T) { var firstStdout *reasonReadCloser - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { rc := newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"hello"}}`) if firstStdout == nil { firstStdout = rc } return &execFakeRunner{stdout: rc, process: &execFakeProcess{pid: 1234}} - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) res := runCodexTaskWithContext(context.Background(), TaskSpec{ID: "task-1", Task: "payload", WorkDir: "."}, nil, nil, false, false, 1) if res.Error != "" || res.Message != "hello" || res.ExitCode != 0 { @@ -432,17 +224,18 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("startErrors", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{startErr: errors.New("executable file not found"), process: &execFakeProcess{pid: 1}} - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: "."}, nil, nil, false, false, 1) if res.ExitCode != 127 { t.Fatalf("expected missing executable exit code, got %d", res.ExitCode) } - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{startErr: errors.New("start failed"), process: &execFakeProcess{pid: 2}} - } + }) res = runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: "."}, nil, nil, false, false, 1) if res.ExitCode == 0 { t.Fatalf("expected non-zero exit on start failure") @@ -450,13 +243,14 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("timeoutAndPipes", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"slow"}}`), process: &execFakeProcess{pid: 5}, waitDelay: 20 * time.Millisecond, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: ".", UseStdin: true}, nil, nil, false, false, 0) if res.ExitCode == 0 { t.Fatalf("expected timeout result, got %+v", res) @@ -464,17 +258,18 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("pipeErrors", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{stdoutErr: errors.New("stdout fail"), process: &execFakeProcess{pid: 6}} - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: "."}, nil, nil, false, false, 1) if res.ExitCode == 0 { t.Fatalf("expected failure on stdout pipe error") } - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{stdinErr: errors.New("stdin fail"), process: &execFakeProcess{pid: 7}} - } + }) res = runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: ".", UseStdin: true}, nil, nil, false, false, 1) if res.ExitCode == 0 { t.Fatalf("expected failure on stdin pipe error") @@ -487,13 +282,14 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { if exitErr == nil { t.Fatalf("expected exec.ExitError") } - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"ignored"}}`), process: &execFakeProcess{pid: 8}, waitErr: exitErr, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: "."}, nil, nil, false, false, 1) if res.ExitCode == 0 { t.Fatalf("expected non-zero exit on wait error") @@ -501,13 +297,14 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("contextCancelled", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"cancel"}}`), process: &execFakeProcess{pid: 9}, waitDelay: 10 * time.Millisecond, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) ctx, cancel := context.WithCancel(context.Background()) cancel() res := runCodexTaskWithContext(ctx, TaskSpec{Task: "payload", WorkDir: "."}, nil, nil, false, false, 1) @@ -517,12 +314,13 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("silentLogger", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"quiet"}}`), process: &execFakeProcess{pid: 10}, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) _ = closeLogger() res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: "."}, nil, nil, false, true, 1) if res.ExitCode != 0 || res.LogPath == "" { @@ -532,12 +330,13 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("injectedLogger", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"injected"}}`), process: &execFakeProcess{pid: 12}, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) _ = closeLogger() injected, err := NewLoggerWithSuffix("executor-injected") @@ -549,7 +348,7 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { _ = os.Remove(injected.Path()) }() - ctx := withTaskLogger(context.Background(), injected) + ctx := executor.WithTaskLogger(context.Background(), injected) res := runCodexTaskWithContext(ctx, TaskSpec{ID: "task-injected", Task: "payload", WorkDir: "."}, nil, nil, false, true, 1) if res.ExitCode != 0 || res.LogPath != injected.Path() { t.Fatalf("expected injected logger path, got %+v", res) @@ -569,12 +368,13 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("contextLoggerWithoutParent", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"ctx"}}`), process: &execFakeProcess{pid: 14}, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) _ = closeLogger() taskLogger, err := NewLoggerWithSuffix("executor-taskctx") @@ -586,8 +386,8 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { _ = os.Remove(taskLogger.Path()) }) - ctx := withTaskLogger(context.Background(), taskLogger) - res := runCodexTaskWithContext(nil, TaskSpec{ID: "task-context", Task: "payload", WorkDir: ".", Context: ctx}, nil, nil, false, true, 1) + ctx := executor.WithTaskLogger(context.Background(), taskLogger) + res := runCodexTaskWithContext(context.TODO(), TaskSpec{ID: "task-context", Task: "payload", WorkDir: ".", Context: ctx}, nil, nil, false, true, 1) if res.ExitCode != 0 || res.LogPath != taskLogger.Path() { t.Fatalf("expected task logger to be reused from spec context, got %+v", res) } @@ -607,16 +407,17 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { t.Run("backendSetsDirAndNilContext", func(t *testing.T) { var rc *execFakeRunner - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { rc = &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"backend"}}`), process: &execFakeProcess{pid: 13}, } return rc - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) _ = closeLogger() - res := runCodexTaskWithContext(nil, TaskSpec{ID: "task-backend", Task: "payload", WorkDir: "/tmp"}, ClaudeBackend{}, nil, false, false, 1) + res := runCodexTaskWithContext(context.TODO(), TaskSpec{ID: "task-backend", Task: "payload", WorkDir: "/tmp"}, ClaudeBackend{}, nil, false, false, 1) if res.ExitCode != 0 || res.Message != "backend" { t.Fatalf("unexpected result: %+v", res) } @@ -628,13 +429,14 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { t.Run("claudeSkipPermissionsPropagatesFromTaskSpec", func(t *testing.T) { t.Setenv("CODEAGENT_SKIP_PERMISSIONS", "false") var gotArgs []string - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { gotArgs = append([]string(nil), args...) return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"agent_message","text":"ok"}}`), process: &execFakeProcess{pid: 15}, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) _ = closeLogger() res := runCodexTaskWithContext(context.Background(), TaskSpec{ID: "task-skip", Task: "payload", WorkDir: ".", SkipPermissions: true}, ClaudeBackend{}, nil, false, false, 1) @@ -647,12 +449,13 @@ func TestExecutorRunCodexTaskWithContext(t *testing.T) { }) t.Run("missingMessage", func(t *testing.T) { - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return &execFakeRunner{ stdout: newReasonReadCloser(`{"type":"item.completed","item":{"type":"task","text":"noop"}}`), process: &execFakeProcess{pid: 11}, } - } + }) + t.Cleanup(func() { executor.SetNewCommandRunner(nil) }) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "payload", WorkDir: "."}, nil, nil, false, false, 1) if res.ExitCode == 0 { t.Fatalf("expected failure when no agent_message returned") @@ -678,7 +481,7 @@ func TestExecutorParallelLogIsolation(t *testing.T) { origRun := runCodexTaskFn runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { - logger := taskLoggerFromContext(task.Context) + logger := executor.TaskLoggerFromContext(task.Context) if logger == nil { return TaskResult{TaskID: task.ID, ExitCode: 1, Error: "missing task logger"} } @@ -702,7 +505,7 @@ func TestExecutorParallelLogIsolation(t *testing.T) { os.Stderr = stderrW defer func() { os.Stderr = oldStderr }() - results := executeConcurrentWithContext(nil, [][]TaskSpec{{{ID: taskA}, {ID: taskB}}}, 1, -1) + results := executeConcurrentWithContext(context.TODO(), [][]TaskSpec{{{ID: taskA}, {ID: taskB}}}, 1, -1) _ = stderrW.Close() os.Stderr = oldStderr @@ -768,7 +571,7 @@ func TestConcurrentExecutorParallelLogIsolationAndClosure(t *testing.T) { t.Setenv("TMPDIR", tempDir) oldArgs := os.Args - os.Args = []string{defaultWrapperName} + os.Args = []string{wrapperName} t.Cleanup(func() { os.Args = oldArgs }) mainLogger, err := NewLoggerWithSuffix("concurrent-main") @@ -814,7 +617,7 @@ func TestConcurrentExecutorParallelLogIsolationAndClosure(t *testing.T) { runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { readyCh <- struct{}{} - logger := taskLoggerFromContext(task.Context) + logger := executor.TaskLoggerFromContext(task.Context) loggerCh <- taskLoggerInfo{taskID: task.ID, logger: logger} if logger == nil { return TaskResult{TaskID: task.ID, ExitCode: 1, Error: "missing task logger"} @@ -901,15 +704,9 @@ func TestConcurrentExecutorParallelLogIsolationAndClosure(t *testing.T) { } for taskID, logger := range loggers { - if !logger.closed.Load() { + if !logger.IsClosed() { t.Fatalf("expected task logger to be closed for %q", taskID) } - if logger.file == nil { - t.Fatalf("expected task logger file to be non-nil for %q", taskID) - } - if _, err := logger.file.Write([]byte("x")); err == nil { - t.Fatalf("expected task logger file to be closed for %q", taskID) - } } mainLogger.Flush() @@ -979,10 +776,10 @@ func parseTaskIDFromLogLine(line string) (string, bool) { } func TestExecutorTaskLoggerContext(t *testing.T) { - if taskLoggerFromContext(nil) != nil { - t.Fatalf("expected nil logger from nil context") + if executor.TaskLoggerFromContext(context.TODO()) != nil { + t.Fatalf("expected nil logger from TODO context") } - if taskLoggerFromContext(context.Background()) != nil { + if executor.TaskLoggerFromContext(context.Background()) != nil { t.Fatalf("expected nil logger when context has no logger") } @@ -995,12 +792,12 @@ func TestExecutorTaskLoggerContext(t *testing.T) { _ = os.Remove(logger.Path()) }() - ctx := withTaskLogger(context.Background(), logger) - if got := taskLoggerFromContext(ctx); got != logger { + ctx := executor.WithTaskLogger(context.Background(), logger) + if got := executor.TaskLoggerFromContext(ctx); got != logger { t.Fatalf("expected logger roundtrip, got %v", got) } - if taskLoggerFromContext(withTaskLogger(context.Background(), nil)) != nil { + if executor.TaskLoggerFromContext(executor.WithTaskLogger(context.Background(), nil)) != nil { t.Fatalf("expected nil logger when injected logger is nil") } } @@ -1157,7 +954,7 @@ func TestExecutorExecuteConcurrentWithContextBranches(t *testing.T) { orig := runCodexTaskFn runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { - logger := taskLoggerFromContext(task.Context) + logger := executor.TaskLoggerFromContext(task.Context) if logger != mainLogger { return TaskResult{TaskID: task.ID, ExitCode: 1, Error: "unexpected logger"} } @@ -1191,9 +988,6 @@ func TestExecutorExecuteConcurrentWithContextBranches(t *testing.T) { if res.LogPath != mainLogger.Path() { t.Fatalf("shared log path mismatch: got %q want %q", res.LogPath, mainLogger.Path()) } - if !res.sharedLog { - t.Fatalf("expected sharedLog flag for %+v", res) - } if !strings.Contains(stderrOut, "Log (shared)") { t.Fatalf("stderr missing shared marker: %s", stderrOut) } @@ -1222,7 +1016,7 @@ func TestExecutorExecuteConcurrentWithContextBranches(t *testing.T) { orig := runCodexTaskFn runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { - logger := taskLoggerFromContext(task.Context) + logger := executor.TaskLoggerFromContext(task.Context) if logger == nil { return TaskResult{TaskID: task.ID, ExitCode: 1, Error: "missing logger"} } @@ -1260,7 +1054,14 @@ func TestExecutorExecuteConcurrentWithContextBranches(t *testing.T) { if err != nil { t.Fatalf("failed to read log %q: %v", res.LogPath, err) } - if !strings.Contains(string(data), "TASK="+res.TaskID) { + found := false + for _, line := range strings.Split(string(data), "\n") { + if strings.Contains(stripTimestampPrefix(line), "TASK="+res.TaskID) { + found = true + break + } + } + if !found { t.Fatalf("log for %q missing task marker, content: %s", res.TaskID, string(data)) } _ = os.Remove(res.LogPath) @@ -1268,147 +1069,6 @@ func TestExecutorExecuteConcurrentWithContextBranches(t *testing.T) { }) } -func TestExecutorSignalAndTermination(t *testing.T) { - forceKillDelay.Store(0) - defer forceKillDelay.Store(5) - - proc := &execFakeProcess{pid: 42} - cmd := &execFakeRunner{process: proc} - - origNotify := signalNotifyFn - origStop := signalStopFn - defer func() { - signalNotifyFn = origNotify - signalStopFn = origStop - }() - - signalNotifyFn = func(c chan<- os.Signal, sigs ...os.Signal) { - go func() { c <- syscall.SIGINT }() - } - signalStopFn = func(c chan<- os.Signal) {} - - forwardSignals(context.Background(), cmd, func(string) {}) - time.Sleep(20 * time.Millisecond) - - proc.mu.Lock() - signalled := len(proc.signals) - proc.mu.Unlock() - if runtime.GOOS != "windows" && signalled == 0 { - t.Fatalf("process did not receive signal") - } - if proc.killed.Load() == 0 { - t.Fatalf("process was not killed after signal") - } - - timer := terminateProcess(cmd) - if timer == nil { - t.Fatalf("terminateProcess returned nil timer") - } - timer.Stop() - - ft := terminateCommand(cmd) - if ft == nil { - t.Fatalf("terminateCommand returned nil") - } - ft.Stop() - - cmdKill := &execFakeRunner{process: &execFakeProcess{pid: 50}} - ftKill := terminateCommand(cmdKill) - time.Sleep(10 * time.Millisecond) - if p, ok := cmdKill.process.(*execFakeProcess); ok && p.killed.Load() == 0 { - t.Fatalf("terminateCommand did not kill process") - } - ftKill.Stop() - - cmdKill2 := &execFakeRunner{process: &execFakeProcess{pid: 51}} - timer2 := terminateProcess(cmdKill2) - time.Sleep(10 * time.Millisecond) - if p, ok := cmdKill2.process.(*execFakeProcess); ok && p.killed.Load() == 0 { - t.Fatalf("terminateProcess did not kill process") - } - timer2.Stop() - - if terminateCommand(nil) != nil { - t.Fatalf("terminateCommand should return nil for nil cmd") - } - if terminateCommand(&execFakeRunner{allowNilProcess: true}) != nil { - t.Fatalf("terminateCommand should return nil when process is nil") - } - if terminateProcess(nil) != nil { - t.Fatalf("terminateProcess should return nil for nil cmd") - } - if terminateProcess(&execFakeRunner{allowNilProcess: true}) != nil { - t.Fatalf("terminateProcess should return nil when process is nil") - } - - signalNotifyFn = func(c chan<- os.Signal, sigs ...os.Signal) {} - ctxDone, cancelDone := context.WithCancel(context.Background()) - cancelDone() - forwardSignals(ctxDone, &execFakeRunner{process: &execFakeProcess{pid: 70}}, func(string) {}) -} - -func TestExecutorCancelReasonAndCloseWithReason(t *testing.T) { - if reason := cancelReason("", nil); !strings.Contains(reason, "Context") { - t.Fatalf("unexpected cancelReason for nil ctx: %s", reason) - } - ctx, cancel := context.WithTimeout(context.Background(), 0) - defer cancel() - if !strings.Contains(cancelReason("cmd", ctx), "timeout") { - t.Fatalf("expected timeout reason") - } - cancelCtx, cancelFn := context.WithCancel(context.Background()) - cancelFn() - if !strings.Contains(cancelReason("cmd", cancelCtx), "Execution cancelled") { - t.Fatalf("expected cancellation reason") - } - if !strings.Contains(cancelReason("", cancelCtx), "codex") { - t.Fatalf("expected default command name in cancel reason") - } - - rc := &reasonReadCloser{r: strings.NewReader("data"), closedC: make(chan struct{}, 1)} - closeWithReason(rc, "why") - select { - case <-rc.closedC: - default: - t.Fatalf("CloseWithReason was not called") - } - - plain := io.NopCloser(strings.NewReader("x")) - closeWithReason(plain, "noop") - closeWithReason(nil, "noop") -} - -func TestExecutorForceKillTimerStop(t *testing.T) { - done := make(chan struct{}, 1) - ft := &forceKillTimer{timer: time.AfterFunc(50*time.Millisecond, func() { done <- struct{}{} }), done: done} - ft.Stop() - - done2 := make(chan struct{}, 1) - ft2 := &forceKillTimer{timer: time.AfterFunc(0, func() { done2 <- struct{}{} }), done: done2} - time.Sleep(10 * time.Millisecond) - ft2.Stop() - - var nilTimer *forceKillTimer - nilTimer.Stop() - (&forceKillTimer{}).Stop() -} - -func TestExecutorForwardSignalsDefaults(t *testing.T) { - origNotify := signalNotifyFn - origStop := signalStopFn - signalNotifyFn = nil - signalStopFn = nil - defer func() { - signalNotifyFn = origNotify - signalStopFn = origStop - }() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - forwardSignals(ctx, &execFakeRunner{process: &execFakeProcess{pid: 80}}, func(string) {}) - time.Sleep(10 * time.Millisecond) -} - func TestExecutorSharedLogFalseWhenCustomLogPath(t *testing.T) { devNull, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) if err != nil { @@ -1464,10 +1124,9 @@ func TestExecutorSharedLogFalseWhenCustomLogPath(t *testing.T) { } res := results[0] - // 关键断言:即使 handle.shared=true(因为 task logger 创建失败), - // 但因为 LogPath 不等于主 logger 的路径,sharedLog 应为 false - if res.sharedLog { - t.Fatalf("expected sharedLog=false when LogPath differs from shared logger, got true") + out := generateFinalOutputWithMode(results, false) + if strings.Contains(out, "(shared)") { + t.Fatalf("did not expect shared marker when LogPath differs from shared logger, got: %s", out) } // 验证 LogPath 确实是自定义的 diff --git a/codeagent-wrapper/internal/app/logger.go b/codeagent-wrapper/internal/app/logger.go new file mode 100644 index 0000000..abe615b --- /dev/null +++ b/codeagent-wrapper/internal/app/logger.go @@ -0,0 +1,26 @@ +package wrapper + +import ilogger "codeagent-wrapper/internal/logger" + +type Logger = ilogger.Logger +type CleanupStats = ilogger.CleanupStats + +func NewLogger() (*Logger, error) { return ilogger.NewLogger() } + +func NewLoggerWithSuffix(suffix string) (*Logger, error) { return ilogger.NewLoggerWithSuffix(suffix) } + +func setLogger(l *Logger) { ilogger.SetLogger(l) } + +func closeLogger() error { return ilogger.CloseLogger() } + +func activeLogger() *Logger { return ilogger.ActiveLogger() } + +func logInfo(msg string) { ilogger.LogInfo(msg) } + +func logWarn(msg string) { ilogger.LogWarn(msg) } + +func logError(msg string) { ilogger.LogError(msg) } + +func cleanupOldLogs() (CleanupStats, error) { return ilogger.CleanupOldLogs() } + +func sanitizeLogSuffix(raw string) string { return ilogger.SanitizeLogSuffix(raw) } diff --git a/codeagent-wrapper/main_integration_test.go b/codeagent-wrapper/internal/app/main_integration_test.go similarity index 83% rename from codeagent-wrapper/main_integration_test.go rename to codeagent-wrapper/internal/app/main_integration_test.go index 294cf6e..0b39a7c 100644 --- a/codeagent-wrapper/main_integration_test.go +++ b/codeagent-wrapper/internal/app/main_integration_test.go @@ -1,7 +1,8 @@ -package main +package wrapper import ( "bytes" + "codeagent-wrapper/internal/logger" "fmt" "io" "os" @@ -36,7 +37,9 @@ func captureStdout(t *testing.T, fn func()) string { os.Stdout = old var buf bytes.Buffer - io.Copy(&buf, r) + if _, err := io.Copy(&buf, r); err != nil { + t.Fatalf("io.Copy() error = %v", err) + } return buf.String() } @@ -57,11 +60,17 @@ func parseIntegrationOutput(t *testing.T, out string) integrationOutput { for _, p := range parts { p = strings.TrimSpace(p) if strings.HasSuffix(p, "tasks") { - fmt.Sscanf(p, "%d tasks", &payload.Summary.Total) + if _, err := fmt.Sscanf(p, "%d tasks", &payload.Summary.Total); err != nil { + t.Fatalf("failed to parse total tasks from %q: %v", p, err) + } } else if strings.HasSuffix(p, "passed") { - fmt.Sscanf(p, "%d passed", &payload.Summary.Success) + if _, err := fmt.Sscanf(p, "%d passed", &payload.Summary.Success); err != nil { + t.Fatalf("failed to parse passed tasks from %q: %v", p, err) + } } else if strings.HasSuffix(p, "failed") { - fmt.Sscanf(p, "%d failed", &payload.Summary.Failed) + if _, err := fmt.Sscanf(p, "%d failed", &payload.Summary.Failed); err != nil { + t.Fatalf("failed to parse failed tasks from %q: %v", p, err) + } } } } else if strings.HasPrefix(line, "Total:") { @@ -70,11 +79,17 @@ func parseIntegrationOutput(t *testing.T, out string) integrationOutput { for _, p := range parts { p = strings.TrimSpace(p) if strings.HasPrefix(p, "Total:") { - fmt.Sscanf(p, "Total: %d", &payload.Summary.Total) + if _, err := fmt.Sscanf(p, "Total: %d", &payload.Summary.Total); err != nil { + t.Fatalf("failed to parse total tasks from %q: %v", p, err) + } } else if strings.HasPrefix(p, "Success:") { - fmt.Sscanf(p, "Success: %d", &payload.Summary.Success) + if _, err := fmt.Sscanf(p, "Success: %d", &payload.Summary.Success); err != nil { + t.Fatalf("failed to parse passed tasks from %q: %v", p, err) + } } else if strings.HasPrefix(p, "Failed:") { - fmt.Sscanf(p, "Failed: %d", &payload.Summary.Failed) + if _, err := fmt.Sscanf(p, "Failed: %d", &payload.Summary.Failed); err != nil { + t.Fatalf("failed to parse failed tasks from %q: %v", p, err) + } } } } else if line == "## Task Results" { @@ -94,34 +109,39 @@ func parseIntegrationOutput(t *testing.T, out string) integrationOutput { currentTask = &TaskResult{} taskLine := strings.TrimPrefix(line, "### ") - success, warning, failed := getStatusSymbols() - // Parse different formats - if strings.Contains(taskLine, " "+success) { - parts := strings.Split(taskLine, " "+success) + parseMarker := func(marker string, exitCode int) bool { + needle := " " + marker + if !strings.Contains(taskLine, needle) { + return false + } + parts := strings.Split(taskLine, needle) currentTask.TaskID = strings.TrimSpace(parts[0]) - currentTask.ExitCode = 0 - // Extract coverage if present - if len(parts) > 1 { + currentTask.ExitCode = exitCode + if exitCode == 0 && len(parts) > 1 { coveragePart := strings.TrimSpace(parts[1]) if strings.HasSuffix(coveragePart, "%") { currentTask.Coverage = coveragePart } } - } else if strings.Contains(taskLine, " "+warning) { - parts := strings.Split(taskLine, " "+warning) - currentTask.TaskID = strings.TrimSpace(parts[0]) - currentTask.ExitCode = 0 - } else if strings.Contains(taskLine, " "+failed) { - parts := strings.Split(taskLine, " "+failed) - currentTask.TaskID = strings.TrimSpace(parts[0]) - currentTask.ExitCode = 1 - } else { + return true + } + + switch { + case parseMarker("✓", 0), parseMarker("PASS", 0): + // ok + case parseMarker("⚠️", 0), parseMarker("WARN", 0): + // warning + case parseMarker("✗", 1), parseMarker("FAIL", 1): + // fail + default: currentTask.TaskID = taskLine } } else if currentTask != nil && inTaskResults { // Parse task details if strings.HasPrefix(line, "Exit code:") { - fmt.Sscanf(line, "Exit code: %d", ¤tTask.ExitCode) + if _, err := fmt.Sscanf(line, "Exit code: %d", ¤tTask.ExitCode); err != nil { + t.Fatalf("failed to parse exit code from %q: %v", line, err) + } } else if strings.HasPrefix(line, "Error:") { currentTask.Error = strings.TrimPrefix(line, "Error: ") } else if strings.HasPrefix(line, "Log:") { @@ -147,7 +167,9 @@ func parseIntegrationOutput(t *testing.T, out string) integrationOutput { currentTask.ExitCode = 0 } else if strings.HasPrefix(line, "Status: FAILED") { if strings.Contains(line, "exit code") { - fmt.Sscanf(line, "Status: FAILED (exit code %d)", ¤tTask.ExitCode) + if _, err := fmt.Sscanf(line, "Status: FAILED (exit code %d)", ¤tTask.ExitCode); err != nil { + t.Fatalf("failed to parse exit code from %q: %v", line, err) + } } else { currentTask.ExitCode = 1 } @@ -180,6 +202,37 @@ func findResultByID(t *testing.T, payload integrationOutput, id string) TaskResu return TaskResult{} } +func setTempDirEnv(t *testing.T, dir string) string { + t.Helper() + resolved := dir + if eval, err := filepath.EvalSymlinks(dir); err == nil { + resolved = eval + } + t.Setenv("TMPDIR", resolved) + t.Setenv("TEMP", resolved) + t.Setenv("TMP", resolved) + return resolved +} + +func createTempLog(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("test"), 0o644); err != nil { + t.Fatalf("failed to create temp log %s: %v", path, err) + } + return path +} + +func stubProcessRunning(t *testing.T, fn func(int) bool) { + t.Helper() + t.Cleanup(logger.SetProcessRunningCheck(fn)) +} + +func stubProcessStartTime(t *testing.T, fn func(int) time.Time) { + t.Helper() + t.Cleanup(logger.SetProcessStartTimeFn(fn)) +} + func TestRunParallelEndToEnd_OrderAndConcurrency(t *testing.T) { defer resetTestHooks() origRun := runCodexTaskFn @@ -365,7 +418,7 @@ id: beta ---CONTENT--- task-beta` stdinReader = bytes.NewReader([]byte(input)) - os.Args = []string{"codex-wrapper", "--parallel"} + os.Args = []string{"codeagent-wrapper", "--parallel"} var exitCode int output := captureStdout(t, func() { @@ -418,9 +471,9 @@ id: d ---CONTENT--- ok-d` stdinReader = bytes.NewReader([]byte(input)) - os.Args = []string{"codex-wrapper", "--parallel"} + os.Args = []string{"codeagent-wrapper", "--parallel"} - expectedLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + expectedLog := filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", os.Getpid())) origRun := runCodexTaskFn runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { @@ -474,9 +527,9 @@ ok-d` // After parallel log isolation fix, each task has its own log file expectedLines := map[string]struct{}{ - fmt.Sprintf("Task a: Log: %s", filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d-a.log", os.Getpid()))): {}, - fmt.Sprintf("Task b: Log: %s", filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d-b.log", os.Getpid()))): {}, - fmt.Sprintf("Task d: Log: %s", filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d-d.log", os.Getpid()))): {}, + fmt.Sprintf("Task a: Log: %s", filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d-a.log", os.Getpid()))): {}, + fmt.Sprintf("Task b: Log: %s", filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d-b.log", os.Getpid()))): {}, + fmt.Sprintf("Task d: Log: %s", filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d-d.log", os.Getpid()))): {}, } if len(taskLines) != len(expectedLines) { @@ -494,7 +547,7 @@ func TestRunNonParallelOutputsIncludeLogPathsIntegration(t *testing.T) { defer resetTestHooks() tempDir := setTempDirEnv(t, t.TempDir()) - os.Args = []string{"codex-wrapper", "integration-log-check"} + os.Args = []string{"codeagent-wrapper", "integration-log-check"} stdinReader = strings.NewReader("") isTerminalFn = func() bool { return true } codexCommand = "echo" @@ -512,7 +565,7 @@ func TestRunNonParallelOutputsIncludeLogPathsIntegration(t *testing.T) { if exitCode != 0 { t.Fatalf("run() exit=%d, want 0", exitCode) } - expectedLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + expectedLog := filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", os.Getpid())) wantLine := fmt.Sprintf("Log: %s", expectedLog) if !strings.Contains(stderr, wantLine) { t.Fatalf("stderr missing %q, got: %q", wantLine, stderr) @@ -693,11 +746,11 @@ func TestRunStartupCleanupRemovesOrphansEndToEnd(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) - orphanA := createTempLog(t, tempDir, "codex-wrapper-5001.log") - orphanB := createTempLog(t, tempDir, "codex-wrapper-5002-extra.log") - orphanC := createTempLog(t, tempDir, "codex-wrapper-5003-suffix.log") + orphanA := createTempLog(t, tempDir, "codeagent-wrapper-5001.log") + orphanB := createTempLog(t, tempDir, "codeagent-wrapper-5002-extra.log") + orphanC := createTempLog(t, tempDir, "codeagent-wrapper-5003-suffix.log") runningPID := 81234 - runningLog := createTempLog(t, tempDir, fmt.Sprintf("codex-wrapper-%d.log", runningPID)) + runningLog := createTempLog(t, tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", runningPID)) unrelated := createTempLog(t, tempDir, "wrapper.log") stubProcessRunning(t, func(pid int) bool { @@ -713,7 +766,7 @@ func TestRunStartupCleanupRemovesOrphansEndToEnd(t *testing.T) { codexCommand = createFakeCodexScript(t, "tid-startup", "ok") stdinReader = strings.NewReader("") isTerminalFn = func() bool { return true } - os.Args = []string{"codex-wrapper", "task"} + os.Args = []string{"codeagent-wrapper", "task"} if exit := run(); exit != 0 { t.Fatalf("run() exit=%d, want 0", exit) @@ -739,7 +792,7 @@ func TestRunStartupCleanupConcurrentWrappers(t *testing.T) { const totalLogs = 40 for i := 0; i < totalLogs; i++ { - createTempLog(t, tempDir, fmt.Sprintf("codex-wrapper-%d.log", 9000+i)) + createTempLog(t, tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", 9000+i)) } stubProcessRunning(t, func(pid int) bool { @@ -763,7 +816,7 @@ func TestRunStartupCleanupConcurrentWrappers(t *testing.T) { close(start) wg.Wait() - matches, err := filepath.Glob(filepath.Join(tempDir, "codex-wrapper-*.log")) + matches, err := filepath.Glob(filepath.Join(tempDir, "codeagent-wrapper-*.log")) if err != nil { t.Fatalf("glob error: %v", err) } @@ -777,9 +830,9 @@ func TestRunCleanupFlagEndToEnd_Success(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) - staleA := createTempLog(t, tempDir, "codex-wrapper-2100.log") - staleB := createTempLog(t, tempDir, "codex-wrapper-2200-extra.log") - keeper := createTempLog(t, tempDir, "codex-wrapper-2300.log") + staleA := createTempLog(t, tempDir, "codeagent-wrapper-2100.log") + staleB := createTempLog(t, tempDir, "codeagent-wrapper-2200-extra.log") + keeper := createTempLog(t, tempDir, "codeagent-wrapper-2300.log") stubProcessRunning(t, func(pid int) bool { return pid == 2300 || pid == os.Getpid() @@ -791,7 +844,7 @@ func TestRunCleanupFlagEndToEnd_Success(t *testing.T) { return time.Time{} }) - os.Args = []string{"codex-wrapper", "--cleanup"} + os.Args = []string{"codeagent-wrapper", "--cleanup"} var exitCode int output := captureStdout(t, func() { @@ -815,10 +868,10 @@ func TestRunCleanupFlagEndToEnd_Success(t *testing.T) { if !strings.Contains(output, "Files kept: 1") { t.Fatalf("missing 'Files kept: 1' in output: %q", output) } - if !strings.Contains(output, "codex-wrapper-2100.log") || !strings.Contains(output, "codex-wrapper-2200-extra.log") { + if !strings.Contains(output, "codeagent-wrapper-2100.log") || !strings.Contains(output, "codeagent-wrapper-2200-extra.log") { t.Fatalf("missing deleted file names in output: %q", output) } - if !strings.Contains(output, "codex-wrapper-2300.log") { + if !strings.Contains(output, "codeagent-wrapper-2300.log") { t.Fatalf("missing kept file names in output: %q", output) } @@ -831,7 +884,7 @@ func TestRunCleanupFlagEndToEnd_Success(t *testing.T) { t.Fatalf("expected kept log to remain, err=%v", err) } - currentLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + currentLog := filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", os.Getpid())) if _, err := os.Stat(currentLog); err == nil { t.Fatalf("cleanup mode should not create new log file %s", currentLog) } else if !os.IsNotExist(err) { @@ -850,7 +903,7 @@ func TestRunCleanupFlagEndToEnd_FailureDoesNotAffectStartup(t *testing.T) { return CleanupStats{Scanned: 1}, fmt.Errorf("permission denied") } - os.Args = []string{"codex-wrapper", "--cleanup"} + os.Args = []string{"codeagent-wrapper", "--cleanup"} var exitCode int errOutput := captureStderr(t, func() { @@ -867,7 +920,7 @@ func TestRunCleanupFlagEndToEnd_FailureDoesNotAffectStartup(t *testing.T) { t.Fatalf("cleanup called %d times, want 1", calls) } - currentLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + currentLog := filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", os.Getpid())) if _, err := os.Stat(currentLog); err == nil { t.Fatalf("cleanup failure should not create new log file %s", currentLog) } else if !os.IsNotExist(err) { @@ -880,7 +933,7 @@ func TestRunCleanupFlagEndToEnd_FailureDoesNotAffectStartup(t *testing.T) { codexCommand = createFakeCodexScript(t, "tid-cleanup-e2e", "ok") stdinReader = strings.NewReader("") isTerminalFn = func() bool { return true } - os.Args = []string{"codex-wrapper", "post-cleanup task"} + os.Args = []string{"codeagent-wrapper", "post-cleanup task"} var normalExit int normalOutput := captureStdout(t, func() { diff --git a/codeagent-wrapper/main_test.go b/codeagent-wrapper/internal/app/main_test.go similarity index 92% rename from codeagent-wrapper/main_test.go rename to codeagent-wrapper/internal/app/main_test.go index 7aacb7c..b6d3d93 100644 --- a/codeagent-wrapper/main_test.go +++ b/codeagent-wrapper/internal/app/main_test.go @@ -1,10 +1,9 @@ -package main +package wrapper import ( "bufio" "bytes" "context" - "encoding/json" "errors" "fmt" "io" @@ -19,6 +18,11 @@ import ( "syscall" "testing" "time" + + config "codeagent-wrapper/internal/config" + executor "codeagent-wrapper/internal/executor" + + "github.com/goccy/go-json" ) // Helper to reset test hooks @@ -28,17 +32,15 @@ func resetTestHooks() { codexCommand = "codex" cleanupHook = nil cleanupLogsFn = cleanupOldLogs - signalNotifyFn = signal.Notify - signalStopFn = signal.Stop + startupCleanupAsync = false + config.ResetModelsConfigCacheForTest() + _ = executor.SetSelectBackendFn(nil) buildCodexArgsFn = buildCodexArgs selectBackendFn = selectBackend - commandContext = exec.CommandContext - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return &realCmd{cmd: commandContext(ctx, name, args...)} - } - forceKillDelay.Store(5) - closeLogger() - executablePathFn = os.Executable + _ = executor.SetCommandContextFn(nil) + _ = executor.SetNewCommandRunner(nil) + _ = executor.SetForceKillDelay(5) + _ = closeLogger() runTaskFn = runCodexTask runCodexTaskFn = defaultRunCodexTaskFn exitFn = os.Exit @@ -86,6 +88,8 @@ func (t testBackend) Command() string { return "echo" } +func (t testBackend) Env(baseURL, apiKey string) map[string]string { return nil } + func withBackend(command string, argsFn func(*Config, string) []string) func() { prev := selectBackendFn selectBackendFn = func(name string) (Backend, error) { @@ -107,7 +111,7 @@ func restoreStdoutPipe(c *capturedStdout) { } c.writer.Close() os.Stdout = c.old - io.Copy(&c.buf, c.reader) + _, _ = io.Copy(&c.buf, c.reader) } func (c *capturedStdout) String() string { @@ -127,7 +131,9 @@ func captureOutput(t *testing.T, fn func()) string { os.Stdout = old var buf bytes.Buffer - io.Copy(&buf, r) + if _, err := io.Copy(&buf, r); err != nil { + t.Fatalf("io.Copy() error = %v", err) + } return buf.String() } @@ -141,7 +147,9 @@ func captureStderr(t *testing.T, fn func()) string { os.Stderr = old var buf bytes.Buffer - io.Copy(&buf, r) + if _, err := io.Copy(&buf, r); err != nil { + t.Fatalf("io.Copy() error = %v", err) + } return buf.String() } @@ -262,7 +270,7 @@ func (d *drainBlockingCmd) SetEnv(env map[string]string) { d.inner.SetEnv(env) } -func (d *drainBlockingCmd) Process() processHandle { +func (d *drainBlockingCmd) Process() executor.ProcessHandle { return d.inner.Process() } @@ -553,7 +561,7 @@ func (f *fakeCmd) SetEnv(env map[string]string) { } } -func (f *fakeCmd) Process() processHandle { +func (f *fakeCmd) Process() executor.ProcessHandle { if f == nil { return nil } @@ -728,9 +736,7 @@ func TestFakeCmdInfra(t *testing.T) { WaitDelay: 5 * time.Millisecond, }) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return fake - } + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return fake }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } @@ -774,9 +780,7 @@ func TestRunCodexTask_WaitBeforeParse(t *testing.T) { WaitDelay: waitDelay, }) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return fake - } + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return fake }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } @@ -821,9 +825,7 @@ func TestRunCodexTask_ParseStall(t *testing.T) { }) blockingCmd := newDrainBlockingCmd(fake) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return blockingCmd - } + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return blockingCmd }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } @@ -878,7 +880,7 @@ func TestRunCodexTask_ParseStall(t *testing.T) { func TestRunCodexTask_ContextTimeout(t *testing.T) { defer resetTestHooks() - forceKillDelay.Store(0) + _ = executor.SetForceKillDelay(0) fake := newFakeCmd(fakeCmdConfig{ KeepStdoutOpen: true, @@ -887,9 +889,7 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) { ReleaseWaitOnSignal: false, }) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return fake - } + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return fake }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } @@ -898,14 +898,6 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - var capturedTimer *forceKillTimer - terminateCommandFn = func(cmd commandRunner) *forceKillTimer { - timer := terminateCommand(cmd) - capturedTimer = timer - return timer - } - defer func() { terminateCommandFn = terminateCommand }() - result := runCodexTaskWithContext(ctx, TaskSpec{Task: "ctx-timeout", WorkDir: defaultWorkdir}, nil, nil, false, false, 60) if result.ExitCode != 124 { @@ -929,15 +921,6 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) { t.Fatalf("expected Kill to eventually run, got 0") } } - if capturedTimer == nil { - t.Fatalf("forceKillTimer not captured") - } - if !capturedTimer.stopped.Load() { - t.Fatalf("forceKillTimer.Stop was not called") - } - if !capturedTimer.drained.Load() { - t.Fatalf("forceKillTimer drain logic did not run") - } if fake.stdout == nil { t.Fatalf("stdout reader not initialized") } @@ -948,7 +931,7 @@ func TestRunCodexTask_ContextTimeout(t *testing.T) { func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) { defer resetTestHooks() - forceKillDelay.Store(0) + _ = executor.SetForceKillDelay(0) fake := newFakeCmd(fakeCmdConfig{ StdoutPlan: []fakeStdoutEvent{ @@ -961,9 +944,7 @@ func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) { ReleaseWaitOnKill: true, }) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return fake - } + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return fake }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } codexCommand = "fake-cmd" @@ -988,7 +969,7 @@ func TestRunCodexTask_ForcesStopAfterCompletion(t *testing.T) { func TestRunCodexTask_ForcesStopAfterTurnCompleted(t *testing.T) { defer resetTestHooks() - forceKillDelay.Store(0) + _ = executor.SetForceKillDelay(0) fake := newFakeCmd(fakeCmdConfig{ StdoutPlan: []fakeStdoutEvent{ @@ -1001,9 +982,7 @@ func TestRunCodexTask_ForcesStopAfterTurnCompleted(t *testing.T) { ReleaseWaitOnKill: true, }) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return fake - } + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return fake }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } codexCommand = "fake-cmd" @@ -1028,7 +1007,7 @@ func TestRunCodexTask_ForcesStopAfterTurnCompleted(t *testing.T) { func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) { defer resetTestHooks() - forceKillDelay.Store(0) + _ = executor.SetForceKillDelay(0) fake := newFakeCmd(fakeCmdConfig{ StdoutPlan: []fakeStdoutEvent{ @@ -1042,9 +1021,7 @@ func TestRunCodexTask_DoesNotTerminateBeforeThreadCompleted(t *testing.T) { ReleaseWaitOnKill: true, }) - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { - return fake - } + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { return fake }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } codexCommand = "fake-cmd" @@ -1498,7 +1475,7 @@ func TestBackendParseBoolFlag(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := parseBoolFlag(tt.val, tt.def); got != tt.want { + if got := config.ParseBoolFlag(tt.val, tt.def); got != tt.want { t.Fatalf("parseBoolFlag(%q,%v) = %v, want %v", tt.val, tt.def, got, tt.want) } }) @@ -1508,17 +1485,17 @@ func TestBackendParseBoolFlag(t *testing.T) { func TestBackendEnvFlagEnabled(t *testing.T) { const key = "TEST_FLAG_ENABLED" t.Setenv(key, "") - if envFlagEnabled(key) { + if config.EnvFlagEnabled(key) { t.Fatalf("envFlagEnabled should be false when unset") } t.Setenv(key, "true") - if !envFlagEnabled(key) { + if !config.EnvFlagEnabled(key) { t.Fatalf("envFlagEnabled should be true for 'true'") } t.Setenv(key, "no") - if envFlagEnabled(key) { + if config.EnvFlagEnabled(key) { t.Fatalf("envFlagEnabled should be false for 'no'") } } @@ -1708,8 +1685,8 @@ func TestClaudeModel_DefaultsFromSettings(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - makeRunner := func(gotName *string, gotArgs *[]string, fake **fakeCmd) func(context.Context, string, ...string) commandRunner { - return func(ctx context.Context, name string, args ...string) commandRunner { + makeRunner := func(gotName *string, gotArgs *[]string, fake **fakeCmd) func(context.Context, string, ...string) executor.CommandRunner { + return func(ctx context.Context, name string, args ...string) executor.CommandRunner { *gotName = name *gotArgs = append([]string(nil), args...) cmd := newFakeCmd(fakeCmdConfig{ @@ -1729,9 +1706,8 @@ func TestClaudeModel_DefaultsFromSettings(t *testing.T) { gotArgs []string fake *fakeCmd ) - origRunner := newCommandRunner - newCommandRunner = makeRunner(&gotName, &gotArgs, &fake) - t.Cleanup(func() { newCommandRunner = origRunner }) + restore := executor.SetNewCommandRunner(makeRunner(&gotName, &gotArgs, &fake)) + t.Cleanup(restore) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5) if res.ExitCode != 0 || res.Message != "ok" { @@ -1761,9 +1737,8 @@ func TestClaudeModel_DefaultsFromSettings(t *testing.T) { gotArgs []string fake *fakeCmd ) - origRunner := newCommandRunner - newCommandRunner = makeRunner(&gotName, &gotArgs, &fake) - t.Cleanup(func() { newCommandRunner = origRunner }) + restore := executor.SetNewCommandRunner(makeRunner(&gotName, &gotArgs, &fake)) + t.Cleanup(restore) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir, Model: "sonnet"}, ClaudeBackend{}, nil, false, true, 5) if res.ExitCode != 0 || res.Message != "ok" { @@ -1787,9 +1762,8 @@ func TestClaudeModel_DefaultsFromSettings(t *testing.T) { gotArgs []string fake *fakeCmd ) - origRunner := newCommandRunner - newCommandRunner = makeRunner(&gotName, &gotArgs, &fake) - t.Cleanup(func() { newCommandRunner = origRunner }) + restore := executor.SetNewCommandRunner(makeRunner(&gotName, &gotArgs, &fake)) + t.Cleanup(restore) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "resume", SessionID: "sid-123", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5) if res.ExitCode != 0 || res.Message != "ok" { @@ -1991,8 +1965,7 @@ func TestRunCodexTaskWithContext_CodexReasoningEffort(t *testing.T) { t.Setenv("CODEX_BYPASS_SANDBOX", "false") var gotArgs []string - origRunner := newCommandRunner - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + restore := executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { gotArgs = append([]string(nil), args...) return newFakeCmd(fakeCmdConfig{ PID: 123, @@ -2000,8 +1973,8 @@ func TestRunCodexTaskWithContext_CodexReasoningEffort(t *testing.T) { {Data: "{\"type\":\"result\",\"session_id\":\"sid\",\"result\":\"ok\"}\n"}, }, }) - } - t.Cleanup(func() { newCommandRunner = origRunner }) + }) + t.Cleanup(restore) res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir, ReasoningEffort: "high"}, nil, nil, false, true, 5) if res.ExitCode != 0 || res.Message != "ok" { @@ -2071,7 +2044,7 @@ func TestRunBuildCodexArgs_BypassSandboxEnvTrue(t *testing.T) { t.Fatalf("NewLogger() error = %v", err) } setLogger(logger) - defer closeLogger() + defer func() { _ = closeLogger() }() t.Setenv("CODEX_BYPASS_SANDBOX", "true") @@ -2716,7 +2689,7 @@ func TestRunLogFunctions(t *testing.T) { t.Fatalf("NewLogger() error = %v", err) } setLogger(logger) - defer closeLogger() + defer func() { _ = closeLogger() }() logInfo("info message") logWarn("warn message") @@ -2751,25 +2724,38 @@ func TestLoggerPathAndRemoveNil(t *testing.T) { } func TestLoggerLogDropOnDone(t *testing.T) { - logger := &Logger{ - ch: make(chan logEntry), - done: make(chan struct{}), - } - close(logger.done) - logger.log("INFO", "dropped") - logger.pendingWG.Wait() + t.Skip("internal logger behavior moved to internal/logger; exercise via public methods instead") } func TestLoggerLogAfterClose(t *testing.T) { defer resetTestHooks() + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + logger, err := NewLogger() if err != nil { t.Fatalf("NewLogger error: %v", err) } + logPath := logger.Path() + t.Cleanup(func() { _ = os.Remove(logPath) }) + + logger.Info("before close") + logger.Flush() + if err := logger.Close(); err != nil { t.Fatalf("Close error: %v", err) } - logger.log("INFO", "should be ignored") + + logger.Info("should be ignored") + logger.Flush() + + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + if strings.Contains(string(data), "should be ignored") { + t.Fatalf("expected log message to be dropped after Close, got: %s", string(data)) + } } func TestLogWriterLogLine(t *testing.T) { @@ -2788,7 +2774,7 @@ func TestLogWriterLogLine(t *testing.T) { if !strings.Contains(string(data), "P:abc") { t.Fatalf("log output missing truncated entry, got %q", string(data)) } - closeLogger() + _ = closeLogger() } func TestNewLogWriterDefaultMaxLen(t *testing.T) { @@ -2807,7 +2793,9 @@ func TestBackendPrintHelp(t *testing.T) { os.Stdout = oldStdout var buf bytes.Buffer - io.Copy(&buf, r) + if _, err := io.Copy(&buf, r); err != nil { + t.Fatalf("io.Copy() error = %v", err) + } output := buf.String() expected := []string{"codeagent-wrapper", "Usage:", "resume", "CODEX_TIMEOUT", "Exit Codes:"} @@ -2929,12 +2917,12 @@ func TestRunCodexTaskFn_UsesTaskBackend(t *testing.T) { var seenName string var seenArgs []string - newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + _ = executor.SetNewCommandRunner(func(ctx context.Context, name string, args ...string) executor.CommandRunner { seenName = name seenArgs = append([]string(nil), args...) return fake - } - selectBackendFn = func(name string) (Backend, error) { + }) + _ = executor.SetSelectBackendFn(func(name string) (Backend, error) { return testBackend{ name: strings.ToLower(name), command: "custom-cli", @@ -2942,7 +2930,7 @@ func TestRunCodexTaskFn_UsesTaskBackend(t *testing.T) { return []string{"do", targetArg} }, }, nil - } + }) res := runCodexTaskFn(TaskSpec{ID: "task-1", Task: "payload", Backend: "Custom"}, 5) @@ -2966,9 +2954,9 @@ func TestRunCodexTaskFn_UsesTaskBackend(t *testing.T) { func TestRunCodexTaskFn_InvalidBackend(t *testing.T) { defer resetTestHooks() - selectBackendFn = func(name string) (Backend, error) { + _ = executor.SetSelectBackendFn(func(name string) (Backend, error) { return nil, fmt.Errorf("invalid backend: %s", name) - } + }) res := runCodexTaskFn(TaskSpec{ID: "bad-task", Task: "noop", Backend: "unknown"}, 5) if res.ExitCode == 0 { @@ -3109,11 +3097,11 @@ func TestRunCodexTask_ExitError(t *testing.T) { func TestRunCodexTask_StdinPipeError(t *testing.T) { defer resetTestHooks() - commandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + _ = executor.SetCommandContextFn(func(ctx context.Context, name string, args ...string) *exec.Cmd { cmd := exec.CommandContext(ctx, "cat") cmd.Stdin = os.Stdin return cmd - } + }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } res := runCodexTask(TaskSpec{Task: "data", UseStdin: true}, false, 1) if res.ExitCode != 1 || !strings.Contains(res.Error, "stdin pipe") { @@ -3123,11 +3111,11 @@ func TestRunCodexTask_StdinPipeError(t *testing.T) { func TestRunCodexTask_StdoutPipeError(t *testing.T) { defer resetTestHooks() - commandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + _ = executor.SetCommandContextFn(func(ctx context.Context, name string, args ...string) *exec.Cmd { cmd := exec.CommandContext(ctx, "echo", "noop") cmd.Stdout = os.Stdout return cmd - } + }) buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } res := runCodexTask(TaskSpec{Task: "noop"}, false, 1) if res.ExitCode != 1 || !strings.Contains(res.Error, "stdout pipe") { @@ -3170,36 +3158,6 @@ func TestRunCodexTask_SignalHandling(t *testing.T) { } } -func TestForwardSignals_ContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - forwardSignals(ctx, &realCmd{cmd: &exec.Cmd{}}, func(string) {}) - cancel() - time.Sleep(10 * time.Millisecond) -} - -func TestCancelReason(t *testing.T) { - const cmdName = "codex" - - if got := cancelReason(cmdName, nil); got != "Context cancelled" { - t.Fatalf("cancelReason(nil) = %q, want %q", got, "Context cancelled") - } - - ctxTimeout, cancelTimeout := context.WithTimeout(context.Background(), 1*time.Nanosecond) - defer cancelTimeout() - <-ctxTimeout.Done() - wantTimeout := fmt.Sprintf("%s execution timeout", cmdName) - if got := cancelReason(cmdName, ctxTimeout); got != wantTimeout { - t.Fatalf("cancelReason(deadline) = %q, want %q", got, wantTimeout) - } - - ctxCancelled, cancel := context.WithCancel(context.Background()) - cancel() - if got := cancelReason(cmdName, ctxCancelled); got != "Execution cancelled, terminating codex process" { - t.Fatalf("cancelReason(cancelled) = %q, want %q", got, "Execution cancelled, terminating codex process") - } -} - func TestRunCodexProcess(t *testing.T) { defer resetTestHooks() script := createFakeCodexScript(t, "proc-thread", "proc-msg") @@ -3235,7 +3193,9 @@ func TestRunSilentMode(t *testing.T) { w.Close() os.Stderr = oldStderr var buf bytes.Buffer - io.Copy(&buf, r) + if _, err := io.Copy(&buf, r); err != nil { + t.Fatalf("io.Copy() error = %v", err) + } return buf.String() } @@ -3336,35 +3296,6 @@ func TestParallelTopologicalSortTasks(t *testing.T) { } } -func TestRunShouldSkipTask(t *testing.T) { - failed := map[string]TaskResult{"a": {TaskID: "a", ExitCode: 1}, "b": {TaskID: "b", ExitCode: 2}} - tests := []struct { - name string - task TaskSpec - skip bool - reasonContains []string - }{ - {"no deps", TaskSpec{ID: "c"}, false, nil}, - {"missing deps not failed", TaskSpec{ID: "d", Dependencies: []string{"x"}}, false, nil}, - {"single failed dep", TaskSpec{ID: "e", Dependencies: []string{"a"}}, true, []string{"a"}}, - {"multiple failed deps", TaskSpec{ID: "f", Dependencies: []string{"a", "b"}}, true, []string{"a", "b"}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - skip, reason := shouldSkipTask(tt.task, failed) - if skip != tt.skip { - t.Fatalf("skip=%v, want %v", skip, tt.skip) - } - for _, expect := range tt.reasonContains { - if !strings.Contains(reason, expect) { - t.Fatalf("reason %q missing %q", reason, expect) - } - } - }) - } -} - func TestRunTopologicalSort_CycleDetection(t *testing.T) { tasks := []TaskSpec{{ID: "a", Dependencies: []string{"b"}}, {ID: "b", Dependencies: []string{"a"}}} if _, err := topologicalSort(tasks); err == nil || !strings.Contains(err.Error(), "cycle detected") { @@ -3701,7 +3632,7 @@ func TestParallelTriggersCleanup(t *testing.T) { oldArgs := os.Args defer func() { os.Args = oldArgs }() - os.Args = []string{"codex-wrapper", "--parallel"} + os.Args = []string{"codeagent-wrapper", "--parallel"} stdinReader = strings.NewReader(`---TASK--- id: only ---CONTENT--- @@ -3736,7 +3667,7 @@ func TestVersionFlag(t *testing.T) { } }) - want := "codeagent-wrapper version 5.6.4\n" + want := "codeagent-wrapper version 6.0.0-alpha1\n" if output != want { t.Fatalf("output = %q, want %q", output, want) @@ -3752,7 +3683,7 @@ func TestVersionShortFlag(t *testing.T) { } }) - want := "codeagent-wrapper version 5.6.4\n" + want := "codeagent-wrapper version 6.0.0-alpha1\n" if output != want { t.Fatalf("output = %q, want %q", output, want) @@ -3761,14 +3692,14 @@ func TestVersionShortFlag(t *testing.T) { func TestVersionLegacyAlias(t *testing.T) { defer resetTestHooks() - os.Args = []string{"codex-wrapper", "--version"} + os.Args = []string{"codeagent-wrapper", "--version"} output := captureOutput(t, func() { if code := run(); code != 0 { t.Errorf("exit = %d, want 0", code) } }) - want := "codex-wrapper version 5.6.4\n" + want := "codeagent-wrapper version 6.0.0-alpha1\n" if output != want { t.Fatalf("output = %q, want %q", output, want) @@ -3793,7 +3724,7 @@ func TestRun_HelpShort(t *testing.T) { func TestRun_HelpDoesNotTriggerCleanup(t *testing.T) { defer resetTestHooks() - os.Args = []string{"codex-wrapper", "--help"} + os.Args = []string{"codeagent-wrapper", "--help"} cleanupLogsFn = func() (CleanupStats, error) { t.Fatalf("cleanup should not run for --help") return CleanupStats{}, nil @@ -3806,7 +3737,7 @@ func TestRun_HelpDoesNotTriggerCleanup(t *testing.T) { func TestVersionDoesNotTriggerCleanup(t *testing.T) { defer resetTestHooks() - os.Args = []string{"codex-wrapper", "--version"} + os.Args = []string{"codeagent-wrapper", "--version"} cleanupLogsFn = func() (CleanupStats, error) { t.Fatalf("cleanup should not run for --version") return CleanupStats{}, nil @@ -3874,7 +3805,7 @@ func TestVersionCoverageFullRun(t *testing.T) { _ = closeLogger() _ = logger.RemoveLogFile() - loggerPtr.Store(nil) + setLogger(nil) }) t.Run("parseArgsError", func(t *testing.T) { @@ -4089,7 +4020,7 @@ func TestVersionMainWrapper(t *testing.T) { exitCalled := -1 exitFn = func(code int) { exitCalled = code } os.Args = []string{"codeagent-wrapper", "--version"} - main() + Main() if exitCalled != 0 { t.Fatalf("main exit = %d, want 0", exitCalled) } @@ -4102,8 +4033,8 @@ func TestBackendCleanupMode_Success(t *testing.T) { Scanned: 5, Deleted: 3, Kept: 2, - DeletedFiles: []string{"codex-wrapper-111.log", "codex-wrapper-222.log", "codex-wrapper-333.log"}, - KeptFiles: []string{"codex-wrapper-444.log", "codex-wrapper-555.log"}, + DeletedFiles: []string{"codeagent-wrapper-111.log", "codeagent-wrapper-222.log", "codeagent-wrapper-333.log"}, + KeptFiles: []string{"codeagent-wrapper-444.log", "codeagent-wrapper-555.log"}, }, nil } @@ -4114,7 +4045,7 @@ func TestBackendCleanupMode_Success(t *testing.T) { if exitCode != 0 { t.Fatalf("exit = %d, want 0", exitCode) } - want := "Cleanup completed\nFiles scanned: 5\nFiles deleted: 3\n - codex-wrapper-111.log\n - codex-wrapper-222.log\n - codex-wrapper-333.log\nFiles kept: 2\n - codex-wrapper-444.log\n - codex-wrapper-555.log\n" + want := "Cleanup completed\nFiles scanned: 5\nFiles deleted: 3\n - codeagent-wrapper-111.log\n - codeagent-wrapper-222.log\n - codeagent-wrapper-333.log\nFiles kept: 2\n - codeagent-wrapper-444.log\n - codeagent-wrapper-555.log\n" if output != want { t.Fatalf("output = %q, want %q", output, want) } @@ -4128,7 +4059,7 @@ func TestBackendCleanupMode_SuccessWithErrorsLine(t *testing.T) { Deleted: 1, Kept: 0, Errors: 1, - DeletedFiles: []string{"codex-wrapper-123.log"}, + DeletedFiles: []string{"codeagent-wrapper-123.log"}, }, nil } @@ -4139,7 +4070,7 @@ func TestBackendCleanupMode_SuccessWithErrorsLine(t *testing.T) { if exitCode != 0 { t.Fatalf("exit = %d, want 0", exitCode) } - want := "Cleanup completed\nFiles scanned: 2\nFiles deleted: 1\n - codex-wrapper-123.log\nFiles kept: 0\nDeletion errors: 1\n" + want := "Cleanup completed\nFiles scanned: 2\nFiles deleted: 1\n - codeagent-wrapper-123.log\nFiles kept: 0\nDeletion errors: 1\n" if output != want { t.Fatalf("output = %q, want %q", output, want) } @@ -4208,7 +4139,7 @@ func TestRun_CleanupFlag(t *testing.T) { oldArgs := os.Args defer func() { os.Args = oldArgs }() - os.Args = []string{"codex-wrapper", "--cleanup"} + os.Args = []string{"codeagent-wrapper", "--cleanup"} calls := 0 cleanupLogsFn = func() (CleanupStats, error) { @@ -4453,7 +4384,7 @@ func TestRun_LoggerRemovedOnSignal(t *testing.T) { defer signal.Reset(syscall.SIGINT, syscall.SIGTERM) // Set shorter delays for faster test - forceKillDelay.Store(1) + _ = executor.SetForceKillDelay(1) tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) @@ -4565,7 +4496,7 @@ func TestRun_CleanupFailureDoesNotBlock(t *testing.T) { codexCommand = createFakeCodexScript(t, "tid-cleanup", "ok") stdinReader = strings.NewReader("") isTerminalFn = func() bool { return true } - os.Args = []string{"codex-wrapper", "task"} + os.Args = []string{"codeagent-wrapper", "task"} if exit := run(); exit != 0 { t.Fatalf("exit = %d, want 0", exit) @@ -4704,73 +4635,6 @@ func TestBackendDiscardInvalidJSONBuffer(t *testing.T) { }) } -func TestRunForwardSignals(t *testing.T) { - defer resetTestHooks() - - if runtime.GOOS == "windows" { - t.Skip("sleep command not available on Windows") - } - - execCmd := exec.Command("sleep", "5") - if err := execCmd.Start(); err != nil { - t.Skipf("unable to start sleep command: %v", err) - } - defer func() { - _ = execCmd.Process.Kill() - execCmd.Wait() - }() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - forceKillDelay.Store(0) - defer forceKillDelay.Store(5) - - ready := make(chan struct{}) - var captured chan<- os.Signal - signalNotifyFn = func(ch chan<- os.Signal, sig ...os.Signal) { - captured = ch - close(ready) - } - signalStopFn = func(ch chan<- os.Signal) {} - defer func() { - signalNotifyFn = signal.Notify - signalStopFn = signal.Stop - }() - - var mu sync.Mutex - var logs []string - cmd := &realCmd{cmd: execCmd} - forwardSignals(ctx, cmd, func(msg string) { - mu.Lock() - defer mu.Unlock() - logs = append(logs, msg) - }) - - select { - case <-ready: - case <-time.After(500 * time.Millisecond): - t.Fatalf("signalNotifyFn not invoked") - } - - captured <- syscall.SIGINT - - done := make(chan error, 1) - go func() { done <- cmd.Wait() }() - - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatalf("process did not exit after forwarded signal") - } - - mu.Lock() - defer mu.Unlock() - if len(logs) == 0 { - t.Fatalf("expected log entry for forwarded signal") - } -} - // Backend-focused coverage suite to ensure run() paths stay exercised under the focused pattern. func TestBackendRunCoverage(t *testing.T) { suite := []struct { @@ -4810,7 +4674,7 @@ func TestParallelLogPathInSerialMode(t *testing.T) { tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) - os.Args = []string{"codex-wrapper", "do-stuff"} + os.Args = []string{"codeagent-wrapper", "do-stuff"} stdinReader = strings.NewReader("") isTerminalFn = func() bool { return true } codexCommand = "echo" @@ -4827,126 +4691,13 @@ func TestParallelLogPathInSerialMode(t *testing.T) { if exitCode != 0 { t.Fatalf("run() exit = %d, want 0", exitCode) } - expectedLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + expectedLog := filepath.Join(tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", os.Getpid())) wantLine := fmt.Sprintf("Log: %s", expectedLog) if !strings.Contains(stderr, wantLine) { t.Fatalf("stderr missing %q, got: %q", wantLine, stderr) } } -func TestRealProcessNilSafety(t *testing.T) { - var proc *realProcess - if pid := proc.Pid(); pid != 0 { - t.Fatalf("Pid() = %d, want 0", pid) - } - if err := proc.Kill(); err != nil { - t.Fatalf("Kill() error = %v", err) - } - if err := proc.Signal(syscall.SIGTERM); err != nil { - t.Fatalf("Signal() error = %v", err) - } -} - -func TestRealProcessKill(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("sleep command not available on Windows") - } - - cmd := exec.Command("sleep", "5") - if err := cmd.Start(); err != nil { - t.Skipf("unable to start sleep command: %v", err) - } - waited := false - defer func() { - if waited { - return - } - if cmd.Process != nil { - _ = cmd.Process.Kill() - cmd.Wait() - } - }() - - proc := &realProcess{proc: cmd.Process} - if proc.Pid() == 0 { - t.Fatalf("Pid() returned 0 for active process") - } - if err := proc.Kill(); err != nil { - t.Fatalf("Kill() error = %v", err) - } - waitErr := cmd.Wait() - waited = true - if waitErr == nil { - t.Fatalf("Kill() should lead to non-nil wait error") - } -} - -func TestRealProcessSignal(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("sleep command not available on Windows") - } - - cmd := exec.Command("sleep", "5") - if err := cmd.Start(); err != nil { - t.Skipf("unable to start sleep command: %v", err) - } - waited := false - defer func() { - if waited { - return - } - if cmd.Process != nil { - _ = cmd.Process.Kill() - cmd.Wait() - } - }() - - proc := &realProcess{proc: cmd.Process} - if err := proc.Signal(syscall.SIGTERM); err != nil { - t.Fatalf("Signal() error = %v", err) - } - waitErr := cmd.Wait() - waited = true - if waitErr == nil { - t.Fatalf("Signal() should lead to non-nil wait error") - } -} - -func TestRealCmdProcess(t *testing.T) { - rc := &realCmd{} - if rc.Process() != nil { - t.Fatalf("Process() should return nil when realCmd has no command") - } - rc = &realCmd{cmd: &exec.Cmd{}} - if rc.Process() != nil { - t.Fatalf("Process() should return nil when exec.Cmd has no process") - } - - if runtime.GOOS == "windows" { - return - } - - cmd := exec.Command("sleep", "5") - if err := cmd.Start(); err != nil { - t.Skipf("unable to start sleep command: %v", err) - } - defer func() { - if cmd.Process != nil { - _ = cmd.Process.Kill() - cmd.Wait() - } - }() - - rc = &realCmd{cmd: cmd} - handle := rc.Process() - if handle == nil { - t.Fatalf("expected non-nil process handle") - } - if pid := handle.Pid(); pid == 0 { - t.Fatalf("process handle returned pid=0") - } -} - func TestRun_CLI_Success(t *testing.T) { defer resetTestHooks() os.Args = []string{"codeagent-wrapper", "do-things"} @@ -4988,7 +4739,7 @@ func TestResolveMaxParallelWorkers(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Setenv("CODEAGENT_MAX_PARALLEL_WORKERS", tt.envValue) - got := resolveMaxParallelWorkers() + got := config.ResolveMaxParallelWorkers() if got != tt.want { t.Errorf("resolveMaxParallelWorkers() = %d, want %d", got, tt.want) } diff --git a/codeagent-wrapper/internal/app/parallel_config.go b/codeagent-wrapper/internal/app/parallel_config.go new file mode 100644 index 0000000..d21c7fc --- /dev/null +++ b/codeagent-wrapper/internal/app/parallel_config.go @@ -0,0 +1,9 @@ +package wrapper + +import ( + executor "codeagent-wrapper/internal/executor" +) + +func parseParallelConfig(data []byte) (*ParallelConfig, error) { + return executor.ParseParallelConfig(data) +} diff --git a/codeagent-wrapper/internal/app/parser.go b/codeagent-wrapper/internal/app/parser.go new file mode 100644 index 0000000..3afc4d9 --- /dev/null +++ b/codeagent-wrapper/internal/app/parser.go @@ -0,0 +1,34 @@ +package wrapper + +import ( + "bufio" + "io" + + parser "codeagent-wrapper/internal/parser" + + "github.com/goccy/go-json" +) + +func parseJSONStream(r io.Reader) (message, threadID string) { + return parseJSONStreamWithLog(r, logWarn, logInfo) +} + +func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadID string) { + return parseJSONStreamWithLog(r, warnFn, logInfo) +} + +func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) { + return parseJSONStreamInternal(r, warnFn, infoFn, nil, nil) +} + +func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) { + return parser.ParseJSONStreamInternal(r, warnFn, infoFn, onMessage, onComplete) +} + +func hasKey(m map[string]json.RawMessage, key string) bool { return parser.HasKey(m, key) } + +func discardInvalidJSON(decoder *json.Decoder, reader *bufio.Reader) (*bufio.Reader, error) { + return parser.DiscardInvalidJSON(decoder, reader) +} + +func normalizeText(text interface{}) string { return parser.NormalizeText(text) } diff --git a/codeagent-wrapper/internal/app/task_types.go b/codeagent-wrapper/internal/app/task_types.go new file mode 100644 index 0000000..416d054 --- /dev/null +++ b/codeagent-wrapper/internal/app/task_types.go @@ -0,0 +1,8 @@ +package wrapper + +import executor "codeagent-wrapper/internal/executor" + +// Type aliases to keep existing names in the wrapper package. +type ParallelConfig = executor.ParallelConfig +type TaskSpec = executor.TaskSpec +type TaskResult = executor.TaskResult diff --git a/codeagent-wrapper/internal/app/terminal_test.go b/codeagent-wrapper/internal/app/terminal_test.go new file mode 100644 index 0000000..c8f0d87 --- /dev/null +++ b/codeagent-wrapper/internal/app/terminal_test.go @@ -0,0 +1,30 @@ +package wrapper + +import ( + "os" + "testing" +) + +func TestDefaultIsTerminalCoverage(t *testing.T) { + oldStdin := os.Stdin + t.Cleanup(func() { os.Stdin = oldStdin }) + + f, err := os.CreateTemp(t.TempDir(), "stdin-*") + if err != nil { + t.Fatalf("os.CreateTemp() error = %v", err) + } + defer os.Remove(f.Name()) + + os.Stdin = f + if got := defaultIsTerminal(); got { + t.Fatalf("defaultIsTerminal() = %v, want false for regular file", got) + } + + if err := f.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + os.Stdin = f + if got := defaultIsTerminal(); !got { + t.Fatalf("defaultIsTerminal() = %v, want true when Stat fails", got) + } +} diff --git a/codeagent-wrapper/utils.go b/codeagent-wrapper/internal/app/utils.go similarity index 73% rename from codeagent-wrapper/utils.go rename to codeagent-wrapper/internal/app/utils.go index fdcc97c..abec759 100644 --- a/codeagent-wrapper/utils.go +++ b/codeagent-wrapper/internal/app/utils.go @@ -1,4 +1,4 @@ -package main +package wrapper import ( "bytes" @@ -7,6 +7,8 @@ import ( "os" "strconv" "strings" + + utils "codeagent-wrapper/internal/utils" ) func resolveTimeout() int { @@ -52,7 +54,7 @@ func shouldUseStdin(taskText string, piped bool) bool { if len(taskText) > 800 { return true } - return strings.IndexAny(taskText, stdinSpecialChars) >= 0 + return strings.ContainsAny(taskText, stdinSpecialChars) } func defaultIsTerminal() bool { @@ -196,69 +198,21 @@ func (b *tailBuffer) String() string { } func truncate(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - if maxLen < 0 { - return "" - } - return s[:maxLen] + "..." + return utils.Truncate(s, maxLen) } // safeTruncate safely truncates string to maxLen, avoiding panic and UTF-8 corruption. func safeTruncate(s string, maxLen int) string { - if maxLen <= 0 || s == "" { - return "" - } - - runes := []rune(s) - if len(runes) <= maxLen { - return s - } - - if maxLen < 4 { - return string(runes[:1]) - } - - cutoff := maxLen - 3 - if cutoff <= 0 { - return string(runes[:1]) - } - if len(runes) <= cutoff { - return s - } - return string(runes[:cutoff]) + "..." + return utils.SafeTruncate(s, maxLen) } // sanitizeOutput removes ANSI escape sequences and control characters. func sanitizeOutput(s string) string { - var result strings.Builder - inEscape := false - for i := 0; i < len(s); i++ { - if s[i] == '\x1b' && i+1 < len(s) && s[i+1] == '[' { - inEscape = true - i++ // skip '[' - continue - } - if inEscape { - if (s[i] >= 'A' && s[i] <= 'Z') || (s[i] >= 'a' && s[i] <= 'z') { - inEscape = false - } - continue - } - // Keep printable chars and common whitespace. - if s[i] >= 32 || s[i] == '\n' || s[i] == '\t' { - result.WriteByte(s[i]) - } - } - return result.String() + return utils.SanitizeOutput(s) } func min(a, b int) int { - if a < b { - return a - } - return b + return utils.Min(a, b) } func hello() string { @@ -381,7 +335,7 @@ func extractFilesChangedFromLines(lines []string) []string { for _, prefix := range []string{"Modified:", "Created:", "Updated:", "Edited:", "Wrote:", "Changed:"} { if strings.HasPrefix(line, prefix) { file := strings.TrimSpace(strings.TrimPrefix(line, prefix)) - file = strings.Trim(file, "`,\"'()[],:") + file = strings.Trim(file, "`\"'()[],:") file = strings.TrimPrefix(file, "@") if file != "" && !seen[file] { files = append(files, file) @@ -398,7 +352,7 @@ func extractFilesChangedFromLines(lines []string) []string { // Pattern 2: Tokens that look like file paths (allow root files, strip @ prefix). parts := strings.Fields(line) for _, part := range parts { - part = strings.Trim(part, "`,\"'()[],:") + part = strings.Trim(part, "`\"'()[],:") part = strings.TrimPrefix(part, "@") for _, ext := range exts { if strings.HasSuffix(part, ext) && !seen[part] { @@ -567,116 +521,3 @@ func extractKeyOutputFromLines(lines []string, maxLen int) string { clean := strings.TrimSpace(strings.Join(lines, "\n")) return safeTruncate(clean, maxLen) } - -// extractCoverageGap extracts what's missing from coverage reports -// Looks for uncovered lines, branches, or functions -func extractCoverageGap(message string) string { - if message == "" { - return "" - } - - lower := strings.ToLower(message) - lines := strings.Split(message, "\n") - - // Look for uncovered/missing patterns - for _, line := range lines { - lineLower := strings.ToLower(line) - line = strings.TrimSpace(line) - - // Common patterns for uncovered code - if strings.Contains(lineLower, "uncovered") || - strings.Contains(lineLower, "not covered") || - strings.Contains(lineLower, "missing coverage") || - strings.Contains(lineLower, "lines not covered") { - if len(line) > 100 { - return line[:97] + "..." - } - return line - } - - // Look for specific file:line patterns in coverage reports - if strings.Contains(lineLower, "branch") && strings.Contains(lineLower, "not taken") { - if len(line) > 100 { - return line[:97] + "..." - } - return line - } - } - - // Look for function names that aren't covered - if strings.Contains(lower, "function") && strings.Contains(lower, "0%") { - for _, line := range lines { - if strings.Contains(strings.ToLower(line), "0%") && strings.Contains(line, "function") { - line = strings.TrimSpace(line) - if len(line) > 100 { - return line[:97] + "..." - } - return line - } - } - } - - return "" -} - -// extractErrorDetail extracts meaningful error context from task output -// Returns the most relevant error information up to maxLen characters -func extractErrorDetail(message string, maxLen int) string { - if message == "" || maxLen <= 0 { - return "" - } - - lines := strings.Split(message, "\n") - var errorLines []string - - // Look for error-related lines - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" { - continue - } - - lower := strings.ToLower(line) - - // Skip noise lines - if strings.HasPrefix(line, "at ") && strings.Contains(line, "(") { - // Stack trace line - only keep first one - if len(errorLines) > 0 && strings.HasPrefix(strings.ToLower(errorLines[len(errorLines)-1]), "at ") { - continue - } - } - - // Prioritize error/fail lines - if strings.Contains(lower, "error") || - strings.Contains(lower, "fail") || - strings.Contains(lower, "exception") || - strings.Contains(lower, "assert") || - strings.Contains(lower, "expected") || - strings.Contains(lower, "timeout") || - strings.Contains(lower, "not found") || - strings.Contains(lower, "cannot") || - strings.Contains(lower, "undefined") || - strings.HasPrefix(line, "FAIL") || - strings.HasPrefix(line, "●") { - errorLines = append(errorLines, line) - } - } - - if len(errorLines) == 0 { - // No specific error lines found, take last few lines - start := len(lines) - 5 - if start < 0 { - start = 0 - } - for _, line := range lines[start:] { - line = strings.TrimSpace(line) - if line != "" { - errorLines = append(errorLines, line) - } - } - } - - // Join and truncate - result := strings.Join(errorLines, " | ") - return safeTruncate(result, maxLen) -} diff --git a/codeagent-wrapper/utils_test.go b/codeagent-wrapper/internal/app/utils_test.go similarity index 99% rename from codeagent-wrapper/utils_test.go rename to codeagent-wrapper/internal/app/utils_test.go index 98a7427..64c3f6b 100644 --- a/codeagent-wrapper/utils_test.go +++ b/codeagent-wrapper/internal/app/utils_test.go @@ -1,4 +1,4 @@ -package main +package wrapper import ( "fmt" diff --git a/codeagent-wrapper/internal/app/wrapper_name.go b/codeagent-wrapper/internal/app/wrapper_name.go new file mode 100644 index 0000000..be721fd --- /dev/null +++ b/codeagent-wrapper/internal/app/wrapper_name.go @@ -0,0 +1,9 @@ +package wrapper + +import ilogger "codeagent-wrapper/internal/logger" + +const wrapperName = ilogger.WrapperName + +func currentWrapperName() string { return ilogger.CurrentWrapperName() } + +func primaryLogPrefix() string { return ilogger.PrimaryLogPrefix() } diff --git a/codeagent-wrapper/internal/backend/backend.go b/codeagent-wrapper/internal/backend/backend.go new file mode 100644 index 0000000..bf6db5e --- /dev/null +++ b/codeagent-wrapper/internal/backend/backend.go @@ -0,0 +1,33 @@ +package backend + +import config "codeagent-wrapper/internal/config" + +// Backend defines the contract for invoking different AI CLI backends. +// Each backend is responsible for supplying the executable command and +// building the argument list based on the wrapper config. +type Backend interface { + Name() string + BuildArgs(cfg *config.Config, targetArg string) []string + Command() string + Env(baseURL, apiKey string) map[string]string +} + +var ( + logWarnFn = func(string) {} + logErrorFn = func(string) {} +) + +// SetLogFuncs configures optional logging hooks used by some backends. +// Callers can safely pass nil to disable the hook. +func SetLogFuncs(warnFn, errorFn func(string)) { + if warnFn != nil { + logWarnFn = warnFn + } else { + logWarnFn = func(string) {} + } + if errorFn != nil { + logErrorFn = errorFn + } else { + logErrorFn = func(string) {} + } +} diff --git a/codeagent-wrapper/backend_test.go b/codeagent-wrapper/internal/backend/backend_test.go similarity index 71% rename from codeagent-wrapper/backend_test.go rename to codeagent-wrapper/internal/backend/backend_test.go index d0c2cec..272622f 100644 --- a/codeagent-wrapper/backend_test.go +++ b/codeagent-wrapper/internal/backend/backend_test.go @@ -1,4 +1,4 @@ -package main +package backend import ( "bytes" @@ -6,6 +6,8 @@ import ( "path/filepath" "reflect" "testing" + + config "codeagent-wrapper/internal/config" ) func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) { @@ -13,7 +15,7 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) { t.Run("new mode omits skip-permissions when env disabled", func(t *testing.T) { t.Setenv("CODEAGENT_SKIP_PERMISSIONS", "false") - cfg := &Config{Mode: "new", WorkDir: "/repo"} + cfg := &config.Config{Mode: "new", WorkDir: "/repo"} got := backend.BuildArgs(cfg, "todo") want := []string{"-p", "--setting-sources", "", "--output-format", "stream-json", "--verbose", "todo"} if !reflect.DeepEqual(got, want) { @@ -22,7 +24,7 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) { }) t.Run("new mode includes skip-permissions by default", func(t *testing.T) { - cfg := &Config{Mode: "new", SkipPermissions: false} + cfg := &config.Config{Mode: "new", SkipPermissions: false} got := backend.BuildArgs(cfg, "-") want := []string{"-p", "--dangerously-skip-permissions", "--setting-sources", "", "--output-format", "stream-json", "--verbose", "-"} if !reflect.DeepEqual(got, want) { @@ -32,7 +34,7 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) { t.Run("resume mode includes session id", func(t *testing.T) { t.Setenv("CODEAGENT_SKIP_PERMISSIONS", "false") - cfg := &Config{Mode: "resume", SessionID: "sid-123", WorkDir: "/ignored"} + cfg := &config.Config{Mode: "resume", SessionID: "sid-123", WorkDir: "/ignored"} got := backend.BuildArgs(cfg, "resume-task") want := []string{"-p", "--setting-sources", "", "-r", "sid-123", "--output-format", "stream-json", "--verbose", "resume-task"} if !reflect.DeepEqual(got, want) { @@ -42,7 +44,7 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) { t.Run("resume mode without session still returns base flags", func(t *testing.T) { t.Setenv("CODEAGENT_SKIP_PERMISSIONS", "false") - cfg := &Config{Mode: "resume", WorkDir: "/ignored"} + cfg := &config.Config{Mode: "resume", WorkDir: "/ignored"} got := backend.BuildArgs(cfg, "follow-up") want := []string{"-p", "--setting-sources", "", "--output-format", "stream-json", "--verbose", "follow-up"} if !reflect.DeepEqual(got, want) { @@ -51,7 +53,7 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) { }) t.Run("resume mode can opt-in skip permissions", func(t *testing.T) { - cfg := &Config{Mode: "resume", SessionID: "sid-123", SkipPermissions: true} + cfg := &config.Config{Mode: "resume", SessionID: "sid-123", SkipPermissions: true} got := backend.BuildArgs(cfg, "resume-task") want := []string{"-p", "--dangerously-skip-permissions", "--setting-sources", "", "-r", "sid-123", "--output-format", "stream-json", "--verbose", "resume-task"} if !reflect.DeepEqual(got, want) { @@ -70,7 +72,7 @@ func TestBackendBuildArgs_Model(t *testing.T) { t.Run("claude includes --model when set", func(t *testing.T) { t.Setenv("CODEAGENT_SKIP_PERMISSIONS", "false") backend := ClaudeBackend{} - cfg := &Config{Mode: "new", Model: "opus"} + cfg := &config.Config{Mode: "new", Model: "opus"} got := backend.BuildArgs(cfg, "todo") want := []string{"-p", "--setting-sources", "", "--model", "opus", "--output-format", "stream-json", "--verbose", "todo"} if !reflect.DeepEqual(got, want) { @@ -80,7 +82,7 @@ func TestBackendBuildArgs_Model(t *testing.T) { t.Run("gemini includes -m when set", func(t *testing.T) { backend := GeminiBackend{} - cfg := &Config{Mode: "new", Model: "gemini-3-pro-preview"} + cfg := &config.Config{Mode: "new", Model: "gemini-3-pro-preview"} got := backend.BuildArgs(cfg, "task") want := []string{"-o", "stream-json", "-y", "-m", "gemini-3-pro-preview", "task"} if !reflect.DeepEqual(got, want) { @@ -93,7 +95,7 @@ func TestBackendBuildArgs_Model(t *testing.T) { t.Setenv(key, "false") backend := CodexBackend{} - cfg := &Config{Mode: "new", WorkDir: "/tmp", Model: "o3"} + cfg := &config.Config{Mode: "new", WorkDir: "/tmp", Model: "o3"} got := backend.BuildArgs(cfg, "task") want := []string{"e", "--model", "o3", "--skip-git-repo-check", "-C", "/tmp", "--json", "task"} if !reflect.DeepEqual(got, want) { @@ -105,7 +107,7 @@ func TestBackendBuildArgs_Model(t *testing.T) { func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) { t.Run("gemini new mode defaults workdir", func(t *testing.T) { backend := GeminiBackend{} - cfg := &Config{Mode: "new", WorkDir: "/workspace"} + cfg := &config.Config{Mode: "new", WorkDir: "/workspace"} got := backend.BuildArgs(cfg, "task") want := []string{"-o", "stream-json", "-y", "task"} if !reflect.DeepEqual(got, want) { @@ -115,7 +117,7 @@ func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) { t.Run("gemini resume mode uses session id", func(t *testing.T) { backend := GeminiBackend{} - cfg := &Config{Mode: "resume", SessionID: "sid-999"} + cfg := &config.Config{Mode: "resume", SessionID: "sid-999"} got := backend.BuildArgs(cfg, "resume") want := []string{"-o", "stream-json", "-y", "-r", "sid-999", "resume"} if !reflect.DeepEqual(got, want) { @@ -125,7 +127,7 @@ func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) { t.Run("gemini resume mode without session omits identifier", func(t *testing.T) { backend := GeminiBackend{} - cfg := &Config{Mode: "resume"} + cfg := &config.Config{Mode: "resume"} got := backend.BuildArgs(cfg, "resume") want := []string{"-o", "stream-json", "-y", "resume"} if !reflect.DeepEqual(got, want) { @@ -142,7 +144,7 @@ func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) { t.Run("gemini stdin mode uses -p flag", func(t *testing.T) { backend := GeminiBackend{} - cfg := &Config{Mode: "new"} + cfg := &config.Config{Mode: "new"} got := backend.BuildArgs(cfg, "-") want := []string{"-o", "stream-json", "-y", "-p", "-"} if !reflect.DeepEqual(got, want) { @@ -155,7 +157,7 @@ func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) { t.Setenv(key, "false") backend := CodexBackend{} - cfg := &Config{Mode: "new", WorkDir: "/tmp"} + cfg := &config.Config{Mode: "new", WorkDir: "/tmp"} got := backend.BuildArgs(cfg, "task") want := []string{"e", "--skip-git-repo-check", "-C", "/tmp", "--json", "task"} if !reflect.DeepEqual(got, want) { @@ -168,7 +170,7 @@ func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) { t.Setenv(key, "true") backend := CodexBackend{} - cfg := &Config{Mode: "new", WorkDir: "/tmp"} + cfg := &config.Config{Mode: "new", WorkDir: "/tmp"} got := backend.BuildArgs(cfg, "task") want := []string{"e", "--dangerously-bypass-approvals-and-sandbox", "--skip-git-repo-check", "-C", "/tmp", "--json", "task"} if !reflect.DeepEqual(got, want) { @@ -204,7 +206,7 @@ func TestLoadMinimalEnvSettings(t *testing.T) { t.Setenv("USERPROFILE", home) t.Run("missing file returns empty", func(t *testing.T) { - if got := loadMinimalEnvSettings(); len(got) != 0 { + if got := LoadMinimalEnvSettings(); len(got) != 0 { t.Fatalf("got %v, want empty", got) } }) @@ -220,7 +222,7 @@ func TestLoadMinimalEnvSettings(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - got := loadMinimalEnvSettings() + got := LoadMinimalEnvSettings() if got["ANTHROPIC_API_KEY"] != "secret" || got["FOO"] != "bar" { t.Fatalf("got %v, want keys present", got) } @@ -234,7 +236,7 @@ func TestLoadMinimalEnvSettings(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - got := loadMinimalEnvSettings() + got := LoadMinimalEnvSettings() if got["GOOD"] != "ok" { t.Fatalf("got %v, want GOOD=ok", got) } @@ -249,12 +251,72 @@ func TestLoadMinimalEnvSettings(t *testing.T) { t.Run("oversized file returns empty", func(t *testing.T) { dir := filepath.Join(home, ".claude") path := filepath.Join(dir, "settings.json") - data := bytes.Repeat([]byte("a"), maxClaudeSettingsBytes+1) + data := bytes.Repeat([]byte("a"), MaxClaudeSettingsBytes+1) if err := os.WriteFile(path, data, 0o600); err != nil { t.Fatalf("WriteFile: %v", err) } - if got := loadMinimalEnvSettings(); len(got) != 0 { + if got := LoadMinimalEnvSettings(); len(got) != 0 { t.Fatalf("got %v, want empty", got) } }) } + +func TestOpencodeBackend_BuildArgs(t *testing.T) { + backend := OpencodeBackend{} + + t.Run("basic", func(t *testing.T) { + cfg := &config.Config{Mode: "new"} + got := backend.BuildArgs(cfg, "hello") + want := []string{"run", "--format", "json", "hello"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("with model", func(t *testing.T) { + cfg := &config.Config{Mode: "new", Model: "opencode/grok-code"} + got := backend.BuildArgs(cfg, "task") + want := []string{"run", "-m", "opencode/grok-code", "--format", "json", "task"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("resume mode", func(t *testing.T) { + cfg := &config.Config{Mode: "resume", SessionID: "ses_123", Model: "opencode/grok-code"} + got := backend.BuildArgs(cfg, "follow-up") + want := []string{"run", "-m", "opencode/grok-code", "-s", "ses_123", "--format", "json", "follow-up"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("resume without session", func(t *testing.T) { + cfg := &config.Config{Mode: "resume"} + got := backend.BuildArgs(cfg, "task") + want := []string{"run", "--format", "json", "task"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("stdin mode omits dash", func(t *testing.T) { + cfg := &config.Config{Mode: "new"} + got := backend.BuildArgs(cfg, "-") + want := []string{"run", "--format", "json"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + }) +} + +func TestOpencodeBackend_Interface(t *testing.T) { + backend := OpencodeBackend{} + + if backend.Name() != "opencode" { + t.Errorf("Name() = %q, want %q", backend.Name(), "opencode") + } + if backend.Command() != "opencode" { + t.Errorf("Command() = %q, want %q", backend.Command(), "opencode") + } +} diff --git a/codeagent-wrapper/internal/backend/claude.go b/codeagent-wrapper/internal/backend/claude.go new file mode 100644 index 0000000..510fe80 --- /dev/null +++ b/codeagent-wrapper/internal/backend/claude.go @@ -0,0 +1,139 @@ +package backend + +import ( + "os" + "path/filepath" + "strings" + + config "codeagent-wrapper/internal/config" + + "github.com/goccy/go-json" +) + +type ClaudeBackend struct{} + +func (ClaudeBackend) Name() string { return "claude" } +func (ClaudeBackend) Command() string { return "claude" } +func (ClaudeBackend) Env(baseURL, apiKey string) map[string]string { + baseURL = strings.TrimSpace(baseURL) + apiKey = strings.TrimSpace(apiKey) + if baseURL == "" && apiKey == "" { + return nil + } + env := make(map[string]string, 2) + if baseURL != "" { + env["ANTHROPIC_BASE_URL"] = baseURL + } + if apiKey != "" { + env["ANTHROPIC_API_KEY"] = apiKey + } + return env +} +func (ClaudeBackend) BuildArgs(cfg *config.Config, targetArg string) []string { + return buildClaudeArgs(cfg, targetArg) +} + +const MaxClaudeSettingsBytes = 1 << 20 // 1MB + +type MinimalClaudeSettings struct { + Env map[string]string + Model string +} + +// LoadMinimalClaudeSettings 从 ~/.claude/settings.json 只提取安全的最小子集: +// - env: 只接受字符串类型的值 +// - model: 只接受字符串类型的值 +// 文件缺失/解析失败/超限都返回空。 +func LoadMinimalClaudeSettings() MinimalClaudeSettings { + home, err := os.UserHomeDir() + if err != nil || home == "" { + return MinimalClaudeSettings{} + } + + claudeDir := filepath.Clean(filepath.Join(home, ".claude")) + settingPath := filepath.Clean(filepath.Join(claudeDir, "settings.json")) + rel, err := filepath.Rel(claudeDir, settingPath) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return MinimalClaudeSettings{} + } + + info, err := os.Stat(settingPath) + if err != nil || info.Size() > MaxClaudeSettingsBytes { + return MinimalClaudeSettings{} + } + + data, err := os.ReadFile(settingPath) // #nosec G304 -- path is fixed under user home and validated to stay within claudeDir + if err != nil { + return MinimalClaudeSettings{} + } + + var cfg struct { + Env map[string]any `json:"env"` + Model any `json:"model"` + } + if err := json.Unmarshal(data, &cfg); err != nil { + return MinimalClaudeSettings{} + } + + out := MinimalClaudeSettings{} + + if model, ok := cfg.Model.(string); ok { + out.Model = strings.TrimSpace(model) + } + + if len(cfg.Env) == 0 { + return out + } + + env := make(map[string]string, len(cfg.Env)) + for k, v := range cfg.Env { + s, ok := v.(string) + if !ok { + continue + } + env[k] = s + } + if len(env) == 0 { + return out + } + out.Env = env + return out +} + +func LoadMinimalEnvSettings() map[string]string { + settings := LoadMinimalClaudeSettings() + if len(settings.Env) == 0 { + return nil + } + return settings.Env +} + +func buildClaudeArgs(cfg *config.Config, targetArg string) []string { + if cfg == nil { + return nil + } + args := []string{"-p"} + // Default to skip permissions unless CODEAGENT_SKIP_PERMISSIONS=false + if cfg.SkipPermissions || cfg.Yolo || config.EnvFlagDefaultTrue("CODEAGENT_SKIP_PERMISSIONS") { + args = append(args, "--dangerously-skip-permissions") + } + + // Prevent infinite recursion: disable all setting sources (user, project, local) + // This ensures a clean execution environment without CLAUDE.md or skills that would trigger codeagent + args = append(args, "--setting-sources", "") + + if model := strings.TrimSpace(cfg.Model); model != "" { + args = append(args, "--model", model) + } + + if cfg.Mode == "resume" { + if cfg.SessionID != "" { + // Claude CLI uses -r for resume. + args = append(args, "-r", cfg.SessionID) + } + } + + args = append(args, "--output-format", "stream-json", "--verbose", targetArg) + + return args +} diff --git a/codeagent-wrapper/internal/backend/codex.go b/codeagent-wrapper/internal/backend/codex.go new file mode 100644 index 0000000..b3a759c --- /dev/null +++ b/codeagent-wrapper/internal/backend/codex.go @@ -0,0 +1,79 @@ +package backend + +import ( + "strings" + + config "codeagent-wrapper/internal/config" +) + +type CodexBackend struct{} + +func (CodexBackend) Name() string { return "codex" } +func (CodexBackend) Command() string { return "codex" } +func (CodexBackend) Env(baseURL, apiKey string) map[string]string { + baseURL = strings.TrimSpace(baseURL) + apiKey = strings.TrimSpace(apiKey) + if baseURL == "" && apiKey == "" { + return nil + } + env := make(map[string]string, 2) + if baseURL != "" { + env["OPENAI_BASE_URL"] = baseURL + } + if apiKey != "" { + env["OPENAI_API_KEY"] = apiKey + } + return env +} +func (CodexBackend) BuildArgs(cfg *config.Config, targetArg string) []string { + return BuildCodexArgs(cfg, targetArg) +} + +func BuildCodexArgs(cfg *config.Config, targetArg string) []string { + if cfg == nil { + panic("buildCodexArgs: nil config") + } + + var resumeSessionID string + isResume := cfg.Mode == "resume" + if isResume { + resumeSessionID = strings.TrimSpace(cfg.SessionID) + if resumeSessionID == "" { + logErrorFn("invalid config: resume mode requires non-empty session_id") + isResume = false + } + } + + args := []string{"e"} + + // Default to bypass sandbox unless CODEX_BYPASS_SANDBOX=false + if cfg.Yolo || config.EnvFlagDefaultTrue("CODEX_BYPASS_SANDBOX") { + logWarnFn("YOLO mode or CODEX_BYPASS_SANDBOX enabled: running without approval/sandbox protection") + args = append(args, "--dangerously-bypass-approvals-and-sandbox") + } + + if model := strings.TrimSpace(cfg.Model); model != "" { + args = append(args, "--model", model) + } + + if reasoningEffort := strings.TrimSpace(cfg.ReasoningEffort); reasoningEffort != "" { + args = append(args, "-c", "model_reasoning_effort="+reasoningEffort) + } + + args = append(args, "--skip-git-repo-check") + + if isResume { + return append(args, + "--json", + "resume", + resumeSessionID, + targetArg, + ) + } + + return append(args, + "-C", cfg.WorkDir, + "--json", + targetArg, + ) +} diff --git a/codeagent-wrapper/internal/backend/gemini.go b/codeagent-wrapper/internal/backend/gemini.go new file mode 100644 index 0000000..5b0b0c3 --- /dev/null +++ b/codeagent-wrapper/internal/backend/gemini.go @@ -0,0 +1,110 @@ +package backend + +import ( + "os" + "path/filepath" + "strings" + + config "codeagent-wrapper/internal/config" +) + +type GeminiBackend struct{} + +func (GeminiBackend) Name() string { return "gemini" } +func (GeminiBackend) Command() string { return "gemini" } +func (GeminiBackend) Env(baseURL, apiKey string) map[string]string { + baseURL = strings.TrimSpace(baseURL) + apiKey = strings.TrimSpace(apiKey) + if baseURL == "" && apiKey == "" { + return nil + } + env := make(map[string]string, 2) + if baseURL != "" { + env["GOOGLE_GEMINI_BASE_URL"] = baseURL + } + if apiKey != "" { + env["GEMINI_API_KEY"] = apiKey + } + return env +} +func (GeminiBackend) BuildArgs(cfg *config.Config, targetArg string) []string { + return buildGeminiArgs(cfg, targetArg) +} + +// LoadGeminiEnv loads environment variables from ~/.gemini/.env +// Supports GEMINI_API_KEY, GEMINI_MODEL, GOOGLE_GEMINI_BASE_URL +// Also sets GEMINI_API_KEY_AUTH_MECHANISM=bearer for third-party API compatibility +func LoadGeminiEnv() map[string]string { + home, err := os.UserHomeDir() + if err != nil || home == "" { + return nil + } + + envDir := filepath.Clean(filepath.Join(home, ".gemini")) + envPath := filepath.Clean(filepath.Join(envDir, ".env")) + rel, err := filepath.Rel(envDir, envPath) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return nil + } + + data, err := os.ReadFile(envPath) // #nosec G304 -- path is fixed under user home and validated to stay within envDir + if err != nil { + return nil + } + + env := make(map[string]string) + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + idx := strings.IndexByte(line, '=') + if idx <= 0 { + continue + } + key := strings.TrimSpace(line[:idx]) + value := strings.TrimSpace(line[idx+1:]) + if key != "" && value != "" { + env[key] = value + } + } + + // Set bearer auth mechanism for third-party API compatibility + if _, ok := env["GEMINI_API_KEY"]; ok { + if _, hasAuth := env["GEMINI_API_KEY_AUTH_MECHANISM"]; !hasAuth { + env["GEMINI_API_KEY_AUTH_MECHANISM"] = "bearer" + } + } + + if len(env) == 0 { + return nil + } + return env +} + +func buildGeminiArgs(cfg *config.Config, targetArg string) []string { + if cfg == nil { + return nil + } + args := []string{"-o", "stream-json", "-y"} + + if model := strings.TrimSpace(cfg.Model); model != "" { + args = append(args, "-m", model) + } + + if cfg.Mode == "resume" { + if cfg.SessionID != "" { + args = append(args, "-r", cfg.SessionID) + } + } + + // Use positional argument instead of deprecated -p flag. + // For stdin mode ("-"), use -p to read from stdin. + if targetArg == "-" { + args = append(args, "-p", targetArg) + } else { + args = append(args, targetArg) + } + + return args +} diff --git a/codeagent-wrapper/internal/backend/opencode.go b/codeagent-wrapper/internal/backend/opencode.go new file mode 100644 index 0000000..67f425b --- /dev/null +++ b/codeagent-wrapper/internal/backend/opencode.go @@ -0,0 +1,29 @@ +package backend + +import ( + "strings" + + config "codeagent-wrapper/internal/config" +) + +type OpencodeBackend struct{} + +func (OpencodeBackend) Name() string { return "opencode" } +func (OpencodeBackend) Command() string { return "opencode" } +func (OpencodeBackend) Env(baseURL, apiKey string) map[string]string { return nil } +func (OpencodeBackend) BuildArgs(cfg *config.Config, targetArg string) []string { + args := []string{"run"} + if cfg != nil { + if model := strings.TrimSpace(cfg.Model); model != "" { + args = append(args, "-m", model) + } + if cfg.Mode == "resume" && cfg.SessionID != "" { + args = append(args, "-s", cfg.SessionID) + } + } + args = append(args, "--format", "json") + if targetArg != "-" { + args = append(args, targetArg) + } + return args +} diff --git a/codeagent-wrapper/internal/backend/registry.go b/codeagent-wrapper/internal/backend/registry.go new file mode 100644 index 0000000..7b421c6 --- /dev/null +++ b/codeagent-wrapper/internal/backend/registry.go @@ -0,0 +1,29 @@ +package backend + +import ( + "fmt" + "strings" +) + +var registry = map[string]Backend{ + "codex": CodexBackend{}, + "claude": ClaudeBackend{}, + "gemini": GeminiBackend{}, + "opencode": OpencodeBackend{}, +} + +// Registry exposes the available backends. Intended for internal inspection/tests. +func Registry() map[string]Backend { + return registry +} + +func Select(name string) (Backend, error) { + key := strings.ToLower(strings.TrimSpace(name)) + if key == "" { + key = "codex" + } + if backend, ok := registry[key]; ok { + return backend, nil + } + return nil, fmt.Errorf("unsupported backend %q", name) +} diff --git a/codeagent-wrapper/internal/config/agent.go b/codeagent-wrapper/internal/config/agent.go new file mode 100644 index 0000000..9bdcfb8 --- /dev/null +++ b/codeagent-wrapper/internal/config/agent.go @@ -0,0 +1,220 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + ilogger "codeagent-wrapper/internal/logger" + + "github.com/goccy/go-json" +) + +type BackendConfig struct { + BaseURL string `json:"base_url,omitempty"` + APIKey string `json:"api_key,omitempty"` +} + +type AgentModelConfig struct { + Backend string `json:"backend"` + Model string `json:"model"` + PromptFile string `json:"prompt_file,omitempty"` + Description string `json:"description,omitempty"` + Yolo bool `json:"yolo,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + BaseURL string `json:"base_url,omitempty"` + APIKey string `json:"api_key,omitempty"` +} + +type ModelsConfig struct { + DefaultBackend string `json:"default_backend"` + DefaultModel string `json:"default_model"` + Agents map[string]AgentModelConfig `json:"agents"` + Backends map[string]BackendConfig `json:"backends,omitempty"` +} + +var defaultModelsConfig = ModelsConfig{ + DefaultBackend: "opencode", + DefaultModel: "opencode/grok-code", + Agents: map[string]AgentModelConfig{ + "oracle": {Backend: "claude", Model: "claude-opus-4-5-20251101", PromptFile: "~/.claude/skills/omo/references/oracle.md", Description: "Technical advisor"}, + "librarian": {Backend: "claude", Model: "claude-sonnet-4-5-20250929", PromptFile: "~/.claude/skills/omo/references/librarian.md", Description: "Researcher"}, + "explore": {Backend: "opencode", Model: "opencode/grok-code", PromptFile: "~/.claude/skills/omo/references/explore.md", Description: "Code search"}, + "develop": {Backend: "codex", Model: "", PromptFile: "~/.claude/skills/omo/references/develop.md", Description: "Code development"}, + "frontend-ui-ux-engineer": {Backend: "gemini", Model: "", PromptFile: "~/.claude/skills/omo/references/frontend-ui-ux-engineer.md", Description: "Frontend engineer"}, + "document-writer": {Backend: "gemini", Model: "", PromptFile: "~/.claude/skills/omo/references/document-writer.md", Description: "Documentation"}, + }, +} + +var ( + modelsConfigOnce sync.Once + modelsConfigCached *ModelsConfig +) + +func modelsConfig() *ModelsConfig { + modelsConfigOnce.Do(func() { + modelsConfigCached = loadModelsConfig() + }) + if modelsConfigCached == nil { + return &defaultModelsConfig + } + return modelsConfigCached +} + +func loadModelsConfig() *ModelsConfig { + home, err := os.UserHomeDir() + if err != nil { + ilogger.LogWarn(fmt.Sprintf("Failed to resolve home directory for models config: %v; using defaults", err)) + return &defaultModelsConfig + } + + configDir := filepath.Clean(filepath.Join(home, ".codeagent")) + configPath := filepath.Clean(filepath.Join(configDir, "models.json")) + rel, err := filepath.Rel(configDir, configPath) + if err != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return &defaultModelsConfig + } + + data, err := os.ReadFile(configPath) // #nosec G304 -- path is fixed under user home and validated to stay within configDir + if err != nil { + if !os.IsNotExist(err) { + ilogger.LogWarn(fmt.Sprintf("Failed to read models config %s: %v; using defaults", configPath, err)) + } + return &defaultModelsConfig + } + + var cfg ModelsConfig + if err := json.Unmarshal(data, &cfg); err != nil { + ilogger.LogWarn(fmt.Sprintf("Failed to parse models config %s: %v; using defaults", configPath, err)) + return &defaultModelsConfig + } + + cfg.DefaultBackend = strings.TrimSpace(cfg.DefaultBackend) + if cfg.DefaultBackend == "" { + cfg.DefaultBackend = defaultModelsConfig.DefaultBackend + } + cfg.DefaultModel = strings.TrimSpace(cfg.DefaultModel) + if cfg.DefaultModel == "" { + cfg.DefaultModel = defaultModelsConfig.DefaultModel + } + + // Merge with defaults + for name, agent := range defaultModelsConfig.Agents { + if _, exists := cfg.Agents[name]; !exists { + if cfg.Agents == nil { + cfg.Agents = make(map[string]AgentModelConfig) + } + cfg.Agents[name] = agent + } + } + + // Normalize backend keys so lookups can be case-insensitive. + if len(cfg.Backends) > 0 { + normalized := make(map[string]BackendConfig, len(cfg.Backends)) + for k, v := range cfg.Backends { + key := strings.ToLower(strings.TrimSpace(k)) + if key == "" { + continue + } + normalized[key] = v + } + if len(normalized) > 0 { + cfg.Backends = normalized + } else { + cfg.Backends = nil + } + } + + return &cfg +} + +func LoadDynamicAgent(name string) (AgentModelConfig, bool) { + if err := ValidateAgentName(name); err != nil { + return AgentModelConfig{}, false + } + + home, err := os.UserHomeDir() + if err != nil || strings.TrimSpace(home) == "" { + return AgentModelConfig{}, false + } + + absPath := filepath.Join(home, ".codeagent", "agents", name+".md") + info, err := os.Stat(absPath) + if err != nil || info.IsDir() { + return AgentModelConfig{}, false + } + + return AgentModelConfig{PromptFile: "~/.codeagent/agents/" + name + ".md"}, true +} + +func ResolveBackendConfig(backendName string) (baseURL, apiKey string) { + cfg := modelsConfig() + resolved := resolveBackendConfig(cfg, backendName) + return strings.TrimSpace(resolved.BaseURL), strings.TrimSpace(resolved.APIKey) +} + +func resolveBackendConfig(cfg *ModelsConfig, backendName string) BackendConfig { + if cfg == nil || len(cfg.Backends) == 0 { + return BackendConfig{} + } + key := strings.ToLower(strings.TrimSpace(backendName)) + if key == "" { + key = strings.ToLower(strings.TrimSpace(cfg.DefaultBackend)) + } + if key == "" { + return BackendConfig{} + } + if backend, ok := cfg.Backends[key]; ok { + return backend + } + return BackendConfig{} +} + +func resolveAgentConfig(agentName string) (backend, model, promptFile, reasoning, baseURL, apiKey string, yolo bool) { + cfg := modelsConfig() + if agent, ok := cfg.Agents[agentName]; ok { + backend = strings.TrimSpace(agent.Backend) + if backend == "" { + backend = cfg.DefaultBackend + } + backendCfg := resolveBackendConfig(cfg, backend) + + baseURL = strings.TrimSpace(agent.BaseURL) + if baseURL == "" { + baseURL = strings.TrimSpace(backendCfg.BaseURL) + } + apiKey = strings.TrimSpace(agent.APIKey) + if apiKey == "" { + apiKey = strings.TrimSpace(backendCfg.APIKey) + } + + return backend, strings.TrimSpace(agent.Model), agent.PromptFile, agent.Reasoning, baseURL, apiKey, agent.Yolo + } + + if dynamic, ok := LoadDynamicAgent(agentName); ok { + backend = cfg.DefaultBackend + model = cfg.DefaultModel + backendCfg := resolveBackendConfig(cfg, backend) + baseURL = strings.TrimSpace(backendCfg.BaseURL) + apiKey = strings.TrimSpace(backendCfg.APIKey) + return backend, model, dynamic.PromptFile, "", baseURL, apiKey, false + } + + backend = cfg.DefaultBackend + model = cfg.DefaultModel + backendCfg := resolveBackendConfig(cfg, backend) + baseURL = strings.TrimSpace(backendCfg.BaseURL) + apiKey = strings.TrimSpace(backendCfg.APIKey) + return backend, model, "", "", baseURL, apiKey, false +} + +func ResolveAgentConfig(agentName string) (backend, model, promptFile, reasoning, baseURL, apiKey string, yolo bool) { + return resolveAgentConfig(agentName) +} + +func ResetModelsConfigCacheForTest() { + modelsConfigCached = nil + modelsConfigOnce = sync.Once{} +} diff --git a/codeagent-wrapper/agent_config_test.go b/codeagent-wrapper/internal/config/agent_config_test.go similarity index 53% rename from codeagent-wrapper/agent_config_test.go rename to codeagent-wrapper/internal/config/agent_config_test.go index a3602be..58876fd 100644 --- a/codeagent-wrapper/agent_config_test.go +++ b/codeagent-wrapper/internal/config/agent_config_test.go @@ -1,9 +1,8 @@ -package main +package config import ( "os" "path/filepath" - "reflect" "testing" ) @@ -11,6 +10,8 @@ func TestResolveAgentConfig_Defaults(t *testing.T) { home := t.TempDir() t.Setenv("HOME", home) t.Setenv("USERPROFILE", home) + t.Cleanup(ResetModelsConfigCacheForTest) + ResetModelsConfigCacheForTest() // Test that default agents resolve correctly without config file tests := []struct { @@ -19,16 +20,16 @@ func TestResolveAgentConfig_Defaults(t *testing.T) { wantModel string wantPromptFile string }{ - {"oracle", "claude", "claude-opus-4-5-20251101", "~/.claude/skills/omo/references/oracle.md"}, - {"librarian", "claude", "claude-sonnet-4-5-20250929", "~/.claude/skills/omo/references/librarian.md"}, - {"explore", "opencode", "opencode/grok-code", "~/.claude/skills/omo/references/explore.md"}, - {"frontend-ui-ux-engineer", "gemini", "", "~/.claude/skills/omo/references/frontend-ui-ux-engineer.md"}, - {"document-writer", "gemini", "", "~/.claude/skills/omo/references/document-writer.md"}, - } + {"oracle", "claude", "claude-opus-4-5-20251101", "~/.claude/skills/omo/references/oracle.md"}, + {"librarian", "claude", "claude-sonnet-4-5-20250929", "~/.claude/skills/omo/references/librarian.md"}, + {"explore", "opencode", "opencode/grok-code", "~/.claude/skills/omo/references/explore.md"}, + {"frontend-ui-ux-engineer", "gemini", "", "~/.claude/skills/omo/references/frontend-ui-ux-engineer.md"}, + {"document-writer", "gemini", "", "~/.claude/skills/omo/references/document-writer.md"}, + } for _, tt := range tests { t.Run(tt.agent, func(t *testing.T) { - backend, model, promptFile, _, _ := resolveAgentConfig(tt.agent) + backend, model, promptFile, _, _, _, _ := resolveAgentConfig(tt.agent) if backend != tt.wantBackend { t.Errorf("backend = %q, want %q", backend, tt.wantBackend) } @@ -46,8 +47,10 @@ func TestResolveAgentConfig_UnknownAgent(t *testing.T) { home := t.TempDir() t.Setenv("HOME", home) t.Setenv("USERPROFILE", home) + t.Cleanup(ResetModelsConfigCacheForTest) + ResetModelsConfigCacheForTest() - backend, model, promptFile, _, _ := resolveAgentConfig("unknown-agent") + backend, model, promptFile, _, _, _, _ := resolveAgentConfig("unknown-agent") if backend != "opencode" { t.Errorf("unknown agent backend = %q, want %q", backend, "opencode") } @@ -63,6 +66,8 @@ func TestLoadModelsConfig_NoFile(t *testing.T) { home := "/nonexistent/path/that/does/not/exist" t.Setenv("HOME", home) t.Setenv("USERPROFILE", home) + t.Cleanup(ResetModelsConfigCacheForTest) + ResetModelsConfigCacheForTest() cfg := loadModelsConfig() if cfg.DefaultBackend != "opencode" { @@ -84,11 +89,23 @@ func TestLoadModelsConfig_WithFile(t *testing.T) { configContent := `{ "default_backend": "claude", "default_model": "claude-opus-4", + "backends": { + "Claude": { + "base_url": "https://backend.example", + "api_key": "backend-key" + }, + "codex": { + "base_url": "https://openai.example", + "api_key": "openai-key" + } + }, "agents": { "custom-agent": { "backend": "codex", "model": "gpt-4o", - "description": "Custom agent" + "description": "Custom agent", + "base_url": "https://agent.example", + "api_key": "agent-key" } } }` @@ -99,6 +116,8 @@ func TestLoadModelsConfig_WithFile(t *testing.T) { t.Setenv("HOME", tmpDir) t.Setenv("USERPROFILE", tmpDir) + t.Cleanup(ResetModelsConfigCacheForTest) + ResetModelsConfigCacheForTest() cfg := loadModelsConfig() @@ -125,6 +144,55 @@ func TestLoadModelsConfig_WithFile(t *testing.T) { if _, ok := cfg.Agents["oracle"]; !ok { t.Error("default agent oracle should be merged") } + + baseURL, apiKey := ResolveBackendConfig("claude") + if baseURL != "https://backend.example" { + t.Errorf("ResolveBackendConfig(baseURL) = %q, want %q", baseURL, "https://backend.example") + } + if apiKey != "backend-key" { + t.Errorf("ResolveBackendConfig(apiKey) = %q, want %q", apiKey, "backend-key") + } + + backend, model, _, _, agentBaseURL, agentAPIKey, _ := ResolveAgentConfig("custom-agent") + if backend != "codex" { + t.Errorf("ResolveAgentConfig(backend) = %q, want %q", backend, "codex") + } + if model != "gpt-4o" { + t.Errorf("ResolveAgentConfig(model) = %q, want %q", model, "gpt-4o") + } + if agentBaseURL != "https://agent.example" { + t.Errorf("ResolveAgentConfig(baseURL) = %q, want %q", agentBaseURL, "https://agent.example") + } + if agentAPIKey != "agent-key" { + t.Errorf("ResolveAgentConfig(apiKey) = %q, want %q", agentAPIKey, "agent-key") + } +} + +func TestResolveAgentConfig_DynamicAgent(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + t.Cleanup(ResetModelsConfigCacheForTest) + ResetModelsConfigCacheForTest() + + agentDir := filepath.Join(home, ".codeagent", "agents") + if err := os.MkdirAll(agentDir, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + if err := os.WriteFile(filepath.Join(agentDir, "sarsh.md"), []byte("prompt\n"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + backend, model, promptFile, _, _, _, _ := resolveAgentConfig("sarsh") + if backend != "opencode" { + t.Errorf("backend = %q, want %q", backend, "opencode") + } + if model != "opencode/grok-code" { + t.Errorf("model = %q, want %q", model, "opencode/grok-code") + } + if promptFile != "~/.codeagent/agents/sarsh.md" { + t.Errorf("promptFile = %q, want %q", promptFile, "~/.codeagent/agents/sarsh.md") + } } func TestLoadModelsConfig_InvalidJSON(t *testing.T) { @@ -142,6 +210,8 @@ func TestLoadModelsConfig_InvalidJSON(t *testing.T) { t.Setenv("HOME", tmpDir) t.Setenv("USERPROFILE", tmpDir) + t.Cleanup(ResetModelsConfigCacheForTest) + ResetModelsConfigCacheForTest() cfg := loadModelsConfig() // Should fall back to defaults @@ -149,69 +219,3 @@ func TestLoadModelsConfig_InvalidJSON(t *testing.T) { t.Errorf("invalid JSON should fallback, got DefaultBackend = %q", cfg.DefaultBackend) } } - -func TestOpencodeBackend_BuildArgs(t *testing.T) { - backend := OpencodeBackend{} - - t.Run("basic", func(t *testing.T) { - cfg := &Config{Mode: "new"} - got := backend.BuildArgs(cfg, "hello") - want := []string{"run", "--format", "json", "hello"} - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } - }) - - t.Run("with model", func(t *testing.T) { - cfg := &Config{Mode: "new", Model: "opencode/grok-code"} - got := backend.BuildArgs(cfg, "task") - want := []string{"run", "-m", "opencode/grok-code", "--format", "json", "task"} - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } - }) - - t.Run("resume mode", func(t *testing.T) { - cfg := &Config{Mode: "resume", SessionID: "ses_123", Model: "opencode/grok-code"} - got := backend.BuildArgs(cfg, "follow-up") - want := []string{"run", "-m", "opencode/grok-code", "-s", "ses_123", "--format", "json", "follow-up"} - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } - }) - - t.Run("resume without session", func(t *testing.T) { - cfg := &Config{Mode: "resume"} - got := backend.BuildArgs(cfg, "task") - want := []string{"run", "--format", "json", "task"} - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } - }) - - t.Run("stdin mode omits dash", func(t *testing.T) { - cfg := &Config{Mode: "new"} - got := backend.BuildArgs(cfg, "-") - want := []string{"run", "--format", "json"} - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } - }) -} - -func TestOpencodeBackend_Interface(t *testing.T) { - backend := OpencodeBackend{} - - if backend.Name() != "opencode" { - t.Errorf("Name() = %q, want %q", backend.Name(), "opencode") - } - if backend.Command() != "opencode" { - t.Errorf("Command() = %q, want %q", backend.Command(), "opencode") - } -} - -func TestBackendRegistry_IncludesOpencode(t *testing.T) { - if _, ok := backendRegistry["opencode"]; !ok { - t.Error("backendRegistry should include opencode") - } -} diff --git a/codeagent-wrapper/internal/config/config.go b/codeagent-wrapper/internal/config/config.go new file mode 100644 index 0000000..9d4c70c --- /dev/null +++ b/codeagent-wrapper/internal/config/config.go @@ -0,0 +1,102 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +// Config holds CLI configuration. +type Config struct { + Mode string // "new" or "resume" + Task string + SessionID string + WorkDir string + Model string + ReasoningEffort string + ExplicitStdin bool + Timeout int + Backend string + Agent string + PromptFile string + PromptFileExplicit bool + SkipPermissions bool + Yolo bool + MaxParallelWorkers int +} + +// EnvFlagEnabled returns true when the environment variable exists and is not +// explicitly set to a falsey value ("0/false/no/off"). +func EnvFlagEnabled(key string) bool { + val, ok := os.LookupEnv(key) + if !ok { + return false + } + val = strings.TrimSpace(strings.ToLower(val)) + switch val { + case "", "0", "false", "no", "off": + return false + default: + return true + } +} + +func ParseBoolFlag(val string, defaultValue bool) bool { + val = strings.TrimSpace(strings.ToLower(val)) + switch val { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return defaultValue + } +} + +// EnvFlagDefaultTrue returns true unless the env var is explicitly set to +// false/0/no/off. +func EnvFlagDefaultTrue(key string) bool { + val, ok := os.LookupEnv(key) + if !ok { + return true + } + return ParseBoolFlag(val, true) +} + +func ValidateAgentName(name string) error { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("agent name is empty") + } + for _, r := range name { + switch { + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case r == '-', r == '_': + default: + return fmt.Errorf("agent name %q contains invalid character %q", name, r) + } + } + return nil +} + +const maxParallelWorkersLimit = 100 + +// ResolveMaxParallelWorkers reads CODEAGENT_MAX_PARALLEL_WORKERS. It returns 0 +// for "unlimited". +func ResolveMaxParallelWorkers() int { + raw := strings.TrimSpace(os.Getenv("CODEAGENT_MAX_PARALLEL_WORKERS")) + if raw == "" { + return 0 + } + + value, err := strconv.Atoi(raw) + if err != nil || value < 0 { + return 0 + } + if value > maxParallelWorkersLimit { + return maxParallelWorkersLimit + } + return value +} diff --git a/codeagent-wrapper/internal/config/viper.go b/codeagent-wrapper/internal/config/viper.go new file mode 100644 index 0000000..4cbe9e3 --- /dev/null +++ b/codeagent-wrapper/internal/config/viper.go @@ -0,0 +1,47 @@ +package config + +import ( + "errors" + "os" + "path/filepath" + "strings" + + "github.com/spf13/viper" +) + +// NewViper returns a viper instance configured for CODEAGENT_* environment +// variables and an optional config file. +// +// Search order when configFile is empty: +// - $HOME/.codeagent/config.(yaml|yml|json|toml|...) +func NewViper(configFile string) (*viper.Viper, error) { + v := viper.New() + v.SetEnvPrefix("CODEAGENT") + v.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + v.AutomaticEnv() + + if strings.TrimSpace(configFile) != "" { + v.SetConfigFile(configFile) + if err := v.ReadInConfig(); err != nil { + return nil, err + } + return v, nil + } + + home, err := os.UserHomeDir() + if err != nil || strings.TrimSpace(home) == "" { + return v, nil + } + + v.SetConfigName("config") + v.AddConfigPath(filepath.Join(home, ".codeagent")) + if err := v.ReadInConfig(); err != nil { + var notFound viper.ConfigFileNotFoundError + if errors.As(err, ¬Found) { + return v, nil + } + return nil, err + } + + return v, nil +} diff --git a/codeagent-wrapper/executor.go b/codeagent-wrapper/internal/executor/executor.go similarity index 86% rename from codeagent-wrapper/executor.go rename to codeagent-wrapper/internal/executor/executor.go index 653942a..ab75dd8 100644 --- a/codeagent-wrapper/executor.go +++ b/codeagent-wrapper/internal/executor/executor.go @@ -1,4 +1,4 @@ -package main +package executor import ( "context" @@ -14,11 +14,92 @@ import ( "sync/atomic" "syscall" "time" + + backend "codeagent-wrapper/internal/backend" + config "codeagent-wrapper/internal/config" + ilogger "codeagent-wrapper/internal/logger" + parser "codeagent-wrapper/internal/parser" + utils "codeagent-wrapper/internal/utils" ) const postMessageTerminateDelay = 1 * time.Second const forceKillWaitTimeout = 5 * time.Second +// Defaults duplicated from wrapper for module decoupling. +const ( + defaultWorkdir = "." + defaultCoverageTarget = 90.0 + defaultBackendName = "codex" + + codexLogLineLimit = 1000 + stderrCaptureLimit = 4 * 1024 +) + +const ( + // stdout close reasons + stdoutCloseReasonWait = "wait-done" + stdoutCloseReasonDrain = "drain-timeout" + stdoutCloseReasonCtx = "context-cancel" + stdoutDrainTimeout = 100 * time.Millisecond +) + +// Hook points (tests can override inside this package). +var ( + selectBackendFn = backend.Select + commandContext = exec.CommandContext + terminateCommandFn = terminateCommand +) + +var forceKillDelay atomic.Int32 + +func init() { + forceKillDelay.Store(5) // seconds - default value +} + +type ( + Backend = backend.Backend + Config = config.Config + Logger = ilogger.Logger +) + +type minimalClaudeSettings = backend.MinimalClaudeSettings + +func loadMinimalClaudeSettings() minimalClaudeSettings { return backend.LoadMinimalClaudeSettings() } + +func loadGeminiEnv() map[string]string { return backend.LoadGeminiEnv() } + +func NewLogger() (*Logger, error) { return ilogger.NewLogger() } + +func NewLoggerWithSuffix(suffix string) (*Logger, error) { return ilogger.NewLoggerWithSuffix(suffix) } + +func setLogger(l *Logger) { ilogger.SetLogger(l) } + +func closeLogger() error { return ilogger.CloseLogger() } + +func activeLogger() *Logger { return ilogger.ActiveLogger() } + +func logInfo(msg string) { ilogger.LogInfo(msg) } + +func logWarn(msg string) { ilogger.LogWarn(msg) } + +func logError(msg string) { ilogger.LogError(msg) } + +func logConcurrencyPlanning(limit, total int) { ilogger.LogConcurrencyPlanning(limit, total) } + +func logConcurrencyState(event, taskID string, active, limit int) { + ilogger.LogConcurrencyState(event, taskID, active, limit) +} + +func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) { + return parser.ParseJSONStreamInternal(r, warnFn, infoFn, onMessage, onComplete) +} + +func sanitizeOutput(s string) string { return utils.SanitizeOutput(s) } + +func safeTruncate(s string, maxLen int) string { return utils.SafeTruncate(s, maxLen) } + +func min(a, b int) int { return utils.Min(a, b) } + // commandRunner abstracts exec.Cmd for testability type commandRunner interface { Start() error @@ -230,7 +311,7 @@ func newTaskLoggerHandle(taskID string) taskLoggerHandle { } // defaultRunCodexTaskFn is the default implementation of runCodexTaskFn (exposed for test reset) -func defaultRunCodexTaskFn(task TaskSpec, timeout int) TaskResult { +func DefaultRunCodexTaskFn(task TaskSpec, timeout int) TaskResult { if task.WorkDir == "" { task.WorkDir = defaultWorkdir } @@ -238,13 +319,13 @@ func defaultRunCodexTaskFn(task TaskSpec, timeout int) TaskResult { task.Mode = "new" } if strings.TrimSpace(task.PromptFile) != "" { - prompt, err := readAgentPromptFile(task.PromptFile, false) + prompt, err := ReadAgentPromptFile(task.PromptFile, false) if err != nil { return TaskResult{TaskID: task.ID, ExitCode: 1, Error: "failed to read prompt file: " + err.Error()} } - task.Task = wrapTaskWithAgentPrompt(prompt, task.Task) + task.Task = WrapTaskWithAgentPrompt(prompt, task.Task) } - if task.UseStdin || shouldUseStdin(task.Task, false) { + if task.UseStdin || ShouldUseStdin(task.Task, false) { task.UseStdin = true } @@ -263,12 +344,10 @@ func defaultRunCodexTaskFn(task TaskSpec, timeout int) TaskResult { if parentCtx == nil { parentCtx = context.Background() } - return runCodexTaskWithContext(parentCtx, task, backend, nil, false, true, timeout) + return RunCodexTaskWithContext(parentCtx, task, backend, "", nil, nil, false, true, timeout) } -var runCodexTaskFn = defaultRunCodexTaskFn - -func topologicalSort(tasks []TaskSpec) ([][]TaskSpec, error) { +func TopologicalSort(tasks []TaskSpec) ([][]TaskSpec, error) { idToTask := make(map[string]TaskSpec, len(tasks)) indegree := make(map[string]int, len(tasks)) adj := make(map[string][]string, len(tasks)) @@ -334,12 +413,16 @@ func topologicalSort(tasks []TaskSpec) ([][]TaskSpec, error) { return layers, nil } -func executeConcurrent(layers [][]TaskSpec, timeout int) []TaskResult { - maxWorkers := resolveMaxParallelWorkers() - return executeConcurrentWithContext(context.Background(), layers, timeout, maxWorkers) +func ExecuteConcurrent(layers [][]TaskSpec, timeout int, runTask func(TaskSpec, int) TaskResult) []TaskResult { + maxWorkers := config.ResolveMaxParallelWorkers() + return ExecuteConcurrentWithContext(context.Background(), layers, timeout, maxWorkers, runTask) } -func executeConcurrentWithContext(parentCtx context.Context, layers [][]TaskSpec, timeout int, maxWorkers int) []TaskResult { +func ExecuteConcurrentWithContext(parentCtx context.Context, layers [][]TaskSpec, timeout int, maxWorkers int, runTask func(TaskSpec, int) TaskResult) []TaskResult { + if runTask == nil { + runTask = DefaultRunCodexTaskFn + } + totalTasks := 0 for _, layer := range layers { totalTasks += len(layer) @@ -470,7 +553,7 @@ func executeConcurrentWithContext(parentCtx context.Context, layers [][]TaskSpec printTaskStart(ts.ID, taskLogPath, handle.shared) - res := runCodexTaskFn(ts, timeout) + res := runTask(ts, timeout) if taskLogPath != "" { if res.LogPath == "" || (handle.shared && handle.logger != nil && res.LogPath == handle.logger.Path()) { res.LogPath = taskLogPath @@ -535,14 +618,14 @@ func getStatusSymbols() (success, warning, failed string) { return "✓", "⚠️", "✗" } -func generateFinalOutput(results []TaskResult) string { - return generateFinalOutputWithMode(results, true) // default to summary mode +func GenerateFinalOutput(results []TaskResult) string { + return GenerateFinalOutputWithMode(results, true) // default to summary mode } // generateFinalOutputWithMode generates output based on mode // summaryOnly=true: structured report - every token has value // summaryOnly=false: full output with complete messages (legacy behavior) -func generateFinalOutputWithMode(results []TaskResult, summaryOnly bool) string { +func GenerateFinalOutputWithMode(results []TaskResult, summaryOnly bool) string { var sb strings.Builder successSymbol, warningSymbol, failedSymbol := getStatusSymbols() @@ -756,7 +839,7 @@ func buildCodexArgs(cfg *Config, targetArg string) []string { args := []string{"e"} // Default to bypass sandbox unless CODEX_BYPASS_SANDBOX=false - if cfg.Yolo || envFlagDefaultTrue("CODEX_BYPASS_SANDBOX") { + if cfg.Yolo || config.EnvFlagDefaultTrue("CODEX_BYPASS_SANDBOX") { logWarn("YOLO mode or CODEX_BYPASS_SANDBOX enabled: running without approval/sandbox protection") args = append(args, "--dangerously-bypass-approvals-and-sandbox") } @@ -787,25 +870,20 @@ func buildCodexArgs(cfg *Config, targetArg string) []string { ) } -func runCodexTask(taskSpec TaskSpec, silent bool, timeoutSec int) TaskResult { - return runCodexTaskWithContext(context.Background(), taskSpec, nil, nil, false, silent, timeoutSec) -} - -func runCodexProcess(parentCtx context.Context, codexArgs []string, taskText string, useStdin bool, timeoutSec int) (message, threadID string, exitCode int) { - res := runCodexTaskWithContext(parentCtx, TaskSpec{Task: taskText, WorkDir: defaultWorkdir, Mode: "new", UseStdin: useStdin}, nil, codexArgs, true, false, timeoutSec) - return res.Message, res.SessionID, res.ExitCode -} - -func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backend Backend, customArgs []string, useCustomArgs bool, silent bool, timeoutSec int) TaskResult { +func RunCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backend Backend, defaultCommandName string, defaultArgsBuilder func(*Config, string) []string, customArgs []string, useCustomArgs bool, silent bool, timeoutSec int) TaskResult { + taskCtx := taskSpec.Context if parentCtx == nil { - parentCtx = taskSpec.Context + parentCtx = taskCtx } if parentCtx == nil { parentCtx = context.Background() } result := TaskResult{TaskID: taskSpec.ID} - injectedLogger := taskLoggerFromContext(parentCtx) + injectedLogger := taskLoggerFromContext(taskCtx) + if injectedLogger == nil { + injectedLogger = taskLoggerFromContext(parentCtx) + } logger := injectedLogger cfg := &Config{ @@ -819,8 +897,14 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe Backend: defaultBackendName, } - commandName := codexCommand - argsBuilder := buildCodexArgsFn + commandName := strings.TrimSpace(defaultCommandName) + if commandName == "" { + commandName = defaultBackendName + } + argsBuilder := defaultArgsBuilder + if argsBuilder == nil { + argsBuilder = buildCodexArgs + } if backend != nil { commandName = backend.Command() argsBuilder = backend.BuildArgs @@ -844,19 +928,18 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe return result } - var claudeEnv map[string]string + var fileEnv map[string]string if cfg.Backend == "claude" { settings := loadMinimalClaudeSettings() - claudeEnv = settings.Env + fileEnv = settings.Env if cfg.Mode != "resume" && strings.TrimSpace(cfg.Model) == "" && settings.Model != "" { cfg.Model = settings.Model } } // Load gemini env from ~/.gemini/.env if exists - var geminiEnv map[string]string if cfg.Backend == "gemini" { - geminiEnv = loadGeminiEnv() + fileEnv = loadGeminiEnv() } useStdin := taskSpec.UseStdin @@ -958,11 +1041,28 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe cmd := newCommandRunner(ctx, commandName, codexArgs...) - if cfg.Backend == "claude" && len(claudeEnv) > 0 { - cmd.SetEnv(claudeEnv) + if len(fileEnv) > 0 { + cmd.SetEnv(fileEnv) } - if cfg.Backend == "gemini" && len(geminiEnv) > 0 { - cmd.SetEnv(geminiEnv) + + envBackend := backend + if envBackend == nil && cfg.Backend != "" { + if b, err := selectBackendFn(cfg.Backend); err == nil { + envBackend = b + } + } + + if envBackend != nil { + baseURL, apiKey := config.ResolveBackendConfig(cfg.Backend) + if agentName := strings.TrimSpace(taskSpec.Agent); agentName != "" { + agentBackend, _, _, _, agentBaseURL, agentAPIKey, _ := config.ResolveAgentConfig(agentName) + if strings.EqualFold(strings.TrimSpace(agentBackend), strings.TrimSpace(cfg.Backend)) { + baseURL, apiKey = agentBaseURL, agentAPIKey + } + } + if injected := envBackend.Env(baseURL, apiKey); len(injected) > 0 { + cmd.SetEnv(injected) + } } // For backends that don't support -C flag (claude, gemini), set working directory via cmd.Dir @@ -1202,11 +1302,9 @@ waitLoop: case parsed = <-parseCh: closeWithReason(stdout, stdoutCloseReasonWait) case <-messageSeen: - messageSeenObserved = true closeWithReason(stdout, stdoutCloseReasonWait) parsed = <-parseCh case <-completeSeen: - completeSeenObserved = true closeWithReason(stdout, stdoutCloseReasonWait) parsed = <-parseCh case <-drainTimer.C: @@ -1276,44 +1374,13 @@ waitLoop: return result } -func forwardSignals(ctx context.Context, cmd commandRunner, logErrorFn func(string)) { - notify := signalNotifyFn - stop := signalStopFn - if notify == nil { - notify = signal.Notify - } - if stop == nil { - stop = signal.Stop - } - - sigCh := make(chan os.Signal, 1) - notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - go func() { - defer stop(sigCh) - select { - case sig := <-sigCh: - logErrorFn(fmt.Sprintf("Received signal: %v", sig)) - if proc := cmd.Process(); proc != nil { - _ = sendTermSignal(proc) - time.AfterFunc(time.Duration(forceKillDelay.Load())*time.Second, func() { - if p := cmd.Process(); p != nil { - _ = p.Kill() - } - }) - } - case <-ctx.Done(): - } - }() -} - func cancelReason(commandName string, ctx context.Context) string { if ctx == nil { return "Context cancelled" } if commandName == "" { - commandName = codexCommand + commandName = defaultBackendName } if errors.Is(ctx.Err(), context.DeadlineExceeded) { @@ -1377,21 +1444,3 @@ func terminateCommand(cmd commandRunner) *forceKillTimer { return &forceKillTimer{timer: timer, done: done} } - -func terminateProcess(cmd commandRunner) *time.Timer { - if cmd == nil { - return nil - } - proc := cmd.Process() - if proc == nil { - return nil - } - - _ = sendTermSignal(proc) - - return time.AfterFunc(time.Duration(forceKillDelay.Load())*time.Second, func() { - if p := cmd.Process(); p != nil { - _ = p.Kill() - } - }) -} diff --git a/codeagent-wrapper/filter.go b/codeagent-wrapper/internal/executor/filter.go similarity index 94% rename from codeagent-wrapper/filter.go rename to codeagent-wrapper/internal/executor/filter.go index 9f37445..fbcbf08 100644 --- a/codeagent-wrapper/filter.go +++ b/codeagent-wrapper/internal/executor/filter.go @@ -1,4 +1,4 @@ -package main +package executor import ( "bytes" @@ -45,7 +45,7 @@ func (f *filteringWriter) Write(p []byte) (n int, err error) { break } if !f.shouldFilter(line) { - f.w.Write([]byte(line)) + _, _ = f.w.Write([]byte(line)) } } return len(p), nil @@ -65,7 +65,7 @@ func (f *filteringWriter) Flush() { if f.buf.Len() > 0 { remaining := f.buf.String() if !f.shouldFilter(remaining) { - f.w.Write([]byte(remaining)) + _, _ = f.w.Write([]byte(remaining)) } f.buf.Reset() } diff --git a/codeagent-wrapper/filter_test.go b/codeagent-wrapper/internal/executor/filter_test.go similarity index 92% rename from codeagent-wrapper/filter_test.go rename to codeagent-wrapper/internal/executor/filter_test.go index 12042f8..55d2949 100644 --- a/codeagent-wrapper/filter_test.go +++ b/codeagent-wrapper/internal/executor/filter_test.go @@ -1,4 +1,4 @@ -package main +package executor import ( "bytes" @@ -48,7 +48,7 @@ func TestFilteringWriter(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var buf bytes.Buffer fw := newFilteringWriter(&buf, tt.patterns) - fw.Write([]byte(tt.input)) + _, _ = fw.Write([]byte(tt.input)) fw.Flush() if got := buf.String(); got != tt.want { @@ -63,8 +63,8 @@ func TestFilteringWriterPartialLines(t *testing.T) { fw := newFilteringWriter(&buf, geminiNoisePatterns) // Write partial line - fw.Write([]byte("Hello ")) - fw.Write([]byte("World\n")) + _, _ = fw.Write([]byte("Hello ")) + _, _ = fw.Write([]byte("World\n")) fw.Flush() if got := buf.String(); got != "Hello World\n" { diff --git a/codeagent-wrapper/internal/executor/log_helpers.go b/codeagent-wrapper/internal/executor/log_helpers.go new file mode 100644 index 0000000..fd5ffd4 --- /dev/null +++ b/codeagent-wrapper/internal/executor/log_helpers.go @@ -0,0 +1,124 @@ +package executor + +import "bytes" + +type logWriter struct { + prefix string + maxLen int + buf bytes.Buffer + dropped bool +} + +func newLogWriter(prefix string, maxLen int) *logWriter { + if maxLen <= 0 { + maxLen = codexLogLineLimit + } + return &logWriter{prefix: prefix, maxLen: maxLen} +} + +func (lw *logWriter) Write(p []byte) (int, error) { + if lw == nil { + return len(p), nil + } + total := len(p) + for len(p) > 0 { + if idx := bytes.IndexByte(p, '\n'); idx >= 0 { + lw.writeLimited(p[:idx]) + lw.logLine(true) + p = p[idx+1:] + continue + } + lw.writeLimited(p) + break + } + return total, nil +} + +func (lw *logWriter) Flush() { + if lw == nil || lw.buf.Len() == 0 { + return + } + lw.logLine(false) +} + +func (lw *logWriter) logLine(force bool) { + if lw == nil { + return + } + line := lw.buf.String() + dropped := lw.dropped + lw.dropped = false + lw.buf.Reset() + if line == "" && !force { + return + } + if lw.maxLen > 0 { + if dropped { + if lw.maxLen > 3 { + line = line[:min(len(line), lw.maxLen-3)] + "..." + } else { + line = line[:min(len(line), lw.maxLen)] + } + } else if len(line) > lw.maxLen { + cutoff := lw.maxLen + if cutoff > 3 { + line = line[:cutoff-3] + "..." + } else { + line = line[:cutoff] + } + } + } + logInfo(lw.prefix + line) +} + +func (lw *logWriter) writeLimited(p []byte) { + if lw == nil || len(p) == 0 { + return + } + if lw.maxLen <= 0 { + lw.buf.Write(p) + return + } + + remaining := lw.maxLen - lw.buf.Len() + if remaining <= 0 { + lw.dropped = true + return + } + if len(p) <= remaining { + lw.buf.Write(p) + return + } + lw.buf.Write(p[:remaining]) + lw.dropped = true +} + +type tailBuffer struct { + limit int + data []byte +} + +func (b *tailBuffer) Write(p []byte) (int, error) { + if b.limit <= 0 { + return len(p), nil + } + + if len(p) >= b.limit { + b.data = append(b.data[:0], p[len(p)-b.limit:]...) + return len(p), nil + } + + total := len(b.data) + len(p) + if total <= b.limit { + b.data = append(b.data, p...) + return len(p), nil + } + + overflow := total - b.limit + b.data = append(b.data[overflow:], p...) + return len(p), nil +} + +func (b *tailBuffer) String() string { + return string(b.data) +} diff --git a/codeagent-wrapper/log_writer_limit_test.go b/codeagent-wrapper/internal/executor/log_writer_limit_test.go similarity index 92% rename from codeagent-wrapper/log_writer_limit_test.go rename to codeagent-wrapper/internal/executor/log_writer_limit_test.go index a89558c..a035d4d 100644 --- a/codeagent-wrapper/log_writer_limit_test.go +++ b/codeagent-wrapper/internal/executor/log_writer_limit_test.go @@ -1,4 +1,4 @@ -package main +package executor import ( "os" @@ -7,14 +7,12 @@ import ( ) func TestLogWriterWriteLimitsBuffer(t *testing.T) { - defer resetTestHooks() - logger, err := NewLogger() if err != nil { t.Fatalf("NewLogger error: %v", err) } setLogger(logger) - defer closeLogger() + t.Cleanup(func() { _ = closeLogger() }) lw := newLogWriter("P:", 10) _, _ = lw.Write([]byte(strings.Repeat("a", 100))) diff --git a/codeagent-wrapper/internal/executor/parallel_config.go b/codeagent-wrapper/internal/executor/parallel_config.go new file mode 100644 index 0000000..57bdff3 --- /dev/null +++ b/codeagent-wrapper/internal/executor/parallel_config.go @@ -0,0 +1,135 @@ +package executor + +import ( + "bytes" + "fmt" + "strings" + + config "codeagent-wrapper/internal/config" +) + +func ParseParallelConfig(data []byte) (*ParallelConfig, error) { + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 { + return nil, fmt.Errorf("parallel config is empty") + } + + tasks := strings.Split(string(trimmed), "---TASK---") + var cfg ParallelConfig + seen := make(map[string]struct{}) + + taskIndex := 0 + for _, taskBlock := range tasks { + taskBlock = strings.TrimSpace(taskBlock) + if taskBlock == "" { + continue + } + taskIndex++ + + parts := strings.SplitN(taskBlock, "---CONTENT---", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("task block #%d missing ---CONTENT--- separator", taskIndex) + } + + meta := strings.TrimSpace(parts[0]) + content := strings.TrimSpace(parts[1]) + + task := TaskSpec{WorkDir: defaultWorkdir} + agentSpecified := false + for _, line := range strings.Split(meta, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + kv := strings.SplitN(line, ":", 2) + if len(kv) != 2 { + continue + } + key := strings.TrimSpace(kv[0]) + value := strings.TrimSpace(kv[1]) + + switch key { + case "id": + task.ID = value + case "workdir": + // Validate workdir: "-" is not a valid directory + if value == "-" { + return nil, fmt.Errorf("task block #%d has invalid workdir: '-' is not a valid directory path", taskIndex) + } + task.WorkDir = value + case "session_id": + task.SessionID = value + task.Mode = "resume" + case "backend": + task.Backend = value + case "model": + task.Model = value + case "reasoning_effort": + task.ReasoningEffort = value + case "agent": + agentSpecified = true + task.Agent = value + case "skip_permissions", "skip-permissions": + if value == "" { + task.SkipPermissions = true + continue + } + task.SkipPermissions = config.ParseBoolFlag(value, false) + case "dependencies": + for _, dep := range strings.Split(value, ",") { + dep = strings.TrimSpace(dep) + if dep != "" { + task.Dependencies = append(task.Dependencies, dep) + } + } + } + } + + if task.Mode == "" { + task.Mode = "new" + } + + if agentSpecified { + if strings.TrimSpace(task.Agent) == "" { + return nil, fmt.Errorf("task block #%d has empty agent field", taskIndex) + } + if err := config.ValidateAgentName(task.Agent); err != nil { + return nil, fmt.Errorf("task block #%d invalid agent name: %w", taskIndex, err) + } + backend, model, promptFile, reasoning, _, _, _ := config.ResolveAgentConfig(task.Agent) + if task.Backend == "" { + task.Backend = backend + } + if task.Model == "" { + task.Model = model + } + if task.ReasoningEffort == "" { + task.ReasoningEffort = reasoning + } + task.PromptFile = promptFile + } + + if task.ID == "" { + return nil, fmt.Errorf("task block #%d missing id field", taskIndex) + } + if content == "" { + return nil, fmt.Errorf("task block #%d (%q) missing content", taskIndex, task.ID) + } + if task.Mode == "resume" && strings.TrimSpace(task.SessionID) == "" { + return nil, fmt.Errorf("task block #%d (%q) has empty session_id", taskIndex, task.ID) + } + if _, exists := seen[task.ID]; exists { + return nil, fmt.Errorf("task block #%d has duplicate id: %s", taskIndex, task.ID) + } + + task.Task = content + cfg.Tasks = append(cfg.Tasks, task) + seen[task.ID] = struct{}{} + } + + if len(cfg.Tasks) == 0 { + return nil, fmt.Errorf("no tasks found") + } + + return &cfg, nil +} diff --git a/codeagent-wrapper/internal/executor/prompt.go b/codeagent-wrapper/internal/executor/prompt.go new file mode 100644 index 0000000..726756d --- /dev/null +++ b/codeagent-wrapper/internal/executor/prompt.go @@ -0,0 +1,130 @@ +package executor + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +func ReadAgentPromptFile(path string, allowOutsideClaudeDir bool) (string, error) { + raw := strings.TrimSpace(path) + if raw == "" { + return "", nil + } + + expanded := raw + if raw == "~" || strings.HasPrefix(raw, "~/") || strings.HasPrefix(raw, "~\\") { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + if raw == "~" { + expanded = home + } else { + expanded = home + raw[1:] + } + } + + absPath, err := filepath.Abs(expanded) + if err != nil { + return "", err + } + absPath = filepath.Clean(absPath) + + home, err := os.UserHomeDir() + if err != nil { + if !allowOutsideClaudeDir { + return "", err + } + logWarn(fmt.Sprintf("Failed to resolve home directory for prompt file validation: %v; proceeding without restriction", err)) + } else { + allowedDirs := []string{ + filepath.Clean(filepath.Join(home, ".claude")), + filepath.Clean(filepath.Join(home, ".codeagent", "agents")), + } + for i := range allowedDirs { + allowedAbs, err := filepath.Abs(allowedDirs[i]) + if err == nil { + allowedDirs[i] = filepath.Clean(allowedAbs) + } + } + + isWithinDir := func(path, dir string) bool { + rel, err := filepath.Rel(dir, path) + if err != nil { + return false + } + rel = filepath.Clean(rel) + if rel == "." { + return true + } + if rel == ".." { + return false + } + prefix := ".." + string(os.PathSeparator) + return !strings.HasPrefix(rel, prefix) + } + + if !allowOutsideClaudeDir { + withinAllowed := false + for _, dir := range allowedDirs { + if isWithinDir(absPath, dir) { + withinAllowed = true + break + } + } + if !withinAllowed { + logWarn(fmt.Sprintf("Refusing to read prompt file outside allowed dirs (%s): %s", strings.Join(allowedDirs, ", "), absPath)) + return "", fmt.Errorf("prompt file must be under ~/.claude or ~/.codeagent/agents") + } + + resolvedPath, errPath := filepath.EvalSymlinks(absPath) + if errPath == nil { + resolvedPath = filepath.Clean(resolvedPath) + resolvedAllowed := make([]string, 0, len(allowedDirs)) + for _, dir := range allowedDirs { + resolvedBase, errBase := filepath.EvalSymlinks(dir) + if errBase != nil { + continue + } + resolvedAllowed = append(resolvedAllowed, filepath.Clean(resolvedBase)) + } + if len(resolvedAllowed) > 0 { + withinResolved := false + for _, dir := range resolvedAllowed { + if isWithinDir(resolvedPath, dir) { + withinResolved = true + break + } + } + if !withinResolved { + logWarn(fmt.Sprintf("Refusing to read prompt file outside allowed dirs (%s) (resolved): %s", strings.Join(resolvedAllowed, ", "), resolvedPath)) + return "", fmt.Errorf("prompt file must be under ~/.claude or ~/.codeagent/agents") + } + } + } + } else { + withinAllowed := false + for _, dir := range allowedDirs { + if isWithinDir(absPath, dir) { + withinAllowed = true + break + } + } + if !withinAllowed { + logWarn(fmt.Sprintf("Reading prompt file outside allowed dirs (%s): %s", strings.Join(allowedDirs, ", "), absPath)) + } + } + } + + data, err := os.ReadFile(absPath) + if err != nil { + return "", err + } + return strings.TrimRight(string(data), "\r\n"), nil +} + +func WrapTaskWithAgentPrompt(prompt string, task string) string { + return "\n" + prompt + "\n\n\n" + task +} diff --git a/codeagent-wrapper/prompt_file_test.go b/codeagent-wrapper/internal/executor/prompt_file_test.go similarity index 77% rename from codeagent-wrapper/prompt_file_test.go rename to codeagent-wrapper/internal/executor/prompt_file_test.go index 7ad4cbc..7c3c59d 100644 --- a/codeagent-wrapper/prompt_file_test.go +++ b/codeagent-wrapper/internal/executor/prompt_file_test.go @@ -1,4 +1,4 @@ -package main +package executor import ( "os" @@ -9,7 +9,7 @@ import ( ) func TestWrapTaskWithAgentPrompt(t *testing.T) { - got := wrapTaskWithAgentPrompt("P", "do") + got := WrapTaskWithAgentPrompt("P", "do") want := "\nP\n\n\ndo" if got != want { t.Fatalf("wrapTaskWithAgentPrompt mismatch:\n got=%q\nwant=%q", got, want) @@ -18,7 +18,7 @@ func TestWrapTaskWithAgentPrompt(t *testing.T) { func TestReadAgentPromptFile_EmptyPath(t *testing.T) { for _, allowOutside := range []bool{false, true} { - got, err := readAgentPromptFile(" ", allowOutside) + got, err := ReadAgentPromptFile(" ", allowOutside) if err != nil { t.Fatalf("unexpected error (allowOutside=%v): %v", allowOutside, err) } @@ -35,7 +35,7 @@ func TestReadAgentPromptFile_ExplicitAbsolutePath(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - got, err := readAgentPromptFile(path, true) + got, err := ReadAgentPromptFile(path, true) if err != nil { t.Fatalf("readAgentPromptFile error: %v", err) } @@ -54,7 +54,7 @@ func TestReadAgentPromptFile_ExplicitTildeExpansion(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - got, err := readAgentPromptFile("~/prompt.md", true) + got, err := ReadAgentPromptFile("~/prompt.md", true) if err != nil { t.Fatalf("readAgentPromptFile error: %v", err) } @@ -77,7 +77,30 @@ func TestReadAgentPromptFile_RestrictedAllowsClaudeDir(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - got, err := readAgentPromptFile("~/.claude/prompt.md", false) + got, err := ReadAgentPromptFile("~/.claude/prompt.md", false) + if err != nil { + t.Fatalf("readAgentPromptFile error: %v", err) + } + if got != "OK" { + t.Fatalf("got %q, want %q", got, "OK") + } +} + +func TestReadAgentPromptFile_RestrictedAllowsCodeagentAgentsDir(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + agentDir := filepath.Join(home, ".codeagent", "agents") + if err := os.MkdirAll(agentDir, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + path := filepath.Join(agentDir, "sarsh.md") + if err := os.WriteFile(path, []byte("OK\n"), 0o644); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + got, err := ReadAgentPromptFile("~/.codeagent/agents/sarsh.md", false) if err != nil { t.Fatalf("readAgentPromptFile error: %v", err) } @@ -96,7 +119,7 @@ func TestReadAgentPromptFile_RestrictedRejectsOutsideClaudeDir(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - if _, err := readAgentPromptFile("~/prompt.md", false); err == nil { + if _, err := ReadAgentPromptFile("~/prompt.md", false); err == nil { t.Fatalf("expected error for prompt file outside ~/.claude, got nil") } } @@ -111,7 +134,7 @@ func TestReadAgentPromptFile_RestrictedRejectsTraversal(t *testing.T) { t.Fatalf("WriteFile: %v", err) } - if _, err := readAgentPromptFile("~/.claude/../secret.md", false); err == nil { + if _, err := ReadAgentPromptFile("~/.claude/../secret.md", false); err == nil { t.Fatalf("expected traversal to be rejected, got nil") } } @@ -126,7 +149,7 @@ func TestReadAgentPromptFile_NotFound(t *testing.T) { t.Fatalf("MkdirAll: %v", err) } - _, err := readAgentPromptFile("~/.claude/missing.md", false) + _, err := ReadAgentPromptFile("~/.claude/missing.md", false) if err == nil || !os.IsNotExist(err) { t.Fatalf("expected not-exist error, got %v", err) } @@ -153,7 +176,7 @@ func TestReadAgentPromptFile_PermissionDenied(t *testing.T) { t.Fatalf("Chmod: %v", err) } - _, err := readAgentPromptFile("~/.claude/private.md", false) + _, err := ReadAgentPromptFile("~/.claude/private.md", false) if err == nil { t.Fatalf("expected permission error, got nil") } diff --git a/codeagent-wrapper/internal/executor/report_helpers.go b/codeagent-wrapper/internal/executor/report_helpers.go new file mode 100644 index 0000000..f3c084b --- /dev/null +++ b/codeagent-wrapper/internal/executor/report_helpers.go @@ -0,0 +1,104 @@ +package executor + +import "strings" + +// extractCoverageGap extracts what's missing from coverage reports. +func extractCoverageGap(message string) string { + if message == "" { + return "" + } + + lower := strings.ToLower(message) + lines := strings.Split(message, "\n") + + for _, line := range lines { + lineLower := strings.ToLower(line) + line = strings.TrimSpace(line) + + if strings.Contains(lineLower, "uncovered") || + strings.Contains(lineLower, "not covered") || + strings.Contains(lineLower, "missing coverage") || + strings.Contains(lineLower, "lines not covered") { + if len(line) > 100 { + return line[:97] + "..." + } + return line + } + + if strings.Contains(lineLower, "branch") && strings.Contains(lineLower, "not taken") { + if len(line) > 100 { + return line[:97] + "..." + } + return line + } + } + + if strings.Contains(lower, "function") && strings.Contains(lower, "0%") { + for _, line := range lines { + if strings.Contains(strings.ToLower(line), "0%") && strings.Contains(line, "function") { + line = strings.TrimSpace(line) + if len(line) > 100 { + return line[:97] + "..." + } + return line + } + } + } + + return "" +} + +// extractErrorDetail extracts meaningful error context from task output. +func extractErrorDetail(message string, maxLen int) string { + if message == "" || maxLen <= 0 { + return "" + } + + lines := strings.Split(message, "\n") + var errorLines []string + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + lower := strings.ToLower(line) + + if strings.HasPrefix(line, "at ") && strings.Contains(line, "(") { + if len(errorLines) > 0 && strings.HasPrefix(strings.ToLower(errorLines[len(errorLines)-1]), "at ") { + continue + } + } + + if strings.Contains(lower, "error") || + strings.Contains(lower, "fail") || + strings.Contains(lower, "exception") || + strings.Contains(lower, "assert") || + strings.Contains(lower, "expected") || + strings.Contains(lower, "timeout") || + strings.Contains(lower, "not found") || + strings.Contains(lower, "cannot") || + strings.Contains(lower, "undefined") || + strings.HasPrefix(line, "FAIL") || + strings.HasPrefix(line, "●") { + errorLines = append(errorLines, line) + } + } + + if len(errorLines) == 0 { + start := len(lines) - 5 + if start < 0 { + start = 0 + } + for _, line := range lines[start:] { + line = strings.TrimSpace(line) + if line != "" { + errorLines = append(errorLines, line) + } + } + } + + result := strings.Join(errorLines, " | ") + return safeTruncate(result, maxLen) +} diff --git a/codeagent-wrapper/signal_unix.go b/codeagent-wrapper/internal/executor/signal_unix.go similarity index 94% rename from codeagent-wrapper/signal_unix.go rename to codeagent-wrapper/internal/executor/signal_unix.go index f89bf0d..6f97082 100644 --- a/codeagent-wrapper/signal_unix.go +++ b/codeagent-wrapper/internal/executor/signal_unix.go @@ -1,7 +1,7 @@ //go:build unix || darwin || linux // +build unix darwin linux -package main +package executor import ( "syscall" diff --git a/codeagent-wrapper/signal_windows.go b/codeagent-wrapper/internal/executor/signal_windows.go similarity index 99% rename from codeagent-wrapper/signal_windows.go rename to codeagent-wrapper/internal/executor/signal_windows.go index cafcaa0..3f3934a 100644 --- a/codeagent-wrapper/signal_windows.go +++ b/codeagent-wrapper/internal/executor/signal_windows.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -package main +package executor import ( "io" diff --git a/codeagent-wrapper/internal/executor/stdin.go b/codeagent-wrapper/internal/executor/stdin.go new file mode 100644 index 0000000..eb744eb --- /dev/null +++ b/codeagent-wrapper/internal/executor/stdin.go @@ -0,0 +1,15 @@ +package executor + +import "strings" + +const stdinSpecialChars = "\n\\\"'`$" + +func ShouldUseStdin(taskText string, piped bool) bool { + if piped { + return true + } + if len(taskText) > 800 { + return true + } + return strings.ContainsAny(taskText, stdinSpecialChars) +} diff --git a/codeagent-wrapper/internal/executor/task_types.go b/codeagent-wrapper/internal/executor/task_types.go new file mode 100644 index 0000000..ab6c298 --- /dev/null +++ b/codeagent-wrapper/internal/executor/task_types.go @@ -0,0 +1,46 @@ +package executor + +import "context" + +// ParallelConfig defines the JSON schema for parallel execution. +type ParallelConfig struct { + Tasks []TaskSpec `json:"tasks"` + GlobalBackend string `json:"backend,omitempty"` +} + +// TaskSpec describes an individual task entry in the parallel config. +type TaskSpec struct { + ID string `json:"id"` + Task string `json:"task"` + WorkDir string `json:"workdir,omitempty"` + Dependencies []string `json:"dependencies,omitempty"` + SessionID string `json:"session_id,omitempty"` + Backend string `json:"backend,omitempty"` + Model string `json:"model,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + Agent string `json:"agent,omitempty"` + PromptFile string `json:"prompt_file,omitempty"` + SkipPermissions bool `json:"skip_permissions,omitempty"` + Mode string `json:"-"` + UseStdin bool `json:"-"` + Context context.Context `json:"-"` +} + +// TaskResult captures the execution outcome of a task. +type TaskResult struct { + TaskID string `json:"task_id"` + ExitCode int `json:"exit_code"` + Message string `json:"message"` + SessionID string `json:"session_id"` + Error string `json:"error"` + LogPath string `json:"log_path"` + // Structured report fields + Coverage string `json:"coverage,omitempty"` // extracted coverage percentage (e.g., "92%") + CoverageNum float64 `json:"coverage_num,omitempty"` // numeric coverage for comparison + CoverageTarget float64 `json:"coverage_target,omitempty"` // target coverage (default 90) + FilesChanged []string `json:"files_changed,omitempty"` // list of changed files + KeyOutput string `json:"key_output,omitempty"` // brief summary of what was done + TestsPassed int `json:"tests_passed,omitempty"` // number of tests passed + TestsFailed int `json:"tests_failed,omitempty"` // number of tests failed + sharedLog bool +} diff --git a/codeagent-wrapper/internal/executor/testhooks.go b/codeagent-wrapper/internal/executor/testhooks.go new file mode 100644 index 0000000..f49947a --- /dev/null +++ b/codeagent-wrapper/internal/executor/testhooks.go @@ -0,0 +1,57 @@ +package executor + +import ( + "context" + "os/exec" + + backend "codeagent-wrapper/internal/backend" +) + +type CommandRunner = commandRunner +type ProcessHandle = processHandle + +func SetForceKillDelay(seconds int32) (restore func()) { + prev := forceKillDelay.Load() + forceKillDelay.Store(seconds) + return func() { forceKillDelay.Store(prev) } +} + +func SetSelectBackendFn(fn func(string) (Backend, error)) (restore func()) { + prev := selectBackendFn + if fn != nil { + selectBackendFn = fn + } else { + selectBackendFn = backend.Select + } + return func() { selectBackendFn = prev } +} + +func SetCommandContextFn(fn func(context.Context, string, ...string) *exec.Cmd) (restore func()) { + prev := commandContext + if fn != nil { + commandContext = fn + } else { + commandContext = exec.CommandContext + } + return func() { commandContext = prev } +} + +func SetNewCommandRunner(fn func(context.Context, string, ...string) CommandRunner) (restore func()) { + prev := newCommandRunner + if fn != nil { + newCommandRunner = fn + } else { + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return &realCmd{cmd: commandContext(ctx, name, args...)} + } + } + return func() { newCommandRunner = prev } +} + +func WithTaskLogger(ctx context.Context, logger *Logger) context.Context { + return withTaskLogger(ctx, logger) +} + +func TaskLoggerFromContext(ctx context.Context) *Logger { + return taskLoggerFromContext(ctx) +} diff --git a/codeagent-wrapper/internal/logger/active.go b/codeagent-wrapper/internal/logger/active.go new file mode 100644 index 0000000..a4b1c18 --- /dev/null +++ b/codeagent-wrapper/internal/logger/active.go @@ -0,0 +1,59 @@ +package logger + +import "sync/atomic" + +var loggerPtr atomic.Pointer[Logger] + +func setLogger(l *Logger) { + loggerPtr.Store(l) +} + +func closeLogger() error { + logger := loggerPtr.Swap(nil) + if logger == nil { + return nil + } + return logger.Close() +} + +func activeLogger() *Logger { + return loggerPtr.Load() +} + +func logDebug(msg string) { + if logger := activeLogger(); logger != nil { + logger.Debug(msg) + } +} + +func logInfo(msg string) { + if logger := activeLogger(); logger != nil { + logger.Info(msg) + } +} + +func logWarn(msg string) { + if logger := activeLogger(); logger != nil { + logger.Warn(msg) + } +} + +func logError(msg string) { + if logger := activeLogger(); logger != nil { + logger.Error(msg) + } +} + +func SetLogger(l *Logger) { setLogger(l) } + +func CloseLogger() error { return closeLogger() } + +func ActiveLogger() *Logger { return activeLogger() } + +func LogInfo(msg string) { logInfo(msg) } + +func LogDebug(msg string) { logDebug(msg) } + +func LogWarn(msg string) { logWarn(msg) } + +func LogError(msg string) { logError(msg) } diff --git a/codeagent-wrapper/logger.go b/codeagent-wrapper/internal/logger/logger.go similarity index 89% rename from codeagent-wrapper/logger.go rename to codeagent-wrapper/internal/logger/logger.go index 425cfc4..5cfa66c 100644 --- a/codeagent-wrapper/logger.go +++ b/codeagent-wrapper/internal/logger/logger.go @@ -1,4 +1,4 @@ -package main +package logger import ( "bufio" @@ -13,6 +13,8 @@ import ( "sync" "sync/atomic" "time" + + "github.com/rs/zerolog" ) // Logger writes log messages asynchronously to a temp file. @@ -22,6 +24,7 @@ type Logger struct { path string file *os.File writer *bufio.Writer + zlogger zerolog.Logger ch chan logEntry flushReq chan chan struct{} done chan struct{} @@ -37,6 +40,7 @@ type Logger struct { type logEntry struct { msg string + level zerolog.Level isError bool // true for ERROR or WARN levels } @@ -73,7 +77,7 @@ func NewLogger() (*Logger, error) { // Useful for tests that need isolated log files within the same process. func NewLoggerWithSuffix(suffix string) (*Logger, error) { pid := os.Getpid() - filename := fmt.Sprintf("%s-%d", primaryLogPrefix(), pid) + filename := fmt.Sprintf("%s-%d", PrimaryLogPrefix(), pid) var safeSuffix string if suffix != "" { safeSuffix = sanitizeLogSuffix(suffix) @@ -103,6 +107,8 @@ func NewLoggerWithSuffix(suffix string) (*Logger, error) { done: make(chan struct{}), } + l.zlogger = zerolog.New(l.writer).With().Timestamp().Logger() + l.workerWG.Add(1) go l.run() @@ -184,17 +190,24 @@ func (l *Logger) Path() string { return l.path } +func (l *Logger) IsClosed() bool { + if l == nil { + return true + } + return l.closed.Load() +} + // Info logs at INFO level. -func (l *Logger) Info(msg string) { l.log("INFO", msg) } +func (l *Logger) Info(msg string) { l.logWithLevel(zerolog.InfoLevel, msg) } // Warn logs at WARN level. -func (l *Logger) Warn(msg string) { l.log("WARN", msg) } +func (l *Logger) Warn(msg string) { l.logWithLevel(zerolog.WarnLevel, msg) } // Debug logs at DEBUG level. -func (l *Logger) Debug(msg string) { l.log("DEBUG", msg) } +func (l *Logger) Debug(msg string) { l.logWithLevel(zerolog.DebugLevel, msg) } // Error logs at ERROR level. -func (l *Logger) Error(msg string) { l.log("ERROR", msg) } +func (l *Logger) Error(msg string) { l.logWithLevel(zerolog.ErrorLevel, msg) } // Close signals the worker to flush and close the log file. // The log file is NOT removed, allowing inspection after program exit. @@ -335,7 +348,7 @@ func (l *Logger) Flush() { } } -func (l *Logger) log(level, msg string) { +func (l *Logger) logWithLevel(entryLevel zerolog.Level, msg string) { if l == nil { return } @@ -343,8 +356,8 @@ func (l *Logger) log(level, msg string) { return } - isError := level == "WARN" || level == "ERROR" - entry := logEntry{msg: msg, isError: isError} + isError := entryLevel == zerolog.WarnLevel || entryLevel == zerolog.ErrorLevel + entry := logEntry{msg: msg, level: entryLevel, isError: isError} l.flushMu.Lock() l.pendingWG.Add(1) l.flushMu.Unlock() @@ -366,8 +379,7 @@ func (l *Logger) run() { defer ticker.Stop() writeEntry := func(entry logEntry) { - timestamp := time.Now().Format("2006-01-02 15:04:05.000") - fmt.Fprintf(l.writer, "[%s] %s\n", timestamp, entry.msg) + l.zlogger.WithLevel(entry.level).Msg(entry.msg) // Cache error/warn entries in memory for fast extraction if entry.isError { @@ -439,10 +451,7 @@ func cleanupOldLogs() (CleanupStats, error) { var stats CleanupStats tempDir := os.TempDir() - prefixes := logPrefixes() - if len(prefixes) == 0 { - prefixes = []string{defaultWrapperName} - } + prefixes := LogPrefixes() seen := make(map[string]struct{}) var matches []string @@ -473,7 +482,8 @@ func cleanupOldLogs() (CleanupStats, error) { stats.Kept++ stats.KeptFiles = append(stats.KeptFiles, filename) if reason != "" { - logWarn(fmt.Sprintf("cleanupOldLogs: skipping %s: %s", filename, reason)) + // Use Debug level to avoid polluting Recent Errors with cleanup noise + logDebug(fmt.Sprintf("cleanupOldLogs: skipping %s: %s", filename, reason)) } continue } @@ -591,10 +601,7 @@ func isPIDReused(logPath string, pid int) bool { if procStartTime.IsZero() { // Can't determine process start time // Check if file is very old (>7 days), likely from a dead process - if time.Since(fileModTime) > 7*24*time.Hour { - return true // File is old enough to be from a different process - } - return false // Be conservative for recent files + return time.Since(fileModTime) > 7*24*time.Hour } // If the log file was modified before the process started, PID was reused @@ -604,10 +611,7 @@ func isPIDReused(logPath string, pid int) bool { func parsePIDFromLog(path string) (int, bool) { name := filepath.Base(path) - prefixes := logPrefixes() - if len(prefixes) == 0 { - prefixes = []string{defaultWrapperName} - } + prefixes := LogPrefixes() for _, prefix := range prefixes { prefixWithDash := fmt.Sprintf("%s-", prefix) @@ -661,3 +665,19 @@ func renderWorkerLimit(limit int) string { } return strconv.Itoa(limit) } + +func CleanupOldLogs() (CleanupStats, error) { return cleanupOldLogs() } + +func IsUnsafeFile(path string, tempDir string) (bool, string) { return isUnsafeFile(path, tempDir) } + +func IsPIDReused(logPath string, pid int) bool { return isPIDReused(logPath, pid) } + +func ParsePIDFromLog(path string) (int, bool) { return parsePIDFromLog(path) } + +func LogConcurrencyPlanning(limit, total int) { logConcurrencyPlanning(limit, total) } + +func LogConcurrencyState(event, taskID string, active, limit int) { + logConcurrencyState(event, taskID, active, limit) +} + +func SanitizeLogSuffix(raw string) string { return sanitizeLogSuffix(raw) } diff --git a/codeagent-wrapper/logger_additional_coverage_test.go b/codeagent-wrapper/internal/logger/logger_additional_coverage_test.go similarity index 97% rename from codeagent-wrapper/logger_additional_coverage_test.go rename to codeagent-wrapper/internal/logger/logger_additional_coverage_test.go index 0e8be30..615f3cf 100644 --- a/codeagent-wrapper/logger_additional_coverage_test.go +++ b/codeagent-wrapper/internal/logger/logger_additional_coverage_test.go @@ -1,4 +1,4 @@ -package main +package logger import ( "fmt" @@ -28,7 +28,7 @@ func TestLoggerConcurrencyLogHelpers(t *testing.T) { t.Fatalf("NewLoggerWithSuffix error: %v", err) } setLogger(logger) - defer closeLogger() + defer func() { _ = closeLogger() }() logConcurrencyPlanning(0, 2) logConcurrencyPlanning(3, 2) @@ -64,8 +64,8 @@ func TestLoggerConcurrencyLogHelpersNoopWithoutActiveLogger(t *testing.T) { func TestLoggerCleanupOldLogsSkipsUnsafeAndHandlesAlreadyDeleted(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) - unsafePath := createTempLog(t, tempDir, fmt.Sprintf("%s-%d.log", primaryLogPrefix(), 222)) - orphanPath := createTempLog(t, tempDir, fmt.Sprintf("%s-%d.log", primaryLogPrefix(), 111)) + unsafePath := createTempLog(t, tempDir, fmt.Sprintf("%s-%d.log", PrimaryLogPrefix(), 222)) + orphanPath := createTempLog(t, tempDir, fmt.Sprintf("%s-%d.log", PrimaryLogPrefix(), 111)) stubFileStat(t, func(path string) (os.FileInfo, error) { if path == unsafePath { diff --git a/codeagent-wrapper/logger_suffix_test.go b/codeagent-wrapper/internal/logger/logger_suffix_test.go similarity index 93% rename from codeagent-wrapper/logger_suffix_test.go rename to codeagent-wrapper/internal/logger/logger_suffix_test.go index dc4a94f..dffbfbd 100644 --- a/codeagent-wrapper/logger_suffix_test.go +++ b/codeagent-wrapper/internal/logger/logger_suffix_test.go @@ -1,4 +1,4 @@ -package main +package logger import ( "fmt" @@ -26,12 +26,12 @@ func TestLoggerWithSuffixNamingAndIsolation(t *testing.T) { } defer loggerB.Close() - wantA := filepath.Join(tempDir, fmt.Sprintf("%s-%d-%s.log", primaryLogPrefix(), os.Getpid(), taskA)) + wantA := filepath.Join(tempDir, fmt.Sprintf("%s-%d-%s.log", PrimaryLogPrefix(), os.Getpid(), taskA)) if loggerA.Path() != wantA { t.Fatalf("loggerA path = %q, want %q", loggerA.Path(), wantA) } - wantB := filepath.Join(tempDir, fmt.Sprintf("%s-%d-%s.log", primaryLogPrefix(), os.Getpid(), taskB)) + wantB := filepath.Join(tempDir, fmt.Sprintf("%s-%d-%s.log", PrimaryLogPrefix(), os.Getpid(), taskB)) if loggerB.Path() != wantB { t.Fatalf("loggerB path = %q, want %q", loggerB.Path(), wantB) } @@ -105,7 +105,7 @@ func TestLoggerWithSuffixSanitizesUnsafeSuffix(t *testing.T) { _ = os.Remove(logger.Path()) }) - wantBase := fmt.Sprintf("%s-%d-%s.log", primaryLogPrefix(), os.Getpid(), safe) + wantBase := fmt.Sprintf("%s-%d-%s.log", PrimaryLogPrefix(), os.Getpid(), safe) if gotBase := filepath.Base(logger.Path()); gotBase != wantBase { t.Fatalf("log filename = %q, want %q", gotBase, wantBase) } diff --git a/codeagent-wrapper/logger_test.go b/codeagent-wrapper/internal/logger/logger_test.go similarity index 68% rename from codeagent-wrapper/logger_test.go rename to codeagent-wrapper/internal/logger/logger_test.go index e0f5e31..6061f1b 100644 --- a/codeagent-wrapper/logger_test.go +++ b/codeagent-wrapper/internal/logger/logger_test.go @@ -1,4 +1,4 @@ -package main +package logger import ( "bufio" @@ -6,7 +6,6 @@ import ( "fmt" "math" "os" - "os/exec" "path/filepath" "strconv" "strings" @@ -77,30 +76,6 @@ func TestLoggerWritesLevels(t *testing.T) { } } -func TestLoggerDefaultIsTerminalCoverage(t *testing.T) { - oldStdin := os.Stdin - t.Cleanup(func() { os.Stdin = oldStdin }) - - f, err := os.CreateTemp(t.TempDir(), "stdin-*") - if err != nil { - t.Fatalf("os.CreateTemp() error = %v", err) - } - defer os.Remove(f.Name()) - - os.Stdin = f - if got := defaultIsTerminal(); got { - t.Fatalf("defaultIsTerminal() = %v, want false for regular file", got) - } - - if err := f.Close(); err != nil { - t.Fatalf("Close() error = %v", err) - } - os.Stdin = f - if got := defaultIsTerminal(); !got { - t.Fatalf("defaultIsTerminal() = %v, want true when Stat fails", got) - } -} - func TestLoggerCloseStopsWorkerAndKeepsFile(t *testing.T) { tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) @@ -118,11 +93,6 @@ func TestLoggerCloseStopsWorkerAndKeepsFile(t *testing.T) { if err := logger.Close(); err != nil { t.Fatalf("Close() returned error: %v", err) } - if logger.file != nil { - if _, err := logger.file.Write([]byte("x")); err == nil { - t.Fatalf("expected file to be closed after Close()") - } - } // After recent changes, log file is kept for debugging - NOT removed if _, err := os.Stat(logPath); os.IsNotExist(err) { @@ -131,18 +101,6 @@ func TestLoggerCloseStopsWorkerAndKeepsFile(t *testing.T) { // Clean up manually for test defer os.Remove(logPath) - - done := make(chan struct{}) - go func() { - logger.workerWG.Wait() - close(done) - }() - - select { - case <-done: - case <-time.After(200 * time.Millisecond): - t.Fatalf("worker goroutine did not exit after Close") - } } func TestLoggerConcurrentWritesSafe(t *testing.T) { @@ -194,50 +152,13 @@ func TestLoggerConcurrentWritesSafe(t *testing.T) { } } -func TestLoggerTerminateProcessActive(t *testing.T) { - cmd := exec.Command("sleep", "5") - if err := cmd.Start(); err != nil { - t.Skipf("cannot start sleep command: %v", err) - } - - timer := terminateProcess(&realCmd{cmd: cmd}) - if timer == nil { - t.Fatalf("terminateProcess returned nil timer for active process") - } - defer timer.Stop() - - done := make(chan error, 1) - go func() { - done <- cmd.Wait() - }() - - select { - case <-time.After(500 * time.Millisecond): - t.Fatalf("process not terminated promptly") - case <-done: - } - - // Force the timer callback to run immediately to cover the kill branch. - timer.Reset(0) - time.Sleep(10 * time.Millisecond) -} - -func TestLoggerTerminateProcessNil(t *testing.T) { - if timer := terminateProcess(nil); timer != nil { - t.Fatalf("terminateProcess(nil) should return nil timer") - } - if timer := terminateProcess(&realCmd{cmd: &exec.Cmd{}}); timer != nil { - t.Fatalf("terminateProcess with nil process should return nil timer") - } -} - func TestLoggerCleanupOldLogsRemovesOrphans(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) - orphan1 := createTempLog(t, tempDir, "codex-wrapper-111.log") - orphan2 := createTempLog(t, tempDir, "codex-wrapper-222-suffix.log") - running1 := createTempLog(t, tempDir, "codex-wrapper-333.log") - running2 := createTempLog(t, tempDir, "codex-wrapper-444-extra-info.log") + orphan1 := createTempLog(t, tempDir, "codeagent-wrapper-111.log") + orphan2 := createTempLog(t, tempDir, "codeagent-wrapper-222-suffix.log") + running1 := createTempLog(t, tempDir, "codeagent-wrapper-333.log") + running2 := createTempLog(t, tempDir, "codeagent-wrapper-444-extra-info.log") untouched := createTempLog(t, tempDir, "unrelated.log") runningPIDs := map[int]bool{333: true, 444: true} @@ -285,15 +206,15 @@ func TestLoggerCleanupOldLogsHandlesInvalidNamesAndErrors(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) invalid := []string{ - "codex-wrapper-.log", - "codex-wrapper.log", - "codex-wrapper-foo-bar.txt", + "codeagent-wrapper-.log", + "codeagent-wrapper.log", + "codeagent-wrapper-foo-bar.txt", "not-a-codex.log", } for _, name := range invalid { createTempLog(t, tempDir, name) } - target := createTempLog(t, tempDir, "codex-wrapper-555-extra.log") + target := createTempLog(t, tempDir, "codeagent-wrapper-555-extra.log") var checked []int stubProcessRunning(t, func(pid int) bool { @@ -389,8 +310,8 @@ func TestLoggerCleanupOldLogsHandlesTempDirPermissionErrors(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) paths := []string{ - createTempLog(t, tempDir, "codex-wrapper-6100.log"), - createTempLog(t, tempDir, "codex-wrapper-6101.log"), + createTempLog(t, tempDir, "codeagent-wrapper-6100.log"), + createTempLog(t, tempDir, "codeagent-wrapper-6101.log"), } stubProcessRunning(t, func(int) bool { return false }) @@ -428,8 +349,8 @@ func TestLoggerCleanupOldLogsHandlesTempDirPermissionErrors(t *testing.T) { func TestLoggerCleanupOldLogsHandlesPermissionDeniedFile(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) - protected := createTempLog(t, tempDir, "codex-wrapper-6200.log") - deletable := createTempLog(t, tempDir, "codex-wrapper-6201.log") + protected := createTempLog(t, tempDir, "codeagent-wrapper-6200.log") + deletable := createTempLog(t, tempDir, "codeagent-wrapper-6201.log") stubProcessRunning(t, func(int) bool { return false }) stubProcessStartTime(t, func(int) time.Time { return time.Time{} }) @@ -468,7 +389,7 @@ func TestLoggerCleanupOldLogsPerformanceBound(t *testing.T) { const fileCount = 400 fakePaths := make([]string, fileCount) for i := 0; i < fileCount; i++ { - name := fmt.Sprintf("codex-wrapper-%d.log", 10000+i) + name := fmt.Sprintf("codeagent-wrapper-%d.log", 10000+i) fakePaths[i] = createTempLog(t, tempDir, name) } @@ -505,102 +426,11 @@ func TestLoggerCleanupOldLogsPerformanceBound(t *testing.T) { } } -func TestLoggerCleanupOldLogsCoverageSuite(t *testing.T) { - TestBackendParseJSONStream_CoverageSuite(t) -} - -// Reuse the existing coverage suite so the focused TestLogger run still exercises -// the rest of the codebase and keeps coverage high. -func TestLoggerCoverageSuite(t *testing.T) { - suite := []struct { - name string - fn func(*testing.T) - }{ - {"TestBackendParseJSONStream_CoverageSuite", TestBackendParseJSONStream_CoverageSuite}, - {"TestVersionCoverageFullRun", TestVersionCoverageFullRun}, - {"TestVersionMainWrapper", TestVersionMainWrapper}, - - {"TestExecutorHelperCoverage", TestExecutorHelperCoverage}, - {"TestExecutorRunCodexTaskWithContext", TestExecutorRunCodexTaskWithContext}, - {"TestExecutorParallelLogIsolation", TestExecutorParallelLogIsolation}, - {"TestExecutorTaskLoggerContext", TestExecutorTaskLoggerContext}, - {"TestExecutorExecuteConcurrentWithContextBranches", TestExecutorExecuteConcurrentWithContextBranches}, - {"TestExecutorSignalAndTermination", TestExecutorSignalAndTermination}, - {"TestExecutorCancelReasonAndCloseWithReason", TestExecutorCancelReasonAndCloseWithReason}, - {"TestExecutorForceKillTimerStop", TestExecutorForceKillTimerStop}, - {"TestExecutorForwardSignalsDefaults", TestExecutorForwardSignalsDefaults}, - - {"TestBackendParseArgs_NewMode", TestBackendParseArgs_NewMode}, - {"TestBackendParseArgs_ResumeMode", TestBackendParseArgs_ResumeMode}, - {"TestBackendParseArgs_BackendFlag", TestBackendParseArgs_BackendFlag}, - {"TestBackendParseArgs_SkipPermissions", TestBackendParseArgs_SkipPermissions}, - {"TestBackendParseBoolFlag", TestBackendParseBoolFlag}, - {"TestBackendEnvFlagEnabled", TestBackendEnvFlagEnabled}, - {"TestRunResolveTimeout", TestRunResolveTimeout}, - {"TestRunIsTerminal", TestRunIsTerminal}, - {"TestRunReadPipedTask", TestRunReadPipedTask}, - {"TestTailBufferWrite", TestTailBufferWrite}, - {"TestLogWriterWriteLimitsBuffer", TestLogWriterWriteLimitsBuffer}, - {"TestLogWriterLogLine", TestLogWriterLogLine}, - {"TestNewLogWriterDefaultMaxLen", TestNewLogWriterDefaultMaxLen}, - {"TestNewLogWriterDefaultLimit", TestNewLogWriterDefaultLimit}, - {"TestRunHello", TestRunHello}, - {"TestRunGreet", TestRunGreet}, - {"TestRunFarewell", TestRunFarewell}, - {"TestRunFarewellEmpty", TestRunFarewellEmpty}, - - {"TestParallelParseConfig_Success", TestParallelParseConfig_Success}, - {"TestParallelParseConfig_Backend", TestParallelParseConfig_Backend}, - {"TestParallelParseConfig_InvalidFormat", TestParallelParseConfig_InvalidFormat}, - {"TestParallelParseConfig_EmptyTasks", TestParallelParseConfig_EmptyTasks}, - {"TestParallelParseConfig_MissingID", TestParallelParseConfig_MissingID}, - {"TestParallelParseConfig_MissingTask", TestParallelParseConfig_MissingTask}, - {"TestParallelParseConfig_DuplicateID", TestParallelParseConfig_DuplicateID}, - {"TestParallelParseConfig_DelimiterFormat", TestParallelParseConfig_DelimiterFormat}, - - {"TestBackendSelectBackend", TestBackendSelectBackend}, - {"TestBackendSelectBackend_Invalid", TestBackendSelectBackend_Invalid}, - {"TestBackendSelectBackend_DefaultOnEmpty", TestBackendSelectBackend_DefaultOnEmpty}, - {"TestBackendBuildArgs_CodexBackend", TestBackendBuildArgs_CodexBackend}, - {"TestBackendBuildArgs_ClaudeBackend", TestBackendBuildArgs_ClaudeBackend}, - {"TestClaudeBackendBuildArgs_OutputValidation", TestClaudeBackendBuildArgs_OutputValidation}, - {"TestBackendBuildArgs_GeminiBackend", TestBackendBuildArgs_GeminiBackend}, - {"TestGeminiBackendBuildArgs_OutputValidation", TestGeminiBackendBuildArgs_OutputValidation}, - {"TestBackendNamesAndCommands", TestBackendNamesAndCommands}, - - {"TestBackendParseJSONStream", TestBackendParseJSONStream}, - {"TestBackendParseJSONStream_ClaudeEvents", TestBackendParseJSONStream_ClaudeEvents}, - {"TestBackendParseJSONStream_GeminiEvents", TestBackendParseJSONStream_GeminiEvents}, - {"TestBackendParseJSONStreamWithWarn_InvalidLine", TestBackendParseJSONStreamWithWarn_InvalidLine}, - {"TestBackendParseJSONStream_OnMessage", TestBackendParseJSONStream_OnMessage}, - {"TestBackendParseJSONStream_ScannerError", TestBackendParseJSONStream_ScannerError}, - {"TestBackendDiscardInvalidJSON", TestBackendDiscardInvalidJSON}, - {"TestBackendDiscardInvalidJSONBuffer", TestBackendDiscardInvalidJSONBuffer}, - - {"TestCurrentWrapperNameFallsBackToExecutable", TestCurrentWrapperNameFallsBackToExecutable}, - {"TestCurrentWrapperNameDetectsLegacyAliasSymlink", TestCurrentWrapperNameDetectsLegacyAliasSymlink}, - - {"TestIsProcessRunning", TestIsProcessRunning}, - {"TestGetProcessStartTimeReadsProcStat", TestGetProcessStartTimeReadsProcStat}, - {"TestGetProcessStartTimeInvalidData", TestGetProcessStartTimeInvalidData}, - {"TestGetBootTimeParsesBtime", TestGetBootTimeParsesBtime}, - {"TestGetBootTimeInvalidData", TestGetBootTimeInvalidData}, - - {"TestClaudeBuildArgs_ModesAndPermissions", TestClaudeBuildArgs_ModesAndPermissions}, - {"TestClaudeBuildArgs_GeminiAndCodexModes", TestClaudeBuildArgs_GeminiAndCodexModes}, - {"TestClaudeBuildArgs_BackendMetadata", TestClaudeBuildArgs_BackendMetadata}, - } - - for _, tc := range suite { - t.Run(tc.name, tc.fn) - } -} - func TestLoggerCleanupOldLogsKeepsCurrentProcessLog(t *testing.T) { tempDir := setTempDirEnv(t, t.TempDir()) currentPID := os.Getpid() - currentLog := createTempLog(t, tempDir, fmt.Sprintf("codex-wrapper-%d.log", currentPID)) + currentLog := createTempLog(t, tempDir, fmt.Sprintf("codeagent-wrapper-%d.log", currentPID)) stubProcessRunning(t, func(pid int) bool { if pid != currentPID { @@ -676,7 +506,7 @@ func TestLoggerIsUnsafeFileSecurityChecks(t *testing.T) { stubEvalSymlinks(t, func(path string) (string, error) { return filepath.Join(absTempDir, filepath.Base(path)), nil }) - unsafe, reason := isUnsafeFile(filepath.Join(absTempDir, "codex-wrapper-1.log"), tempDir) + unsafe, reason := isUnsafeFile(filepath.Join(absTempDir, "codeagent-wrapper-1.log"), tempDir) if !unsafe || reason != "refusing to delete symlink" { t.Fatalf("expected symlink to be rejected, got unsafe=%v reason=%q", unsafe, reason) } @@ -702,9 +532,9 @@ func TestLoggerIsUnsafeFileSecurityChecks(t *testing.T) { }) otherDir := t.TempDir() stubEvalSymlinks(t, func(string) (string, error) { - return filepath.Join(otherDir, "codex-wrapper-9.log"), nil + return filepath.Join(otherDir, "codeagent-wrapper-9.log"), nil }) - unsafe, reason := isUnsafeFile(filepath.Join(otherDir, "codex-wrapper-9.log"), tempDir) + unsafe, reason := isUnsafeFile(filepath.Join(otherDir, "codeagent-wrapper-9.log"), tempDir) if !unsafe || reason != "file is outside tempDir" { t.Fatalf("expected outside file to be rejected, got unsafe=%v reason=%q", unsafe, reason) } @@ -713,15 +543,21 @@ func TestLoggerIsUnsafeFileSecurityChecks(t *testing.T) { func TestLoggerPathAndRemove(t *testing.T) { tempDir := t.TempDir() - path := filepath.Join(tempDir, "sample.log") - if err := os.WriteFile(path, []byte("test"), 0o644); err != nil { - t.Fatalf("failed to create temp file: %v", err) + t.Setenv("TMPDIR", tempDir) + + logger, err := NewLoggerWithSuffix("sample") + if err != nil { + t.Fatalf("NewLoggerWithSuffix() error = %v", err) + } + path := logger.Path() + if path == "" { + _ = logger.Close() + t.Fatalf("logger.Path() returned empty path") + } + if err := logger.Close(); err != nil { + t.Fatalf("Close() error = %v", err) } - logger := &Logger{path: path} - if got := logger.Path(); got != path { - t.Fatalf("Path() = %q, want %q", got, path) - } if err := logger.RemoveLogFile(); err != nil { t.Fatalf("RemoveLogFile() error = %v", err) } @@ -738,43 +574,6 @@ func TestLoggerPathAndRemove(t *testing.T) { } } -func TestLoggerTruncateBytesCoverage(t *testing.T) { - if got := truncateBytes([]byte("abc"), 3); got != "abc" { - t.Fatalf("truncateBytes() = %q, want %q", got, "abc") - } - if got := truncateBytes([]byte("abcd"), 3); got != "abc..." { - t.Fatalf("truncateBytes() = %q, want %q", got, "abc...") - } - if got := truncateBytes([]byte("abcd"), -1); got != "" { - t.Fatalf("truncateBytes() = %q, want empty string", got) - } -} - -func TestLoggerInternalLog(t *testing.T) { - logger := &Logger{ - ch: make(chan logEntry, 1), - done: make(chan struct{}), - pendingWG: sync.WaitGroup{}, - } - - done := make(chan logEntry, 1) - go func() { - entry := <-logger.ch - logger.pendingWG.Done() - done <- entry - }() - - logger.log("INFO", "hello") - entry := <-done - if entry.msg != "hello" { - t.Fatalf("unexpected entry %+v", entry) - } - - logger.closed.Store(true) - logger.log("INFO", "ignored") - close(logger.done) -} - func TestLoggerParsePIDFromLog(t *testing.T) { hugePID := strconv.FormatInt(math.MaxInt64, 10) + "0" tests := []struct { @@ -782,13 +581,13 @@ func TestLoggerParsePIDFromLog(t *testing.T) { pid int ok bool }{ - {"codex-wrapper-123.log", 123, true}, - {"codex-wrapper-999-extra.log", 999, true}, - {"codex-wrapper-.log", 0, false}, + {"codeagent-wrapper-123.log", 123, true}, + {"codeagent-wrapper-999-extra.log", 999, true}, + {"codeagent-wrapper-.log", 0, false}, {"invalid-name.log", 0, false}, - {"codex-wrapper--5.log", 0, false}, - {"codex-wrapper-0.log", 0, false}, - {fmt.Sprintf("codex-wrapper-%s.log", hugePID), 0, false}, + {"codeagent-wrapper--5.log", 0, false}, + {"codeagent-wrapper-0.log", 0, false}, + {fmt.Sprintf("codeagent-wrapper-%s.log", hugePID), 0, false}, } for _, tt := range tests { @@ -827,56 +626,32 @@ func setTempDirEnv(t *testing.T, dir string) string { func stubProcessRunning(t *testing.T, fn func(int) bool) { t.Helper() - original := processRunningCheck - processRunningCheck = fn - t.Cleanup(func() { - processRunningCheck = original - }) + t.Cleanup(SetProcessRunningCheck(fn)) } func stubProcessStartTime(t *testing.T, fn func(int) time.Time) { t.Helper() - original := processStartTimeFn - processStartTimeFn = fn - t.Cleanup(func() { - processStartTimeFn = original - }) + t.Cleanup(SetProcessStartTimeFn(fn)) } func stubRemoveLogFile(t *testing.T, fn func(string) error) { t.Helper() - original := removeLogFileFn - removeLogFileFn = fn - t.Cleanup(func() { - removeLogFileFn = original - }) + t.Cleanup(SetRemoveLogFileFn(fn)) } func stubGlobLogFiles(t *testing.T, fn func(string) ([]string, error)) { t.Helper() - original := globLogFiles - globLogFiles = fn - t.Cleanup(func() { - globLogFiles = original - }) + t.Cleanup(SetGlobLogFilesFn(fn)) } func stubFileStat(t *testing.T, fn func(string) (os.FileInfo, error)) { t.Helper() - original := fileStatFn - fileStatFn = fn - t.Cleanup(func() { - fileStatFn = original - }) + t.Cleanup(SetFileStatFn(fn)) } func stubEvalSymlinks(t *testing.T, fn func(string) (string, error)) { t.Helper() - original := evalSymlinksFn - evalSymlinksFn = fn - t.Cleanup(func() { - evalSymlinksFn = original - }) + t.Cleanup(SetEvalSymlinksFn(fn)) } type fakeFileInfo struct { @@ -960,7 +735,7 @@ func TestLoggerExtractRecentErrors(t *testing.T) { t.Fatalf("NewLoggerWithSuffix() error = %v", err) } defer logger.Close() - defer logger.RemoveLogFile() + defer func() { _ = logger.RemoveLogFile() }() // Write logs using logger methods for _, entry := range tt.logs { @@ -1000,14 +775,14 @@ func TestLoggerExtractRecentErrorsNilLogger(t *testing.T) { } func TestLoggerExtractRecentErrorsEmptyPath(t *testing.T) { - logger := &Logger{path: ""} + logger := &Logger{} if got := logger.ExtractRecentErrors(10); got != nil { t.Fatalf("empty path ExtractRecentErrors() should return nil, got %v", got) } } func TestLoggerExtractRecentErrorsFileNotExist(t *testing.T) { - logger := &Logger{path: "/nonexistent/path/to/log.log"} + logger := &Logger{} if got := logger.ExtractRecentErrors(10); got != nil { t.Fatalf("nonexistent file ExtractRecentErrors() should return nil, got %v", got) } @@ -1049,7 +824,7 @@ func TestExtractRecentErrorsBoundaryCheck(t *testing.T) { t.Fatalf("NewLoggerWithSuffix() error = %v", err) } defer logger.Close() - defer logger.RemoveLogFile() + defer func() { _ = logger.RemoveLogFile() }() // Write some errors logger.Error("error 1") @@ -1082,7 +857,7 @@ func TestErrorEntriesMaxLimit(t *testing.T) { t.Fatalf("NewLoggerWithSuffix() error = %v", err) } defer logger.Close() - defer logger.RemoveLogFile() + defer func() { _ = logger.RemoveLogFile() }() // Write 150 error/warn entries for i := 1; i <= 150; i++ { diff --git a/codeagent-wrapper/internal/logger/process_check.go b/codeagent-wrapper/internal/logger/process_check.go new file mode 100644 index 0000000..20d3010 --- /dev/null +++ b/codeagent-wrapper/internal/logger/process_check.go @@ -0,0 +1,63 @@ +package logger + +import ( + "errors" + "math" + "time" + + "github.com/shirou/gopsutil/v3/process" +) + +func pidToInt32(pid int) (int32, bool) { + if pid <= 0 || pid > math.MaxInt32 { + return 0, false + } + return int32(pid), true +} + +// isProcessRunning reports whether a process with the given pid appears to be running. +// It is intentionally conservative on errors to avoid deleting logs for live processes. +func isProcessRunning(pid int) bool { + pid32, ok := pidToInt32(pid) + if !ok { + return false + } + + exists, err := process.PidExists(pid32) + if err == nil { + return exists + } + + // If we can positively identify that the process doesn't exist, report false. + if errors.Is(err, process.ErrorProcessNotRunning) { + return false + } + + // Permission/inspection failures: assume it's running to be safe. + return true +} + +// getProcessStartTime returns the start time of a process. +// Returns zero time if the start time cannot be determined. +func getProcessStartTime(pid int) time.Time { + pid32, ok := pidToInt32(pid) + if !ok { + return time.Time{} + } + + proc, err := process.NewProcess(pid32) + if err != nil { + return time.Time{} + } + + ms, err := proc.CreateTime() + if err != nil || ms <= 0 { + return time.Time{} + } + + return time.UnixMilli(ms) +} + +func IsProcessRunning(pid int) bool { return isProcessRunning(pid) } + +func GetProcessStartTime(pid int) time.Time { return getProcessStartTime(pid) } diff --git a/codeagent-wrapper/internal/logger/process_check_test.go b/codeagent-wrapper/internal/logger/process_check_test.go new file mode 100644 index 0000000..eabb830 --- /dev/null +++ b/codeagent-wrapper/internal/logger/process_check_test.go @@ -0,0 +1,112 @@ +package logger + +import ( + "math" + "os" + "os/exec" + "runtime" + "strconv" + "testing" + "time" +) + +func TestIsProcessRunning(t *testing.T) { + t.Run("boundary values", func(t *testing.T) { + if isProcessRunning(0) { + t.Fatalf("pid 0 should never be treated as running") + } + if isProcessRunning(-1) { + t.Fatalf("negative pid should never be treated as running") + } + }) + + t.Run("pid out of int32 range", func(t *testing.T) { + if strconv.IntSize <= 32 { + t.Skip("int cannot represent values above int32 range") + } + + pid := int(int64(math.MaxInt32) + 1) + if isProcessRunning(pid) { + t.Fatalf("expected pid %d (out of int32 range) to be treated as not running", pid) + } + }) + + t.Run("current process", func(t *testing.T) { + if !isProcessRunning(os.Getpid()) { + t.Fatalf("expected current process (pid=%d) to be running", os.Getpid()) + } + }) + + t.Run("fake pid", func(t *testing.T) { + const nonexistentPID = 1 << 30 + if isProcessRunning(nonexistentPID) { + t.Fatalf("expected pid %d to be reported as not running", nonexistentPID) + } + }) + + t.Run("terminated process", func(t *testing.T) { + pid := exitedProcessPID(t) + if isProcessRunning(pid) { + t.Fatalf("expected exited child process (pid=%d) to be reported as not running", pid) + } + }) +} + +func exitedProcessPID(t *testing.T) int { + t.Helper() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("cmd", "/c", "exit 0") + } else { + cmd = exec.Command("sh", "-c", "exit 0") + } + + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start helper process: %v", err) + } + pid := cmd.Process.Pid + + if err := cmd.Wait(); err != nil { + t.Fatalf("helper process did not exit cleanly: %v", err) + } + + time.Sleep(50 * time.Millisecond) + return pid +} + +func TestGetProcessStartTimeReadsProcStat(t *testing.T) { + start := getProcessStartTime(os.Getpid()) + if start.IsZero() { + t.Fatalf("expected non-zero start time for current process") + } + if start.After(time.Now().Add(5 * time.Second)) { + t.Fatalf("start time is unexpectedly in the future: %v", start) + } +} + +func TestGetProcessStartTimeInvalidData(t *testing.T) { + if !getProcessStartTime(0).IsZero() { + t.Fatalf("expected zero time for pid 0") + } + if !getProcessStartTime(-1).IsZero() { + t.Fatalf("expected zero time for negative pid") + } + if !getProcessStartTime(1 << 30).IsZero() { + t.Fatalf("expected zero time for non-existent pid") + } + if strconv.IntSize > 32 { + pid := int(int64(math.MaxInt32) + 1) + if !getProcessStartTime(pid).IsZero() { + t.Fatalf("expected zero time for pid %d (out of int32 range)", pid) + } + } +} + +func TestGetBootTimeParsesBtime(t *testing.T) { + t.Skip("legacy boot-time probing removed; start time now uses gopsutil") +} + +func TestGetBootTimeInvalidData(t *testing.T) { + t.Skip("legacy boot-time probing removed; start time now uses gopsutil") +} diff --git a/codeagent-wrapper/internal/logger/testhooks.go b/codeagent-wrapper/internal/logger/testhooks.go new file mode 100644 index 0000000..370711d --- /dev/null +++ b/codeagent-wrapper/internal/logger/testhooks.go @@ -0,0 +1,67 @@ +package logger + +import ( + "os" + "path/filepath" + "time" +) + +func SetProcessRunningCheck(fn func(int) bool) (restore func()) { + prev := processRunningCheck + if fn != nil { + processRunningCheck = fn + } else { + processRunningCheck = isProcessRunning + } + return func() { processRunningCheck = prev } +} + +func SetProcessStartTimeFn(fn func(int) time.Time) (restore func()) { + prev := processStartTimeFn + if fn != nil { + processStartTimeFn = fn + } else { + processStartTimeFn = getProcessStartTime + } + return func() { processStartTimeFn = prev } +} + +func SetRemoveLogFileFn(fn func(string) error) (restore func()) { + prev := removeLogFileFn + if fn != nil { + removeLogFileFn = fn + } else { + removeLogFileFn = os.Remove + } + return func() { removeLogFileFn = prev } +} + +func SetGlobLogFilesFn(fn func(string) ([]string, error)) (restore func()) { + prev := globLogFiles + if fn != nil { + globLogFiles = fn + } else { + globLogFiles = filepath.Glob + } + return func() { globLogFiles = prev } +} + +func SetFileStatFn(fn func(string) (os.FileInfo, error)) (restore func()) { + prev := fileStatFn + if fn != nil { + fileStatFn = fn + } else { + fileStatFn = os.Lstat + } + return func() { fileStatFn = prev } +} + +func SetEvalSymlinksFn(fn func(string) (string, error)) (restore func()) { + prev := evalSymlinksFn + if fn != nil { + evalSymlinksFn = fn + } else { + evalSymlinksFn = filepath.EvalSymlinks + } + return func() { evalSymlinksFn = prev } +} diff --git a/codeagent-wrapper/internal/logger/wrapper_name.go b/codeagent-wrapper/internal/logger/wrapper_name.go new file mode 100644 index 0000000..2ca2b37 --- /dev/null +++ b/codeagent-wrapper/internal/logger/wrapper_name.go @@ -0,0 +1,13 @@ +package logger + +// WrapperName is the fixed name for this tool. +const WrapperName = "codeagent-wrapper" + +// CurrentWrapperName returns the wrapper name (always "codeagent-wrapper"). +func CurrentWrapperName() string { return WrapperName } + +// LogPrefixes returns the log file name prefixes to look for. +func LogPrefixes() []string { return []string{WrapperName} } + +// PrimaryLogPrefix returns the preferred filename prefix for log files. +func PrimaryLogPrefix() string { return WrapperName } diff --git a/codeagent-wrapper/internal/parser/event.go b/codeagent-wrapper/internal/parser/event.go new file mode 100644 index 0000000..30774b6 --- /dev/null +++ b/codeagent-wrapper/internal/parser/event.go @@ -0,0 +1,74 @@ +package parser + +import "github.com/goccy/go-json" + +// JSONEvent represents a Codex JSON output event. +type JSONEvent struct { + Type string `json:"type"` + ThreadID string `json:"thread_id,omitempty"` + Item *EventItem `json:"item,omitempty"` +} + +// EventItem represents the item field in a JSON event. +type EventItem struct { + Type string `json:"type"` + Text interface{} `json:"text"` +} + +// ClaudeEvent for Claude stream-json format. +type ClaudeEvent struct { + Type string `json:"type"` + Subtype string `json:"subtype,omitempty"` + SessionID string `json:"session_id,omitempty"` + Result string `json:"result,omitempty"` +} + +// GeminiEvent for Gemini stream-json format. +type GeminiEvent struct { + Type string `json:"type"` + SessionID string `json:"session_id,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Delta bool `json:"delta,omitempty"` + Status string `json:"status,omitempty"` +} + +// UnifiedEvent combines all backend event formats into a single structure +// to avoid multiple JSON unmarshal operations per event. +type UnifiedEvent struct { + // Common fields + Type string `json:"type"` + + // Codex-specific fields + ThreadID string `json:"thread_id,omitempty"` + Item json.RawMessage `json:"item,omitempty"` // Lazy parse + + // Claude-specific fields + Subtype string `json:"subtype,omitempty"` + SessionID string `json:"session_id,omitempty"` + Result string `json:"result,omitempty"` + + // Gemini-specific fields + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Delta *bool `json:"delta,omitempty"` + Status string `json:"status,omitempty"` + + // Opencode-specific fields (camelCase sessionID) + OpencodeSessionID string `json:"sessionID,omitempty"` + Part json.RawMessage `json:"part,omitempty"` +} + +// OpencodePart represents the part field in opencode events. +type OpencodePart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Reason string `json:"reason,omitempty"` + SessionID string `json:"sessionID,omitempty"` +} + +// ItemContent represents the parsed item.text field for Codex events. +type ItemContent struct { + Type string `json:"type"` + Text interface{} `json:"text"` +} diff --git a/codeagent-wrapper/parser.go b/codeagent-wrapper/internal/parser/parser.go similarity index 71% rename from codeagent-wrapper/parser.go rename to codeagent-wrapper/internal/parser/parser.go index ceaa18d..edc45ee 100644 --- a/codeagent-wrapper/parser.go +++ b/codeagent-wrapper/internal/parser/parser.go @@ -1,106 +1,65 @@ -package main +package parser import ( "bufio" "bytes" - "encoding/json" "errors" "fmt" "io" "strings" + "sync" + + "github.com/goccy/go-json" ) -// JSONEvent represents a Codex JSON output event -type JSONEvent struct { - Type string `json:"type"` - ThreadID string `json:"thread_id,omitempty"` - Item *EventItem `json:"item,omitempty"` -} - -// EventItem represents the item field in a JSON event -type EventItem struct { - Type string `json:"type"` - Text interface{} `json:"text"` -} - -// ClaudeEvent for Claude stream-json format -type ClaudeEvent struct { - Type string `json:"type"` - Subtype string `json:"subtype,omitempty"` - SessionID string `json:"session_id,omitempty"` - Result string `json:"result,omitempty"` -} - -// GeminiEvent for Gemini stream-json format -type GeminiEvent struct { - Type string `json:"type"` - SessionID string `json:"session_id,omitempty"` - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - Delta bool `json:"delta,omitempty"` - Status string `json:"status,omitempty"` -} - -func parseJSONStream(r io.Reader) (message, threadID string) { - return parseJSONStreamWithLog(r, logWarn, logInfo) -} - -func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadID string) { - return parseJSONStreamWithLog(r, warnFn, logInfo) -} - -func parseJSONStreamWithLog(r io.Reader, warnFn func(string), infoFn func(string)) (message, threadID string) { - return parseJSONStreamInternal(r, warnFn, infoFn, nil, nil) -} - const ( jsonLineReaderSize = 64 * 1024 jsonLineMaxBytes = 10 * 1024 * 1024 jsonLinePreviewBytes = 256 ) -// UnifiedEvent combines all backend event formats into a single structure -// to avoid multiple JSON unmarshal operations per event -type UnifiedEvent struct { - // Common fields - Type string `json:"type"` - - // Codex-specific fields - ThreadID string `json:"thread_id,omitempty"` - Item json.RawMessage `json:"item,omitempty"` // Lazy parse - - // Claude-specific fields - Subtype string `json:"subtype,omitempty"` - SessionID string `json:"session_id,omitempty"` - Result string `json:"result,omitempty"` - - // Gemini-specific fields - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - Delta *bool `json:"delta,omitempty"` - Status string `json:"status,omitempty"` - - // Opencode-specific fields (camelCase sessionID) - OpencodeSessionID string `json:"sessionID,omitempty"` - Part json.RawMessage `json:"part,omitempty"` +type lineScratch struct { + buf []byte + preview []byte } -// OpencodePart represents the part field in opencode events -type OpencodePart struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Reason string `json:"reason,omitempty"` - SessionID string `json:"sessionID,omitempty"` +const maxPooledLineScratchCap = 1 << 20 // 1 MiB + +var lineScratchPool = sync.Pool{ + New: func() any { + return &lineScratch{ + buf: make([]byte, 0, jsonLineReaderSize), + preview: make([]byte, 0, jsonLinePreviewBytes), + } + }, } -// ItemContent represents the parsed item.text field for Codex events -type ItemContent struct { - Type string `json:"type"` - Text interface{} `json:"text"` -} - -func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) { +func ParseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(string), onMessage func(), onComplete func()) (message, threadID string) { reader := bufio.NewReaderSize(r, jsonLineReaderSize) + scratch := lineScratchPool.Get().(*lineScratch) + if scratch.buf == nil { + scratch.buf = make([]byte, 0, jsonLineReaderSize) + } else { + scratch.buf = scratch.buf[:0] + } + if scratch.preview == nil { + scratch.preview = make([]byte, 0, jsonLinePreviewBytes) + } else { + scratch.preview = scratch.preview[:0] + } + defer func() { + if cap(scratch.buf) > maxPooledLineScratchCap { + scratch.buf = nil + } else if scratch.buf != nil { + scratch.buf = scratch.buf[:0] + } + if cap(scratch.preview) > jsonLinePreviewBytes*4 { + scratch.preview = nil + } else if scratch.preview != nil { + scratch.preview = scratch.preview[:0] + } + lineScratchPool.Put(scratch) + }() if warnFn == nil { warnFn = func(string) {} @@ -131,7 +90,7 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin ) for { - line, tooLong, err := readLineWithLimit(reader, jsonLineMaxBytes, jsonLinePreviewBytes) + line, tooLong, err := readLineWithLimit(reader, jsonLineMaxBytes, jsonLinePreviewBytes, scratch) if err != nil { if errors.Is(err, io.EOF) { break @@ -147,14 +106,14 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin totalEvents++ if tooLong { - warnFn(fmt.Sprintf("Skipped overlong JSON line (> %d bytes): %s", jsonLineMaxBytes, truncateBytes(line, 100))) + warnFn(fmt.Sprintf("Skipped overlong JSON line (> %d bytes): %s", jsonLineMaxBytes, TruncateBytes(line, 100))) continue } // Single unmarshal for all backend types var event UnifiedEvent if err := json.Unmarshal(line, &event); err != nil { - warnFn(fmt.Sprintf("Failed to parse event: %s", truncateBytes(line, 100))) + warnFn(fmt.Sprintf("Failed to parse event: %s", TruncateBytes(line, 100))) continue } @@ -253,7 +212,7 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin // Lazy parse: only parse item content when needed var item ItemContent if err := json.Unmarshal(event.Item, &item); err == nil { - normalized := normalizeText(item.Text) + normalized := NormalizeText(item.Text) infoFn(fmt.Sprintf("item.completed event item_type=%s message_len=%d", itemType, len(normalized))) if normalized != "" { codexMessage = normalized @@ -334,12 +293,12 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin return message, threadID } -func hasKey(m map[string]json.RawMessage, key string) bool { +func HasKey(m map[string]json.RawMessage, key string) bool { _, ok := m[key] return ok } -func discardInvalidJSON(decoder *json.Decoder, reader *bufio.Reader) (*bufio.Reader, error) { +func DiscardInvalidJSON(decoder *json.Decoder, reader *bufio.Reader) (*bufio.Reader, error) { var buffered bytes.Buffer if decoder != nil { @@ -365,7 +324,7 @@ func discardInvalidJSON(decoder *json.Decoder, reader *bufio.Reader) (*bufio.Rea return bufio.NewReader(io.MultiReader(bytes.NewReader(remaining), reader)), err } -func readLineWithLimit(r *bufio.Reader, maxBytes int, previewBytes int) (line []byte, tooLong bool, err error) { +func readLineWithLimit(r *bufio.Reader, maxBytes int, previewBytes int, scratch *lineScratch) (line []byte, tooLong bool, err error) { if r == nil { return nil, false, errors.New("reader is nil") } @@ -388,12 +347,22 @@ func readLineWithLimit(r *bufio.Reader, maxBytes int, previewBytes int) (line [] return part, false, nil } - preview := make([]byte, 0, min(previewBytes, len(part))) + if scratch == nil { + scratch = &lineScratch{} + } + if scratch.preview == nil { + scratch.preview = make([]byte, 0, min(previewBytes, len(part))) + } + if scratch.buf == nil { + scratch.buf = make([]byte, 0, min(maxBytes, len(part)*2)) + } + + preview := scratch.preview[:0] if previewBytes > 0 { preview = append(preview, part[:min(previewBytes, len(part))]...) } - buf := make([]byte, 0, min(maxBytes, len(part)*2)) + buf := scratch.buf[:0] total := 0 if len(part) > maxBytes { tooLong = true @@ -423,12 +392,16 @@ func readLineWithLimit(r *bufio.Reader, maxBytes int, previewBytes int) (line [] } if tooLong { + scratch.preview = preview + scratch.buf = buf return preview, true, nil } + scratch.preview = preview + scratch.buf = buf return buf, false, nil } -func truncateBytes(b []byte, maxLen int) string { +func TruncateBytes(b []byte, maxLen int) string { if len(b) <= maxLen { return string(b) } @@ -438,7 +411,7 @@ func truncateBytes(b []byte, maxLen int) string { return string(b[:maxLen]) + "..." } -func normalizeText(text interface{}) string { +func NormalizeText(text interface{}) string { switch v := text.(type) { case string: return v @@ -454,3 +427,10 @@ func normalizeText(text interface{}) string { return "" } } + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/codeagent-wrapper/parser_opencode_test.go b/codeagent-wrapper/internal/parser/parser_opencode_test.go similarity index 88% rename from codeagent-wrapper/parser_opencode_test.go rename to codeagent-wrapper/internal/parser/parser_opencode_test.go index f8ced6a..35ca5f5 100644 --- a/codeagent-wrapper/parser_opencode_test.go +++ b/codeagent-wrapper/internal/parser/parser_opencode_test.go @@ -1,4 +1,4 @@ -package main +package parser import ( "strings" @@ -10,7 +10,7 @@ func TestParseJSONStream_Opencode(t *testing.T) { {"type":"text","timestamp":1768187744432,"sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","part":{"id":"prt_bb0339cb5001QDd0Lh0PzFZpa3","sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","messageID":"msg_bb033866f0011oZxTqvfy0TKtS","type":"text","text":"Hello from opencode"}} {"type":"step_finish","timestamp":1768187744471,"sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","part":{"id":"prt_bb033d0af0019VRZzpO2OVW1na","sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","messageID":"msg_bb033866f0011oZxTqvfy0TKtS","type":"step-finish","reason":"stop","snapshot":"904f0fd58c125b79e60f0993e38f9d9f6200bf47","cost":0}}` - message, threadID := parseJSONStream(strings.NewReader(input)) + message, threadID := ParseJSONStreamInternal(strings.NewReader(input), nil, nil, nil, nil) if threadID != "ses_44fced3c7ffe83sZpzY1rlQka3" { t.Errorf("threadID = %q, want %q", threadID, "ses_44fced3c7ffe83sZpzY1rlQka3") @@ -25,7 +25,7 @@ func TestParseJSONStream_Opencode_MultipleTextEvents(t *testing.T) { {"type":"text","sessionID":"ses_123","part":{"type":"text","text":" Part 2"}} {"type":"step_finish","sessionID":"ses_123","part":{"type":"step-finish","reason":"stop"}}` - message, threadID := parseJSONStream(strings.NewReader(input)) + message, threadID := ParseJSONStreamInternal(strings.NewReader(input), nil, nil, nil, nil) if threadID != "ses_123" { t.Errorf("threadID = %q, want %q", threadID, "ses_123") @@ -39,7 +39,7 @@ func TestParseJSONStream_Opencode_NoStopReason(t *testing.T) { input := `{"type":"text","sessionID":"ses_456","part":{"type":"text","text":"Content"}} {"type":"step_finish","sessionID":"ses_456","part":{"type":"step-finish","reason":"tool-calls"}}` - message, threadID := parseJSONStream(strings.NewReader(input)) + message, threadID := ParseJSONStreamInternal(strings.NewReader(input), nil, nil, nil, nil) if threadID != "ses_456" { t.Errorf("threadID = %q, want %q", threadID, "ses_456") diff --git a/codeagent-wrapper/parser_token_too_long_test.go b/codeagent-wrapper/internal/parser/parser_token_too_long_test.go similarity index 92% rename from codeagent-wrapper/parser_token_too_long_test.go rename to codeagent-wrapper/internal/parser/parser_token_too_long_test.go index 662e443..0de421b 100644 --- a/codeagent-wrapper/parser_token_too_long_test.go +++ b/codeagent-wrapper/internal/parser/parser_token_too_long_test.go @@ -1,4 +1,4 @@ -package main +package parser import ( "strings" @@ -18,7 +18,7 @@ func TestParseJSONStream_SkipsOverlongLineAndContinues(t *testing.T) { var warns []string warnFn := func(msg string) { warns = append(warns, msg) } - gotMessage, gotThreadID := parseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil, nil) + gotMessage, gotThreadID := ParseJSONStreamInternal(strings.NewReader(input), warnFn, nil, nil, nil) if gotMessage != "ok" { t.Fatalf("message=%q, want %q (warns=%v)", gotMessage, "ok", warns) } diff --git a/codeagent-wrapper/parser_unknown_event_test.go b/codeagent-wrapper/internal/parser/parser_unknown_event_test.go similarity index 90% rename from codeagent-wrapper/parser_unknown_event_test.go rename to codeagent-wrapper/internal/parser/parser_unknown_event_test.go index b3a6e5b..cb9163c 100644 --- a/codeagent-wrapper/parser_unknown_event_test.go +++ b/codeagent-wrapper/internal/parser/parser_unknown_event_test.go @@ -1,4 +1,4 @@ -package main +package parser import ( "strings" @@ -16,7 +16,7 @@ func TestBackendParseJSONStream_UnknownEventsAreSilent(t *testing.T) { var infos []string infoFn := func(msg string) { infos = append(infos, msg) } - message, threadID := parseJSONStreamInternal(strings.NewReader(input), nil, infoFn, nil, nil) + message, threadID := ParseJSONStreamInternal(strings.NewReader(input), nil, infoFn, nil, nil) if message != "ok" { t.Fatalf("message=%q, want %q (infos=%v)", message, "ok", infos) } diff --git a/codeagent-wrapper/internal/parser/truncate_bytes_test.go b/codeagent-wrapper/internal/parser/truncate_bytes_test.go new file mode 100644 index 0000000..8b7ce90 --- /dev/null +++ b/codeagent-wrapper/internal/parser/truncate_bytes_test.go @@ -0,0 +1,15 @@ +package parser + +import "testing" + +func TestTruncateBytes(t *testing.T) { + if got := TruncateBytes([]byte("abc"), 3); got != "abc" { + t.Fatalf("TruncateBytes() = %q, want %q", got, "abc") + } + if got := TruncateBytes([]byte("abcd"), 3); got != "abc..." { + t.Fatalf("TruncateBytes() = %q, want %q", got, "abc...") + } + if got := TruncateBytes([]byte("abcd"), -1); got != "" { + t.Fatalf("TruncateBytes() = %q, want empty string", got) + } +} diff --git a/codeagent-wrapper/internal/utils/math.go b/codeagent-wrapper/internal/utils/math.go new file mode 100644 index 0000000..43d30e7 --- /dev/null +++ b/codeagent-wrapper/internal/utils/math.go @@ -0,0 +1,8 @@ +package utils + +func Min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/codeagent-wrapper/internal/utils/math_test.go b/codeagent-wrapper/internal/utils/math_test.go new file mode 100644 index 0000000..d8a99a6 --- /dev/null +++ b/codeagent-wrapper/internal/utils/math_test.go @@ -0,0 +1,36 @@ +package utils + +import "testing" + +func TestMin(t *testing.T) { + tests := []struct { + name string + a, b int + want int + }{ + {"a less than b", 1, 2, 1}, + {"b less than a", 5, 3, 3}, + {"equal values", 7, 7, 7}, + {"negative a", -5, 3, -5}, + {"negative b", 5, -3, -3}, + {"both negative", -5, -3, -5}, + {"zero and positive", 0, 5, 0}, + {"zero and negative", 0, -5, -5}, + {"large values", 1000000, 999999, 999999}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Min(tt.a, tt.b) + if got != tt.want { + t.Errorf("Min(%d, %d) = %d, want %d", tt.a, tt.b, got, tt.want) + } + }) + } +} + +func BenchmarkMin(b *testing.B) { + for i := 0; i < b.N; i++ { + Min(i, i+1) + } +} diff --git a/codeagent-wrapper/internal/utils/strings.go b/codeagent-wrapper/internal/utils/strings.go new file mode 100644 index 0000000..089dc7d --- /dev/null +++ b/codeagent-wrapper/internal/utils/strings.go @@ -0,0 +1,62 @@ +package utils + +import "strings" + +func Truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + if maxLen < 0 { + return "" + } + return s[:maxLen] + "..." +} + +// SafeTruncate safely truncates string to maxLen, avoiding panic and UTF-8 corruption. +func SafeTruncate(s string, maxLen int) string { + if maxLen <= 0 || s == "" { + return "" + } + + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + + if maxLen < 4 { + return string(runes[:1]) + } + + cutoff := maxLen - 3 + if cutoff <= 0 { + return string(runes[:1]) + } + if len(runes) <= cutoff { + return s + } + return string(runes[:cutoff]) + "..." +} + +// SanitizeOutput removes ANSI escape sequences and control characters. +func SanitizeOutput(s string) string { + var result strings.Builder + inEscape := false + for i := 0; i < len(s); i++ { + if s[i] == '\x1b' && i+1 < len(s) && s[i+1] == '[' { + inEscape = true + i++ // skip '[' + continue + } + if inEscape { + if (s[i] >= 'A' && s[i] <= 'Z') || (s[i] >= 'a' && s[i] <= 'z') { + inEscape = false + } + continue + } + // Keep printable chars and common whitespace. + if s[i] >= 32 || s[i] == '\n' || s[i] == '\t' { + result.WriteByte(s[i]) + } + } + return result.String() +} diff --git a/codeagent-wrapper/internal/utils/strings_test.go b/codeagent-wrapper/internal/utils/strings_test.go new file mode 100644 index 0000000..d572035 --- /dev/null +++ b/codeagent-wrapper/internal/utils/strings_test.go @@ -0,0 +1,122 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + s string + maxLen int + want string + }{ + {"empty string", "", 10, ""}, + {"short string", "hello", 10, "hello"}, + {"exact length", "hello", 5, "hello"}, + {"needs truncation", "hello world", 5, "hello..."}, + {"zero maxLen", "hello", 0, "..."}, + {"negative maxLen", "hello", -1, ""}, + {"maxLen 1", "hello", 1, "h..."}, + {"unicode bytes truncate", "你好世界", 10, "你好世\xe7..."}, // Truncate works on bytes, not runes + {"mixed truncate", "hello世界abc", 7, "hello\xe4\xb8..."}, // byte-based truncation + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Truncate(tt.s, tt.maxLen) + if got != tt.want { + t.Errorf("Truncate(%q, %d) = %q, want %q", tt.s, tt.maxLen, got, tt.want) + } + }) + } +} + +func TestSafeTruncate(t *testing.T) { + tests := []struct { + name string + s string + maxLen int + want string + }{ + {"empty string", "", 10, ""}, + {"zero maxLen", "hello", 0, ""}, + {"negative maxLen", "hello", -1, ""}, + {"short string", "hello", 10, "hello"}, + {"exact length", "hello", 5, "hello"}, + {"needs truncation", "hello world", 8, "hello..."}, + {"maxLen 1", "hello", 1, "h"}, + {"maxLen 2", "hello", 2, "h"}, + {"maxLen 3", "hello", 3, "h"}, + {"maxLen 4", "hello", 4, "h..."}, + {"unicode preserved", "你好世界", 10, "你好世界"}, + {"unicode exact", "你好世界", 4, "你好世界"}, + {"unicode truncate", "你好世界test", 6, "你好世..."}, + {"mixed unicode", "ab你好cd", 5, "ab..."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SafeTruncate(tt.s, tt.maxLen) + if got != tt.want { + t.Errorf("SafeTruncate(%q, %d) = %q, want %q", tt.s, tt.maxLen, got, tt.want) + } + }) + } +} + +func TestSanitizeOutput(t *testing.T) { + tests := []struct { + name string + s string + want string + }{ + {"empty string", "", ""}, + {"plain text", "hello world", "hello world"}, + {"with newline", "hello\nworld", "hello\nworld"}, + {"with tab", "hello\tworld", "hello\tworld"}, + {"ANSI color red", "\x1b[31mred\x1b[0m", "red"}, + {"ANSI bold", "\x1b[1mbold\x1b[0m", "bold"}, + {"ANSI complex", "\x1b[1;31;40mtext\x1b[0m", "text"}, + {"control chars", "hello\x00\x01\x02world", "helloworld"}, + {"mixed ANSI and control", "\x1b[32m\x00ok\x1b[0m", "ok"}, + {"multiple ANSI sequences", "\x1b[31mred\x1b[0m \x1b[32mgreen\x1b[0m", "red green"}, + {"incomplete escape", "\x1b[", ""}, + {"escape without bracket", "\x1bA", "A"}, + {"cursor movement", "\x1b[2Aup\x1b[2Bdown", "updown"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SanitizeOutput(tt.s) + if got != tt.want { + t.Errorf("SanitizeOutput(%q) = %q, want %q", tt.s, got, tt.want) + } + }) + } +} + +func BenchmarkTruncate(b *testing.B) { + s := strings.Repeat("hello world ", 100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Truncate(s, 50) + } +} + +func BenchmarkSafeTruncate(b *testing.B) { + s := strings.Repeat("你好世界", 100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + SafeTruncate(s, 50) + } +} + +func BenchmarkSanitizeOutput(b *testing.B) { + s := strings.Repeat("\x1b[31mred\x1b[0m text ", 50) + b.ResetTimer() + for i := 0; i < b.N; i++ { + SanitizeOutput(s) + } +} diff --git a/codeagent-wrapper/main.go b/codeagent-wrapper/main.go deleted file mode 100644 index 6ecb7a6..0000000 --- a/codeagent-wrapper/main.go +++ /dev/null @@ -1,627 +0,0 @@ -package main - -import ( - "fmt" - "io" - "os" - "os/exec" - "os/signal" - "path/filepath" - "reflect" - "strings" - "sync/atomic" - "time" -) - -const ( - version = "5.6.4" - defaultWorkdir = "." - defaultTimeout = 7200 // seconds (2 hours) - defaultCoverageTarget = 90.0 - codexLogLineLimit = 1000 - stdinSpecialChars = "\n\\\"'`$" - stderrCaptureLimit = 4 * 1024 - defaultBackendName = "codex" - defaultCodexCommand = "codex" - - // stdout close reasons - stdoutCloseReasonWait = "wait-done" - stdoutCloseReasonDrain = "drain-timeout" - stdoutCloseReasonCtx = "context-cancel" - stdoutDrainTimeout = 100 * time.Millisecond -) - -// Test hooks for dependency injection -var ( - stdinReader io.Reader = os.Stdin - isTerminalFn = defaultIsTerminal - codexCommand = defaultCodexCommand - cleanupHook func() - loggerPtr atomic.Pointer[Logger] - - buildCodexArgsFn = buildCodexArgs - selectBackendFn = selectBackend - commandContext = exec.CommandContext - cleanupLogsFn = cleanupOldLogs - signalNotifyFn = signal.Notify - signalStopFn = signal.Stop - terminateCommandFn = terminateCommand - defaultBuildArgsFn = buildCodexArgs - runTaskFn = runCodexTask - exitFn = os.Exit -) - -var forceKillDelay atomic.Int32 - -func init() { - forceKillDelay.Store(5) // seconds - default value -} - -func runStartupCleanup() { - if cleanupLogsFn == nil { - return - } - defer func() { - if r := recover(); r != nil { - logWarn(fmt.Sprintf("cleanupOldLogs panic: %v", r)) - } - }() - if _, err := cleanupLogsFn(); err != nil { - logWarn(fmt.Sprintf("cleanupOldLogs error: %v", err)) - } -} - -func runCleanupMode() int { - if cleanupLogsFn == nil { - fmt.Fprintln(os.Stderr, "Cleanup failed: log cleanup function not configured") - return 1 - } - - stats, err := cleanupLogsFn() - if err != nil { - fmt.Fprintf(os.Stderr, "Cleanup failed: %v\n", err) - return 1 - } - - fmt.Println("Cleanup completed") - fmt.Printf("Files scanned: %d\n", stats.Scanned) - fmt.Printf("Files deleted: %d\n", stats.Deleted) - if len(stats.DeletedFiles) > 0 { - for _, f := range stats.DeletedFiles { - fmt.Printf(" - %s\n", f) - } - } - fmt.Printf("Files kept: %d\n", stats.Kept) - if len(stats.KeptFiles) > 0 { - for _, f := range stats.KeptFiles { - fmt.Printf(" - %s\n", f) - } - } - if stats.Errors > 0 { - fmt.Printf("Deletion errors: %d\n", stats.Errors) - } - return 0 -} - -func main() { - exitCode := run() - exitFn(exitCode) -} - -// run is the main logic, returns exit code for testability -func run() (exitCode int) { - name := currentWrapperName() - // Handle --version and --help first (no logger needed) - if len(os.Args) > 1 { - switch os.Args[1] { - case "--version", "-v": - fmt.Printf("%s version %s\n", name, version) - return 0 - case "--help", "-h": - printHelp() - return 0 - case "--cleanup": - return runCleanupMode() - } - } - - // Initialize logger for all other commands - logger, err := NewLogger() - if err != nil { - fmt.Fprintf(os.Stderr, "ERROR: failed to initialize logger: %v\n", err) - return 1 - } - setLogger(logger) - - defer func() { - logger := activeLogger() - if logger != nil { - logger.Flush() - } - if err := closeLogger(); err != nil { - fmt.Fprintf(os.Stderr, "ERROR: failed to close logger: %v\n", err) - } - // On failure, extract and display recent errors before removing log - if logger != nil { - if exitCode != 0 { - if errors := logger.ExtractRecentErrors(10); len(errors) > 0 { - fmt.Fprintln(os.Stderr, "\n=== Recent Errors ===") - for _, entry := range errors { - fmt.Fprintln(os.Stderr, entry) - } - fmt.Fprintf(os.Stderr, "Log file: %s (deleted)\n", logger.Path()) - } - } - if err := logger.RemoveLogFile(); err != nil && !os.IsNotExist(err) { - // Silently ignore removal errors - } - } - }() - defer runCleanupHook() - - // Clean up stale logs from previous runs. - runStartupCleanup() - - // Handle remaining commands - if len(os.Args) > 1 { - args := os.Args[1:] - parallelIndex := -1 - for i, arg := range args { - if arg == "--parallel" { - parallelIndex = i - break - } - } - - if parallelIndex != -1 { - backendName := defaultBackendName - model := "" - fullOutput := false - skipPermissions := envFlagEnabled("CODEAGENT_SKIP_PERMISSIONS") - var extras []string - - for i := 0; i < len(args); i++ { - arg := args[i] - switch { - case arg == "--parallel": - continue - case arg == "--full-output": - fullOutput = true - case arg == "--backend": - if i+1 >= len(args) { - fmt.Fprintln(os.Stderr, "ERROR: --backend flag requires a value") - return 1 - } - backendName = args[i+1] - i++ - case strings.HasPrefix(arg, "--backend="): - value := strings.TrimPrefix(arg, "--backend=") - if value == "" { - fmt.Fprintln(os.Stderr, "ERROR: --backend flag requires a value") - return 1 - } - backendName = value - case arg == "--model": - if i+1 >= len(args) { - fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value") - return 1 - } - model = args[i+1] - i++ - case strings.HasPrefix(arg, "--model="): - value := strings.TrimPrefix(arg, "--model=") - if value == "" { - fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value") - return 1 - } - model = value - case arg == "--skip-permissions", arg == "--dangerously-skip-permissions": - skipPermissions = true - case strings.HasPrefix(arg, "--skip-permissions="): - skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--skip-permissions="), skipPermissions) - case strings.HasPrefix(arg, "--dangerously-skip-permissions="): - skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--dangerously-skip-permissions="), skipPermissions) - default: - extras = append(extras, arg) - } - } - - if len(extras) > 0 { - fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend, --model, --full-output and --skip-permissions are allowed.") - fmt.Fprintln(os.Stderr, "Usage examples:") - fmt.Fprintf(os.Stderr, " %s --parallel < tasks.txt\n", name) - fmt.Fprintf(os.Stderr, " echo '...' | %s --parallel\n", name) - fmt.Fprintf(os.Stderr, " %s --parallel <<'EOF'\n", name) - fmt.Fprintf(os.Stderr, " %s --parallel --full-output <<'EOF' # include full task output\n", name) - return 1 - } - - backend, err := selectBackendFn(backendName) - if err != nil { - fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) - return 1 - } - backendName = backend.Name() - - data, err := io.ReadAll(stdinReader) - if err != nil { - fmt.Fprintf(os.Stderr, "ERROR: failed to read stdin: %v\n", err) - return 1 - } - - cfg, err := parseParallelConfig(data) - if err != nil { - fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) - return 1 - } - - cfg.GlobalBackend = backendName - model = strings.TrimSpace(model) - for i := range cfg.Tasks { - if strings.TrimSpace(cfg.Tasks[i].Backend) == "" { - cfg.Tasks[i].Backend = backendName - } - if strings.TrimSpace(cfg.Tasks[i].Model) == "" && model != "" { - cfg.Tasks[i].Model = model - } - cfg.Tasks[i].SkipPermissions = cfg.Tasks[i].SkipPermissions || skipPermissions - } - - timeoutSec := resolveTimeout() - layers, err := topologicalSort(cfg.Tasks) - if err != nil { - fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) - return 1 - } - - results := executeConcurrent(layers, timeoutSec) - - // Extract structured report fields from each result - for i := range results { - results[i].CoverageTarget = defaultCoverageTarget - if results[i].Message == "" { - continue - } - - lines := strings.Split(results[i].Message, "\n") - - // Coverage extraction - results[i].Coverage = extractCoverageFromLines(lines) - results[i].CoverageNum = extractCoverageNum(results[i].Coverage) - - // Files changed - results[i].FilesChanged = extractFilesChangedFromLines(lines) - - // Test results - results[i].TestsPassed, results[i].TestsFailed = extractTestResultsFromLines(lines) - - // Key output summary - results[i].KeyOutput = extractKeyOutputFromLines(lines, 150) - } - - // Default: summary mode (context-efficient) - // --full-output: legacy full output mode - fmt.Println(generateFinalOutputWithMode(results, !fullOutput)) - - exitCode = 0 - for _, res := range results { - if res.ExitCode != 0 { - exitCode = res.ExitCode - } - } - - return exitCode - } - } - - logInfo("Script started") - - cfg, err := parseArgs() - if err != nil { - logError(err.Error()) - return 1 - } - logInfo(fmt.Sprintf("Parsed args: mode=%s, task_len=%d, backend=%s", cfg.Mode, len(cfg.Task), cfg.Backend)) - - backend, err := selectBackendFn(cfg.Backend) - if err != nil { - logError(err.Error()) - return 1 - } - cfg.Backend = backend.Name() - - cmdInjected := codexCommand != defaultCodexCommand - argsInjected := buildCodexArgsFn != nil && reflect.ValueOf(buildCodexArgsFn).Pointer() != reflect.ValueOf(defaultBuildArgsFn).Pointer() - - // Wire selected backend into runtime hooks for the rest of the execution, - // but preserve any injected test hooks for the default backend. - if backend.Name() != defaultBackendName || !cmdInjected { - codexCommand = backend.Command() - } - if backend.Name() != defaultBackendName || !argsInjected { - buildCodexArgsFn = backend.BuildArgs - } - logInfo(fmt.Sprintf("Selected backend: %s", backend.Name())) - - timeoutSec := resolveTimeout() - logInfo(fmt.Sprintf("Timeout: %ds", timeoutSec)) - cfg.Timeout = timeoutSec - - var taskText string - var piped bool - - if cfg.ExplicitStdin { - logInfo("Explicit stdin mode: reading task from stdin") - data, err := io.ReadAll(stdinReader) - if err != nil { - logError("Failed to read stdin: " + err.Error()) - return 1 - } - taskText = string(data) - if taskText == "" { - logError("Explicit stdin mode requires task input from stdin") - return 1 - } - piped = !isTerminal() - } else { - pipedTask, err := readPipedTask() - if err != nil { - logError("Failed to read piped stdin: " + err.Error()) - return 1 - } - piped = pipedTask != "" - if piped { - taskText = pipedTask - } else { - taskText = cfg.Task - } - } - - if strings.TrimSpace(cfg.PromptFile) != "" { - prompt, err := readAgentPromptFile(cfg.PromptFile, cfg.PromptFileExplicit) - if err != nil { - logError("Failed to read prompt file: " + err.Error()) - return 1 - } - taskText = wrapTaskWithAgentPrompt(prompt, taskText) - } - - useStdin := cfg.ExplicitStdin || shouldUseStdin(taskText, piped) - - targetArg := taskText - if useStdin { - targetArg = "-" - } - codexArgs := buildCodexArgsFn(cfg, targetArg) - - // Print startup information to stderr - fmt.Fprintf(os.Stderr, "[%s]\n", name) - fmt.Fprintf(os.Stderr, " Backend: %s\n", cfg.Backend) - fmt.Fprintf(os.Stderr, " Command: %s %s\n", codexCommand, strings.Join(codexArgs, " ")) - fmt.Fprintf(os.Stderr, " PID: %d\n", os.Getpid()) - fmt.Fprintf(os.Stderr, " Log: %s\n", logger.Path()) - - if useStdin { - var reasons []string - if piped { - reasons = append(reasons, "piped input") - } - if cfg.ExplicitStdin { - reasons = append(reasons, "explicit \"-\"") - } - if strings.Contains(taskText, "\n") { - reasons = append(reasons, "newline") - } - if strings.Contains(taskText, "\\") { - reasons = append(reasons, "backslash") - } - if strings.Contains(taskText, "\"") { - reasons = append(reasons, "double-quote") - } - if strings.Contains(taskText, "'") { - reasons = append(reasons, "single-quote") - } - if strings.Contains(taskText, "`") { - reasons = append(reasons, "backtick") - } - if strings.Contains(taskText, "$") { - reasons = append(reasons, "dollar") - } - if len(taskText) > 800 { - reasons = append(reasons, "length>800") - } - if len(reasons) > 0 { - logWarn(fmt.Sprintf("Using stdin mode for task due to: %s", strings.Join(reasons, ", "))) - } - } - - logInfo(fmt.Sprintf("%s running...", cfg.Backend)) - - taskSpec := TaskSpec{ - Task: taskText, - WorkDir: cfg.WorkDir, - Mode: cfg.Mode, - SessionID: cfg.SessionID, - Model: cfg.Model, - ReasoningEffort: cfg.ReasoningEffort, - SkipPermissions: cfg.SkipPermissions, - UseStdin: useStdin, - } - - result := runTaskFn(taskSpec, false, cfg.Timeout) - - if result.ExitCode != 0 { - return result.ExitCode - } - - fmt.Println(result.Message) - if result.SessionID != "" { - fmt.Printf("\n---\nSESSION_ID: %s\n", result.SessionID) - } - - return 0 -} - -func readAgentPromptFile(path string, allowOutsideClaudeDir bool) (string, error) { - raw := strings.TrimSpace(path) - if raw == "" { - return "", nil - } - - expanded := raw - if raw == "~" || strings.HasPrefix(raw, "~/") || strings.HasPrefix(raw, "~\\") { - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - if raw == "~" { - expanded = home - } else { - expanded = home + raw[1:] - } - } - - absPath, err := filepath.Abs(expanded) - if err != nil { - return "", err - } - absPath = filepath.Clean(absPath) - - home, err := os.UserHomeDir() - if err != nil { - if !allowOutsideClaudeDir { - return "", err - } - logWarn(fmt.Sprintf("Failed to resolve home directory for prompt file validation: %v; proceeding without restriction", err)) - } else { - allowedDir := filepath.Clean(filepath.Join(home, ".claude")) - allowedAbs, err := filepath.Abs(allowedDir) - if err == nil { - allowedDir = filepath.Clean(allowedAbs) - } - - isWithinDir := func(path, dir string) bool { - rel, err := filepath.Rel(dir, path) - if err != nil { - return false - } - rel = filepath.Clean(rel) - if rel == "." { - return true - } - if rel == ".." { - return false - } - prefix := ".." + string(os.PathSeparator) - return !strings.HasPrefix(rel, prefix) - } - - if !allowOutsideClaudeDir { - if !isWithinDir(absPath, allowedDir) { - logWarn(fmt.Sprintf("Refusing to read prompt file outside %s: %s", allowedDir, absPath)) - return "", fmt.Errorf("prompt file must be under %s", allowedDir) - } - resolvedPath, errPath := filepath.EvalSymlinks(absPath) - resolvedBase, errBase := filepath.EvalSymlinks(allowedDir) - if errPath == nil && errBase == nil { - resolvedPath = filepath.Clean(resolvedPath) - resolvedBase = filepath.Clean(resolvedBase) - if !isWithinDir(resolvedPath, resolvedBase) { - logWarn(fmt.Sprintf("Refusing to read prompt file outside %s (resolved): %s", resolvedBase, resolvedPath)) - return "", fmt.Errorf("prompt file must be under %s", resolvedBase) - } - } - } else if !isWithinDir(absPath, allowedDir) { - logWarn(fmt.Sprintf("Reading prompt file outside %s: %s", allowedDir, absPath)) - } - } - - data, err := os.ReadFile(absPath) - if err != nil { - return "", err - } - return strings.TrimRight(string(data), "\r\n"), nil -} - -func wrapTaskWithAgentPrompt(prompt string, task string) string { - return "\n" + prompt + "\n\n\n" + task -} - -func setLogger(l *Logger) { - loggerPtr.Store(l) -} - -func closeLogger() error { - logger := loggerPtr.Swap(nil) - if logger == nil { - return nil - } - return logger.Close() -} - -func activeLogger() *Logger { - return loggerPtr.Load() -} - -func logInfo(msg string) { - if logger := activeLogger(); logger != nil { - logger.Info(msg) - } -} - -func logWarn(msg string) { - if logger := activeLogger(); logger != nil { - logger.Warn(msg) - } -} - -func logError(msg string) { - if logger := activeLogger(); logger != nil { - logger.Error(msg) - } -} - -func runCleanupHook() { - if logger := activeLogger(); logger != nil { - logger.Flush() - } - if cleanupHook != nil { - cleanupHook() - } -} - -func printHelp() { - name := currentWrapperName() - help := fmt.Sprintf(`%[1]s - Go wrapper for AI CLI backends - -Usage: - %[1]s "task" [workdir] - %[1]s --backend claude "task" [workdir] - %[1]s --prompt-file /path/to/prompt.md "task" [workdir] - %[1]s - [workdir] Read task from stdin - %[1]s resume "task" [workdir] - %[1]s resume - [workdir] - %[1]s --parallel Run tasks in parallel (config from stdin) - %[1]s --parallel --full-output Run tasks in parallel with full output (legacy) - %[1]s --version - %[1]s --help - -Parallel mode examples: - %[1]s --parallel < tasks.txt - echo '...' | %[1]s --parallel - %[1]s --parallel --full-output < tasks.txt - %[1]s --parallel <<'EOF' - -Environment Variables: - CODEX_TIMEOUT Timeout in milliseconds (default: 7200000) - CODEAGENT_ASCII_MODE Use ASCII symbols instead of Unicode (PASS/WARN/FAIL) - -Exit Codes: - 0 Success - 1 General error (missing args, no output) - 124 Timeout - 127 backend command not found - 130 Interrupted (Ctrl+C) - * Passthrough from backend process`, name) - fmt.Println(help) -} diff --git a/codeagent-wrapper/process_check_test.go b/codeagent-wrapper/process_check_test.go deleted file mode 100644 index 9ad661e..0000000 --- a/codeagent-wrapper/process_check_test.go +++ /dev/null @@ -1,217 +0,0 @@ -//go:build unix || darwin || linux -// +build unix darwin linux - -package main - -import ( - "errors" - "fmt" - "os" - "os/exec" - "runtime" - "strconv" - "strings" - "testing" - "time" -) - -func TestIsProcessRunning(t *testing.T) { - t.Run("current process", func(t *testing.T) { - if !isProcessRunning(os.Getpid()) { - t.Fatalf("expected current process (pid=%d) to be running", os.Getpid()) - } - }) - - t.Run("fake pid", func(t *testing.T) { - const nonexistentPID = 1 << 30 - if isProcessRunning(nonexistentPID) { - t.Fatalf("expected pid %d to be reported as not running", nonexistentPID) - } - }) - - t.Run("terminated process", func(t *testing.T) { - pid := exitedProcessPID(t) - if isProcessRunning(pid) { - t.Fatalf("expected exited child process (pid=%d) to be reported as not running", pid) - } - }) - - t.Run("boundary values", func(t *testing.T) { - if isProcessRunning(0) { - t.Fatalf("pid 0 should never be treated as running") - } - if isProcessRunning(-42) { - t.Fatalf("negative pid should never be treated as running") - } - }) - - t.Run("find process error", func(t *testing.T) { - original := findProcess - defer func() { findProcess = original }() - - mockErr := errors.New("findProcess failure") - findProcess = func(pid int) (*os.Process, error) { - return nil, mockErr - } - - if isProcessRunning(1234) { - t.Fatalf("expected false when os.FindProcess fails") - } - }) -} - -func exitedProcessPID(t *testing.T) int { - t.Helper() - - var cmd *exec.Cmd - if runtime.GOOS == "windows" { - cmd = exec.Command("cmd", "/c", "exit 0") - } else { - cmd = exec.Command("sh", "-c", "exit 0") - } - - if err := cmd.Start(); err != nil { - t.Fatalf("failed to start helper process: %v", err) - } - pid := cmd.Process.Pid - - if err := cmd.Wait(); err != nil { - t.Fatalf("helper process did not exit cleanly: %v", err) - } - - time.Sleep(50 * time.Millisecond) - return pid -} - -func TestRunProcessCheckSmoke(t *testing.T) { - t.Run("current process", func(t *testing.T) { - if !isProcessRunning(os.Getpid()) { - t.Fatalf("expected current process (pid=%d) to be running", os.Getpid()) - } - }) - - t.Run("fake pid", func(t *testing.T) { - const nonexistentPID = 1 << 30 - if isProcessRunning(nonexistentPID) { - t.Fatalf("expected pid %d to be reported as not running", nonexistentPID) - } - }) - - t.Run("boundary values", func(t *testing.T) { - if isProcessRunning(0) { - t.Fatalf("pid 0 should never be treated as running") - } - if isProcessRunning(-42) { - t.Fatalf("negative pid should never be treated as running") - } - }) - - t.Run("find process error", func(t *testing.T) { - original := findProcess - defer func() { findProcess = original }() - - mockErr := errors.New("findProcess failure") - findProcess = func(pid int) (*os.Process, error) { - return nil, mockErr - } - - if isProcessRunning(1234) { - t.Fatalf("expected false when os.FindProcess fails") - } - }) -} - -func TestGetProcessStartTimeReadsProcStat(t *testing.T) { - pid := 4321 - boot := time.Unix(1_710_000_000, 0) - startTicks := uint64(4500) - - statFields := make([]string, 25) - for i := range statFields { - statFields[i] = strconv.Itoa(i + 1) - } - statFields[19] = strconv.FormatUint(startTicks, 10) - statContent := fmt.Sprintf("%d (%s) %s", pid, "cmd with space", strings.Join(statFields, " ")) - - stubReadFile(t, func(path string) ([]byte, error) { - switch path { - case fmt.Sprintf("/proc/%d/stat", pid): - return []byte(statContent), nil - case "/proc/stat": - return []byte(fmt.Sprintf("cpu 0 0 0 0\nbtime %d\n", boot.Unix())), nil - default: - return nil, os.ErrNotExist - } - }) - - got := getProcessStartTime(pid) - want := boot.Add(time.Duration(startTicks/100) * time.Second) - if !got.Equal(want) { - t.Fatalf("getProcessStartTime() = %v, want %v", got, want) - } -} - -func TestGetProcessStartTimeInvalidData(t *testing.T) { - pid := 99 - stubReadFile(t, func(path string) ([]byte, error) { - switch path { - case fmt.Sprintf("/proc/%d/stat", pid): - return []byte("garbage"), nil - case "/proc/stat": - return []byte("btime not-a-number\n"), nil - default: - return nil, os.ErrNotExist - } - }) - - if got := getProcessStartTime(pid); !got.IsZero() { - t.Fatalf("invalid /proc data should return zero time, got %v", got) - } -} - -func TestGetBootTimeParsesBtime(t *testing.T) { - const bootSec = 1_711_111_111 - stubReadFile(t, func(path string) ([]byte, error) { - if path != "/proc/stat" { - return nil, os.ErrNotExist - } - content := fmt.Sprintf("intr 0\nbtime %d\n", bootSec) - return []byte(content), nil - }) - - got := getBootTime() - want := time.Unix(bootSec, 0) - if !got.Equal(want) { - t.Fatalf("getBootTime() = %v, want %v", got, want) - } -} - -func TestGetBootTimeInvalidData(t *testing.T) { - cases := []struct { - name string - content string - }{ - {"missing", "cpu 0 0 0 0"}, - {"malformed", "btime abc"}, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - stubReadFile(t, func(string) ([]byte, error) { - return []byte(tt.content), nil - }) - if got := getBootTime(); !got.IsZero() { - t.Fatalf("getBootTime() unexpected value for %s: %v", tt.name, got) - } - }) - } -} - -func stubReadFile(t *testing.T, fn func(string) ([]byte, error)) { - t.Helper() - original := readFileFn - readFileFn = fn - t.Cleanup(func() { - readFileFn = original - }) -} diff --git a/codeagent-wrapper/process_check_unix.go b/codeagent-wrapper/process_check_unix.go deleted file mode 100644 index c235d65..0000000 --- a/codeagent-wrapper/process_check_unix.go +++ /dev/null @@ -1,104 +0,0 @@ -//go:build unix || darwin || linux -// +build unix darwin linux - -package main - -import ( - "errors" - "fmt" - "os" - "strconv" - "strings" - "syscall" - "time" -) - -var findProcess = os.FindProcess -var readFileFn = os.ReadFile - -// isProcessRunning returns true if a process with the given pid is running on Unix-like systems. -func isProcessRunning(pid int) bool { - if pid <= 0 { - return false - } - - proc, err := findProcess(pid) - if err != nil || proc == nil { - return false - } - - err = proc.Signal(syscall.Signal(0)) - if err != nil && (errors.Is(err, syscall.ESRCH) || errors.Is(err, os.ErrProcessDone)) { - return false - } - return true -} - -// getProcessStartTime returns the start time of a process on Unix-like systems. -// Returns zero time if the start time cannot be determined. -func getProcessStartTime(pid int) time.Time { - if pid <= 0 { - return time.Time{} - } - - // Read /proc//stat to get process start time - statPath := fmt.Sprintf("/proc/%d/stat", pid) - data, err := readFileFn(statPath) - if err != nil { - return time.Time{} - } - - // Parse stat file: fields are space-separated, but comm (field 2) can contain spaces - // Find the last ')' to skip comm field safely - content := string(data) - lastParen := strings.LastIndex(content, ")") - if lastParen == -1 { - return time.Time{} - } - - fields := strings.Fields(content[lastParen+1:]) - if len(fields) < 20 { - return time.Time{} - } - - // Field 22 (index 19 after comm) is starttime in clock ticks since boot - startTicks, err := strconv.ParseUint(fields[19], 10, 64) - if err != nil { - return time.Time{} - } - - // Get system boot time - bootTime := getBootTime() - if bootTime.IsZero() { - return time.Time{} - } - - // Convert ticks to duration (typically 100 ticks/sec on most systems) - ticksPerSec := uint64(100) // sysconf(_SC_CLK_TCK), typically 100 - startTime := bootTime.Add(time.Duration(startTicks/ticksPerSec) * time.Second) - - return startTime -} - -// getBootTime returns the system boot time by reading /proc/stat. -func getBootTime() time.Time { - data, err := readFileFn("/proc/stat") - if err != nil { - return time.Time{} - } - - lines := strings.Split(string(data), "\n") - for _, line := range lines { - if strings.HasPrefix(line, "btime ") { - fields := strings.Fields(line) - if len(fields) >= 2 { - bootSec, err := strconv.ParseInt(fields[1], 10, 64) - if err == nil { - return time.Unix(bootSec, 0) - } - } - } - } - - return time.Time{} -} diff --git a/codeagent-wrapper/process_check_windows.go b/codeagent-wrapper/process_check_windows.go deleted file mode 100644 index d303e42..0000000 --- a/codeagent-wrapper/process_check_windows.go +++ /dev/null @@ -1,87 +0,0 @@ -//go:build windows -// +build windows - -package main - -import ( - "errors" - "os" - "syscall" - "time" - "unsafe" -) - -const ( - processQueryLimitedInformation = 0x1000 - stillActive = 259 // STILL_ACTIVE exit code -) - -var ( - findProcess = os.FindProcess - kernel32 = syscall.NewLazyDLL("kernel32.dll") - getProcessTimes = kernel32.NewProc("GetProcessTimes") - fileTimeToUnixFn = fileTimeToUnix -) - -// isProcessRunning returns true if a process with the given pid is running on Windows. -func isProcessRunning(pid int) bool { - if pid <= 0 { - return false - } - - if _, err := findProcess(pid); err != nil { - return false - } - - handle, err := syscall.OpenProcess(processQueryLimitedInformation, false, uint32(pid)) - if err != nil { - if errors.Is(err, syscall.ERROR_ACCESS_DENIED) { - return true - } - return false - } - defer syscall.CloseHandle(handle) - - var exitCode uint32 - if err := syscall.GetExitCodeProcess(handle, &exitCode); err != nil { - return true - } - - return exitCode == stillActive -} - -// getProcessStartTime returns the start time of a process on Windows. -// Returns zero time if the start time cannot be determined. -func getProcessStartTime(pid int) time.Time { - if pid <= 0 { - return time.Time{} - } - - handle, err := syscall.OpenProcess(processQueryLimitedInformation, false, uint32(pid)) - if err != nil { - return time.Time{} - } - defer syscall.CloseHandle(handle) - - var creationTime, exitTime, kernelTime, userTime syscall.Filetime - ret, _, _ := getProcessTimes.Call( - uintptr(handle), - uintptr(unsafe.Pointer(&creationTime)), - uintptr(unsafe.Pointer(&exitTime)), - uintptr(unsafe.Pointer(&kernelTime)), - uintptr(unsafe.Pointer(&userTime)), - ) - - if ret == 0 { - return time.Time{} - } - - return fileTimeToUnixFn(creationTime) -} - -// fileTimeToUnix converts Windows FILETIME to Unix time. -func fileTimeToUnix(ft syscall.Filetime) time.Time { - // FILETIME is 100-nanosecond intervals since January 1, 1601 UTC - nsec := ft.Nanoseconds() - return time.Unix(0, nsec) -} diff --git a/codeagent-wrapper/process_check_windows_test.go b/codeagent-wrapper/process_check_windows_test.go deleted file mode 100644 index feb3ce0..0000000 --- a/codeagent-wrapper/process_check_windows_test.go +++ /dev/null @@ -1,64 +0,0 @@ -//go:build windows -// +build windows - -package main - -import ( - "os" - "testing" - "time" -) - -func TestIsProcessRunning(t *testing.T) { - t.Run("boundary values", func(t *testing.T) { - if isProcessRunning(0) { - t.Fatalf("expected pid 0 to be reported as not running") - } - if isProcessRunning(-1) { - t.Fatalf("expected pid -1 to be reported as not running") - } - }) - - t.Run("current process", func(t *testing.T) { - if !isProcessRunning(os.Getpid()) { - t.Fatalf("expected current process (pid=%d) to be running", os.Getpid()) - } - }) - - t.Run("fake pid", func(t *testing.T) { - const nonexistentPID = 1 << 30 - if isProcessRunning(nonexistentPID) { - t.Fatalf("expected pid %d to be reported as not running", nonexistentPID) - } - }) -} - -func TestGetProcessStartTimeReadsProcStat(t *testing.T) { - start := getProcessStartTime(os.Getpid()) - if start.IsZero() { - t.Fatalf("expected non-zero start time for current process") - } - if start.After(time.Now().Add(5 * time.Second)) { - t.Fatalf("start time is unexpectedly in the future: %v", start) - } -} - -func TestGetProcessStartTimeInvalidData(t *testing.T) { - if !getProcessStartTime(0).IsZero() { - t.Fatalf("expected zero time for pid 0") - } - if !getProcessStartTime(-1).IsZero() { - t.Fatalf("expected zero time for negative pid") - } - if !getProcessStartTime(1 << 30).IsZero() { - t.Fatalf("expected zero time for non-existent pid") - } -} - -func TestGetBootTimeParsesBtime(t *testing.T) { - t.Skip("getBootTime is only implemented on Unix-like systems") -} - -func TestGetBootTimeInvalidData(t *testing.T) { - t.Skip("getBootTime is only implemented on Unix-like systems") -} diff --git a/codeagent-wrapper/wrapper_name.go b/codeagent-wrapper/wrapper_name.go deleted file mode 100644 index 236df20..0000000 --- a/codeagent-wrapper/wrapper_name.go +++ /dev/null @@ -1,126 +0,0 @@ -package main - -import ( - "os" - "path/filepath" - "strings" -) - -const ( - defaultWrapperName = "codeagent-wrapper" - legacyWrapperName = "codex-wrapper" -) - -var executablePathFn = os.Executable - -func normalizeWrapperName(path string) string { - if path == "" { - return "" - } - - base := filepath.Base(path) - base = strings.TrimSuffix(base, ".exe") // tolerate Windows executables - - switch base { - case defaultWrapperName, legacyWrapperName: - return base - default: - return "" - } -} - -// currentWrapperName resolves the wrapper name based on the invoked binary. -// Only known names are honored to avoid leaking build/test binary names into logs. -func currentWrapperName() string { - if len(os.Args) == 0 { - return defaultWrapperName - } - - if name := normalizeWrapperName(os.Args[0]); name != "" { - return name - } - - execPath, err := executablePathFn() - if err == nil { - if name := normalizeWrapperName(execPath); name != "" { - return name - } - - if resolved, err := filepath.EvalSymlinks(execPath); err == nil { - if name := normalizeWrapperName(resolved); name != "" { - return name - } - if alias := resolveAlias(execPath, resolved); alias != "" { - return alias - } - } - - if alias := resolveAlias(execPath, ""); alias != "" { - return alias - } - } - - return defaultWrapperName -} - -// logPrefixes returns the set of accepted log name prefixes, including the -// current wrapper name and legacy aliases. -func logPrefixes() []string { - prefixes := []string{currentWrapperName(), defaultWrapperName, legacyWrapperName} - seen := make(map[string]struct{}, len(prefixes)) - var unique []string - for _, prefix := range prefixes { - if prefix == "" { - continue - } - if _, ok := seen[prefix]; ok { - continue - } - seen[prefix] = struct{}{} - unique = append(unique, prefix) - } - return unique -} - -// primaryLogPrefix returns the preferred filename prefix for log files. -// Defaults to the current wrapper name when available, otherwise falls back -// to the canonical default name. -func primaryLogPrefix() string { - prefixes := logPrefixes() - if len(prefixes) == 0 { - return defaultWrapperName - } - return prefixes[0] -} - -func resolveAlias(execPath string, target string) string { - if execPath == "" { - return "" - } - - dir := filepath.Dir(execPath) - for _, candidate := range []string{defaultWrapperName, legacyWrapperName} { - aliasPath := filepath.Join(dir, candidate) - info, err := os.Lstat(aliasPath) - if err != nil { - continue - } - if info.Mode()&os.ModeSymlink == 0 { - continue - } - - resolved, err := filepath.EvalSymlinks(aliasPath) - if err != nil { - continue - } - if target != "" && resolved != target { - continue - } - - if name := normalizeWrapperName(aliasPath); name != "" { - return name - } - } - - return "" -} diff --git a/codeagent-wrapper/wrapper_name_test.go b/codeagent-wrapper/wrapper_name_test.go deleted file mode 100644 index b133d95..0000000 --- a/codeagent-wrapper/wrapper_name_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import ( - "os" - "path/filepath" - "testing" -) - -func TestCurrentWrapperNameFallsBackToExecutable(t *testing.T) { - defer resetTestHooks() - - tempDir := t.TempDir() - execPath := filepath.Join(tempDir, "codeagent-wrapper") - if err := os.WriteFile(execPath, []byte("#!/bin/true\n"), 0o755); err != nil { - t.Fatalf("failed to write fake binary: %v", err) - } - - os.Args = []string{filepath.Join(tempDir, "custom-name")} - executablePathFn = func() (string, error) { - return execPath, nil - } - - if got := currentWrapperName(); got != defaultWrapperName { - t.Fatalf("currentWrapperName() = %q, want %q", got, defaultWrapperName) - } -} - -func TestCurrentWrapperNameDetectsLegacyAliasSymlink(t *testing.T) { - defer resetTestHooks() - - tempDir := t.TempDir() - execPath := filepath.Join(tempDir, "wrapper") - aliasPath := filepath.Join(tempDir, legacyWrapperName) - - if err := os.WriteFile(execPath, []byte("#!/bin/true\n"), 0o755); err != nil { - t.Fatalf("failed to write fake binary: %v", err) - } - if err := os.Symlink(execPath, aliasPath); err != nil { - t.Fatalf("failed to create alias: %v", err) - } - - os.Args = []string{filepath.Join(tempDir, "unknown-runner")} - executablePathFn = func() (string, error) { - return execPath, nil - } - - if got := currentWrapperName(); got != legacyWrapperName { - t.Fatalf("currentWrapperName() = %q, want %q", got, legacyWrapperName) - } -}