feat(codeagent-wrapper): add multi-agent support with yolo mode

- Add --agent parameter for agent-based backend/model resolution
- Add --prompt-file parameter for agent prompt injection
- Add opencode backend support with JSON output parsing
- Add yolo field in agent config for auto-enabling dangerous flags
  - claude: --dangerously-skip-permissions
  - codex: --dangerously-bypass-approvals-and-sandbox
- Add develop agent for code development tasks
- Add omo skill for multi-agent orchestration with Sisyphus coordinator
- Bump version to 5.5.0

Generated with SWE-Agent.ai

Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
cexll
2026-01-12 14:11:15 +08:00
parent 55246ce9c4
commit 17e52d78d2
27 changed files with 3220 additions and 59 deletions

View File

@@ -0,0 +1,79 @@
package main
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
)
type AgentModelConfig struct {
Backend string `json:"backend"`
Model string `json:"model"`
PromptFile string `json:"prompt_file,omitempty"`
Description string `json:"description,omitempty"`
Yolo bool `json:"yolo,omitempty"`
}
type ModelsConfig struct {
DefaultBackend string `json:"default_backend"`
DefaultModel string `json:"default_model"`
Agents map[string]AgentModelConfig `json:"agents"`
}
var defaultModelsConfig = ModelsConfig{
DefaultBackend: "opencode",
DefaultModel: "opencode/grok-code",
Agents: map[string]AgentModelConfig{
"sisyphus": {Backend: "claude", Model: "claude-sonnet-4-20250514", PromptFile: "~/.claude/skills/omo/references/sisyphus.md", Description: "Primary orchestrator"},
"oracle": {Backend: "claude", Model: "claude-sonnet-4-20250514", PromptFile: "~/.claude/skills/omo/references/oracle.md", Description: "Technical advisor"},
"librarian": {Backend: "claude", Model: "claude-sonnet-4-5-20250514", PromptFile: "~/.claude/skills/omo/references/librarian.md", Description: "Researcher"},
"explore": {Backend: "opencode", Model: "opencode/grok-code", PromptFile: "~/.claude/skills/omo/references/explore.md", Description: "Code search"},
"develop": {Backend: "codex", Model: "", PromptFile: "~/.claude/skills/omo/references/develop.md", Description: "Code development"},
"frontend-ui-ux-engineer": {Backend: "gemini", Model: "gemini-3-pro-preview", PromptFile: "~/.claude/skills/omo/references/frontend-ui-ux-engineer.md", Description: "Frontend engineer"},
"document-writer": {Backend: "gemini", Model: "gemini-3-flash-preview", PromptFile: "~/.claude/skills/omo/references/document-writer.md", Description: "Documentation"},
},
}
func loadModelsConfig() *ModelsConfig {
home, err := os.UserHomeDir()
if err != nil {
logWarn(fmt.Sprintf("Failed to resolve home directory for models config: %v; using defaults", err))
return &defaultModelsConfig
}
configPath := filepath.Join(home, ".codeagent", "models.json")
data, err := os.ReadFile(configPath)
if err != nil {
if !os.IsNotExist(err) {
logWarn(fmt.Sprintf("Failed to read models config %s: %v; using defaults", configPath, err))
}
return &defaultModelsConfig
}
var cfg ModelsConfig
if err := json.Unmarshal(data, &cfg); err != nil {
logWarn(fmt.Sprintf("Failed to parse models config %s: %v; using defaults", configPath, err))
return &defaultModelsConfig
}
// Merge with defaults
for name, agent := range defaultModelsConfig.Agents {
if _, exists := cfg.Agents[name]; !exists {
if cfg.Agents == nil {
cfg.Agents = make(map[string]AgentModelConfig)
}
cfg.Agents[name] = agent
}
}
return &cfg
}
func resolveAgentConfig(agentName string) (backend, model, promptFile string, yolo bool) {
cfg := loadModelsConfig()
if agent, ok := cfg.Agents[agentName]; ok {
return agent.Backend, agent.Model, agent.PromptFile, agent.Yolo
}
return cfg.DefaultBackend, cfg.DefaultModel, "", false
}

View File

