fix: support model parameter for all backends, auto-inject from settings (#105)

- Add Model field to Config and TaskSpec for per-task model selection
- Parse --model flag and model: metadata in parallel tasks
- Auto-inject model from ~/.claude/settings.json for claude backend in new mode
- Pass --model to claude CLI, -m to gemini CLI, --model to codex CLI
- Preserve --setting-sources "" isolation while reading minimal safe subset
- Add comprehensive tests for model parsing, propagation, and settings injection

Fixes #105

Generated with SWE-Agent.ai

Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
cexll
2026-01-06 15:03:21 +08:00
parent cf93a0ada9
commit 66df48ea76
6 changed files with 374 additions and 16 deletions

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"os"
"path/filepath"
"strings"
)
// Backend defines the contract for invoking different AI CLI backends.
@@ -37,33 +38,48 @@ func (ClaudeBackend) BuildArgs(cfg *Config, targetArg string) []string {
const maxClaudeSettingsBytes = 1 << 20 // 1MB
// loadMinimalEnvSettings 从 ~/.claude/settings.json 只提取 env 配置。
// 只接受字符串类型的值;文件缺失/解析失败/超限都返回空。
func loadMinimalEnvSettings() map[string]string {
type minimalClaudeSettings struct {
Env map[string]string
Model string
}
// loadMinimalClaudeSettings 从 ~/.claude/settings.json 只提取安全的最小子集:
// - env: 只接受字符串类型的值
// - model: 只接受字符串类型的值
// 文件缺失/解析失败/超限都返回空。
func loadMinimalClaudeSettings() minimalClaudeSettings {
home, err := os.UserHomeDir()
if err != nil || home == "" {
return nil
return minimalClaudeSettings{}
}
settingPath := filepath.Join(home, ".claude", "settings.json")
info, err := os.Stat(settingPath)
if err != nil || info.Size() > maxClaudeSettingsBytes {
return nil
return minimalClaudeSettings{}
}
data, err := os.ReadFile(settingPath)
if err != nil {
return nil
return minimalClaudeSettings{}
}
var cfg struct {
Env map[string]any `json:"env"`
Env map[string]any `json:"env"`
Model any `json:"model"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return nil
return minimalClaudeSettings{}
}
out := minimalClaudeSettings{}
if model, ok := cfg.Model.(string); ok {
out.Model = strings.TrimSpace(model)
}
if len(cfg.Env) == 0 {
return nil
return out
}
env := make(map[string]string, len(cfg.Env))
@@ -75,9 +91,19 @@ func loadMinimalEnvSettings() map[string]string {
env[k] = s
}
if len(env) == 0 {
return out
}
out.Env = env
return out
}
// loadMinimalEnvSettings is kept for backwards tests; prefer loadMinimalClaudeSettings.
func loadMinimalEnvSettings() map[string]string {
settings := loadMinimalClaudeSettings()
if len(settings.Env) == 0 {
return nil
}
return env
return settings.Env
}
func buildClaudeArgs(cfg *Config, targetArg string) []string {
@@ -93,6 +119,10 @@ func buildClaudeArgs(cfg *Config, targetArg string) []string {
// This ensures a clean execution environment without CLAUDE.md or skills that would trigger codeagent
args = append(args, "--setting-sources", "")
if model := strings.TrimSpace(cfg.Model); model != "" {
args = append(args, "--model", model)
}
if cfg.Mode == "resume" {
if cfg.SessionID != "" {
// Claude CLI uses -r <session_id> for resume.
@@ -122,6 +152,10 @@ func buildGeminiArgs(cfg *Config, targetArg string) []string {
}
args := []string{"-o", "stream-json", "-y"}
if model := strings.TrimSpace(cfg.Model); model != "" {
args = append(args, "-m", model)
}
if cfg.Mode == "resume" {
if cfg.SessionID != "" {
args = append(args, "-r", cfg.SessionID)

View File

@@ -63,6 +63,42 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) {
})
}
func TestBackendBuildArgs_Model(t *testing.T) {
t.Run("claude includes --model when set", func(t *testing.T) {
backend := ClaudeBackend{}
cfg := &Config{Mode: "new", Model: "opus"}
got := backend.BuildArgs(cfg, "todo")
want := []string{"-p", "--setting-sources", "", "--model", "opus", "--output-format", "stream-json", "--verbose", "todo"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
})
t.Run("gemini includes -m when set", func(t *testing.T) {
backend := GeminiBackend{}
cfg := &Config{Mode: "new", Model: "gemini-3-pro-preview"}
got := backend.BuildArgs(cfg, "task")
want := []string{"-o", "stream-json", "-y", "-m", "gemini-3-pro-preview", "-p", "task"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
})
t.Run("codex includes --model when set", func(t *testing.T) {
const key = "CODEX_BYPASS_SANDBOX"
t.Cleanup(func() { os.Unsetenv(key) })
os.Unsetenv(key)
backend := CodexBackend{}
cfg := &Config{Mode: "new", WorkDir: "/tmp", Model: "o3"}
got := backend.BuildArgs(cfg, "task")
want := []string{"e", "--model", "o3", "--skip-git-repo-check", "-C", "/tmp", "--json", "task"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
})
}
func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) {
t.Run("gemini new mode defaults workdir", func(t *testing.T) {
backend := GeminiBackend{}

View File

@@ -15,6 +15,7 @@ type Config struct {
Task string
SessionID string
WorkDir string
Model string
ExplicitStdin bool
Timeout int
Backend string
@@ -36,6 +37,7 @@ type TaskSpec struct {
Dependencies []string `json:"dependencies,omitempty"`
SessionID string `json:"session_id,omitempty"`
Backend string `json:"backend,omitempty"`
Model string `json:"model,omitempty"`
Mode string `json:"-"`
UseStdin bool `json:"-"`
Context context.Context `json:"-"`
@@ -152,6 +154,8 @@ func parseParallelConfig(data []byte) (*ParallelConfig, error) {
task.Mode = "resume"
case "backend":
task.Backend = value
case "model":
task.Model = value
case "dependencies":
for _, dep := range strings.Split(value, ",") {
dep = strings.TrimSpace(dep)
@@ -198,6 +202,7 @@ func parseArgs() (*Config, error) {
}
backendName := defaultBackendName
model := ""
skipPermissions := envFlagEnabled("CODEAGENT_SKIP_PERMISSIONS")
filtered := make([]string, 0, len(args))
for i := 0; i < len(args); i++ {
@@ -220,6 +225,20 @@ func parseArgs() (*Config, error) {
case arg == "--skip-permissions", arg == "--dangerously-skip-permissions":
skipPermissions = true
continue
case arg == "--model":
if i+1 >= len(args) {
return nil, fmt.Errorf("--model flag requires a value")
}
model = args[i+1]
i++
continue
case strings.HasPrefix(arg, "--model="):
value := strings.TrimPrefix(arg, "--model=")
if value == "" {
return nil, fmt.Errorf("--model flag requires a value")
}
model = value
continue
case strings.HasPrefix(arg, "--skip-permissions="):
skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--skip-permissions="), skipPermissions)
continue
@@ -235,7 +254,7 @@ func parseArgs() (*Config, error) {
}
args = filtered
cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, SkipPermissions: skipPermissions}
cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, SkipPermissions: skipPermissions, Model: strings.TrimSpace(model)}
cfg.MaxParallelWorkers = resolveMaxParallelWorkers()
if args[0] == "resume" {

View File

@@ -744,6 +744,10 @@ func buildCodexArgs(cfg *Config, targetArg string) []string {
args = append(args, "--dangerously-bypass-approvals-and-sandbox")
}
if model := strings.TrimSpace(cfg.Model); model != "" {
args = append(args, "--model", model)
}
args = append(args, "--skip-git-repo-check")
if isResume {
@@ -788,6 +792,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
Task: taskSpec.Task,
SessionID: taskSpec.SessionID,
WorkDir: taskSpec.WorkDir,
Model: taskSpec.Model,
Backend: defaultBackendName,
}
@@ -816,6 +821,15 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
return result
}
var claudeEnv map[string]string
if cfg.Backend == "claude" {
settings := loadMinimalClaudeSettings()
claudeEnv = settings.Env
if cfg.Mode != "resume" && strings.TrimSpace(cfg.Model) == "" && settings.Model != "" {
cfg.Model = settings.Model
}
}
useStdin := taskSpec.UseStdin
targetArg := taskSpec.Task
if useStdin {
@@ -915,10 +929,8 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
cmd := newCommandRunner(ctx, commandName, codexArgs...)
if cfg.Backend == "claude" {
if env := loadMinimalEnvSettings(); len(env) > 0 {
cmd.SetEnv(env)
}
if cfg.Backend == "claude" && len(claudeEnv) > 0 {
cmd.SetEnv(claudeEnv)
}
// For backends that don't support -C flag (claude, gemini), set working directory via cmd.Dir

View File

@@ -178,6 +178,7 @@ func run() (exitCode int) {
if parallelIndex != -1 {
backendName := defaultBackendName
model := ""
fullOutput := false
var extras []string
@@ -202,13 +203,27 @@ func run() (exitCode int) {
return 1
}
backendName = value
case arg == "--model":
if i+1 >= len(args) {
fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value")
return 1
}
model = args[i+1]
i++
case strings.HasPrefix(arg, "--model="):
value := strings.TrimPrefix(arg, "--model=")
if value == "" {
fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value")
return 1
}
model = value
default:
extras = append(extras, arg)
}
}
if len(extras) > 0 {
fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend and --full-output are allowed.")
fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend, --model and --full-output are allowed.")
fmt.Fprintln(os.Stderr, "Usage examples:")
fmt.Fprintf(os.Stderr, " %s --parallel < tasks.txt\n", name)
fmt.Fprintf(os.Stderr, " echo '...' | %s --parallel\n", name)
@@ -237,10 +252,14 @@ func run() (exitCode int) {
}
cfg.GlobalBackend = backendName
model = strings.TrimSpace(model)
for i := range cfg.Tasks {
if strings.TrimSpace(cfg.Tasks[i].Backend) == "" {
cfg.Tasks[i].Backend = backendName
}
if strings.TrimSpace(cfg.Tasks[i].Model) == "" && model != "" {
cfg.Tasks[i].Model = model
}
}
timeoutSec := resolveTimeout()
@@ -409,6 +428,7 @@ func run() (exitCode int) {
WorkDir: cfg.WorkDir,
Mode: cfg.Mode,
SessionID: cfg.SessionID,
Model: cfg.Model,
UseStdin: useStdin,
}

View File

@@ -1139,6 +1139,65 @@ func TestBackendParseArgs_BackendFlag(t *testing.T) {
}
}
func TestBackendParseArgs_ModelFlag(t *testing.T) {
tests := []struct {
name string
args []string
want string
wantErr bool
}{
{
name: "model flag",
args: []string{"codeagent-wrapper", "--model", "opus", "task"},
want: "opus",
},
{
name: "model equals syntax",
args: []string{"codeagent-wrapper", "--model=opus", "task"},
want: "opus",
},
{
name: "model trimmed",
args: []string{"codeagent-wrapper", "--model", " opus ", "task"},
want: "opus",
},
{
name: "model with resume mode",
args: []string{"codeagent-wrapper", "--model", "sonnet", "resume", "sid", "task"},
want: "sonnet",
},
{
name: "missing model value",
args: []string{"codeagent-wrapper", "--model"},
wantErr: true,
},
{
name: "model equals missing value",
args: []string{"codeagent-wrapper", "--model=", "task"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Args = tt.args
cfg, err := parseArgs()
if tt.wantErr {
if err == nil {
t.Fatalf("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Model != tt.want {
t.Fatalf("Model = %q, want %q", cfg.Model, tt.want)
}
})
}
}
func TestBackendParseArgs_SkipPermissions(t *testing.T) {
const envKey = "CODEAGENT_SKIP_PERMISSIONS"
t.Cleanup(func() { os.Unsetenv(envKey) })
@@ -1276,6 +1335,26 @@ do something`
}
}
func TestParallelParseConfig_Model(t *testing.T) {
input := `---TASK---
id: task-1
model: opus
---CONTENT---
do something`
cfg, err := parseParallelConfig([]byte(input))
if err != nil {
t.Fatalf("parseParallelConfig() unexpected error: %v", err)
}
if len(cfg.Tasks) != 1 {
t.Fatalf("expected 1 task, got %d", len(cfg.Tasks))
}
task := cfg.Tasks[0]
if task.Model != "opus" {
t.Fatalf("model = %q, want opus", task.Model)
}
}
func TestParallelParseConfig_EmptySessionID(t *testing.T) {
input := `---TASK---
id: task-1
@@ -1358,6 +1437,120 @@ code with special chars: $var "quotes"`
}
}
func TestClaudeModel_DefaultsFromSettings(t *testing.T) {
defer resetTestHooks()
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
dir := filepath.Join(home, ".claude")
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("MkdirAll: %v", err)
}
settingsModel := "claude-opus-4-5-20250929"
path := filepath.Join(dir, "settings.json")
data := []byte(fmt.Sprintf(`{"model":%q,"env":{"FOO":"bar"}}`, settingsModel))
if err := os.WriteFile(path, data, 0o600); err != nil {
t.Fatalf("WriteFile: %v", err)
}
makeRunner := func(gotName *string, gotArgs *[]string, fake **fakeCmd) func(context.Context, string, ...string) commandRunner {
return func(ctx context.Context, name string, args ...string) commandRunner {
*gotName = name
*gotArgs = append([]string(nil), args...)
cmd := newFakeCmd(fakeCmdConfig{
PID: 123,
StdoutPlan: []fakeStdoutEvent{
{Data: "{\"type\":\"result\",\"session_id\":\"sid\",\"result\":\"ok\"}\n"},
},
})
*fake = cmd
return cmd
}
}
t.Run("new mode inherits model when unset", func(t *testing.T) {
var (
gotName string
gotArgs []string
fake *fakeCmd
)
origRunner := newCommandRunner
newCommandRunner = makeRunner(&gotName, &gotArgs, &fake)
t.Cleanup(func() { newCommandRunner = origRunner })
res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5)
if res.ExitCode != 0 || res.Message != "ok" {
t.Fatalf("unexpected result: %+v", res)
}
if gotName != "claude" {
t.Fatalf("command = %q, want claude", gotName)
}
found := false
for i := 0; i+1 < len(gotArgs); i++ {
if gotArgs[i] == "--model" && gotArgs[i+1] == settingsModel {
found = true
break
}
}
if !found {
t.Fatalf("expected --model %q in args, got %v", settingsModel, gotArgs)
}
if fake == nil || fake.env["FOO"] != "bar" {
t.Fatalf("expected env to include FOO=bar, got %v", fake.env)
}
})
t.Run("explicit model overrides settings", func(t *testing.T) {
var (
gotName string
gotArgs []string
fake *fakeCmd
)
origRunner := newCommandRunner
newCommandRunner = makeRunner(&gotName, &gotArgs, &fake)
t.Cleanup(func() { newCommandRunner = origRunner })
res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir, Model: "sonnet"}, ClaudeBackend{}, nil, false, true, 5)
if res.ExitCode != 0 || res.Message != "ok" {
t.Fatalf("unexpected result: %+v", res)
}
found := false
for i := 0; i+1 < len(gotArgs); i++ {
if gotArgs[i] == "--model" && gotArgs[i+1] == "sonnet" {
found = true
break
}
}
if !found {
t.Fatalf("expected --model sonnet in args, got %v", gotArgs)
}
})
t.Run("resume mode does not inherit model by default", func(t *testing.T) {
var (
gotName string
gotArgs []string
fake *fakeCmd
)
origRunner := newCommandRunner
newCommandRunner = makeRunner(&gotName, &gotArgs, &fake)
t.Cleanup(func() { newCommandRunner = origRunner })
res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "resume", SessionID: "sid-123", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5)
if res.ExitCode != 0 || res.Message != "ok" {
t.Fatalf("unexpected result: %+v", res)
}
for i := 0; i < len(gotArgs); i++ {
if gotArgs[i] == "--model" {
t.Fatalf("did not expect --model in resume args, got %v", gotArgs)
}
}
})
}
func TestRunShouldUseStdin(t *testing.T) {
tests := []struct {
name string
@@ -2947,6 +3140,50 @@ do two`)
}
}
func TestParallelModelPropagation(t *testing.T) {
defer resetTestHooks()
cleanupLogsFn = func() (CleanupStats, error) { return CleanupStats{}, nil }
orig := runCodexTaskFn
var mu sync.Mutex
seen := make(map[string]string)
runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult {
mu.Lock()
seen[task.ID] = task.Model
mu.Unlock()
return TaskResult{TaskID: task.ID, ExitCode: 0, Message: "ok"}
}
t.Cleanup(func() { runCodexTaskFn = orig })
stdinReader = strings.NewReader(`---TASK---
id: first
---CONTENT---
do one
---TASK---
id: second
model: opus
---CONTENT---
do two`)
os.Args = []string{"codeagent-wrapper", "--parallel", "--model", "sonnet"}
if code := run(); code != 0 {
t.Fatalf("run exit = %d, want 0", code)
}
mu.Lock()
firstModel, firstOK := seen["first"]
secondModel, secondOK := seen["second"]
mu.Unlock()
if !firstOK || firstModel != "sonnet" {
t.Fatalf("first model = %q (present=%v), want sonnet", firstModel, firstOK)
}
if !secondOK || secondModel != "opus" {
t.Fatalf("second model = %q (present=%v), want opus", secondModel, secondOK)
}
}
func TestParallelFlag(t *testing.T) {
oldArgs := os.Args
defer func() { os.Args = oldArgs }()