From 13465b12e56f21f77cca1c56f184670c0d99a2d0 Mon Sep 17 00:00:00 2001 From: cexll Date: Tue, 6 Jan 2026 15:03:21 +0800 Subject: [PATCH] fix: support model parameter for all backends, auto-inject from settings (#105) - Add Model field to Config and TaskSpec for per-task model selection - Parse --model flag and model: metadata in parallel tasks - Auto-inject model from ~/.claude/settings.json for claude backend in new mode - Pass --model to claude CLI, -m to gemini CLI, --model to codex CLI - Preserve --setting-sources "" isolation while reading minimal safe subset - Add comprehensive tests for model parsing, propagation, and settings injection Fixes #105 Generated with SWE-Agent.ai Co-Authored-By: SWE-Agent.ai --- codeagent-wrapper/backend.go | 54 +++++-- codeagent-wrapper/backend_test.go | 36 +++++ codeagent-wrapper/config.go | 21 ++- codeagent-wrapper/executor.go | 20 ++- codeagent-wrapper/main.go | 22 ++- codeagent-wrapper/main_test.go | 237 ++++++++++++++++++++++++++++++ 6 files changed, 374 insertions(+), 16 deletions(-) diff --git a/codeagent-wrapper/backend.go b/codeagent-wrapper/backend.go index bcb6ecf..e0b2c12 100644 --- a/codeagent-wrapper/backend.go +++ b/codeagent-wrapper/backend.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" ) // Backend defines the contract for invoking different AI CLI backends. @@ -37,33 +38,48 @@ func (ClaudeBackend) BuildArgs(cfg *Config, targetArg string) []string { const maxClaudeSettingsBytes = 1 << 20 // 1MB -// loadMinimalEnvSettings 从 ~/.claude/settings.json 只提取 env 配置。 -// 只接受字符串类型的值;文件缺失/解析失败/超限都返回空。 -func loadMinimalEnvSettings() map[string]string { +type minimalClaudeSettings struct { + Env map[string]string + Model string +} + +// loadMinimalClaudeSettings 从 ~/.claude/settings.json 只提取安全的最小子集: +// - env: 只接受字符串类型的值 +// - model: 只接受字符串类型的值 +// 文件缺失/解析失败/超限都返回空。 +func loadMinimalClaudeSettings() minimalClaudeSettings { home, err := os.UserHomeDir() if err != nil || home == "" { - return nil + return minimalClaudeSettings{} } settingPath := filepath.Join(home, ".claude", "settings.json") info, err := os.Stat(settingPath) if err != nil || info.Size() > maxClaudeSettingsBytes { - return nil + return minimalClaudeSettings{} } data, err := os.ReadFile(settingPath) if err != nil { - return nil + return minimalClaudeSettings{} } var cfg struct { - Env map[string]any `json:"env"` + Env map[string]any `json:"env"` + Model any `json:"model"` } if err := json.Unmarshal(data, &cfg); err != nil { - return nil + return minimalClaudeSettings{} } + + out := minimalClaudeSettings{} + + if model, ok := cfg.Model.(string); ok { + out.Model = strings.TrimSpace(model) + } + if len(cfg.Env) == 0 { - return nil + return out } env := make(map[string]string, len(cfg.Env)) @@ -75,9 +91,19 @@ func loadMinimalEnvSettings() map[string]string { env[k] = s } if len(env) == 0 { + return out + } + out.Env = env + return out +} + +// loadMinimalEnvSettings is kept for backwards tests; prefer loadMinimalClaudeSettings. +func loadMinimalEnvSettings() map[string]string { + settings := loadMinimalClaudeSettings() + if len(settings.Env) == 0 { return nil } - return env + return settings.Env } func buildClaudeArgs(cfg *Config, targetArg string) []string { @@ -93,6 +119,10 @@ func buildClaudeArgs(cfg *Config, targetArg string) []string { // This ensures a clean execution environment without CLAUDE.md or skills that would trigger codeagent args = append(args, "--setting-sources", "") + if model := strings.TrimSpace(cfg.Model); model != "" { + args = append(args, "--model", model) + } + if cfg.Mode == "resume" { if cfg.SessionID != "" { // Claude CLI uses -r for resume. @@ -122,6 +152,10 @@ func buildGeminiArgs(cfg *Config, targetArg string) []string { } args := []string{"-o", "stream-json", "-y"} + if model := strings.TrimSpace(cfg.Model); model != "" { + args = append(args, "-m", model) + } + if cfg.Mode == "resume" { if cfg.SessionID != "" { args = append(args, "-r", cfg.SessionID) diff --git a/codeagent-wrapper/backend_test.go b/codeagent-wrapper/backend_test.go index 92faf1b..1b2ad77 100644 --- a/codeagent-wrapper/backend_test.go +++ b/codeagent-wrapper/backend_test.go @@ -63,6 +63,42 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) { }) } +func TestBackendBuildArgs_Model(t *testing.T) { + t.Run("claude includes --model when set", func(t *testing.T) { + backend := ClaudeBackend{} + cfg := &Config{Mode: "new", Model: "opus"} + got := backend.BuildArgs(cfg, "todo") + want := []string{"-p", "--setting-sources", "", "--model", "opus", "--output-format", "stream-json", "--verbose", "todo"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %v, want %v", got, want) + } + }) + + t.Run("gemini includes -m when set", func(t *testing.T) { + backend := GeminiBackend{} + cfg := &Config{Mode: "new", Model: "gemini-3-pro-preview"} + got := backend.BuildArgs(cfg, "task") + want := []string{"-o", "stream-json", "-y", "-m", "gemini-3-pro-preview", "-p", "task"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %v, want %v", got, want) + } + }) + + t.Run("codex includes --model when set", func(t *testing.T) { + const key = "CODEX_BYPASS_SANDBOX" + t.Cleanup(func() { os.Unsetenv(key) }) + os.Unsetenv(key) + + backend := CodexBackend{} + cfg := &Config{Mode: "new", WorkDir: "/tmp", Model: "o3"} + got := backend.BuildArgs(cfg, "task") + want := []string{"e", "--model", "o3", "--skip-git-repo-check", "-C", "/tmp", "--json", "task"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %v, want %v", got, want) + } + }) +} + func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) { t.Run("gemini new mode defaults workdir", func(t *testing.T) { backend := GeminiBackend{} diff --git a/codeagent-wrapper/config.go b/codeagent-wrapper/config.go index f7ad663..814e758 100644 --- a/codeagent-wrapper/config.go +++ b/codeagent-wrapper/config.go @@ -15,6 +15,7 @@ type Config struct { Task string SessionID string WorkDir string + Model string ExplicitStdin bool Timeout int Backend string @@ -36,6 +37,7 @@ type TaskSpec struct { Dependencies []string `json:"dependencies,omitempty"` SessionID string `json:"session_id,omitempty"` Backend string `json:"backend,omitempty"` + Model string `json:"model,omitempty"` Mode string `json:"-"` UseStdin bool `json:"-"` Context context.Context `json:"-"` @@ -152,6 +154,8 @@ func parseParallelConfig(data []byte) (*ParallelConfig, error) { task.Mode = "resume" case "backend": task.Backend = value + case "model": + task.Model = value case "dependencies": for _, dep := range strings.Split(value, ",") { dep = strings.TrimSpace(dep) @@ -198,6 +202,7 @@ func parseArgs() (*Config, error) { } backendName := defaultBackendName + model := "" skipPermissions := envFlagEnabled("CODEAGENT_SKIP_PERMISSIONS") filtered := make([]string, 0, len(args)) for i := 0; i < len(args); i++ { @@ -220,6 +225,20 @@ func parseArgs() (*Config, error) { case arg == "--skip-permissions", arg == "--dangerously-skip-permissions": skipPermissions = true continue + case arg == "--model": + if i+1 >= len(args) { + return nil, fmt.Errorf("--model flag requires a value") + } + model = args[i+1] + i++ + continue + case strings.HasPrefix(arg, "--model="): + value := strings.TrimPrefix(arg, "--model=") + if value == "" { + return nil, fmt.Errorf("--model flag requires a value") + } + model = value + continue case strings.HasPrefix(arg, "--skip-permissions="): skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--skip-permissions="), skipPermissions) continue @@ -235,7 +254,7 @@ func parseArgs() (*Config, error) { } args = filtered - cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, SkipPermissions: skipPermissions} + cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, SkipPermissions: skipPermissions, Model: strings.TrimSpace(model)} cfg.MaxParallelWorkers = resolveMaxParallelWorkers() if args[0] == "resume" { diff --git a/codeagent-wrapper/executor.go b/codeagent-wrapper/executor.go index db34887..2812a69 100644 --- a/codeagent-wrapper/executor.go +++ b/codeagent-wrapper/executor.go @@ -744,6 +744,10 @@ func buildCodexArgs(cfg *Config, targetArg string) []string { args = append(args, "--dangerously-bypass-approvals-and-sandbox") } + if model := strings.TrimSpace(cfg.Model); model != "" { + args = append(args, "--model", model) + } + args = append(args, "--skip-git-repo-check") if isResume { @@ -788,6 +792,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe Task: taskSpec.Task, SessionID: taskSpec.SessionID, WorkDir: taskSpec.WorkDir, + Model: taskSpec.Model, Backend: defaultBackendName, } @@ -816,6 +821,15 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe return result } + var claudeEnv map[string]string + if cfg.Backend == "claude" { + settings := loadMinimalClaudeSettings() + claudeEnv = settings.Env + if cfg.Mode != "resume" && strings.TrimSpace(cfg.Model) == "" && settings.Model != "" { + cfg.Model = settings.Model + } + } + useStdin := taskSpec.UseStdin targetArg := taskSpec.Task if useStdin { @@ -915,10 +929,8 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe cmd := newCommandRunner(ctx, commandName, codexArgs...) - if cfg.Backend == "claude" { - if env := loadMinimalEnvSettings(); len(env) > 0 { - cmd.SetEnv(env) - } + if cfg.Backend == "claude" && len(claudeEnv) > 0 { + cmd.SetEnv(claudeEnv) } // For backends that don't support -C flag (claude, gemini), set working directory via cmd.Dir diff --git a/codeagent-wrapper/main.go b/codeagent-wrapper/main.go index 39ce81e..ffe251b 100644 --- a/codeagent-wrapper/main.go +++ b/codeagent-wrapper/main.go @@ -178,6 +178,7 @@ func run() (exitCode int) { if parallelIndex != -1 { backendName := defaultBackendName + model := "" fullOutput := false var extras []string @@ -202,13 +203,27 @@ func run() (exitCode int) { return 1 } backendName = value + case arg == "--model": + if i+1 >= len(args) { + fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value") + return 1 + } + model = args[i+1] + i++ + case strings.HasPrefix(arg, "--model="): + value := strings.TrimPrefix(arg, "--model=") + if value == "" { + fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value") + return 1 + } + model = value default: extras = append(extras, arg) } } if len(extras) > 0 { - fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend and --full-output are allowed.") + fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend, --model and --full-output are allowed.") fmt.Fprintln(os.Stderr, "Usage examples:") fmt.Fprintf(os.Stderr, " %s --parallel < tasks.txt\n", name) fmt.Fprintf(os.Stderr, " echo '...' | %s --parallel\n", name) @@ -237,10 +252,14 @@ func run() (exitCode int) { } cfg.GlobalBackend = backendName + model = strings.TrimSpace(model) for i := range cfg.Tasks { if strings.TrimSpace(cfg.Tasks[i].Backend) == "" { cfg.Tasks[i].Backend = backendName } + if strings.TrimSpace(cfg.Tasks[i].Model) == "" && model != "" { + cfg.Tasks[i].Model = model + } } timeoutSec := resolveTimeout() @@ -409,6 +428,7 @@ func run() (exitCode int) { WorkDir: cfg.WorkDir, Mode: cfg.Mode, SessionID: cfg.SessionID, + Model: cfg.Model, UseStdin: useStdin, } diff --git a/codeagent-wrapper/main_test.go b/codeagent-wrapper/main_test.go index b4ac34a..18b271f 100644 --- a/codeagent-wrapper/main_test.go +++ b/codeagent-wrapper/main_test.go @@ -1139,6 +1139,65 @@ func TestBackendParseArgs_BackendFlag(t *testing.T) { } } +func TestBackendParseArgs_ModelFlag(t *testing.T) { + tests := []struct { + name string + args []string + want string + wantErr bool + }{ + { + name: "model flag", + args: []string{"codeagent-wrapper", "--model", "opus", "task"}, + want: "opus", + }, + { + name: "model equals syntax", + args: []string{"codeagent-wrapper", "--model=opus", "task"}, + want: "opus", + }, + { + name: "model trimmed", + args: []string{"codeagent-wrapper", "--model", " opus ", "task"}, + want: "opus", + }, + { + name: "model with resume mode", + args: []string{"codeagent-wrapper", "--model", "sonnet", "resume", "sid", "task"}, + want: "sonnet", + }, + { + name: "missing model value", + args: []string{"codeagent-wrapper", "--model"}, + wantErr: true, + }, + { + name: "model equals missing value", + args: []string{"codeagent-wrapper", "--model=", "task"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Args = tt.args + cfg, err := parseArgs() + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Model != tt.want { + t.Fatalf("Model = %q, want %q", cfg.Model, tt.want) + } + }) + } +} + func TestBackendParseArgs_SkipPermissions(t *testing.T) { const envKey = "CODEAGENT_SKIP_PERMISSIONS" t.Cleanup(func() { os.Unsetenv(envKey) }) @@ -1276,6 +1335,26 @@ do something` } } +func TestParallelParseConfig_Model(t *testing.T) { + input := `---TASK--- +id: task-1 +model: opus +---CONTENT--- +do something` + + cfg, err := parseParallelConfig([]byte(input)) + if err != nil { + t.Fatalf("parseParallelConfig() unexpected error: %v", err) + } + if len(cfg.Tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(cfg.Tasks)) + } + task := cfg.Tasks[0] + if task.Model != "opus" { + t.Fatalf("model = %q, want opus", task.Model) + } +} + func TestParallelParseConfig_EmptySessionID(t *testing.T) { input := `---TASK--- id: task-1 @@ -1358,6 +1437,120 @@ code with special chars: $var "quotes"` } } +func TestClaudeModel_DefaultsFromSettings(t *testing.T) { + defer resetTestHooks() + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + dir := filepath.Join(home, ".claude") + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("MkdirAll: %v", err) + } + + settingsModel := "claude-opus-4-5-20250929" + path := filepath.Join(dir, "settings.json") + data := []byte(fmt.Sprintf(`{"model":%q,"env":{"FOO":"bar"}}`, settingsModel)) + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + makeRunner := func(gotName *string, gotArgs *[]string, fake **fakeCmd) func(context.Context, string, ...string) commandRunner { + return func(ctx context.Context, name string, args ...string) commandRunner { + *gotName = name + *gotArgs = append([]string(nil), args...) + cmd := newFakeCmd(fakeCmdConfig{ + PID: 123, + StdoutPlan: []fakeStdoutEvent{ + {Data: "{\"type\":\"result\",\"session_id\":\"sid\",\"result\":\"ok\"}\n"}, + }, + }) + *fake = cmd + return cmd + } + } + + t.Run("new mode inherits model when unset", func(t *testing.T) { + var ( + gotName string + gotArgs []string + fake *fakeCmd + ) + origRunner := newCommandRunner + newCommandRunner = makeRunner(&gotName, &gotArgs, &fake) + t.Cleanup(func() { newCommandRunner = origRunner }) + + res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5) + if res.ExitCode != 0 || res.Message != "ok" { + t.Fatalf("unexpected result: %+v", res) + } + if gotName != "claude" { + t.Fatalf("command = %q, want claude", gotName) + } + found := false + for i := 0; i+1 < len(gotArgs); i++ { + if gotArgs[i] == "--model" && gotArgs[i+1] == settingsModel { + found = true + break + } + } + if !found { + t.Fatalf("expected --model %q in args, got %v", settingsModel, gotArgs) + } + if fake == nil || fake.env["FOO"] != "bar" { + t.Fatalf("expected env to include FOO=bar, got %v", fake.env) + } + }) + + t.Run("explicit model overrides settings", func(t *testing.T) { + var ( + gotName string + gotArgs []string + fake *fakeCmd + ) + origRunner := newCommandRunner + newCommandRunner = makeRunner(&gotName, &gotArgs, &fake) + t.Cleanup(func() { newCommandRunner = origRunner }) + + res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir, Model: "sonnet"}, ClaudeBackend{}, nil, false, true, 5) + if res.ExitCode != 0 || res.Message != "ok" { + t.Fatalf("unexpected result: %+v", res) + } + found := false + for i := 0; i+1 < len(gotArgs); i++ { + if gotArgs[i] == "--model" && gotArgs[i+1] == "sonnet" { + found = true + break + } + } + if !found { + t.Fatalf("expected --model sonnet in args, got %v", gotArgs) + } + }) + + t.Run("resume mode does not inherit model by default", func(t *testing.T) { + var ( + gotName string + gotArgs []string + fake *fakeCmd + ) + origRunner := newCommandRunner + newCommandRunner = makeRunner(&gotName, &gotArgs, &fake) + t.Cleanup(func() { newCommandRunner = origRunner }) + + res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "resume", SessionID: "sid-123", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5) + if res.ExitCode != 0 || res.Message != "ok" { + t.Fatalf("unexpected result: %+v", res) + } + for i := 0; i < len(gotArgs); i++ { + if gotArgs[i] == "--model" { + t.Fatalf("did not expect --model in resume args, got %v", gotArgs) + } + } + }) +} + func TestRunShouldUseStdin(t *testing.T) { tests := []struct { name string @@ -2947,6 +3140,50 @@ do two`) } } +func TestParallelModelPropagation(t *testing.T) { + defer resetTestHooks() + cleanupLogsFn = func() (CleanupStats, error) { return CleanupStats{}, nil } + + orig := runCodexTaskFn + var mu sync.Mutex + seen := make(map[string]string) + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + mu.Lock() + seen[task.ID] = task.Model + mu.Unlock() + return TaskResult{TaskID: task.ID, ExitCode: 0, Message: "ok"} + } + t.Cleanup(func() { runCodexTaskFn = orig }) + + stdinReader = strings.NewReader(`---TASK--- +id: first +---CONTENT--- +do one + +---TASK--- +id: second +model: opus +---CONTENT--- +do two`) + os.Args = []string{"codeagent-wrapper", "--parallel", "--model", "sonnet"} + + if code := run(); code != 0 { + t.Fatalf("run exit = %d, want 0", code) + } + + mu.Lock() + firstModel, firstOK := seen["first"] + secondModel, secondOK := seen["second"] + mu.Unlock() + + if !firstOK || firstModel != "sonnet" { + t.Fatalf("first model = %q (present=%v), want sonnet", firstModel, firstOK) + } + if !secondOK || secondModel != "opus" { + t.Fatalf("second model = %q (present=%v), want opus", secondModel, secondOK) + } +} + func TestParallelFlag(t *testing.T) { oldArgs := os.Args defer func() { os.Args = oldArgs }()