@@ -0,0 +1,209 @@
package main
import (
"os"
"path/filepath"
"reflect"
"testing"
)
func TestResolveAgentConfig_Defaults(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
// Test that default agents resolve correctly without config file
tests := []struct {
agent string
wantBackend string
wantModel string
wantPromptFile string
}{
{"sisyphus", "claude", "claude-sonnet-4-20250514", "~/.claude/skills/omo/references/sisyphus.md"},
{"oracle", "claude", "claude-sonnet-4-20250514", "~/.claude/skills/omo/references/oracle.md"},
{"librarian", "claude", "claude-sonnet-4-5-20250514", "~/.claude/skills/omo/references/librarian.md"},
{"explore", "opencode", "opencode/grok-code", "~/.claude/skills/omo/references/explore.md"},
{"frontend-ui-ux-engineer", "gemini", "gemini-3-pro-preview", "~/.claude/skills/omo/references/frontend-ui-ux-engineer.md"},
{"document-writer", "gemini", "gemini-3-flash-preview", "~/.claude/skills/omo/references/document-writer.md"},
}
for _, tt := range tests {
t.Run(tt.agent, func(t *testing.T) {
backend, model, promptFile, _ := resolveAgentConfig(tt.agent)
if backend != tt.wantBackend {
t.Errorf("backend = %q, want %q", backend, tt.wantBackend)
}
if model != tt.wantModel {
t.Errorf("model = %q, want %q", model, tt.wantModel)
}
if promptFile != tt.wantPromptFile {
t.Errorf("promptFile = %q, want %q", promptFile, tt.wantPromptFile)
}
})
}
}
func TestResolveAgentConfig_UnknownAgent(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
backend, model, promptFile, _ := resolveAgentConfig("unknown-agent")
if backend != "opencode" {
t.Errorf("unknown agent backend = %q, want %q", backend, "opencode")
}
if model != "opencode/grok-code" {
t.Errorf("unknown agent model = %q, want %q", model, "opencode/grok-code")
}
if promptFile != "" {
t.Errorf("unknown agent promptFile = %q, want empty", promptFile)
}
}
func TestLoadModelsConfig_NoFile(t *testing.T) {
home := "/nonexistent/path/that/does/not/exist"
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
cfg := loadModelsConfig()
if cfg.DefaultBackend != "opencode" {
t.Errorf("DefaultBackend = %q, want %q", cfg.DefaultBackend, "opencode")
}
if len(cfg.Agents) != 7 {
t.Errorf("len(Agents) = %d, want 7", len(cfg.Agents))
}
}
func TestLoadModelsConfig_WithFile(t *testing.T) {
// Create temp dir and config file
tmpDir := t.TempDir()
configDir := filepath.Join(tmpDir, ".codeagent")
if err := os.MkdirAll(configDir, 0755); err != nil {
t.Fatal(err)
}
configContent := `{
"default_backend": "claude",
"default_model": "claude-opus-4",
"agents": {
"custom-agent": {
"backend": "codex",
"model": "gpt-4o",
"description": "Custom agent"
}
}
}`
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil {
t.Fatal(err)
}
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)
cfg := loadModelsConfig()
if cfg.DefaultBackend != "claude" {
t.Errorf("DefaultBackend = %q, want %q", cfg.DefaultBackend, "claude")
}
if cfg.DefaultModel != "claude-opus-4" {
t.Errorf("DefaultModel = %q, want %q", cfg.DefaultModel, "claude-opus-4")
}
// Check custom agent
if agent, ok := cfg.Agents["custom-agent"]; !ok {
t.Error("custom-agent not found")
} else {
if agent.Backend != "codex" {
t.Errorf("custom-agent.Backend = %q, want %q", agent.Backend, "codex")
}
if agent.Model != "gpt-4o" {
t.Errorf("custom-agent.Model = %q, want %q", agent.Model, "gpt-4o")
}
}
// Check that defaults are merged
if _, ok := cfg.Agents["sisyphus"]; !ok {
t.Error("default agent sisyphus should be merged")
}
}
func TestLoadModelsConfig_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
configDir := filepath.Join(tmpDir, ".codeagent")
if err := os.MkdirAll(configDir, 0755); err != nil {
t.Fatal(err)
}
// Write invalid JSON
configPath := filepath.Join(configDir, "models.json")
if err := os.WriteFile(configPath, []byte("invalid json {"), 0644); err != nil {
t.Fatal(err)
}
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)
cfg := loadModelsConfig()
// Should fall back to defaults
if cfg.DefaultBackend != "opencode" {
t.Errorf("invalid JSON should fallback, got DefaultBackend = %q", cfg.DefaultBackend)
}
}
func TestOpencodeBackend_BuildArgs(t *testing.T) {
backend := OpencodeBackend{}
t.Run("basic", func(t *testing.T) {
cfg := &Config{Mode: "new"}
got := backend.BuildArgs(cfg, "hello")
want := []string{"run", "--format", "json", "hello"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
})
t.Run("with model", func(t *testing.T) {
cfg := &Config{Mode: "new", Model: "opencode/grok-code"}
got := backend.BuildArgs(cfg, "task")
want := []string{"run", "-m", "opencode/grok-code", "--format", "json", "task"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
})
t.Run("resume mode", func(t *testing.T) {
cfg := &Config{Mode: "resume", SessionID: "ses_123", Model: "opencode/grok-code"}
got := backend.BuildArgs(cfg, "follow-up")
want := []string{"run", "-m", "opencode/grok-code", "-s", "ses_123", "--format", "json", "follow-up"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
})
t.Run("resume without session", func(t *testing.T) {
cfg := &Config{Mode: "resume"}
got := backend.BuildArgs(cfg, "task")
want := []string{"run", "--format", "json", "task"}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
})
}
func TestOpencodeBackend_Interface(t *testing.T) {
backend := OpencodeBackend{}
if backend.Name() != "opencode" {
t.Errorf("Name() = %q, want %q", backend.Name(), "opencode")
}
if backend.Command() != "opencode" {
t.Errorf("Command() = %q, want %q", backend.Command(), "opencode")
}
}
func TestBackendRegistry_IncludesOpencode(t *testing.T) {
if _, ok := backendRegistry["opencode"]; !ok {
t.Error("backendRegistry should include opencode")
}
}

View File

@@ -0,0 +1,147 @@
package main
import (
"context"
"os"
"path/filepath"
"testing"
"time"
)
func TestValidateAgentName(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{name: "simple", input: "sisyphus", wantErr: false},
{name: "upper", input: "ABC", wantErr: false},
{name: "digits", input: "a1", wantErr: false},
{name: "dash underscore", input: "a-b_c", wantErr: false},
{name: "empty", input: "", wantErr: true},
{name: "space", input: "a b", wantErr: true},
{name: "slash", input: "a/b", wantErr: true},
{name: "dotdot", input: "../evil", wantErr: true},
{name: "unicode", input: "中文", wantErr: true},
{name: "symbol", input: "a$b", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateAgentName(tt.input)
if (err != nil) != tt.wantErr {
t.Fatalf("validateAgentName(%q) err=%v, wantErr=%v", tt.input, err, tt.wantErr)
}
})
}
}
func TestParseArgs_InvalidAgentNameRejected(t *testing.T) {
defer resetTestHooks()
os.Args = []string{"codeagent-wrapper", "--agent", "../evil", "task"}
if _, err := parseArgs(); err == nil {
t.Fatalf("expected parseArgs to reject invalid agent name")
}
}
func TestParseParallelConfig_InvalidAgentNameRejected(t *testing.T) {
input := `---TASK---
id: task-1
agent: ../evil
---CONTENT---
do something`
if _, err := parseParallelConfig([]byte(input)); err == nil {
t.Fatalf("expected parseParallelConfig to reject invalid agent name")
}
}
func TestParseParallelConfig_ResolvesAgentPromptFile(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
configDir := filepath.Join(home, ".codeagent")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
if err := os.WriteFile(filepath.Join(configDir, "models.json"), []byte(`{
"default_backend": "codex",
"default_model": "gpt-test",
"agents": {
"custom-agent": {
"backend": "codex",
"model": "gpt-test",
"prompt_file": "~/.claude/prompt.md"
}
}
}`), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
input := `---TASK---
id: task-1
agent: custom-agent
---CONTENT---
do something`
cfg, err := parseParallelConfig([]byte(input))
if err != nil {
t.Fatalf("parseParallelConfig() unexpected error: %v", err)
}
if len(cfg.Tasks) != 1 {
t.Fatalf("expected 1 task, got %d", len(cfg.Tasks))
}
if got := cfg.Tasks[0].PromptFile; got != "~/.claude/prompt.md" {
t.Fatalf("PromptFile = %q, want %q", got, "~/.claude/prompt.md")
}
}
func TestDefaultRunCodexTaskFn_AppliesAgentPromptFile(t *testing.T) {
defer resetTestHooks()
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
claudeDir := filepath.Join(home, ".claude")
if err := os.MkdirAll(claudeDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
if err := os.WriteFile(filepath.Join(claudeDir, "prompt.md"), []byte("P\n"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
fake := newFakeCmd(fakeCmdConfig{
StdoutPlan: []fakeStdoutEvent{
{Data: `{"type":"item.completed","item":{"type":"agent_message","text":"ok"}}` + "\n"},
},
WaitDelay: 2 * time.Millisecond,
})
newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner {
return fake
}
selectBackendFn = func(name string) (Backend, error) {
return testBackend{
name: name,
command: "fake-cmd",
argsFn: func(cfg *Config, targetArg string) []string {
return []string{targetArg}
},
}, nil
}
res := defaultRunCodexTaskFn(TaskSpec{
ID: "t",
Task: "do",
Backend: "codex",
PromptFile: "~/.claude/prompt.md",
}, 5)
if res.ExitCode != 0 {
t.Fatalf("unexpected result: %+v", res)
}
want := "<agent-prompt>\nP\n</agent-prompt>\n\ndo"
if got := fake.StdinContents(); got != want {
t.Fatalf("stdin mismatch:\n got=%q\nwant=%q", got, want)
}
}

View File

@@ -111,7 +111,7 @@ func buildClaudeArgs(cfg *Config, targetArg string) []string {
return nil
}
args := []string{"-p"}
if cfg.SkipPermissions {
if cfg.SkipPermissions || cfg.Yolo {
args = append(args, "--dangerously-skip-permissions")
}
@@ -146,6 +146,22 @@ func (GeminiBackend) BuildArgs(cfg *Config, targetArg string) []string {
return buildGeminiArgs(cfg, targetArg)
}
type OpencodeBackend struct{}
func (OpencodeBackend) Name() string { return "opencode" }
func (OpencodeBackend) Command() string { return "opencode" }
func (OpencodeBackend) BuildArgs(cfg *Config, targetArg string) []string {
args := []string{"run"}
if model := strings.TrimSpace(cfg.Model); model != "" {
args = append(args, "-m", model)
}
if cfg.Mode == "resume" && cfg.SessionID != "" {
args = append(args, "-s", cfg.SessionID)
}
args = append(args, "--format", "json", targetArg)
return args
}
func buildGeminiArgs(cfg *Config, targetArg string) []string {
if cfg == nil {
return nil

View File

@@ -86,8 +86,7 @@ func TestBackendBuildArgs_Model(t *testing.T) {
t.Run("codex includes --model when set", func(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
t.Setenv(key, "false")
backend := CodexBackend{}
cfg := &Config{Mode: "new", WorkDir: "/tmp", Model: "o3"}
@@ -139,8 +138,7 @@ func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) {
t.Run("codex build args omits bypass flag by default", func(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
t.Setenv(key, "false")
backend := CodexBackend{}
cfg := &Config{Mode: "new", WorkDir: "/tmp"}
@@ -153,8 +151,7 @@ func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) {
t.Run("codex build args includes bypass flag when enabled", func(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Setenv(key, "true")
t.Setenv(key, "true")
backend := CodexBackend{}
cfg := &Config{Mode: "new", WorkDir: "/tmp"}

View File

@@ -19,7 +19,11 @@ type Config struct {
ExplicitStdin bool
Timeout int
Backend string
Agent string
PromptFile string
PromptFileExplicit bool
SkipPermissions bool
Yolo bool
MaxParallelWorkers int
}
@@ -38,6 +42,8 @@ type TaskSpec struct {
SessionID string `json:"session_id,omitempty"`
Backend string `json:"backend,omitempty"`
Model string `json:"model,omitempty"`
Agent string `json:"agent,omitempty"`
PromptFile string `json:"prompt_file,omitempty"`
Mode string `json:"-"`
UseStdin bool `json:"-"`
Context context.Context `json:"-"`
@@ -63,9 +69,10 @@ type TaskResult struct {
}
var backendRegistry = map[string]Backend{
"codex": CodexBackend{},
"claude": ClaudeBackend{},
"gemini": GeminiBackend{},
"codex": CodexBackend{},
"claude": ClaudeBackend{},
"gemini": GeminiBackend{},
"opencode": OpencodeBackend{},
}
func selectBackend(name string) (Backend, error) {
@@ -105,6 +112,23 @@ func parseBoolFlag(val string, defaultValue bool) bool {
}
}
func validateAgentName(name string) error {
if strings.TrimSpace(name) == "" {
return fmt.Errorf("agent name is empty")
}
for _, r := range name {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case r == '-', r == '_':
default:
return fmt.Errorf("agent name %q contains invalid character %q", name, r)
}
}
return nil
}
func parseParallelConfig(data []byte) (*ParallelConfig, error) {
trimmed := bytes.TrimSpace(data)
if len(trimmed) == 0 {
@@ -132,6 +156,7 @@ func parseParallelConfig(data []byte) (*ParallelConfig, error) {
content := strings.TrimSpace(parts[1])
task := TaskSpec{WorkDir: defaultWorkdir}
agentSpecified := false
for _, line := range strings.Split(meta, "\n") {
line = strings.TrimSpace(line)
if line == "" {
@@ -156,6 +181,9 @@ func parseParallelConfig(data []byte) (*ParallelConfig, error) {
task.Backend = value
case "model":
task.Model = value
case "agent":
agentSpecified = true
task.Agent = value
case "dependencies":
for _, dep := range strings.Split(value, ",") {
dep = strings.TrimSpace(dep)
@@ -170,6 +198,23 @@ func parseParallelConfig(data []byte) (*ParallelConfig, error) {
task.Mode = "new"
}
if agentSpecified {
if strings.TrimSpace(task.Agent) == "" {
return nil, fmt.Errorf("task block #%d has empty agent field", taskIndex)
}
if err := validateAgentName(task.Agent); err != nil {
return nil, fmt.Errorf("task block #%d invalid agent name: %w", taskIndex, err)
}
backend, model, promptFile, _ := resolveAgentConfig(task.Agent)
if task.Backend == "" {
task.Backend = backend
}
if task.Model == "" {
task.Model = model
}
task.PromptFile = promptFile
}
if task.ID == "" {
return nil, fmt.Errorf("task block #%d missing id field", taskIndex)
}
@@ -203,11 +248,73 @@ func parseArgs() (*Config, error) {
backendName := defaultBackendName
model := ""
agentName := ""
promptFile := ""
promptFileExplicit := false
yolo := false
skipPermissions := envFlagEnabled("CODEAGENT_SKIP_PERMISSIONS")
filtered := make([]string, 0, len(args))
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case arg == "--agent":
if i+1 >= len(args) {
return nil, fmt.Errorf("--agent flag requires a value")
}
value := strings.TrimSpace(args[i+1])
if value == "" {
return nil, fmt.Errorf("--agent flag requires a value")
}
if err := validateAgentName(value); err != nil {
return nil, fmt.Errorf("--agent flag invalid value: %w", err)
}
resolvedBackend, resolvedModel, resolvedPromptFile, resolvedYolo := resolveAgentConfig(value)
backendName = resolvedBackend
model = resolvedModel
if !promptFileExplicit {
promptFile = resolvedPromptFile
}
yolo = resolvedYolo
agentName = value
i++
continue
case strings.HasPrefix(arg, "--agent="):
value := strings.TrimSpace(strings.TrimPrefix(arg, "--agent="))
if value == "" {
return nil, fmt.Errorf("--agent flag requires a value")
}
if err := validateAgentName(value); err != nil {
return nil, fmt.Errorf("--agent flag invalid value: %w", err)
}
resolvedBackend, resolvedModel, resolvedPromptFile, resolvedYolo := resolveAgentConfig(value)
backendName = resolvedBackend
model = resolvedModel
if !promptFileExplicit {
promptFile = resolvedPromptFile
}
yolo = resolvedYolo
agentName = value
continue
case arg == "--prompt-file":
if i+1 >= len(args) {
return nil, fmt.Errorf("--prompt-file flag requires a value")
}
value := strings.TrimSpace(args[i+1])
if value == "" {
return nil, fmt.Errorf("--prompt-file flag requires a value")
}
promptFile = value
promptFileExplicit = true
i++
continue
case strings.HasPrefix(arg, "--prompt-file="):
value := strings.TrimSpace(strings.TrimPrefix(arg, "--prompt-file="))
if value == "" {
return nil, fmt.Errorf("--prompt-file flag requires a value")
}
promptFile = value
promptFileExplicit = true
continue
case arg == "--backend":
if i+1 >= len(args) {
return nil, fmt.Errorf("--backend flag requires a value")
@@ -254,7 +361,7 @@ func parseArgs() (*Config, error) {
}
args = filtered
cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, SkipPermissions: skipPermissions, Model: strings.TrimSpace(model)}
cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, Agent: agentName, PromptFile: promptFile, PromptFileExplicit: promptFileExplicit, SkipPermissions: skipPermissions, Yolo: yolo, Model: strings.TrimSpace(model)}
cfg.MaxParallelWorkers = resolveMaxParallelWorkers()
if args[0] == "resume" {

View File

@@ -236,6 +236,13 @@ func defaultRunCodexTaskFn(task TaskSpec, timeout int) TaskResult {
if task.Mode == "" {
task.Mode = "new"
}
if strings.TrimSpace(task.PromptFile) != "" {
prompt, err := readAgentPromptFile(task.PromptFile, false)
if err != nil {
return TaskResult{TaskID: task.ID, ExitCode: 1, Error: "failed to read prompt file: " + err.Error()}
}
task.Task = wrapTaskWithAgentPrompt(prompt, task.Task)
}
if task.UseStdin || shouldUseStdin(task.Task, false) {
task.UseStdin = true
}
@@ -747,8 +754,8 @@ func buildCodexArgs(cfg *Config, targetArg string) []string {
args := []string{"e"}
if envFlagEnabled("CODEX_BYPASS_SANDBOX") {
logWarn("CODEX_BYPASS_SANDBOX=true: running without approval/sandbox protection")
if cfg.Yolo || envFlagEnabled("CODEX_BYPASS_SANDBOX") {
logWarn("YOLO mode or CODEX_BYPASS_SANDBOX=true: running without approval/sandbox protection")
args = append(args, "--dangerously-bypass-approvals-and-sandbox")
}

View File

@@ -282,8 +282,7 @@ func TestExecutorHelperCoverage(t *testing.T) {
t.Run("generateFinalOutputAndArgs", func(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
t.Setenv(key, "false")
out := generateFinalOutput([]TaskResult{
{TaskID: "ok", ExitCode: 0},
@@ -358,8 +357,7 @@ func TestExecutorHelperCoverage(t *testing.T) {
runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult {
return TaskResult{TaskID: task.ID, ExitCode: 0, Message: "done"}
}
os.Setenv("CODEAGENT_MAX_PARALLEL_WORKERS", "1")
defer os.Unsetenv("CODEAGENT_MAX_PARALLEL_WORKERS")
t.Setenv("CODEAGENT_MAX_PARALLEL_WORKERS", "1")
results := executeConcurrent([][]TaskSpec{{{ID: "wrap"}}}, 1)
if len(results) != 1 || results[0].TaskID != "wrap" {

View File

@@ -36,4 +36,3 @@ func TestLogWriterWriteLimitsBuffer(t *testing.T) {
t.Fatalf("log output missing truncated entry, got %q", string(data))
}
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"os/signal"
"path/filepath"
"reflect"
"strings"
"sync/atomic"
@@ -14,7 +15,7 @@ import (
)
const (
version = "5.4.0"
version = "5.5.0"
defaultWorkdir = "."
defaultTimeout = 7200 // seconds (2 hours)
defaultCoverageTarget = 90.0
@@ -372,6 +373,15 @@ func run() (exitCode int) {
}
}
if strings.TrimSpace(cfg.PromptFile) != "" {
prompt, err := readAgentPromptFile(cfg.PromptFile, cfg.PromptFileExplicit)
if err != nil {
logError("Failed to read prompt file: " + err.Error())
return 1
}
taskText = wrapTaskWithAgentPrompt(prompt, taskText)
}
useStdin := cfg.ExplicitStdin || shouldUseStdin(taskText, piped)
targetArg := taskText
@@ -446,6 +456,91 @@ func run() (exitCode int) {
return 0
}
func readAgentPromptFile(path string, allowOutsideClaudeDir bool) (string, error) {
raw := strings.TrimSpace(path)
if raw == "" {
return "", nil
}
expanded := raw
if raw == "~" || strings.HasPrefix(raw, "~/") || strings.HasPrefix(raw, "~\\") {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
if raw == "~" {
expanded = home
} else {
expanded = home + raw[1:]
}
}
absPath, err := filepath.Abs(expanded)
if err != nil {
return "", err
}
absPath = filepath.Clean(absPath)
home, err := os.UserHomeDir()
if err != nil {
if !allowOutsideClaudeDir {
return "", err
}
logWarn(fmt.Sprintf("Failed to resolve home directory for prompt file validation: %v; proceeding without restriction", err))
} else {
allowedDir := filepath.Clean(filepath.Join(home, ".claude"))
allowedAbs, err := filepath.Abs(allowedDir)
if err == nil {
allowedDir = filepath.Clean(allowedAbs)
}
isWithinDir := func(path, dir string) bool {
rel, err := filepath.Rel(dir, path)
if err != nil {
return false
}
rel = filepath.Clean(rel)
if rel == "." {
return true
}
if rel == ".." {
return false
}
prefix := ".." + string(os.PathSeparator)
return !strings.HasPrefix(rel, prefix)
}
if !allowOutsideClaudeDir {
if !isWithinDir(absPath, allowedDir) {
logWarn(fmt.Sprintf("Refusing to read prompt file outside %s: %s", allowedDir, absPath))
return "", fmt.Errorf("prompt file must be under %s", allowedDir)
}
resolvedPath, errPath := filepath.EvalSymlinks(absPath)
resolvedBase, errBase := filepath.EvalSymlinks(allowedDir)
if errPath == nil && errBase == nil {
resolvedPath = filepath.Clean(resolvedPath)
resolvedBase = filepath.Clean(resolvedBase)
if !isWithinDir(resolvedPath, resolvedBase) {
logWarn(fmt.Sprintf("Refusing to read prompt file outside %s (resolved): %s", resolvedBase, resolvedPath))
return "", fmt.Errorf("prompt file must be under %s", resolvedBase)
}
}
} else if !isWithinDir(absPath, allowedDir) {
logWarn(fmt.Sprintf("Reading prompt file outside %s: %s", allowedDir, absPath))
}
}
data, err := os.ReadFile(absPath)
if err != nil {
return "", err
}
return strings.TrimRight(string(data), "\r\n"), nil
}
func wrapTaskWithAgentPrompt(prompt string, task string) string {
return "<agent-prompt>\n" + prompt + "\n</agent-prompt>\n\n" + task
}
func setLogger(l *Logger) {
loggerPtr.Store(l)
}
@@ -496,6 +591,7 @@ func printHelp() {
Usage:
%[1]s "task" [workdir]
%[1]s --backend claude "task" [workdir]
%[1]s --prompt-file /path/to/prompt.md "task" [workdir]
%[1]s - [workdir] Read task from stdin
%[1]s resume <session_id> "task" [workdir]
%[1]s resume <session_id> - [workdir]

View File

@@ -641,7 +641,6 @@ func TestRunParallelTimeoutPropagation(t *testing.T) {
t.Cleanup(func() {
runCodexTaskFn = origRun
resetTestHooks()
os.Unsetenv("CODEX_TIMEOUT")
})
var receivedTimeout int
@@ -650,7 +649,7 @@ func TestRunParallelTimeoutPropagation(t *testing.T) {
return TaskResult{TaskID: task.ID, ExitCode: 124, Error: "timeout"}
}
os.Setenv("CODEX_TIMEOUT", "1")
t.Setenv("CODEX_TIMEOUT", "1")
input := `---TASK---
id: T
---CONTENT---

View File

@@ -1290,11 +1290,85 @@ func TestBackendParseArgs_ModelFlag(t *testing.T) {
}
}
func TestBackendParseArgs_PromptFileFlag(t *testing.T) {
tests := []struct {
name string
args []string
want string
wantErr bool
}{
{
name: "prompt file flag",
args: []string{"codeagent-wrapper", "--prompt-file", "/tmp/prompt.md", "task"},
want: "/tmp/prompt.md",
},
{
name: "prompt file equals syntax",
args: []string{"codeagent-wrapper", "--prompt-file=/tmp/prompt.md", "task"},
want: "/tmp/prompt.md",
},
{
name: "prompt file trimmed",
args: []string{"codeagent-wrapper", "--prompt-file", " /tmp/prompt.md ", "task"},
want: "/tmp/prompt.md",
},
{
name: "prompt file missing value",
args: []string{"codeagent-wrapper", "--prompt-file"},
wantErr: true,
},
{
name: "prompt file equals missing value",
args: []string{"codeagent-wrapper", "--prompt-file=", "task"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Args = tt.args
cfg, err := parseArgs()
if tt.wantErr {
if err == nil {
t.Fatalf("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.PromptFile != tt.want {
t.Fatalf("PromptFile = %q, want %q", cfg.PromptFile, tt.want)
}
})
}
}
func TestBackendParseArgs_PromptFileOverridesAgent(t *testing.T) {
defer resetTestHooks()
os.Args = []string{"codeagent-wrapper", "--prompt-file", "/tmp/custom.md", "--agent", "sisyphus", "task"}
cfg, err := parseArgs()
if err != nil {
t.Fatalf("parseArgs() unexpected error: %v", err)
}
if cfg.PromptFile != "/tmp/custom.md" {
t.Fatalf("PromptFile = %q, want %q", cfg.PromptFile, "/tmp/custom.md")
}
os.Args = []string{"codeagent-wrapper", "--agent", "sisyphus", "--prompt-file", "/tmp/custom.md", "task"}
cfg, err = parseArgs()
if err != nil {
t.Fatalf("parseArgs() unexpected error: %v", err)
}
if cfg.PromptFile != "/tmp/custom.md" {
t.Fatalf("PromptFile = %q, want %q", cfg.PromptFile, "/tmp/custom.md")
}
}
func TestBackendParseArgs_SkipPermissions(t *testing.T) {
const envKey = "CODEAGENT_SKIP_PERMISSIONS"
t.Cleanup(func() { os.Unsetenv(envKey) })
os.Setenv(envKey, "true")
t.Setenv(envKey, "true")
os.Args = []string{"codeagent-wrapper", "task"}
cfg, err := parseArgs()
if err != nil {
@@ -1365,19 +1439,17 @@ func TestBackendParseBoolFlag(t *testing.T) {
func TestBackendEnvFlagEnabled(t *testing.T) {
const key = "TEST_FLAG_ENABLED"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
t.Setenv(key, "")
if envFlagEnabled(key) {
t.Fatalf("envFlagEnabled should be false when unset")
}
os.Setenv(key, "true")
t.Setenv(key, "true")
if !envFlagEnabled(key) {
t.Fatalf("envFlagEnabled should be true for 'true'")
}
os.Setenv(key, "no")
t.Setenv(key, "no")
if envFlagEnabled(key) {
t.Fatalf("envFlagEnabled should be false for 'no'")
}
@@ -1672,10 +1744,94 @@ func TestRunShouldUseStdin(t *testing.T) {
}
}
func TestRun_PromptFilePrefixesTask(t *testing.T) {
t.Run("absolute path", func(t *testing.T) {
defer resetTestHooks()
cleanupLogsFn = func() (CleanupStats, error) { return CleanupStats{}, nil }
selectBackendFn = func(name string) (Backend, error) {
return testBackend{
name: name,
command: "echo",
argsFn: func(cfg *Config, targetArg string) []string {
return []string{targetArg}
},
}, nil
}
var gotTask string
runTaskFn = func(task TaskSpec, silent bool, timeout int) TaskResult {
gotTask = task.Task
return TaskResult{ExitCode: 0, Message: "ok"}
}
isTerminalFn = func() bool { return true }
stdinReader = strings.NewReader("")
promptPath := filepath.Join(t.TempDir(), "prompt.md")
prompt := "LINE1\nLINE2\n"
if err := os.WriteFile(promptPath, []byte(prompt), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
os.Args = []string{"codeagent-wrapper", "--prompt-file", promptPath, "do"}
if code := run(); code != 0 {
t.Fatalf("run() exit=%d, want 0", code)
}
want := "<agent-prompt>\nLINE1\nLINE2\n</agent-prompt>\n\ndo"
if gotTask != want {
t.Fatalf("task mismatch:\n got=%q\nwant=%q", gotTask, want)
}
})
t.Run("tilde expansion", func(t *testing.T) {
defer resetTestHooks()
cleanupLogsFn = func() (CleanupStats, error) { return CleanupStats{}, nil }
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
selectBackendFn = func(name string) (Backend, error) {
return testBackend{
name: name,
command: "echo",
argsFn: func(cfg *Config, targetArg string) []string {
return []string{targetArg}
},
}, nil
}
var gotTask string
runTaskFn = func(task TaskSpec, silent bool, timeout int) TaskResult {
gotTask = task.Task
return TaskResult{ExitCode: 0, Message: "ok"}
}
isTerminalFn = func() bool { return true }
stdinReader = strings.NewReader("")
promptPath := filepath.Join(home, "prompt.md")
if err := os.WriteFile(promptPath, []byte("P\n"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
os.Args = []string{"codeagent-wrapper", "--prompt-file", "~/prompt.md", "do"}
if code := run(); code != 0 {
t.Fatalf("run() exit=%d, want 0", code)
}
want := "<agent-prompt>\nP\n</agent-prompt>\n\ndo"
if gotTask != want {
t.Fatalf("task mismatch:\n got=%q\nwant=%q", gotTask, want)
}
})
}
func TestRunBuildCodexArgs_NewMode(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
t.Setenv(key, "false")
cfg := &Config{Mode: "new", WorkDir: "/test/dir"}
args := buildCodexArgs(cfg, "my task")
@@ -1698,8 +1854,7 @@ func TestRunBuildCodexArgs_NewMode(t *testing.T) {
func TestRunBuildCodexArgs_ResumeMode(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
t.Setenv(key, "false")
cfg := &Config{Mode: "resume", SessionID: "session-abc"}
args := buildCodexArgs(cfg, "-")
@@ -1723,8 +1878,7 @@ func TestRunBuildCodexArgs_ResumeMode(t *testing.T) {
func TestRunBuildCodexArgs_ResumeMode_EmptySessionHandledGracefully(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
t.Setenv(key, "false")
cfg := &Config{Mode: "resume", SessionID: " ", WorkDir: "/test/dir"}
args := buildCodexArgs(cfg, "task")
@@ -1964,8 +2118,7 @@ func TestRunResolveTimeout(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv("CODEX_TIMEOUT", tt.envVal)
defer os.Unsetenv("CODEX_TIMEOUT")
t.Setenv("CODEX_TIMEOUT", tt.envVal)
got := resolveTimeout()
if got != tt.want {
t.Errorf("resolveTimeout() with env=%q = %v, want %v", tt.envVal, got, tt.want)
@@ -2305,10 +2458,10 @@ func TestRunGetEnv(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Unsetenv(tt.key)
if tt.setEnv {
os.Setenv(tt.key, tt.envVal)
defer os.Unsetenv(tt.key)
t.Setenv(tt.key, tt.envVal)
} else {
t.Setenv(tt.key, "")
}
got := getEnv(tt.key, tt.defaultVal)
@@ -3412,7 +3565,7 @@ func TestVersionFlag(t *testing.T) {
}
})
want := "codeagent-wrapper version 5.4.0\n"
want := "codeagent-wrapper version 5.5.0\n"
if output != want {
t.Fatalf("output = %q, want %q", output, want)
@@ -3428,7 +3581,7 @@ func TestVersionShortFlag(t *testing.T) {
}
})
want := "codeagent-wrapper version 5.4.0\n"
want := "codeagent-wrapper version 5.5.0\n"
if output != want {
t.Fatalf("output = %q, want %q", output, want)
@@ -3444,7 +3597,7 @@ func TestVersionLegacyAlias(t *testing.T) {
}
})
want := "codex-wrapper version 5.4.0\n"
want := "codex-wrapper version 5.5.0\n"
if output != want {
t.Fatalf("output = %q, want %q", output, want)
@@ -4638,12 +4791,7 @@ func TestResolveMaxParallelWorkers(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue != "" {
os.Setenv("CODEAGENT_MAX_PARALLEL_WORKERS", tt.envValue)
} else {
os.Unsetenv("CODEAGENT_MAX_PARALLEL_WORKERS")
}
defer os.Unsetenv("CODEAGENT_MAX_PARALLEL_WORKERS")
t.Setenv("CODEAGENT_MAX_PARALLEL_WORKERS", tt.envValue)
got := resolveMaxParallelWorkers()
if got != tt.want {

View File

@@ -87,6 +87,18 @@ type UnifiedEvent struct {
Content string `json:"content,omitempty"`
Delta *bool `json:"delta,omitempty"`
Status string `json:"status,omitempty"`
// Opencode-specific fields (camelCase sessionID)
OpencodeSessionID string `json:"sessionID,omitempty"`
Part json.RawMessage `json:"part,omitempty"`
}
// OpencodePart represents the part field in opencode events
type OpencodePart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Reason string `json:"reason,omitempty"`
SessionID string `json:"sessionID,omitempty"`
}
// ItemContent represents the parsed item.text field for Codex events
@@ -120,9 +132,10 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
totalEvents := 0
var (
codexMessage string
claudeMessage string
geminiBuffer strings.Builder
codexMessage string
claudeMessage string
geminiBuffer strings.Builder
opencodeMessage strings.Builder
)
for {
@@ -172,6 +185,37 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
isClaude = true
}
isGemini := (event.Type == "init" && event.SessionID != "") || event.Role != "" || event.Delta != nil || event.Status != ""
isOpencode := event.OpencodeSessionID != "" && len(event.Part) > 0
// Handle Opencode events first (most specific detection)
if isOpencode {
if threadID == "" {
threadID = event.OpencodeSessionID
}
var part OpencodePart
if err := json.Unmarshal(event.Part, &part); err != nil {
warnFn(fmt.Sprintf("Failed to parse opencode part: %s", err.Error()))
continue
}
// Extract sessionID from part if available
if part.SessionID != "" && threadID == "" {
threadID = part.SessionID
}
infoFn(fmt.Sprintf("Parsed Opencode event #%d type=%s part_type=%s", totalEvents, event.Type, part.Type))
if event.Type == "text" && part.Text != "" {
opencodeMessage.WriteString(part.Text)
notifyMessage()
}
if part.Type == "step-finish" && part.Reason == "stop" {
notifyComplete()
}
continue
}
// Handle Codex events
if isCodex {
@@ -284,6 +328,8 @@ func parseJSONStreamInternal(r io.Reader, warnFn func(string), infoFn func(strin
}
switch {
case opencodeMessage.Len() > 0:
message = opencodeMessage.String()
case geminiBuffer.Len() > 0:
message = geminiBuffer.String()
case claudeMessage != "":

View File

@@ -0,0 +1,50 @@
package main
import (
"strings"
"testing"
)
func TestParseJSONStream_Opencode(t *testing.T) {
input := `{"type":"step_start","timestamp":1768187730683,"sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","part":{"id":"prt_bb0339afa001NTqoJ2NS8x91zP","sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","messageID":"msg_bb033866f0011oZxTqvfy0TKtS","type":"step-start","snapshot":"904f0fd58c125b79e60f0993e38f9d9f6200bf47"}}
{"type":"text","timestamp":1768187744432,"sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","part":{"id":"prt_bb0339cb5001QDd0Lh0PzFZpa3","sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","messageID":"msg_bb033866f0011oZxTqvfy0TKtS","type":"text","text":"Hello from opencode"}}
{"type":"step_finish","timestamp":1768187744471,"sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","part":{"id":"prt_bb033d0af0019VRZzpO2OVW1na","sessionID":"ses_44fced3c7ffe83sZpzY1rlQka3","messageID":"msg_bb033866f0011oZxTqvfy0TKtS","type":"step-finish","reason":"stop","snapshot":"904f0fd58c125b79e60f0993e38f9d9f6200bf47","cost":0}}`
message, threadID := parseJSONStream(strings.NewReader(input))
if threadID != "ses_44fced3c7ffe83sZpzY1rlQka3" {
t.Errorf("threadID = %q, want %q", threadID, "ses_44fced3c7ffe83sZpzY1rlQka3")
}
if message != "Hello from opencode" {
t.Errorf("message = %q, want %q", message, "Hello from opencode")
}
}
func TestParseJSONStream_Opencode_MultipleTextEvents(t *testing.T) {
input := `{"type":"text","sessionID":"ses_123","part":{"type":"text","text":"Part 1"}}
{"type":"text","sessionID":"ses_123","part":{"type":"text","text":" Part 2"}}
{"type":"step_finish","sessionID":"ses_123","part":{"type":"step-finish","reason":"stop"}}`
message, threadID := parseJSONStream(strings.NewReader(input))
if threadID != "ses_123" {
t.Errorf("threadID = %q, want %q", threadID, "ses_123")
}
if message != "Part 1 Part 2" {
t.Errorf("message = %q, want %q", message, "Part 1 Part 2")
}
}
func TestParseJSONStream_Opencode_NoStopReason(t *testing.T) {
input := `{"type":"text","sessionID":"ses_456","part":{"type":"text","text":"Content"}}
{"type":"step_finish","sessionID":"ses_456","part":{"type":"step-finish","reason":"tool-calls"}}`
message, threadID := parseJSONStream(strings.NewReader(input))
if threadID != "ses_456" {
t.Errorf("threadID = %q, want %q", threadID, "ses_456")
}
if message != "Content" {
t.Errorf("message = %q, want %q", message, "Content")
}
}

View File

@@ -30,4 +30,3 @@ func TestBackendParseJSONStream_UnknownEventsAreSilent(t *testing.T) {
}
}
}

View File

@@ -17,10 +17,10 @@ const (
)
var (
findProcess = os.FindProcess
kernel32 = syscall.NewLazyDLL("kernel32.dll")
getProcessTimes = kernel32.NewProc("GetProcessTimes")
fileTimeToUnixFn = fileTimeToUnix
findProcess = os.FindProcess
kernel32 = syscall.NewLazyDLL("kernel32.dll")
getProcessTimes = kernel32.NewProc("GetProcessTimes")
fileTimeToUnixFn = fileTimeToUnix
)
// isProcessRunning returns true if a process with the given pid is running on Windows.

View File

@@ -0,0 +1,163 @@
package main
import (
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func TestWrapTaskWithAgentPrompt(t *testing.T) {
got := wrapTaskWithAgentPrompt("P", "do")
want := "<agent-prompt>\nP\n</agent-prompt>\n\ndo"
if got != want {
t.Fatalf("wrapTaskWithAgentPrompt mismatch:\n got=%q\nwant=%q", got, want)
}
}
func TestReadAgentPromptFile_EmptyPath(t *testing.T) {
for _, allowOutside := range []bool{false, true} {
got, err := readAgentPromptFile(" ", allowOutside)
if err != nil {
t.Fatalf("unexpected error (allowOutside=%v): %v", allowOutside, err)
}
if got != "" {
t.Fatalf("expected empty result (allowOutside=%v), got %q", allowOutside, got)
}
}
}
func TestReadAgentPromptFile_ExplicitAbsolutePath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "prompt.md")
if err := os.WriteFile(path, []byte("LINE1\n"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
got, err := readAgentPromptFile(path, true)
if err != nil {
t.Fatalf("readAgentPromptFile error: %v", err)
}
if got != "LINE1" {
t.Fatalf("got %q, want %q", got, "LINE1")
}
}
func TestReadAgentPromptFile_ExplicitTildeExpansion(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
path := filepath.Join(home, "prompt.md")
if err := os.WriteFile(path, []byte("P\n"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
got, err := readAgentPromptFile("~/prompt.md", true)
if err != nil {
t.Fatalf("readAgentPromptFile error: %v", err)
}
if got != "P" {
t.Fatalf("got %q, want %q", got, "P")
}
}
func TestReadAgentPromptFile_RestrictedAllowsClaudeDir(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
claudeDir := filepath.Join(home, ".claude")
if err := os.MkdirAll(claudeDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
path := filepath.Join(claudeDir, "prompt.md")
if err := os.WriteFile(path, []byte("OK\n"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
got, err := readAgentPromptFile("~/.claude/prompt.md", false)
if err != nil {
t.Fatalf("readAgentPromptFile error: %v", err)
}
if got != "OK" {
t.Fatalf("got %q, want %q", got, "OK")
}
}
func TestReadAgentPromptFile_RestrictedRejectsOutsideClaudeDir(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
path := filepath.Join(home, "prompt.md")
if err := os.WriteFile(path, []byte("NO\n"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if _, err := readAgentPromptFile("~/prompt.md", false); err == nil {
t.Fatalf("expected error for prompt file outside ~/.claude, got nil")
}
}
func TestReadAgentPromptFile_RestrictedRejectsTraversal(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
path := filepath.Join(home, "secret.md")
if err := os.WriteFile(path, []byte("SECRET\n"), 0o644); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if _, err := readAgentPromptFile("~/.claude/../secret.md", false); err == nil {
t.Fatalf("expected traversal to be rejected, got nil")
}
}
func TestReadAgentPromptFile_NotFound(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
claudeDir := filepath.Join(home, ".claude")
if err := os.MkdirAll(claudeDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
_, err := readAgentPromptFile("~/.claude/missing.md", false)
if err == nil || !os.IsNotExist(err) {
t.Fatalf("expected not-exist error, got %v", err)
}
}
func TestReadAgentPromptFile_PermissionDenied(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("chmod-based permission test is not reliable on Windows")
}
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
claudeDir := filepath.Join(home, ".claude")
if err := os.MkdirAll(claudeDir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
path := filepath.Join(claudeDir, "private.md")
if err := os.WriteFile(path, []byte("PRIVATE\n"), 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
if err := os.Chmod(path, 0o000); err != nil {
t.Fatalf("Chmod: %v", err)
}
_, err := readAgentPromptFile("~/.claude/private.md", false)
if err == nil {
t.Fatalf("expected permission error, got nil")
}
if !os.IsPermission(err) && !strings.Contains(strings.ToLower(err.Error()), "permission") {
t.Fatalf("expected permission denied, got: %v", err)
}
}