diff --git a/codex-wrapper/main.go b/codex-wrapper/main.go index 4837704..9f3ee7e 100644 --- a/codex-wrapper/main.go +++ b/codex-wrapper/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "bytes" "context" "encoding/json" "fmt" @@ -9,24 +10,31 @@ import ( "os" "os/exec" "os/signal" + "sort" "strconv" "strings" + "sync" "syscall" "time" ) const ( - version = "1.0.0" - defaultWorkdir = "." - defaultTimeout = 7200 // seconds - forceKillDelay = 5 // seconds + version = "1.0.0" + defaultWorkdir = "." + defaultTimeout = 7200 // seconds + forceKillDelay = 5 // seconds + stdinSpecialChars = "\n\\\"'`$" ) + // Test hooks for dependency injection var ( - stdinReader io.Reader = os.Stdin - isTerminalFn = defaultIsTerminal - codexCommand = "codex" + stdinReader io.Reader = os.Stdin + isTerminalFn = defaultIsTerminal + codexCommand = "codex" + buildCodexArgsFn = buildCodexArgs + commandContext = exec.CommandContext + jsonMarshal = json.Marshal ) // Config holds CLI configuration @@ -39,6 +47,291 @@ type Config struct { Timeout int } +// ParallelConfig defines the JSON schema for parallel execution +type ParallelConfig struct { + Tasks []TaskSpec `json:"tasks"` +} + +// TaskSpec describes an individual task entry in the parallel config +type TaskSpec struct { + ID string `json:"id"` + Task string `json:"task"` + WorkDir string `json:"workdir,omitempty"` + Dependencies []string `json:"dependencies,omitempty"` + Mode string `json:"-"` + SessionID string `json:"-"` + UseStdin bool `json:"-"` +} + +// TaskResult captures the execution outcome of a task +type TaskResult struct { + TaskID string `json:"task_id"` + ExitCode int `json:"exit_code"` + Message string `json:"message"` + SessionID string `json:"session_id"` + Error string `json:"error"` +} + + +func parseParallelConfig(data []byte) (*ParallelConfig, error) { + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 { + return nil, fmt.Errorf("parallel config is empty") + } + + tasks := strings.Split(string(trimmed), "---TASK---") + var cfg ParallelConfig + seen := make(map[string]struct{}) + + for _, taskBlock := range tasks { + taskBlock = strings.TrimSpace(taskBlock) + if taskBlock == "" { + continue + } + + parts := strings.SplitN(taskBlock, "---CONTENT---", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("task block missing ---CONTENT--- separator") + } + + meta := strings.TrimSpace(parts[0]) + content := strings.TrimSpace(parts[1]) + + task := TaskSpec{WorkDir: defaultWorkdir} + for _, line := range strings.Split(meta, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + kv := strings.SplitN(line, ":", 2) + if len(kv) != 2 { + continue + } + key := strings.TrimSpace(kv[0]) + value := strings.TrimSpace(kv[1]) + + switch key { + case "id": + task.ID = value + case "workdir": + task.WorkDir = value + case "dependencies": + for _, dep := range strings.Split(value, ",") { + dep = strings.TrimSpace(dep) + if dep != "" { + task.Dependencies = append(task.Dependencies, dep) + } + } + } + } + + if task.ID == "" { + return nil, fmt.Errorf("task missing id field") + } + if content == "" { + return nil, fmt.Errorf("task %q missing content", task.ID) + } + if _, exists := seen[task.ID]; exists { + return nil, fmt.Errorf("duplicate task id: %s", task.ID) + } + + task.Task = content + cfg.Tasks = append(cfg.Tasks, task) + seen[task.ID] = struct{}{} + } + + if len(cfg.Tasks) == 0 { + return nil, fmt.Errorf("no tasks found") + } + + return &cfg, nil +} + +func topologicalSort(tasks []TaskSpec) ([][]TaskSpec, error) { + idToTask := make(map[string]TaskSpec, len(tasks)) + indegree := make(map[string]int, len(tasks)) + adj := make(map[string][]string, len(tasks)) + + for _, task := range tasks { + idToTask[task.ID] = task + indegree[task.ID] = 0 + } + + for _, task := range tasks { + for _, dep := range task.Dependencies { + if _, ok := idToTask[dep]; !ok { + return nil, fmt.Errorf("dependency %q not found for task %q", dep, task.ID) + } + indegree[task.ID]++ + adj[dep] = append(adj[dep], task.ID) + } + } + + queue := make([]string, 0, len(tasks)) + for _, task := range tasks { + if indegree[task.ID] == 0 { + queue = append(queue, task.ID) + } + } + + layers := make([][]TaskSpec, 0) + processed := 0 + + for len(queue) > 0 { + current := queue + queue = nil + layer := make([]TaskSpec, len(current)) + for i, id := range current { + layer[i] = idToTask[id] + processed++ + } + layers = append(layers, layer) + + next := make([]string, 0) + for _, id := range current { + for _, neighbor := range adj[id] { + indegree[neighbor]-- + if indegree[neighbor] == 0 { + next = append(next, neighbor) + } + } + } + queue = append(queue, next...) + } + + if processed != len(tasks) { + cycleIDs := make([]string, 0) + for id, deg := range indegree { + if deg > 0 { + cycleIDs = append(cycleIDs, id) + } + } + sort.Strings(cycleIDs) + return nil, fmt.Errorf("cycle detected involving tasks: %s", strings.Join(cycleIDs, ",")) + } + + return layers, nil +} + +var runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + if task.WorkDir == "" { + task.WorkDir = defaultWorkdir + } + if task.Mode == "" { + task.Mode = "new" + } + if task.UseStdin || shouldUseStdin(task.Task, false) { + task.UseStdin = true + } + + return runCodexTask(task, true, timeout) +} + +func executeConcurrent(layers [][]TaskSpec, timeout int) []TaskResult { + totalTasks := 0 + for _, layer := range layers { + totalTasks += len(layer) + } + + results := make([]TaskResult, 0, totalTasks) + failed := make(map[string]TaskResult, totalTasks) + resultsCh := make(chan TaskResult, totalTasks) + + for _, layer := range layers { + var wg sync.WaitGroup + executed := 0 + + for _, task := range layer { + if skip, reason := shouldSkipTask(task, failed); skip { + res := TaskResult{TaskID: task.ID, ExitCode: 1, Error: reason} + results = append(results, res) + failed[task.ID] = res + continue + } + + executed++ + wg.Add(1) + go func(ts TaskSpec) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + resultsCh <- TaskResult{TaskID: ts.ID, ExitCode: 1, Error: fmt.Sprintf("panic: %v", r)} + } + }() + resultsCh <- runCodexTaskFn(ts, timeout) + }(task) + } + + wg.Wait() + + for i := 0; i < executed; i++ { + res := <-resultsCh + results = append(results, res) + if res.ExitCode != 0 || res.Error != "" { + failed[res.TaskID] = res + } + } + } + + return results +} + +func shouldSkipTask(task TaskSpec, failed map[string]TaskResult) (bool, string) { + if len(task.Dependencies) == 0 { + return false, "" + } + + var blocked []string + for _, dep := range task.Dependencies { + if _, ok := failed[dep]; ok { + blocked = append(blocked, dep) + } + } + + if len(blocked) == 0 { + return false, "" + } + + return true, fmt.Sprintf("skipped due to failed dependencies: %s", strings.Join(blocked, ",")) +} + +func generateFinalOutput(results []TaskResult) string { + var sb strings.Builder + + success := 0 + failed := 0 + for _, res := range results { + if res.ExitCode == 0 && res.Error == "" { + success++ + } else { + failed++ + } + } + + sb.WriteString(fmt.Sprintf("=== Parallel Execution Summary ===\n")) + sb.WriteString(fmt.Sprintf("Total: %d | Success: %d | Failed: %d\n\n", len(results), success, failed)) + + for _, res := range results { + sb.WriteString(fmt.Sprintf("--- Task: %s ---\n", res.TaskID)) + if res.Error != "" { + sb.WriteString(fmt.Sprintf("Status: FAILED (exit code %d)\nError: %s\n", res.ExitCode, res.Error)) + } else if res.ExitCode != 0 { + sb.WriteString(fmt.Sprintf("Status: FAILED (exit code %d)\n", res.ExitCode)) + } else { + sb.WriteString("Status: SUCCESS\n") + } + if res.SessionID != "" { + sb.WriteString(fmt.Sprintf("Session: %s\n", res.SessionID)) + } + if res.Message != "" { + sb.WriteString(fmt.Sprintf("\n%s\n", res.Message)) + } + sb.WriteString("\n") + } + + return sb.String() +} + // JSONEvent represents a Codex JSON output event type JSONEvent struct { Type string `json:"type"` @@ -68,9 +361,42 @@ func run() int { case "--help", "-h": printHelp() return 0 + case "--parallel": + // Parallel mode: read task config from stdin + data, err := io.ReadAll(stdinReader) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: failed to read stdin: %v\n", err) + return 1 + } + + cfg, err := parseParallelConfig(data) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) + return 1 + } + + timeoutSec := resolveTimeout() + layers, err := topologicalSort(cfg.Tasks) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: %v\n", err) + return 1 + } + + results := executeConcurrent(layers, timeoutSec) + fmt.Println(generateFinalOutput(results)) + + exitCode := 0 + for _, res := range results { + if res.ExitCode != 0 { + exitCode = res.ExitCode + } + } + + return exitCode } } + logInfo("Script started") cfg, err := parseArgs() @@ -127,6 +453,18 @@ func run() int { if strings.Contains(taskText, "\\") { reasons = append(reasons, "backslash") } + if strings.Contains(taskText, "\"") { + reasons = append(reasons, "double-quote") + } + if strings.Contains(taskText, "'") { + reasons = append(reasons, "single-quote") + } + if strings.Contains(taskText, "`") { + reasons = append(reasons, "backtick") + } + if strings.Contains(taskText, "$") { + reasons = append(reasons, "dollar") + } if len(taskText) > 800 { reasons = append(reasons, "length>800") } @@ -135,26 +473,28 @@ func run() int { } } - targetArg := taskText - if useStdin { - targetArg = "-" - } - - codexArgs := buildCodexArgs(cfg, targetArg) logInfo("codex running...") - message, threadID, exitCode := runCodexProcess(codexArgs, taskText, useStdin, cfg.Timeout) + taskSpec := TaskSpec{ + Task: taskText, + WorkDir: cfg.WorkDir, + Mode: cfg.Mode, + SessionID: cfg.SessionID, + UseStdin: useStdin, + } - if exitCode != 0 { - return exitCode + result := runCodexTask(taskSpec, false, cfg.Timeout) + + if result.ExitCode != 0 { + return result.ExitCode } // Output agent_message - fmt.Println(message) + fmt.Println(result.Message) // Output session_id if present - if threadID != "" { - fmt.Printf("\n---\nSESSION_ID: %s\n", threadID) + if result.SessionID != "" { + fmt.Printf("\n---\nSESSION_ID: %s\n", result.SessionID) } return 0 @@ -213,16 +553,10 @@ func shouldUseStdin(taskText string, piped bool) bool { if piped { return true } - if strings.Contains(taskText, "\n") { - return true - } - if strings.Contains(taskText, "\\") { - return true - } if len(taskText) > 800 { return true } - return false + return strings.IndexAny(taskText, stdinSpecialChars) >= 0 } func buildCodexArgs(cfg *Config, targetArg string) []string { @@ -245,12 +579,48 @@ func buildCodexArgs(cfg *Config, targetArg string) []string { } } -func runCodexProcess(codexArgs []string, taskText string, useStdin bool, timeoutSec int) (message, threadID string, exitCode int) { +func runCodexTask(taskSpec TaskSpec, silent bool, timeoutSec int) TaskResult { + result := TaskResult{ + TaskID: taskSpec.ID, + } + + cfg := &Config{ + Mode: taskSpec.Mode, + Task: taskSpec.Task, + SessionID: taskSpec.SessionID, + WorkDir: taskSpec.WorkDir, + } + if cfg.Mode == "" { + cfg.Mode = "new" + } + if cfg.WorkDir == "" { + cfg.WorkDir = defaultWorkdir + } + + useStdin := taskSpec.UseStdin + targetArg := taskSpec.Task + if useStdin { + targetArg = "-" + } + + codexArgs := buildCodexArgsFn(cfg, targetArg) + + logInfoFn := logInfo + logWarnFn := logWarn + logErrorFn := logError + stderrWriter := io.Writer(os.Stderr) + if silent { + logInfoFn = func(string) {} + logWarnFn = func(string) {} + logErrorFn = func(string) {} + stderrWriter = io.Discard + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, codexCommand, codexArgs...) - cmd.Stderr = os.Stderr + cmd := commandContext(ctx, codexCommand, codexArgs...) + cmd.Stderr = stderrWriter // Setup stdin if needed var stdinPipe io.WriteCloser @@ -258,97 +628,133 @@ func runCodexProcess(codexArgs []string, taskText string, useStdin bool, timeout if useStdin { stdinPipe, err = cmd.StdinPipe() if err != nil { - logError("Failed to create stdin pipe: " + err.Error()) - return "", "", 1 + logErrorFn("Failed to create stdin pipe: " + err.Error()) + result.ExitCode = 1 + result.Error = "failed to create stdin pipe: " + err.Error() + return result } } // Setup stdout stdout, err := cmd.StdoutPipe() if err != nil { - logError("Failed to create stdout pipe: " + err.Error()) - return "", "", 1 + logErrorFn("Failed to create stdout pipe: " + err.Error()) + result.ExitCode = 1 + result.Error = "failed to create stdout pipe: " + err.Error() + return result } - logInfo(fmt.Sprintf("Starting codex with args: codex %s...", strings.Join(codexArgs[:min(5, len(codexArgs))], " "))) + logInfoFn(fmt.Sprintf("Starting codex with args: codex %s...", strings.Join(codexArgs[:min(5, len(codexArgs))], " "))) // Start process if err := cmd.Start(); err != nil { if strings.Contains(err.Error(), "executable file not found") { - logError("codex command not found in PATH") - return "", "", 127 + logErrorFn("codex command not found in PATH") + result.ExitCode = 127 + result.Error = "codex command not found in PATH" + return result } - logError("Failed to start codex: " + err.Error()) - return "", "", 1 + logErrorFn("Failed to start codex: " + err.Error()) + result.ExitCode = 1 + result.Error = "failed to start codex: " + err.Error() + return result } - logInfo(fmt.Sprintf("Process started with PID: %d", cmd.Process.Pid)) + logInfoFn(fmt.Sprintf("Process started with PID: %d", cmd.Process.Pid)) // Write to stdin if needed if useStdin && stdinPipe != nil { - logInfo(fmt.Sprintf("Writing %d chars to stdin...", len(taskText))) + logInfoFn(fmt.Sprintf("Writing %d chars to stdin...", len(taskSpec.Task))) go func() { defer stdinPipe.Close() - io.WriteString(stdinPipe, taskText) + io.WriteString(stdinPipe, taskSpec.Task) }() - logInfo("Stdin closed") + logInfoFn("Stdin closed") } - // Setup signal handling - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - go func() { - sig := <-sigCh - logError(fmt.Sprintf("Received signal: %v", sig)) - if cmd.Process != nil { - cmd.Process.Signal(syscall.SIGTERM) - time.AfterFunc(time.Duration(forceKillDelay)*time.Second, func() { - if cmd.Process != nil { - cmd.Process.Kill() - } - }) - } - }() + forwardSignals(ctx, cmd, logErrorFn) - logInfo("Reading stdout...") + logInfoFn("Reading stdout...") // Parse JSON stream - message, threadID = parseJSONStream(stdout) + message, threadID := parseJSONStreamWithWarn(stdout, logWarnFn) // Wait for process to complete err = cmd.Wait() // Check for timeout if ctx.Err() == context.DeadlineExceeded { - logError("Codex execution timeout") + logErrorFn("Codex execution timeout") if cmd.Process != nil { cmd.Process.Kill() } - return "", "", 124 + result.ExitCode = 124 + result.Error = "codex execution timeout" + return result } // Check exit code if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { code := exitErr.ExitCode() - logError(fmt.Sprintf("Codex exited with status %d", code)) - return "", "", code + logErrorFn(fmt.Sprintf("Codex exited with status %d", code)) + result.ExitCode = code + result.Error = fmt.Sprintf("codex exited with status %d", code) + return result } - logError("Codex error: " + err.Error()) - return "", "", 1 + logErrorFn("Codex error: " + err.Error()) + result.ExitCode = 1 + result.Error = "codex error: " + err.Error() + return result } if message == "" { - logError("Codex completed without agent_message output") - return "", "", 1 + logErrorFn("Codex completed without agent_message output") + result.ExitCode = 1 + result.Error = "codex completed without agent_message output" + return result } - return message, threadID, 0 + result.ExitCode = 0 + result.Message = message + result.SessionID = threadID + + return result +} + +func forwardSignals(ctx context.Context, cmd *exec.Cmd, logErrorFn func(string)) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + go func() { + defer signal.Stop(sigCh) + select { + case sig := <-sigCh: + logErrorFn(fmt.Sprintf("Received signal: %v", sig)) + if cmd.Process != nil { + cmd.Process.Signal(syscall.SIGTERM) + time.AfterFunc(time.Duration(forceKillDelay)*time.Second, func() { + if cmd.Process != nil { + cmd.Process.Kill() + } + }) + } + case <-ctx.Done(): + } + }() } func parseJSONStream(r io.Reader) (message, threadID string) { + return parseJSONStreamWithWarn(r, logWarn) +} + +func parseJSONStreamWithWarn(r io.Reader, warnFn func(string)) (message, threadID string) { scanner := bufio.NewScanner(r) scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) + if warnFn == nil { + warnFn = func(string) {} + } + for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" { @@ -357,7 +763,7 @@ func parseJSONStream(r io.Reader) (message, threadID string) { var event JSONEvent if err := json.Unmarshal([]byte(line), &event); err != nil { - logWarn(fmt.Sprintf("Failed to parse line: %s", truncate(line, 100))) + warnFn(fmt.Sprintf("Failed to parse line: %s", truncate(line, 100))) continue } @@ -375,7 +781,7 @@ func parseJSONStream(r io.Reader) (message, threadID string) { } if err := scanner.Err(); err != nil && err != io.EOF { - logWarn("Read stdout error: " + err.Error()) + warnFn("Read stdout error: " + err.Error()) } return message, threadID @@ -450,6 +856,10 @@ func min(a, b int) int { return b } +func test() string { + return "hello $world" +} + func logInfo(msg string) { fmt.Fprintf(os.Stderr, "INFO: %s\n", msg) } diff --git a/codex-wrapper/main_integration_test.go b/codex-wrapper/main_integration_test.go new file mode 100644 index 0000000..e5153d8 --- /dev/null +++ b/codex-wrapper/main_integration_test.go @@ -0,0 +1,400 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "os" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type integrationSummary struct { + Total int `json:"total"` + Success int `json:"success"` + Failed int `json:"failed"` +} + +type integrationOutput struct { + Results []TaskResult `json:"results"` + Summary integrationSummary `json:"summary"` +} + +func captureStdout(t *testing.T, fn func()) string { + t.Helper() + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + fn() + + w.Close() + os.Stdout = old + + var buf bytes.Buffer + io.Copy(&buf, r) + return buf.String() +} + +func parseIntegrationOutput(t *testing.T, out string) integrationOutput { + t.Helper() + var payload integrationOutput + + lines := strings.Split(out, "\n") + var currentTask *TaskResult + + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "Total:") { + parts := strings.Split(line, "|") + for _, p := range parts { + p = strings.TrimSpace(p) + if strings.HasPrefix(p, "Total:") { + fmt.Sscanf(p, "Total: %d", &payload.Summary.Total) + } else if strings.HasPrefix(p, "Success:") { + fmt.Sscanf(p, "Success: %d", &payload.Summary.Success) + } else if strings.HasPrefix(p, "Failed:") { + fmt.Sscanf(p, "Failed: %d", &payload.Summary.Failed) + } + } + } else if strings.HasPrefix(line, "--- Task:") { + if currentTask != nil { + payload.Results = append(payload.Results, *currentTask) + } + currentTask = &TaskResult{} + currentTask.TaskID = strings.TrimSuffix(strings.TrimPrefix(line, "--- Task: "), " ---") + } else if currentTask != nil { + if strings.HasPrefix(line, "Status: SUCCESS") { + currentTask.ExitCode = 0 + } else if strings.HasPrefix(line, "Status: FAILED") { + if strings.Contains(line, "exit code") { + fmt.Sscanf(line, "Status: FAILED (exit code %d)", ¤tTask.ExitCode) + } else { + currentTask.ExitCode = 1 + } + } else if strings.HasPrefix(line, "Error:") { + currentTask.Error = strings.TrimPrefix(line, "Error: ") + } else if strings.HasPrefix(line, "Session:") { + currentTask.SessionID = strings.TrimPrefix(line, "Session: ") + } else if line != "" && !strings.HasPrefix(line, "===") && !strings.HasPrefix(line, "---") { + if currentTask.Message != "" { + currentTask.Message += "\n" + } + currentTask.Message += line + } + } + } + + if currentTask != nil { + payload.Results = append(payload.Results, *currentTask) + } + + return payload +} + +func findResultByID(t *testing.T, payload integrationOutput, id string) TaskResult { + t.Helper() + for _, res := range payload.Results { + if res.TaskID == id { + return res + } + } + t.Fatalf("result for task %s not found", id) + return TaskResult{} +} + +func TestParallelEndToEnd_OrderAndConcurrency(t *testing.T) { + defer resetTestHooks() + origRun := runCodexTaskFn + t.Cleanup(func() { + runCodexTaskFn = origRun + resetTestHooks() + }) + + input := `---TASK--- +id: A +---CONTENT--- +task-a +---TASK--- +id: B +dependencies: A +---CONTENT--- +task-b +---TASK--- +id: C +dependencies: B +---CONTENT--- +task-c +---TASK--- +id: D +---CONTENT--- +task-d +---TASK--- +id: E +---CONTENT--- +task-e` + stdinReader = bytes.NewReader([]byte(input)) + os.Args = []string{"codex-wrapper", "--parallel"} + + var mu sync.Mutex + starts := make(map[string]time.Time) + ends := make(map[string]time.Time) + var running int64 + var maxParallel int64 + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + start := time.Now() + mu.Lock() + starts[task.ID] = start + mu.Unlock() + + cur := atomic.AddInt64(&running, 1) + for { + prev := atomic.LoadInt64(&maxParallel) + if cur <= prev { + break + } + if atomic.CompareAndSwapInt64(&maxParallel, prev, cur) { + break + } + } + + time.Sleep(40 * time.Millisecond) + + mu.Lock() + ends[task.ID] = time.Now() + mu.Unlock() + + atomic.AddInt64(&running, -1) + return TaskResult{TaskID: task.ID, ExitCode: 0, Message: task.Task} + } + + var exitCode int + output := captureStdout(t, func() { + exitCode = run() + }) + + if exitCode != 0 { + t.Fatalf("run() exit = %d, want 0", exitCode) + } + + payload := parseIntegrationOutput(t, output) + if payload.Summary.Failed != 0 || payload.Summary.Total != 5 || payload.Summary.Success != 5 { + t.Fatalf("unexpected summary: %+v", payload.Summary) + } + + aEnd := ends["A"] + bStart := starts["B"] + cStart := starts["C"] + bEnd := ends["B"] + if aEnd.IsZero() || bStart.IsZero() || bEnd.IsZero() || cStart.IsZero() { + t.Fatalf("missing timestamps, starts=%v ends=%v", starts, ends) + } + if !aEnd.Before(bStart) && !aEnd.Equal(bStart) { + t.Fatalf("B should start after A ends: A_end=%v B_start=%v", aEnd, bStart) + } + if !bEnd.Before(cStart) && !bEnd.Equal(cStart) { + t.Fatalf("C should start after B ends: B_end=%v C_start=%v", bEnd, cStart) + } + + dStart := starts["D"] + eStart := starts["E"] + if dStart.IsZero() || eStart.IsZero() { + t.Fatalf("missing D/E start times: %v", starts) + } + delta := dStart.Sub(eStart) + if delta < 0 { + delta = -delta + } + if delta > 25*time.Millisecond { + t.Fatalf("D and E should run in parallel, delta=%v", delta) + } + if maxParallel < 2 { + t.Fatalf("expected at least 2 concurrent tasks, got %d", maxParallel) + } +} + +func TestParallelCycleDetectionStopsExecution(t *testing.T) { + defer resetTestHooks() + origRun := runCodexTaskFn + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + t.Fatalf("task %s should not execute on cycle", task.ID) + return TaskResult{} + } + t.Cleanup(func() { + runCodexTaskFn = origRun + resetTestHooks() + }) + + input := `---TASK--- +id: A +dependencies: B +---CONTENT--- +a +---TASK--- +id: B +dependencies: A +---CONTENT--- +b` + stdinReader = bytes.NewReader([]byte(input)) + os.Args = []string{"codex-wrapper", "--parallel"} + + exitCode := 0 + output := captureStdout(t, func() { + exitCode = run() + }) + + if exitCode == 0 { + t.Fatalf("cycle should cause non-zero exit, got %d", exitCode) + } + if strings.TrimSpace(output) != "" { + t.Fatalf("expected no JSON output on cycle, got %q", output) + } +} + +func TestParallelPartialFailureBlocksDependents(t *testing.T) { + defer resetTestHooks() + origRun := runCodexTaskFn + t.Cleanup(func() { + runCodexTaskFn = origRun + resetTestHooks() + }) + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + if task.ID == "A" { + return TaskResult{TaskID: "A", ExitCode: 2, Error: "boom"} + } + return TaskResult{TaskID: task.ID, ExitCode: 0, Message: task.Task} + } + + input := `---TASK--- +id: A +---CONTENT--- +fail +---TASK--- +id: B +dependencies: A +---CONTENT--- +blocked +---TASK--- +id: D +---CONTENT--- +ok-d +---TASK--- +id: E +---CONTENT--- +ok-e` + stdinReader = bytes.NewReader([]byte(input)) + os.Args = []string{"codex-wrapper", "--parallel"} + + var exitCode int + output := captureStdout(t, func() { + exitCode = run() + }) + + payload := parseIntegrationOutput(t, output) + if exitCode == 0 { + t.Fatalf("expected non-zero exit when a task fails, got %d", exitCode) + } + + resA := findResultByID(t, payload, "A") + resB := findResultByID(t, payload, "B") + resD := findResultByID(t, payload, "D") + resE := findResultByID(t, payload, "E") + + if resA.ExitCode == 0 { + t.Fatalf("task A should fail, got %+v", resA) + } + if resB.ExitCode == 0 || !strings.Contains(resB.Error, "dependencies") { + t.Fatalf("task B should be skipped due to dependency failure, got %+v", resB) + } + if resD.ExitCode != 0 || resE.ExitCode != 0 { + t.Fatalf("independent tasks should run successfully, D=%+v E=%+v", resD, resE) + } + if payload.Summary.Failed != 2 || payload.Summary.Total != 4 { + t.Fatalf("unexpected summary after partial failure: %+v", payload.Summary) + } +} + +func TestParallelTimeoutPropagation(t *testing.T) { + defer resetTestHooks() + origRun := runCodexTaskFn + t.Cleanup(func() { + runCodexTaskFn = origRun + resetTestHooks() + os.Unsetenv("CODEX_TIMEOUT") + }) + + var receivedTimeout int + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + receivedTimeout = timeout + return TaskResult{TaskID: task.ID, ExitCode: 124, Error: "timeout"} + } + + os.Setenv("CODEX_TIMEOUT", "1") + input := `---TASK--- +id: T +---CONTENT--- +slow` + stdinReader = bytes.NewReader([]byte(input)) + os.Args = []string{"codex-wrapper", "--parallel"} + + exitCode := 0 + output := captureStdout(t, func() { + exitCode = run() + }) + + payload := parseIntegrationOutput(t, output) + if receivedTimeout != 1 { + t.Fatalf("expected timeout 1s to propagate, got %d", receivedTimeout) + } + if exitCode != 124 { + t.Fatalf("expected timeout exit code 124, got %d", exitCode) + } + if payload.Summary.Failed != 1 || payload.Summary.Total != 1 { + t.Fatalf("unexpected summary for timeout case: %+v", payload.Summary) + } + res := findResultByID(t, payload, "T") + if res.Error == "" || res.ExitCode != 124 { + t.Fatalf("timeout result not propagated, got %+v", res) + } +} + +func TestConcurrentSpeedupBenchmark(t *testing.T) { + defer resetTestHooks() + origRun := runCodexTaskFn + t.Cleanup(func() { + runCodexTaskFn = origRun + resetTestHooks() + }) + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + time.Sleep(50 * time.Millisecond) + return TaskResult{TaskID: task.ID} + } + + tasks := make([]TaskSpec, 10) + for i := range tasks { + tasks[i] = TaskSpec{ID: fmt.Sprintf("task-%d", i)} + } + layers := [][]TaskSpec{tasks} + + serialStart := time.Now() + for _, task := range tasks { + _ = runCodexTaskFn(task, 5) + } + serialElapsed := time.Since(serialStart) + + concurrentStart := time.Now() + _ = executeConcurrent(layers, 5) + concurrentElapsed := time.Since(concurrentStart) + + if concurrentElapsed >= serialElapsed/5 { + t.Fatalf("expected concurrent time <20%% of serial, serial=%v concurrent=%v", serialElapsed, concurrentElapsed) + } + ratio := float64(concurrentElapsed) / float64(serialElapsed) + t.Logf("speedup ratio (concurrent/serial)=%.3f", ratio) +} diff --git a/codex-wrapper/main_test.go b/codex-wrapper/main_test.go index ab123cb..d55b9ef 100644 --- a/codex-wrapper/main_test.go +++ b/codex-wrapper/main_test.go @@ -2,10 +2,19 @@ package main import ( "bytes" + "context" + "encoding/json" + "fmt" "io" "os" + "os/exec" + "os/signal" "strings" + "sync" + "sync/atomic" + "syscall" "testing" + "time" ) // Helper to reset test hooks @@ -13,6 +22,9 @@ func resetTestHooks() { stdinReader = os.Stdin isTerminalFn = defaultIsTerminal codexCommand = "codex" + buildCodexArgsFn = buildCodexArgs + commandContext = exec.CommandContext + jsonMarshal = json.Marshal } func TestParseArgs_NewMode(t *testing.T) { @@ -192,6 +204,113 @@ func TestParseArgs_ResumeMode(t *testing.T) { } } +func TestParseParallelConfig_Success(t *testing.T) { + input := `---TASK--- +id: task-1 +dependencies: task-0 +---CONTENT--- +do something` + + cfg, err := parseParallelConfig([]byte(input)) + if err != nil { + t.Fatalf("parseParallelConfig() unexpected error: %v", err) + } + + if len(cfg.Tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(cfg.Tasks)) + } + + task := cfg.Tasks[0] + if task.ID != "task-1" { + t.Errorf("task.ID = %q, want %q", task.ID, "task-1") + } + if task.Task != "do something" { + t.Errorf("task.Task = %q, want %q", task.Task, "do something") + } + if task.WorkDir != defaultWorkdir { + t.Errorf("task.WorkDir = %q, want %q", task.WorkDir, defaultWorkdir) + } + if len(task.Dependencies) != 1 || task.Dependencies[0] != "task-0" { + t.Errorf("dependencies = %v, want [task-0]", task.Dependencies) + } +} + +func TestParseParallelConfig_InvalidFormat(t *testing.T) { + if _, err := parseParallelConfig([]byte("invalid format")); err == nil { + t.Fatalf("expected error for invalid format, got nil") + } +} + +func TestParseParallelConfig_EmptyTasks(t *testing.T) { + input := `---TASK--- +id: empty +---CONTENT--- +` + if _, err := parseParallelConfig([]byte(input)); err == nil { + t.Fatalf("expected error for empty tasks array, got nil") + } +} + +func TestParseParallelConfig_MissingID(t *testing.T) { + input := `---TASK--- +---CONTENT--- +do something` + if _, err := parseParallelConfig([]byte(input)); err == nil { + t.Fatalf("expected error for missing id, got nil") + } +} + +func TestParseParallelConfig_MissingTask(t *testing.T) { + input := `---TASK--- +id: task-1 +---CONTENT--- +` + if _, err := parseParallelConfig([]byte(input)); err == nil { + t.Fatalf("expected error for missing task, got nil") + } +} + +func TestParseParallelConfig_DuplicateID(t *testing.T) { + input := `---TASK--- +id: dup +---CONTENT--- +one +---TASK--- +id: dup +---CONTENT--- +two` + if _, err := parseParallelConfig([]byte(input)); err == nil { + t.Fatalf("expected error for duplicate id, got nil") + } +} + +func TestParseParallelConfig_DelimiterFormat(t *testing.T) { + input := `---TASK--- +id: T1 +workdir: /tmp +---CONTENT--- +echo 'test' +---TASK--- +id: T2 +dependencies: T1 +---CONTENT--- +code with special chars: $var "quotes"` + + cfg, err := parseParallelConfig([]byte(input)) + if err != nil { + t.Fatalf("parseParallelConfig() error = %v", err) + } + if len(cfg.Tasks) != 2 { + t.Fatalf("expected 2 tasks, got %d", len(cfg.Tasks)) + } + if cfg.Tasks[0].ID != "T1" || cfg.Tasks[0].Task != "echo 'test'" { + t.Errorf("task T1 mismatch") + } + if cfg.Tasks[1].ID != "T2" || len(cfg.Tasks[1].Dependencies) != 1 { + t.Errorf("task T2 mismatch") + } +} + func TestShouldUseStdin(t *testing.T) { tests := []struct { name string @@ -203,6 +322,10 @@ func TestShouldUseStdin(t *testing.T) { {"piped input", "analyze code", true, true}, {"contains newline", "line1\nline2", false, true}, {"contains backslash", "path\\to\\file", false, true}, + {"contains double quote", `say "hi"`, false, true}, + {"contains single quote", "it's tricky", false, true}, + {"contains backtick", "use `code`", false, true}, + {"contains dollar", "price is $5", false, true}, {"long task", strings.Repeat("a", 801), false, true}, {"exactly 800 chars", strings.Repeat("a", 800), false, false}, } @@ -411,6 +534,21 @@ func TestParseJSONStream(t *testing.T) { } } +func TestParseJSONStreamWithWarn_InvalidLine(t *testing.T) { + var warnings []string + warnFn := func(msg string) { + warnings = append(warnings, msg) + } + + message, threadID := parseJSONStreamWithWarn(strings.NewReader("not-json"), warnFn) + if message != "" || threadID != "" { + t.Fatalf("expected empty output for invalid json, got message=%q thread=%q", message, threadID) + } + if len(warnings) == 0 { + t.Fatalf("expected warning to be emitted") + } +} + func TestGetEnv(t *testing.T) { tests := []struct { name string @@ -596,82 +734,270 @@ func TestReadPipedTask(t *testing.T) { } } -// Tests for runCodexProcess with mock command -func TestRunCodexProcess_CommandNotFound(t *testing.T) { +// Tests for runCodexTask with mock command +func TestRunCodexTask_CommandNotFound(t *testing.T) { defer resetTestHooks() codexCommand = "nonexistent-command-xyz" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } - _, _, exitCode := runCodexProcess([]string{"arg1"}, "task", false, 10) + res := runCodexTask(TaskSpec{Task: "task"}, false, 10) - if exitCode != 127 { - t.Errorf("runCodexProcess() exitCode = %d, want 127 for command not found", exitCode) + if res.ExitCode != 127 { + t.Errorf("runCodexTask() exitCode = %d, want 127 for command not found", res.ExitCode) + } + if res.Error == "" { + t.Errorf("runCodexTask() expected error message for missing command") } } -func TestRunCodexProcess_WithEcho(t *testing.T) { +func TestRunCodexTask_StartError(t *testing.T) { + defer resetTestHooks() + + tmpFile, err := os.CreateTemp("", "start-error") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + codexCommand = tmpFile.Name() + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } + + res := runCodexTask(TaskSpec{Task: "task"}, false, 1) + + if res.ExitCode != 1 { + t.Fatalf("runCodexTask() exitCode = %d, want 1 for start error", res.ExitCode) + } + if !strings.Contains(res.Error, "failed to start codex") { + t.Fatalf("runCodexTask() unexpected error: %s", res.Error) + } +} + +func TestRunCodexTask_WithEcho(t *testing.T) { defer resetTestHooks() - // Use echo to simulate codex output codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } jsonOutput := `{"type":"thread.started","thread_id":"test-session"} {"type":"item.completed","item":{"type":"agent_message","text":"Test output"}}` - message, threadID, exitCode := runCodexProcess([]string{jsonOutput}, "", false, 10) + res := runCodexTask(TaskSpec{Task: jsonOutput}, false, 10) - if exitCode != 0 { - t.Errorf("runCodexProcess() exitCode = %d, want 0", exitCode) + if res.ExitCode != 0 { + t.Errorf("runCodexTask() exitCode = %d, want 0", res.ExitCode) } - if message != "Test output" { - t.Errorf("runCodexProcess() message = %q, want %q", message, "Test output") + if res.Message != "Test output" { + t.Errorf("runCodexTask() message = %q, want %q", res.Message, "Test output") } - if threadID != "test-session" { - t.Errorf("runCodexProcess() threadID = %q, want %q", threadID, "test-session") + if res.SessionID != "test-session" { + t.Errorf("runCodexTask() sessionID = %q, want %q", res.SessionID, "test-session") } } -func TestRunCodexProcess_NoMessage(t *testing.T) { +func TestRunCodexTask_NoMessage(t *testing.T) { defer resetTestHooks() codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } - // Output without agent_message jsonOutput := `{"type":"thread.started","thread_id":"test-session"}` - _, _, exitCode := runCodexProcess([]string{jsonOutput}, "", false, 10) + res := runCodexTask(TaskSpec{Task: jsonOutput}, false, 10) - if exitCode != 1 { - t.Errorf("runCodexProcess() exitCode = %d, want 1 for no message", exitCode) + if res.ExitCode != 1 { + t.Errorf("runCodexTask() exitCode = %d, want 1 for no message", res.ExitCode) + } + if res.Error == "" { + t.Errorf("runCodexTask() expected error for missing agent_message output") } } -func TestRunCodexProcess_WithStdin(t *testing.T) { +func TestRunCodexTask_WithStdin(t *testing.T) { defer resetTestHooks() - // Use cat to echo stdin back codexCommand = "cat" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } - message, _, exitCode := runCodexProcess([]string{}, `{"type":"item.completed","item":{"type":"agent_message","text":"from stdin"}}`, true, 10) + jsonInput := `{"type":"item.completed","item":{"type":"agent_message","text":"from stdin"}}` - if exitCode != 0 { - t.Errorf("runCodexProcess() exitCode = %d, want 0", exitCode) + res := runCodexTask(TaskSpec{Task: jsonInput, UseStdin: true}, false, 10) + + if res.ExitCode != 0 { + t.Errorf("runCodexTask() exitCode = %d, want 0", res.ExitCode) } - if message != "from stdin" { - t.Errorf("runCodexProcess() message = %q, want %q", message, "from stdin") + if res.Message != "from stdin" { + t.Errorf("runCodexTask() message = %q, want %q", res.Message, "from stdin") } } -func TestRunCodexProcess_ExitError(t *testing.T) { +func TestRunCodexTask_ExitError(t *testing.T) { defer resetTestHooks() - // Use false command which exits with code 1 codexCommand = "false" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } - _, _, exitCode := runCodexProcess([]string{}, "", false, 10) + res := runCodexTask(TaskSpec{Task: "noop"}, false, 10) - if exitCode == 0 { - t.Errorf("runCodexProcess() exitCode = 0, want non-zero for failed command") + if res.ExitCode == 0 { + t.Errorf("runCodexTask() exitCode = 0, want non-zero for failed command") + } + if res.Error == "" { + t.Errorf("runCodexTask() expected error message for failed command") + } +} + +func TestRunCodexTask_StdinPipeError(t *testing.T) { + defer resetTestHooks() + + commandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + cmd := exec.CommandContext(ctx, "cat") + cmd.Stdin = os.Stdin + return cmd + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } + + res := runCodexTask(TaskSpec{Task: "data", UseStdin: true}, false, 1) + if res.ExitCode != 1 || !strings.Contains(res.Error, "stdin pipe") { + t.Fatalf("expected stdin pipe error, got %+v", res) + } +} + +func TestRunCodexTask_StdoutPipeError(t *testing.T) { + defer resetTestHooks() + + commandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + cmd := exec.CommandContext(ctx, "echo", "noop") + cmd.Stdout = os.Stdout + return cmd + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } + + res := runCodexTask(TaskSpec{Task: "noop"}, false, 1) + if res.ExitCode != 1 || !strings.Contains(res.Error, "stdout pipe") { + t.Fatalf("expected stdout pipe error, got %+v", res) + } +} + +func TestRunCodexTask_Timeout(t *testing.T) { + defer resetTestHooks() + + codexCommand = "sleep" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{"2"} } + + res := runCodexTask(TaskSpec{Task: "ignored"}, false, 1) + if res.ExitCode != 124 || !strings.Contains(res.Error, "timeout") { + t.Fatalf("expected timeout exit, got %+v", res) + } +} + +func TestRunCodexTask_SignalHandling(t *testing.T) { + defer resetTestHooks() + + codexCommand = "sleep" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{"5"} } + + resultCh := make(chan TaskResult, 1) + go func() { + resultCh <- runCodexTask(TaskSpec{Task: "ignored"}, false, 5) + }() + + time.Sleep(200 * time.Millisecond) + syscall.Kill(os.Getpid(), syscall.SIGTERM) + + res := <-resultCh + signal.Reset(syscall.SIGINT, syscall.SIGTERM) + + if res.ExitCode == 0 { + t.Fatalf("expected non-zero exit after signal, got %+v", res) + } + if res.Error == "" { + t.Fatalf("expected error after signal, got %+v", res) + } +} + +func TestSilentMode(t *testing.T) { + defer resetTestHooks() + + jsonOutput := `{"type":"thread.started","thread_id":"silent-session"} +{"type":"item.completed","item":{"type":"agent_message","text":"quiet"}}` + + codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } + + capture := func(silent bool) string { + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + res := runCodexTask(TaskSpec{Task: jsonOutput}, silent, 10) + if res.ExitCode != 0 { + t.Fatalf("runCodexTask() unexpected exitCode %d", res.ExitCode) + } + + w.Close() + os.Stderr = oldStderr + + var buf bytes.Buffer + io.Copy(&buf, r) + return buf.String() + } + + verbose := capture(false) + quiet := capture(true) + + if quiet != "" { + t.Fatalf("silent mode should suppress stderr, got: %q", quiet) + } + if !strings.Contains(verbose, "INFO: Starting codex") { + t.Fatalf("non-silent mode should log to stderr, got: %q", verbose) + } +} + +func TestGenerateFinalOutput(t *testing.T) { + results := []TaskResult{ + {TaskID: "a", ExitCode: 0, Message: "ok"}, + {TaskID: "b", ExitCode: 1, Error: "boom"}, + {TaskID: "c", ExitCode: 0}, + } + + out := generateFinalOutput(results) + if out == "" { + t.Fatalf("generateFinalOutput() returned empty string") + } + + if !strings.Contains(out, "Total: 3") { + t.Errorf("output missing 'Total: 3'") + } + if !strings.Contains(out, "Success: 2") { + t.Errorf("output missing 'Success: 2'") + } + if !strings.Contains(out, "Failed: 1") { + t.Errorf("output missing 'Failed: 1'") + } + if !strings.Contains(out, "Task: a") { + t.Errorf("output missing task a") + } + if !strings.Contains(out, "Task: b") { + t.Errorf("output missing task b") + } + if !strings.Contains(out, "Status: SUCCESS") { + t.Errorf("output missing success status") + } + if !strings.Contains(out, "Status: FAILED") { + t.Errorf("output missing failed status") + } +} + +func TestGenerateFinalOutput_MarshalError(t *testing.T) { + // This test is no longer relevant since we don't use JSON marshaling + // generateFinalOutput now uses string building + out := generateFinalOutput([]TaskResult{{TaskID: "x"}}) + if out == "" { + t.Fatalf("generateFinalOutput() should not return empty string") + } + if !strings.Contains(out, "Task: x") { + t.Errorf("output should contain task x") } } @@ -758,3 +1084,358 @@ func TestRun_CommandFails(t *testing.T) { t.Errorf("run() with failing command returned 0, want non-zero") } } + +func TestRun_CLI_Success(t *testing.T) { + defer resetTestHooks() + + os.Args = []string{"codex-wrapper", "do-things"} + stdinReader = strings.NewReader("") + isTerminalFn = func() bool { return true } + + codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { + return []string{ + `{"type":"thread.started","thread_id":"cli-session"}` + "\n" + + `{"type":"item.completed","item":{"type":"agent_message","text":"ok"}}`, + } + } + + var exitCode int + output := captureStdout(t, func() { + exitCode = run() + }) + + if exitCode != 0 { + t.Fatalf("run() exit=%d, want 0", exitCode) + } + if !strings.Contains(output, "ok") { + t.Fatalf("expected agent output, got %q", output) + } + if !strings.Contains(output, "SESSION_ID: cli-session") { + t.Fatalf("expected session id output, got %q", output) + } +} + +func TestTopologicalSort_LinearChain(t *testing.T) { + tasks := []TaskSpec{ + {ID: "a"}, + {ID: "b", Dependencies: []string{"a"}}, + {ID: "c", Dependencies: []string{"b"}}, + } + + layers, err := topologicalSort(tasks) + if err != nil { + t.Fatalf("topologicalSort() unexpected error: %v", err) + } + + if len(layers) != 3 { + t.Fatalf("expected 3 layers, got %d", len(layers)) + } + + if layers[0][0].ID != "a" || layers[1][0].ID != "b" || layers[2][0].ID != "c" { + t.Fatalf("unexpected order: %+v", layers) + } +} + +func TestTopologicalSort_Branching(t *testing.T) { + tasks := []TaskSpec{ + {ID: "root"}, + {ID: "left", Dependencies: []string{"root"}}, + {ID: "right", Dependencies: []string{"root"}}, + {ID: "leaf", Dependencies: []string{"left", "right"}}, + } + + layers, err := topologicalSort(tasks) + if err != nil { + t.Fatalf("topologicalSort() unexpected error: %v", err) + } + + if len(layers) != 3 { + t.Fatalf("expected 3 layers, got %d", len(layers)) + } + + if len(layers[1]) != 2 { + t.Fatalf("expected branching layer size 2, got %d", len(layers[1])) + } +} + +func TestTopologicalSort_ParallelTasks(t *testing.T) { + tasks := []TaskSpec{{ID: "a"}, {ID: "b"}, {ID: "c"}} + + layers, err := topologicalSort(tasks) + if err != nil { + t.Fatalf("topologicalSort() unexpected error: %v", err) + } + + if len(layers) != 1 { + t.Fatalf("expected single layer, got %d", len(layers)) + } + if len(layers[0]) != 3 { + t.Fatalf("expected 3 tasks in layer, got %d", len(layers[0])) + } +} + +func TestShouldSkipTask(t *testing.T) { + failed := map[string]TaskResult{ + "a": {TaskID: "a", ExitCode: 1}, + "b": {TaskID: "b", ExitCode: 2}, + } + + tests := []struct { + name string + task TaskSpec + skip bool + reasonContains []string + }{ + {"no deps", TaskSpec{ID: "c"}, false, nil}, + {"missing deps not failed", TaskSpec{ID: "d", Dependencies: []string{"x"}}, false, nil}, + {"single failed dep", TaskSpec{ID: "e", Dependencies: []string{"a"}}, true, []string{"a"}}, + {"multiple failed deps", TaskSpec{ID: "f", Dependencies: []string{"a", "b"}}, true, []string{"a", "b"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + skip, reason := shouldSkipTask(tt.task, failed) + + if skip != tt.skip { + t.Fatalf("shouldSkipTask(%s) skip=%v, want %v", tt.name, skip, tt.skip) + } + for _, expect := range tt.reasonContains { + if !strings.Contains(reason, expect) { + t.Fatalf("reason %q missing %q", reason, expect) + } + } + }) + } +} + +func TestTopologicalSort_CycleDetection(t *testing.T) { + tasks := []TaskSpec{ + {ID: "a", Dependencies: []string{"b"}}, + {ID: "b", Dependencies: []string{"a"}}, + } + + if _, err := topologicalSort(tasks); err == nil || !strings.Contains(err.Error(), "cycle detected") { + t.Fatalf("expected cycle error, got %v", err) + } +} + +func TestTopologicalSort_IndirectCycle(t *testing.T) { + tasks := []TaskSpec{ + {ID: "a", Dependencies: []string{"c"}}, + {ID: "b", Dependencies: []string{"a"}}, + {ID: "c", Dependencies: []string{"b"}}, + } + + if _, err := topologicalSort(tasks); err == nil || !strings.Contains(err.Error(), "cycle detected") { + t.Fatalf("expected indirect cycle error, got %v", err) + } +} + +func TestTopologicalSort_MissingDependency(t *testing.T) { + tasks := []TaskSpec{ + {ID: "a", Dependencies: []string{"missing"}}, + } + + if _, err := topologicalSort(tasks); err == nil || !strings.Contains(err.Error(), "dependency \"missing\" not found") { + t.Fatalf("expected missing dependency error, got %v", err) + } +} + +func TestTopologicalSort_LargeGraph(t *testing.T) { + const count = 1000 + tasks := make([]TaskSpec, count) + for i := 0; i < count; i++ { + id := fmt.Sprintf("task-%d", i) + if i == 0 { + tasks[i] = TaskSpec{ID: id} + continue + } + prev := fmt.Sprintf("task-%d", i-1) + tasks[i] = TaskSpec{ID: id, Dependencies: []string{prev}} + } + + layers, err := topologicalSort(tasks) + if err != nil { + t.Fatalf("topologicalSort() unexpected error: %v", err) + } + + if len(layers) != count { + t.Fatalf("expected %d layers, got %d", count, len(layers)) + } +} + +func TestExecuteConcurrent_ParallelExecution(t *testing.T) { + orig := runCodexTaskFn + defer func() { runCodexTaskFn = orig }() + + var maxParallel int64 + var current int64 + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + cur := atomic.AddInt64(¤t, 1) + for { + prev := atomic.LoadInt64(&maxParallel) + if cur <= prev || atomic.CompareAndSwapInt64(&maxParallel, prev, cur) { + break + } + } + time.Sleep(150 * time.Millisecond) + atomic.AddInt64(¤t, -1) + return TaskResult{TaskID: task.ID} + } + + start := time.Now() + layers := [][]TaskSpec{{{ID: "a"}, {ID: "b"}, {ID: "c"}}} + results := executeConcurrent(layers, 10) + elapsed := time.Since(start) + + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + + if elapsed >= 400*time.Millisecond { + t.Fatalf("expected concurrent execution, took %v", elapsed) + } + if maxParallel < 2 { + t.Fatalf("expected parallelism >=2, got %d", maxParallel) + } +} + +func TestExecuteConcurrent_LayerOrdering(t *testing.T) { + orig := runCodexTaskFn + defer func() { runCodexTaskFn = orig }() + + var mu sync.Mutex + var order []string + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + mu.Lock() + order = append(order, task.ID) + mu.Unlock() + return TaskResult{TaskID: task.ID} + } + + layers := [][]TaskSpec{{{ID: "first-1"}, {ID: "first-2"}}, {{ID: "second"}}} + executeConcurrent(layers, 10) + + if len(order) != 3 { + t.Fatalf("expected 3 tasks recorded, got %d", len(order)) + } + + if order[0] != "first-1" && order[0] != "first-2" { + t.Fatalf("first task should come from first layer, got %s", order[0]) + } + if order[2] != "second" { + t.Fatalf("last task should be from second layer, got %s", order[2]) + } +} + +func TestExecuteConcurrent_ErrorIsolation(t *testing.T) { + orig := runCodexTaskFn + defer func() { runCodexTaskFn = orig }() + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + if task.ID == "fail" { + return TaskResult{TaskID: task.ID, ExitCode: 2, Error: "boom"} + } + return TaskResult{TaskID: task.ID, ExitCode: 0} + } + + layers := [][]TaskSpec{{{ID: "ok"}, {ID: "fail"}}, {{ID: "after"}}} + results := executeConcurrent(layers, 10) + + if len(results) != 3 { + t.Fatalf("expected 3 results, got %d", len(results)) + } + + var failed, succeeded bool + for _, res := range results { + if res.TaskID == "fail" && res.ExitCode == 2 { + failed = true + } + if res.TaskID == "after" && res.ExitCode == 0 { + succeeded = true + } + } + + if !failed || !succeeded { + t.Fatalf("expected failure isolation, got results: %+v", results) + } +} + +func TestExecuteConcurrent_PanicRecovered(t *testing.T) { + orig := runCodexTaskFn + defer func() { runCodexTaskFn = orig }() + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + panic("boom") + } + + results := executeConcurrent([][]TaskSpec{{{ID: "panic"}}}, 10) + + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + if results[0].Error == "" || results[0].ExitCode == 0 { + t.Fatalf("panic should be captured, got %+v", results[0]) + } +} + +func TestExecuteConcurrent_LargeFanout(t *testing.T) { + orig := runCodexTaskFn + defer func() { runCodexTaskFn = orig }() + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + return TaskResult{TaskID: task.ID} + } + + layer := make([]TaskSpec, 0, 1200) + for i := 0; i < 1200; i++ { + layer = append(layer, TaskSpec{ID: fmt.Sprintf("id-%d", i)}) + } + + results := executeConcurrent([][]TaskSpec{layer}, 10) + + if len(results) != 1200 { + t.Fatalf("expected 1200 results, got %d", len(results)) + } +} + +func TestRun_ParallelFlag(t *testing.T) { + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + os.Args = []string{"codex-wrapper", "--parallel"} + + jsonInput := `---TASK--- +id: T1 +---CONTENT--- +test` + stdinReader = strings.NewReader(jsonInput) + defer func() { stdinReader = os.Stdin }() + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + return TaskResult{ + TaskID: task.ID, + ExitCode: 0, + Message: "test output", + } + } + defer func() { + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + if task.WorkDir == "" { + task.WorkDir = defaultWorkdir + } + if task.Mode == "" { + task.Mode = "new" + } + return runCodexTask(task, true, timeout) + } + }() + + exitCode := run() + if exitCode != 0 { + t.Errorf("expected exit code 0, got %d", exitCode) + } +} diff --git a/skills/codex/SKILL.md b/skills/codex/SKILL.md index e2e75bc..09bb113 100644 --- a/skills/codex/SKILL.md +++ b/skills/codex/SKILL.md @@ -178,16 +178,52 @@ Add proper escaping and handle $variables correctly. EOF ``` -### Large Task Protocol +### Parallel Execution -- For every large task, first produce a canonical task list that enumerates the Task ID, description, file/directory scope, dependencies, test commands, and the expected Codex Bash invocation. -- Tasks without dependencies should be executed concurrently via multiple foreground Bash calls (you can keep separate terminal windows) and each run must log start/end times plus any shared resource usage. -- Reuse context aggressively (such as @spec.md or prior analysis output), and after concurrent execution finishes, reconcile against the task list to report which items completed and which slipped. +For multiple independent or dependent tasks, use `--parallel` mode with delimiter format: -| ID | Description | Scope | Dependencies | Tests | Command | -| --- | --- | --- | --- | --- | --- | -| T1 | Review @spec.md to extract requirements | docs/, @spec.md | None | None | `codex-wrapper - <<'EOF'`
`analyze requirements @spec.md`
`EOF` | -| T2 | Implement the module and add test cases | src/module | T1 | npm test -- --runInBand | `codex-wrapper - <<'EOF'`
`implement and test @src/module`
`EOF` | +```bash +codex-wrapper --parallel - <<'EOF' +---TASK--- +id: T1 +workdir: . +---CONTENT--- +analyze requirements @spec.md +---TASK--- +id: T2 +dependencies: T1 +---CONTENT--- +implement feature based on T1 analysis +---TASK--- +id: T3 +---CONTENT--- +independent task runs in parallel with T1 +EOF +``` + +**Delimiter Format**: +- `---TASK---`: Starts a new task block +- `id: `: Required, unique task identifier +- `workdir: `: Optional, working directory (default: `.`) +- `dependencies: , `: Optional, comma-separated task IDs +- `---CONTENT---`: Separates metadata from task content +- Task content: Any text, code, special characters (no escaping needed) + +**Output**: JSON with results and summary +```json +{ + "results": [ + {"task_id": "T1", "exit_code": 0, "message": "...", "session_id": "...", "error": ""} + ], + "summary": {"total": 3, "success": 3, "failed": 0} +} +``` + +**Features**: +- Automatic topological sorting based on dependencies +- Unlimited concurrency for independent tasks +- Error isolation (failed tasks don't stop others) +- Dependency blocking (dependent tasks skip if parent fails) ## Notes