mirror of
https://github.com/cexll/myclaude.git
synced 2026-02-15 03:32:43 +08:00
fix: support model parameter for all backends, auto-inject from settings (#105)
- Add Model field to Config and TaskSpec for per-task model selection - Parse --model flag and model: metadata in parallel tasks - Auto-inject model from ~/.claude/settings.json for claude backend in new mode - Pass --model to claude CLI, -m to gemini CLI, --model to codex CLI - Preserve --setting-sources "" isolation while reading minimal safe subset - Add comprehensive tests for model parsing, propagation, and settings injection Fixes #105 Generated with SWE-Agent.ai Co-Authored-By: SWE-Agent.ai <noreply@swe-agent.ai>
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Backend defines the contract for invoking different AI CLI backends.
|
// Backend defines the contract for invoking different AI CLI backends.
|
||||||
@@ -37,33 +38,48 @@ func (ClaudeBackend) BuildArgs(cfg *Config, targetArg string) []string {
|
|||||||
|
|
||||||
const maxClaudeSettingsBytes = 1 << 20 // 1MB
|
const maxClaudeSettingsBytes = 1 << 20 // 1MB
|
||||||
|
|
||||||
// loadMinimalEnvSettings 从 ~/.claude/settings.json 只提取 env 配置。
|
type minimalClaudeSettings struct {
|
||||||
// 只接受字符串类型的值;文件缺失/解析失败/超限都返回空。
|
Env map[string]string
|
||||||
func loadMinimalEnvSettings() map[string]string {
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadMinimalClaudeSettings 从 ~/.claude/settings.json 只提取安全的最小子集:
|
||||||
|
// - env: 只接受字符串类型的值
|
||||||
|
// - model: 只接受字符串类型的值
|
||||||
|
// 文件缺失/解析失败/超限都返回空。
|
||||||
|
func loadMinimalClaudeSettings() minimalClaudeSettings {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil || home == "" {
|
if err != nil || home == "" {
|
||||||
return nil
|
return minimalClaudeSettings{}
|
||||||
}
|
}
|
||||||
|
|
||||||
settingPath := filepath.Join(home, ".claude", "settings.json")
|
settingPath := filepath.Join(home, ".claude", "settings.json")
|
||||||
info, err := os.Stat(settingPath)
|
info, err := os.Stat(settingPath)
|
||||||
if err != nil || info.Size() > maxClaudeSettingsBytes {
|
if err != nil || info.Size() > maxClaudeSettingsBytes {
|
||||||
return nil
|
return minimalClaudeSettings{}
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := os.ReadFile(settingPath)
|
data, err := os.ReadFile(settingPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return minimalClaudeSettings{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg struct {
|
var cfg struct {
|
||||||
Env map[string]any `json:"env"`
|
Env map[string]any `json:"env"`
|
||||||
|
Model any `json:"model"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
return nil
|
return minimalClaudeSettings{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
out := minimalClaudeSettings{}
|
||||||
|
|
||||||
|
if model, ok := cfg.Model.(string); ok {
|
||||||
|
out.Model = strings.TrimSpace(model)
|
||||||
|
}
|
||||||
|
|
||||||
if len(cfg.Env) == 0 {
|
if len(cfg.Env) == 0 {
|
||||||
return nil
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
env := make(map[string]string, len(cfg.Env))
|
env := make(map[string]string, len(cfg.Env))
|
||||||
@@ -75,9 +91,19 @@ func loadMinimalEnvSettings() map[string]string {
|
|||||||
env[k] = s
|
env[k] = s
|
||||||
}
|
}
|
||||||
if len(env) == 0 {
|
if len(env) == 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
out.Env = env
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadMinimalEnvSettings is kept for backwards tests; prefer loadMinimalClaudeSettings.
|
||||||
|
func loadMinimalEnvSettings() map[string]string {
|
||||||
|
settings := loadMinimalClaudeSettings()
|
||||||
|
if len(settings.Env) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return env
|
return settings.Env
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildClaudeArgs(cfg *Config, targetArg string) []string {
|
func buildClaudeArgs(cfg *Config, targetArg string) []string {
|
||||||
@@ -93,6 +119,10 @@ func buildClaudeArgs(cfg *Config, targetArg string) []string {
|
|||||||
// This ensures a clean execution environment without CLAUDE.md or skills that would trigger codeagent
|
// This ensures a clean execution environment without CLAUDE.md or skills that would trigger codeagent
|
||||||
args = append(args, "--setting-sources", "")
|
args = append(args, "--setting-sources", "")
|
||||||
|
|
||||||
|
if model := strings.TrimSpace(cfg.Model); model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.Mode == "resume" {
|
if cfg.Mode == "resume" {
|
||||||
if cfg.SessionID != "" {
|
if cfg.SessionID != "" {
|
||||||
// Claude CLI uses -r <session_id> for resume.
|
// Claude CLI uses -r <session_id> for resume.
|
||||||
@@ -122,6 +152,10 @@ func buildGeminiArgs(cfg *Config, targetArg string) []string {
|
|||||||
}
|
}
|
||||||
args := []string{"-o", "stream-json", "-y"}
|
args := []string{"-o", "stream-json", "-y"}
|
||||||
|
|
||||||
|
if model := strings.TrimSpace(cfg.Model); model != "" {
|
||||||
|
args = append(args, "-m", model)
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.Mode == "resume" {
|
if cfg.Mode == "resume" {
|
||||||
if cfg.SessionID != "" {
|
if cfg.SessionID != "" {
|
||||||
args = append(args, "-r", cfg.SessionID)
|
args = append(args, "-r", cfg.SessionID)
|
||||||
|
|||||||
@@ -63,6 +63,42 @@ func TestClaudeBuildArgs_ModesAndPermissions(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackendBuildArgs_Model(t *testing.T) {
|
||||||
|
t.Run("claude includes --model when set", func(t *testing.T) {
|
||||||
|
backend := ClaudeBackend{}
|
||||||
|
cfg := &Config{Mode: "new", Model: "opus"}
|
||||||
|
got := backend.BuildArgs(cfg, "todo")
|
||||||
|
want := []string{"-p", "--setting-sources", "", "--model", "opus", "--output-format", "stream-json", "--verbose", "todo"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("gemini includes -m when set", func(t *testing.T) {
|
||||||
|
backend := GeminiBackend{}
|
||||||
|
cfg := &Config{Mode: "new", Model: "gemini-3-pro-preview"}
|
||||||
|
got := backend.BuildArgs(cfg, "task")
|
||||||
|
want := []string{"-o", "stream-json", "-y", "-m", "gemini-3-pro-preview", "-p", "task"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("codex includes --model when set", func(t *testing.T) {
|
||||||
|
const key = "CODEX_BYPASS_SANDBOX"
|
||||||
|
t.Cleanup(func() { os.Unsetenv(key) })
|
||||||
|
os.Unsetenv(key)
|
||||||
|
|
||||||
|
backend := CodexBackend{}
|
||||||
|
cfg := &Config{Mode: "new", WorkDir: "/tmp", Model: "o3"}
|
||||||
|
got := backend.BuildArgs(cfg, "task")
|
||||||
|
want := []string{"e", "--model", "o3", "--skip-git-repo-check", "-C", "/tmp", "--json", "task"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("got %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) {
|
func TestClaudeBuildArgs_GeminiAndCodexModes(t *testing.T) {
|
||||||
t.Run("gemini new mode defaults workdir", func(t *testing.T) {
|
t.Run("gemini new mode defaults workdir", func(t *testing.T) {
|
||||||
backend := GeminiBackend{}
|
backend := GeminiBackend{}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ type Config struct {
|
|||||||
Task string
|
Task string
|
||||||
SessionID string
|
SessionID string
|
||||||
WorkDir string
|
WorkDir string
|
||||||
|
Model string
|
||||||
ExplicitStdin bool
|
ExplicitStdin bool
|
||||||
Timeout int
|
Timeout int
|
||||||
Backend string
|
Backend string
|
||||||
@@ -36,6 +37,7 @@ type TaskSpec struct {
|
|||||||
Dependencies []string `json:"dependencies,omitempty"`
|
Dependencies []string `json:"dependencies,omitempty"`
|
||||||
SessionID string `json:"session_id,omitempty"`
|
SessionID string `json:"session_id,omitempty"`
|
||||||
Backend string `json:"backend,omitempty"`
|
Backend string `json:"backend,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
Mode string `json:"-"`
|
Mode string `json:"-"`
|
||||||
UseStdin bool `json:"-"`
|
UseStdin bool `json:"-"`
|
||||||
Context context.Context `json:"-"`
|
Context context.Context `json:"-"`
|
||||||
@@ -152,6 +154,8 @@ func parseParallelConfig(data []byte) (*ParallelConfig, error) {
|
|||||||
task.Mode = "resume"
|
task.Mode = "resume"
|
||||||
case "backend":
|
case "backend":
|
||||||
task.Backend = value
|
task.Backend = value
|
||||||
|
case "model":
|
||||||
|
task.Model = value
|
||||||
case "dependencies":
|
case "dependencies":
|
||||||
for _, dep := range strings.Split(value, ",") {
|
for _, dep := range strings.Split(value, ",") {
|
||||||
dep = strings.TrimSpace(dep)
|
dep = strings.TrimSpace(dep)
|
||||||
@@ -198,6 +202,7 @@ func parseArgs() (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
backendName := defaultBackendName
|
backendName := defaultBackendName
|
||||||
|
model := ""
|
||||||
skipPermissions := envFlagEnabled("CODEAGENT_SKIP_PERMISSIONS")
|
skipPermissions := envFlagEnabled("CODEAGENT_SKIP_PERMISSIONS")
|
||||||
filtered := make([]string, 0, len(args))
|
filtered := make([]string, 0, len(args))
|
||||||
for i := 0; i < len(args); i++ {
|
for i := 0; i < len(args); i++ {
|
||||||
@@ -220,6 +225,20 @@ func parseArgs() (*Config, error) {
|
|||||||
case arg == "--skip-permissions", arg == "--dangerously-skip-permissions":
|
case arg == "--skip-permissions", arg == "--dangerously-skip-permissions":
|
||||||
skipPermissions = true
|
skipPermissions = true
|
||||||
continue
|
continue
|
||||||
|
case arg == "--model":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
return nil, fmt.Errorf("--model flag requires a value")
|
||||||
|
}
|
||||||
|
model = args[i+1]
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
case strings.HasPrefix(arg, "--model="):
|
||||||
|
value := strings.TrimPrefix(arg, "--model=")
|
||||||
|
if value == "" {
|
||||||
|
return nil, fmt.Errorf("--model flag requires a value")
|
||||||
|
}
|
||||||
|
model = value
|
||||||
|
continue
|
||||||
case strings.HasPrefix(arg, "--skip-permissions="):
|
case strings.HasPrefix(arg, "--skip-permissions="):
|
||||||
skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--skip-permissions="), skipPermissions)
|
skipPermissions = parseBoolFlag(strings.TrimPrefix(arg, "--skip-permissions="), skipPermissions)
|
||||||
continue
|
continue
|
||||||
@@ -235,7 +254,7 @@ func parseArgs() (*Config, error) {
|
|||||||
}
|
}
|
||||||
args = filtered
|
args = filtered
|
||||||
|
|
||||||
cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, SkipPermissions: skipPermissions}
|
cfg := &Config{WorkDir: defaultWorkdir, Backend: backendName, SkipPermissions: skipPermissions, Model: strings.TrimSpace(model)}
|
||||||
cfg.MaxParallelWorkers = resolveMaxParallelWorkers()
|
cfg.MaxParallelWorkers = resolveMaxParallelWorkers()
|
||||||
|
|
||||||
if args[0] == "resume" {
|
if args[0] == "resume" {
|
||||||
|
|||||||
@@ -744,6 +744,10 @@ func buildCodexArgs(cfg *Config, targetArg string) []string {
|
|||||||
args = append(args, "--dangerously-bypass-approvals-and-sandbox")
|
args = append(args, "--dangerously-bypass-approvals-and-sandbox")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if model := strings.TrimSpace(cfg.Model); model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
|
||||||
args = append(args, "--skip-git-repo-check")
|
args = append(args, "--skip-git-repo-check")
|
||||||
|
|
||||||
if isResume {
|
if isResume {
|
||||||
@@ -788,6 +792,7 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
Task: taskSpec.Task,
|
Task: taskSpec.Task,
|
||||||
SessionID: taskSpec.SessionID,
|
SessionID: taskSpec.SessionID,
|
||||||
WorkDir: taskSpec.WorkDir,
|
WorkDir: taskSpec.WorkDir,
|
||||||
|
Model: taskSpec.Model,
|
||||||
Backend: defaultBackendName,
|
Backend: defaultBackendName,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -816,6 +821,15 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var claudeEnv map[string]string
|
||||||
|
if cfg.Backend == "claude" {
|
||||||
|
settings := loadMinimalClaudeSettings()
|
||||||
|
claudeEnv = settings.Env
|
||||||
|
if cfg.Mode != "resume" && strings.TrimSpace(cfg.Model) == "" && settings.Model != "" {
|
||||||
|
cfg.Model = settings.Model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
useStdin := taskSpec.UseStdin
|
useStdin := taskSpec.UseStdin
|
||||||
targetArg := taskSpec.Task
|
targetArg := taskSpec.Task
|
||||||
if useStdin {
|
if useStdin {
|
||||||
@@ -915,10 +929,8 @@ func runCodexTaskWithContext(parentCtx context.Context, taskSpec TaskSpec, backe
|
|||||||
|
|
||||||
cmd := newCommandRunner(ctx, commandName, codexArgs...)
|
cmd := newCommandRunner(ctx, commandName, codexArgs...)
|
||||||
|
|
||||||
if cfg.Backend == "claude" {
|
if cfg.Backend == "claude" && len(claudeEnv) > 0 {
|
||||||
if env := loadMinimalEnvSettings(); len(env) > 0 {
|
cmd.SetEnv(claudeEnv)
|
||||||
cmd.SetEnv(env)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For backends that don't support -C flag (claude, gemini), set working directory via cmd.Dir
|
// For backends that don't support -C flag (claude, gemini), set working directory via cmd.Dir
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ func run() (exitCode int) {
|
|||||||
|
|
||||||
if parallelIndex != -1 {
|
if parallelIndex != -1 {
|
||||||
backendName := defaultBackendName
|
backendName := defaultBackendName
|
||||||
|
model := ""
|
||||||
fullOutput := false
|
fullOutput := false
|
||||||
var extras []string
|
var extras []string
|
||||||
|
|
||||||
@@ -202,13 +203,27 @@ func run() (exitCode int) {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
backendName = value
|
backendName = value
|
||||||
|
case arg == "--model":
|
||||||
|
if i+1 >= len(args) {
|
||||||
|
fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value")
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
model = args[i+1]
|
||||||
|
i++
|
||||||
|
case strings.HasPrefix(arg, "--model="):
|
||||||
|
value := strings.TrimPrefix(arg, "--model=")
|
||||||
|
if value == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "ERROR: --model flag requires a value")
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
model = value
|
||||||
default:
|
default:
|
||||||
extras = append(extras, arg)
|
extras = append(extras, arg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(extras) > 0 {
|
if len(extras) > 0 {
|
||||||
fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend and --full-output are allowed.")
|
fmt.Fprintln(os.Stderr, "ERROR: --parallel reads its task configuration from stdin; only --backend, --model and --full-output are allowed.")
|
||||||
fmt.Fprintln(os.Stderr, "Usage examples:")
|
fmt.Fprintln(os.Stderr, "Usage examples:")
|
||||||
fmt.Fprintf(os.Stderr, " %s --parallel < tasks.txt\n", name)
|
fmt.Fprintf(os.Stderr, " %s --parallel < tasks.txt\n", name)
|
||||||
fmt.Fprintf(os.Stderr, " echo '...' | %s --parallel\n", name)
|
fmt.Fprintf(os.Stderr, " echo '...' | %s --parallel\n", name)
|
||||||
@@ -237,10 +252,14 @@ func run() (exitCode int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg.GlobalBackend = backendName
|
cfg.GlobalBackend = backendName
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
for i := range cfg.Tasks {
|
for i := range cfg.Tasks {
|
||||||
if strings.TrimSpace(cfg.Tasks[i].Backend) == "" {
|
if strings.TrimSpace(cfg.Tasks[i].Backend) == "" {
|
||||||
cfg.Tasks[i].Backend = backendName
|
cfg.Tasks[i].Backend = backendName
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(cfg.Tasks[i].Model) == "" && model != "" {
|
||||||
|
cfg.Tasks[i].Model = model
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
timeoutSec := resolveTimeout()
|
timeoutSec := resolveTimeout()
|
||||||
@@ -409,6 +428,7 @@ func run() (exitCode int) {
|
|||||||
WorkDir: cfg.WorkDir,
|
WorkDir: cfg.WorkDir,
|
||||||
Mode: cfg.Mode,
|
Mode: cfg.Mode,
|
||||||
SessionID: cfg.SessionID,
|
SessionID: cfg.SessionID,
|
||||||
|
Model: cfg.Model,
|
||||||
UseStdin: useStdin,
|
UseStdin: useStdin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1139,6 +1139,65 @@ func TestBackendParseArgs_BackendFlag(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackendParseArgs_ModelFlag(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model flag",
|
||||||
|
args: []string{"codeagent-wrapper", "--model", "opus", "task"},
|
||||||
|
want: "opus",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model equals syntax",
|
||||||
|
args: []string{"codeagent-wrapper", "--model=opus", "task"},
|
||||||
|
want: "opus",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model trimmed",
|
||||||
|
args: []string{"codeagent-wrapper", "--model", " opus ", "task"},
|
||||||
|
want: "opus",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with resume mode",
|
||||||
|
args: []string{"codeagent-wrapper", "--model", "sonnet", "resume", "sid", "task"},
|
||||||
|
want: "sonnet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing model value",
|
||||||
|
args: []string{"codeagent-wrapper", "--model"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model equals missing value",
|
||||||
|
args: []string{"codeagent-wrapper", "--model=", "task"},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
os.Args = tt.args
|
||||||
|
cfg, err := parseArgs()
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Model != tt.want {
|
||||||
|
t.Fatalf("Model = %q, want %q", cfg.Model, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBackendParseArgs_SkipPermissions(t *testing.T) {
|
func TestBackendParseArgs_SkipPermissions(t *testing.T) {
|
||||||
const envKey = "CODEAGENT_SKIP_PERMISSIONS"
|
const envKey = "CODEAGENT_SKIP_PERMISSIONS"
|
||||||
t.Cleanup(func() { os.Unsetenv(envKey) })
|
t.Cleanup(func() { os.Unsetenv(envKey) })
|
||||||
@@ -1276,6 +1335,26 @@ do something`
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParallelParseConfig_Model(t *testing.T) {
|
||||||
|
input := `---TASK---
|
||||||
|
id: task-1
|
||||||
|
model: opus
|
||||||
|
---CONTENT---
|
||||||
|
do something`
|
||||||
|
|
||||||
|
cfg, err := parseParallelConfig([]byte(input))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseParallelConfig() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(cfg.Tasks) != 1 {
|
||||||
|
t.Fatalf("expected 1 task, got %d", len(cfg.Tasks))
|
||||||
|
}
|
||||||
|
task := cfg.Tasks[0]
|
||||||
|
if task.Model != "opus" {
|
||||||
|
t.Fatalf("model = %q, want opus", task.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParallelParseConfig_EmptySessionID(t *testing.T) {
|
func TestParallelParseConfig_EmptySessionID(t *testing.T) {
|
||||||
input := `---TASK---
|
input := `---TASK---
|
||||||
id: task-1
|
id: task-1
|
||||||
@@ -1358,6 +1437,120 @@ code with special chars: $var "quotes"`
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClaudeModel_DefaultsFromSettings(t *testing.T) {
|
||||||
|
defer resetTestHooks()
|
||||||
|
|
||||||
|
home := t.TempDir()
|
||||||
|
t.Setenv("HOME", home)
|
||||||
|
t.Setenv("USERPROFILE", home)
|
||||||
|
|
||||||
|
dir := filepath.Join(home, ".claude")
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
t.Fatalf("MkdirAll: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
settingsModel := "claude-opus-4-5-20250929"
|
||||||
|
path := filepath.Join(dir, "settings.json")
|
||||||
|
data := []byte(fmt.Sprintf(`{"model":%q,"env":{"FOO":"bar"}}`, settingsModel))
|
||||||
|
if err := os.WriteFile(path, data, 0o600); err != nil {
|
||||||
|
t.Fatalf("WriteFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
makeRunner := func(gotName *string, gotArgs *[]string, fake **fakeCmd) func(context.Context, string, ...string) commandRunner {
|
||||||
|
return func(ctx context.Context, name string, args ...string) commandRunner {
|
||||||
|
*gotName = name
|
||||||
|
*gotArgs = append([]string(nil), args...)
|
||||||
|
cmd := newFakeCmd(fakeCmdConfig{
|
||||||
|
PID: 123,
|
||||||
|
StdoutPlan: []fakeStdoutEvent{
|
||||||
|
{Data: "{\"type\":\"result\",\"session_id\":\"sid\",\"result\":\"ok\"}\n"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
*fake = cmd
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("new mode inherits model when unset", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
gotName string
|
||||||
|
gotArgs []string
|
||||||
|
fake *fakeCmd
|
||||||
|
)
|
||||||
|
origRunner := newCommandRunner
|
||||||
|
newCommandRunner = makeRunner(&gotName, &gotArgs, &fake)
|
||||||
|
t.Cleanup(func() { newCommandRunner = origRunner })
|
||||||
|
|
||||||
|
res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5)
|
||||||
|
if res.ExitCode != 0 || res.Message != "ok" {
|
||||||
|
t.Fatalf("unexpected result: %+v", res)
|
||||||
|
}
|
||||||
|
if gotName != "claude" {
|
||||||
|
t.Fatalf("command = %q, want claude", gotName)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for i := 0; i+1 < len(gotArgs); i++ {
|
||||||
|
if gotArgs[i] == "--model" && gotArgs[i+1] == settingsModel {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected --model %q in args, got %v", settingsModel, gotArgs)
|
||||||
|
}
|
||||||
|
if fake == nil || fake.env["FOO"] != "bar" {
|
||||||
|
t.Fatalf("expected env to include FOO=bar, got %v", fake.env)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("explicit model overrides settings", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
gotName string
|
||||||
|
gotArgs []string
|
||||||
|
fake *fakeCmd
|
||||||
|
)
|
||||||
|
origRunner := newCommandRunner
|
||||||
|
newCommandRunner = makeRunner(&gotName, &gotArgs, &fake)
|
||||||
|
t.Cleanup(func() { newCommandRunner = origRunner })
|
||||||
|
|
||||||
|
res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "new", WorkDir: defaultWorkdir, Model: "sonnet"}, ClaudeBackend{}, nil, false, true, 5)
|
||||||
|
if res.ExitCode != 0 || res.Message != "ok" {
|
||||||
|
t.Fatalf("unexpected result: %+v", res)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for i := 0; i+1 < len(gotArgs); i++ {
|
||||||
|
if gotArgs[i] == "--model" && gotArgs[i+1] == "sonnet" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected --model sonnet in args, got %v", gotArgs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("resume mode does not inherit model by default", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
gotName string
|
||||||
|
gotArgs []string
|
||||||
|
fake *fakeCmd
|
||||||
|
)
|
||||||
|
origRunner := newCommandRunner
|
||||||
|
newCommandRunner = makeRunner(&gotName, &gotArgs, &fake)
|
||||||
|
t.Cleanup(func() { newCommandRunner = origRunner })
|
||||||
|
|
||||||
|
res := runCodexTaskWithContext(context.Background(), TaskSpec{Task: "hi", Mode: "resume", SessionID: "sid-123", WorkDir: defaultWorkdir}, ClaudeBackend{}, nil, false, true, 5)
|
||||||
|
if res.ExitCode != 0 || res.Message != "ok" {
|
||||||
|
t.Fatalf("unexpected result: %+v", res)
|
||||||
|
}
|
||||||
|
for i := 0; i < len(gotArgs); i++ {
|
||||||
|
if gotArgs[i] == "--model" {
|
||||||
|
t.Fatalf("did not expect --model in resume args, got %v", gotArgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRunShouldUseStdin(t *testing.T) {
|
func TestRunShouldUseStdin(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -2947,6 +3140,50 @@ do two`)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParallelModelPropagation(t *testing.T) {
|
||||||
|
defer resetTestHooks()
|
||||||
|
cleanupLogsFn = func() (CleanupStats, error) { return CleanupStats{}, nil }
|
||||||
|
|
||||||
|
orig := runCodexTaskFn
|
||||||
|
var mu sync.Mutex
|
||||||
|
seen := make(map[string]string)
|
||||||
|
runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult {
|
||||||
|
mu.Lock()
|
||||||
|
seen[task.ID] = task.Model
|
||||||
|
mu.Unlock()
|
||||||
|
return TaskResult{TaskID: task.ID, ExitCode: 0, Message: "ok"}
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { runCodexTaskFn = orig })
|
||||||
|
|
||||||
|
stdinReader = strings.NewReader(`---TASK---
|
||||||
|
id: first
|
||||||
|
---CONTENT---
|
||||||
|
do one
|
||||||
|
|
||||||
|
---TASK---
|
||||||
|
id: second
|
||||||
|
model: opus
|
||||||
|
---CONTENT---
|
||||||
|
do two`)
|
||||||
|
os.Args = []string{"codeagent-wrapper", "--parallel", "--model", "sonnet"}
|
||||||
|
|
||||||
|
if code := run(); code != 0 {
|
||||||
|
t.Fatalf("run exit = %d, want 0", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
firstModel, firstOK := seen["first"]
|
||||||
|
secondModel, secondOK := seen["second"]
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
if !firstOK || firstModel != "sonnet" {
|
||||||
|
t.Fatalf("first model = %q (present=%v), want sonnet", firstModel, firstOK)
|
||||||
|
}
|
||||||
|
if !secondOK || secondModel != "opus" {
|
||||||
|
t.Fatalf("second model = %q (present=%v), want opus", secondModel, secondOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParallelFlag(t *testing.T) {
|
func TestParallelFlag(t *testing.T) {
|
||||||
oldArgs := os.Args
|
oldArgs := os.Args
|
||||||
defer func() { os.Args = oldArgs }()
|
defer func() { os.Args = oldArgs }()
|
||||||
|
|||||||
Reference in New Issue
Block a user