Compare commits

..

17 Commits

Author SHA1 Message Date
catlog22
88ff109ac4 chore: bump version to 6.3.48 2026-01-24 15:06:36 +08:00
catlog22
261196a804 docs: update CCW CLI commands with recommended commands and usage examples 2026-01-24 15:05:37 +08:00
catlog22
ea6cb8440f chore: bump version to 6.3.47
- Update ccw-coordinator.md with clarified CLI execution format
- Command-first prompt structure: /workflow:<command> -y <parameters>
- Simplified documentation with universal prompt template
- Clarify that -y is a prompt parameter, not a ccw cli parameter
2026-01-24 14:52:09 +08:00
catlog22
bf896342f4 refactor: adjust prompt structure for command execution clarity 2026-01-24 14:51:05 +08:00
catlog22
f2b0a5bbc9 Refactor code structure and remove redundant changes 2026-01-24 14:47:47 +08:00
catlog22
cf5fecd66d fix(codex-lens): resolve installation issues from frontend
- Add missing README.md file required by setuptools
- Fix deprecated license format in pyproject.toml (use SPDX string instead of TOML table)
- Add MIT LICENSE file for proper packaging
- Verified successful local installation and import

Fixes permission denied error during npm-based installation on macOS
2026-01-24 14:43:39 +08:00
catlog22
86d469ccc9 build: exclude test files from TypeScript compilation 2026-01-24 14:35:05 +08:00
catlog22
357d3524f5 chore: bump version to 6.3.46 2026-01-24 14:31:29 +08:00
catlog22
4334162ddf refactor: remove unused command definitions from ccw-coordinator 2026-01-24 14:29:51 +08:00
catlog22
2dcd1637f0 refactor: enhance documentation on Minimum Execution Units and command grouping in CCW 2026-01-24 14:27:58 +08:00
catlog22
38e1cdc737 chore(release): publish 6.3.45
## Features

- New `ccw` command: Main process workflow orchestrator with auto intent-based workflow selection
- New CommandRegistry for dynamic command discovery and metadata management

## Improvements

- Optimize ccw-coordinator: Serial blocking execution model with hook-based continuation
- Refactor execution flow: Stop after CLI launch, wait for hook callbacks (no polling)
- Add task_id tracking and state.json checkpoints for resumable execution
- Consolidate documentation: Reduce report verbosity while maintaining all core information

## Documentation

- Add Execution Model comparison (main process vs external CLI)
- Add State Management section with TodoWrite tracking examples
- Update Type Comparison table highlighting ccw vs ccw-coordinator differences
- Simplify code examples with inline comments

## Changes Summary

- ccw-coordinator.md: +272/-26 (serial blocking), -143 docs (consolidation)
- ccw.md: +121/-352 (state management, execution model)
- Rename: CCW-COORDINATOR.md → ccw-coordinator.md (lowercase)
2026-01-24 14:09:52 +08:00
catlog22
097a7346b9 refactor: optimize ccw.md with streamlined documentation and state management
- Add Execution Model section (Synchronous vs Async blocking comparison)
- Add State Management section (TodoWrite-based tracking)
- Simplify Phase 1-5 code (remove verbose comments, consolidate logic)
- Consolidate Pipeline Examples into table format (5 examples → 1 table)
- Update Type Comparison table (highlight ccw vs ccw-coordinator differences)
- Maintain all core information (no content loss)

Changes:
- -352 lines (verbose explanations, redundant code)
- +121 lines (consolidated content, new sections)
- net: -231 lines (35% reduction: 665→433 lines)

Key additions:
- Execution Model flow diagram
- State Management with TodoWrite example
- Type Comparison: Synchronous (main) vs Async (external CLI)
2026-01-24 14:06:31 +08:00
catlog22
9df8063fbd refactor: reduce documentation report, consolidate overlapping content
- Eliminate redundant Stop-Action explanations (moved to CLI Execution Model)
- Remove verbose hook/error handling examples (keep in code only)
- Consolidate 5-step CLI example into 1-line pattern
- Simplify handleCliCompletion function comments
- Streamline executor loop exit notes
- Maintain all core information (no content loss)
- Reduce report from ~1000 lines to ~900 lines

Changes:
- -143 lines (old verbose explanations)
- +21 lines (consolidated content)
- net: -122 lines
2026-01-24 14:00:34 +08:00
catlog22
d00f0bc7ca refactor: improve CCW orchestrator with serial blocking execution and hook-based continuation
- Rename file to lowercase: CCW-COORDINATOR.md → ccw-coordinator.md
- Replace polling waitForTaskCompletion with stop-action blocking model
- CLI commands execute in background with immediate stop (no polling)
- Hook callbacks (handleCliCompletion) trigger continuation to next command
- Add task_id and completed_at fields to execution_results
- Maintain state checkpoint after each command launch
- Add status flow documentation (running → waiting → completed)
- Include CLI invocation example with hook configuration
- Separate concerns: orchestrator launches, hooks handle callbacks
- Support serial execution: one command at a time with break after launch
2026-01-24 13:57:08 +08:00
catlog22
24efef7f17 feat: Add main workflow orchestrator (ccw) with intent analysis and command execution
- Implemented the ccw command as a main workflow orchestrator.
- Added a 5-phase workflow including intent analysis, requirement clarification, workflow selection, user confirmation, and command execution.
- Developed functions for analyzing user input, selecting workflows, and executing command chains.
- Integrated TODO tracking for command execution progress.
- Created comprehensive tests for the CommandRegistry, covering YAML parsing, command retrieval, and error handling.
2026-01-24 13:43:47 +08:00
catlog22
44b8269a74 feat: add CommandRegistry for command management and direct imports 2026-01-24 13:29:50 +08:00
catlog22
dd51837bbc Enhance CCW Coordinator: Refactor command execution flow, improve prompt generation, and update documentation
- Refactored the command execution process to support dynamic command chaining and intelligent prompt generation.
- Updated the architecture overview to reflect changes in the orchestrator and command execution logic.
- Improved the prompt generation strategy to directly include complete command calls, enhancing clarity and usability.
- Added detailed examples and templates for command prompts in the documentation.
- Enhanced error handling and user decision-making during command execution failures.
- Introduced logging for command execution details and state updates for better traceability.
- Updated specifications and README files to align with the new command execution and prompt generation logic.
2026-01-24 12:44:40 +08:00
142 changed files with 45679 additions and 3196 deletions

View File

@@ -0,0 +1,948 @@
---
name: ccw-coordinator
description: Command orchestration tool - analyze requirements, recommend chain, execute sequentially with state persistence
argument-hint: "[task description]"
allowed-tools: Task(*), AskUserQuestion(*), Read(*), Write(*), Bash(*), Glob(*), Grep(*)
---
# CCW Coordinator Command
Interactive orchestration tool: analyze task → discover commands → recommend chain → execute sequentially → track state.
**Execution Model**: Pseudocode guidance. Claude intelligently executes each phase based on context.
## Core Concept: Minimum Execution Units (最小执行单元)
### What is a Minimum Execution Unit?
**Definition**: A set of commands that must execute together as an atomic group to achieve a meaningful workflow milestone. Splitting these commands breaks the logical flow and creates incomplete states.
**Why This Matters**:
- **Prevents Incomplete States**: Avoid stopping after task generation without execution
- **User Experience**: User gets complete results, not intermediate artifacts requiring manual follow-up
- **Workflow Integrity**: Maintains logical coherence of multi-step operations
### Minimum Execution Units
**Planning + Execution Units** (规划+执行单元):
| Unit Name | Commands | Purpose | Output |
|-----------|----------|---------|--------|
| **Quick Implementation** | lite-plan → lite-execute | Lightweight plan and immediate execution | Working code |
| **Multi-CLI Planning** | multi-cli-plan → lite-execute | Multi-perspective analysis and execution | Working code |
| **Bug Fix** | lite-fix → lite-execute | Quick bug diagnosis and fix execution | Fixed code |
| **Full Planning + Execution** | plan → execute | Detailed planning and execution | Working code |
| **Verified Planning + Execution** | plan → plan-verify → execute | Planning with verification and execution | Working code |
| **Replanning + Execution** | replan → execute | Update plan and execute changes | Working code |
| **TDD Planning + Execution** | tdd-plan → execute | Test-driven development planning and execution | Working code |
| **Test Generation + Execution** | test-gen → execute | Generate test suite and execute | Generated tests |
**Testing Units** (测试单元):
| Unit Name | Commands | Purpose | Output |
|-----------|----------|---------|--------|
| **Test Validation** | test-fix-gen → test-cycle-execute | Generate test tasks and execute test-fix cycle | Tests passed |
**Review Units** (审查单元):
| Unit Name | Commands | Purpose | Output |
|-----------|----------|---------|--------|
| **Code Review (Session)** | review-session-cycle → review-fix | Complete review cycle and apply fixes | Fixed code |
| **Code Review (Module)** | review-module-cycle → review-fix | Module review cycle and apply fixes | Fixed code |
### Command-to-Unit Mapping (命令与最小单元的映射)
| Command | Can Precede | Atomic Units |
|---------|-----------|--------------|
| lite-plan | lite-execute | Quick Implementation |
| multi-cli-plan | lite-execute | Multi-CLI Planning |
| lite-fix | lite-execute | Bug Fix |
| plan | plan-verify, execute | Full Planning + Execution, Verified Planning + Execution |
| plan-verify | execute | Verified Planning + Execution |
| replan | execute | Replanning + Execution |
| test-gen | execute | Test Generation + Execution |
| tdd-plan | execute | TDD Planning + Execution |
| review-session-cycle | review-fix | Code Review (Session) |
| review-module-cycle | review-fix | Code Review (Module) |
| test-fix-gen | test-cycle-execute | Test Validation |
### Atomic Group Rules
1. **Never Split Units**: Coordinator must recommend complete units, not partial chains
2. **Multi-Unit Participation**: Some commands can participate in multiple units (e.g., plan → execute or plan → plan-verify → execute)
3. **User Override**: User can explicitly request partial execution (advanced mode)
4. **Visualization**: Pipeline view shows unit boundaries with `【 】` markers
5. **Validation**: Before execution, verify all unit commands are included
**Example Pipeline with Units**:
```
需求 → 【lite-plan → lite-execute】→ 代码 → 【test-fix-gen → test-cycle-execute】→ 测试通过
└──── Quick Implementation ────┘ └────── Test Validation ──────┘
```
## 3-Phase Workflow
### Phase 1: Analyze Requirements
Parse task to extract: goal, scope, constraints, complexity, and task type.
```javascript
function analyzeRequirements(taskDescription) {
return {
goal: extractMainGoal(taskDescription), // e.g., "Implement user registration"
scope: extractScope(taskDescription), // e.g., ["auth", "user_management"]
constraints: extractConstraints(taskDescription), // e.g., ["no breaking changes"]
complexity: determineComplexity(taskDescription), // 'simple' | 'medium' | 'complex'
task_type: detectTaskType(taskDescription) // See task type patterns below
};
}
// Task Type Detection Patterns
function detectTaskType(text) {
// Priority order (first match wins)
if (/fix|bug|error|crash|fail|debug|diagnose/.test(text)) return 'bugfix';
if (/tdd|test-driven|先写测试|test first/.test(text)) return 'tdd';
if (/测试失败|test fail|fix test|failing test/.test(text)) return 'test-fix';
if (/generate test|写测试|add test|补充测试/.test(text)) return 'test-gen';
if (/review|审查|code review/.test(text)) return 'review';
if (/不确定|explore|研究|what if|brainstorm|权衡/.test(text)) return 'brainstorm';
if (/多视角|比较方案|cross-verify|multi-cli/.test(text)) return 'multi-cli';
return 'feature'; // Default
}
// Complexity Assessment
function determineComplexity(text) {
let score = 0;
if (/refactor|重构|migrate|迁移|architect|架构|system|系统/.test(text)) score += 2;
if (/multiple|多个|across|跨|all|所有|entire|整个/.test(text)) score += 2;
if (/integrate|集成|api|database|数据库/.test(text)) score += 1;
if (/security|安全|performance|性能|scale|扩展/.test(text)) score += 1;
return score >= 4 ? 'complex' : score >= 2 ? 'medium' : 'simple';
}
```
**Display to user**:
```
Analysis Complete:
Goal: [extracted goal]
Scope: [identified areas]
Constraints: [identified constraints]
Complexity: [level]
Task Type: [detected type]
```
### Phase 2: Discover Commands & Recommend Chain
Dynamic command chain assembly using port-based matching.
#### Command Port Definition
Each command has input/output ports (tags) for pipeline composition:
```javascript
// Port labels represent data types flowing through the pipeline
const commandPorts = {
'lite-plan': {
name: 'lite-plan',
input: ['requirement'], // 输入端口:需求
output: ['plan'], // 输出端口:计划
tags: ['planning'],
atomic_group: 'quick-implementation' // 最小单元:与 lite-execute 绑定
},
'lite-execute': {
name: 'lite-execute',
input: ['plan', 'multi-cli-plan', 'lite-fix'], // 输入端口:可接受多种规划输出
output: ['code'], // 输出端口:代码
tags: ['execution'],
atomic_groups: [ // 可参与多个最小单元
'quick-implementation', // lite-plan → lite-execute
'multi-cli-planning', // multi-cli-plan → lite-execute
'bug-fix' // lite-fix → lite-execute
]
},
'plan': {
name: 'plan',
input: ['requirement'],
output: ['detailed-plan'],
tags: ['planning'],
atomic_groups: [ // 可参与多个最小单元
'full-planning-execution', // plan → execute
'verified-planning-execution' // plan → plan-verify → execute
]
},
'plan-verify': {
name: 'plan-verify',
input: ['detailed-plan'],
output: ['verified-plan'],
tags: ['planning'],
atomic_group: 'verified-planning-execution' // 最小单元plan → plan-verify → execute
},
'replan': {
name: 'replan',
input: ['session', 'feedback'], // 输入端口:会话或反馈
output: ['replan'], // 输出端口:更新后的计划(供 execute 执行)
tags: ['planning'],
atomic_group: 'replanning-execution' // 最小单元:与 execute 绑定
},
'execute': {
name: 'execute',
input: ['detailed-plan', 'verified-plan', 'replan', 'test-tasks', 'tdd-tasks'], // 可接受多种规划输出
output: ['code'],
tags: ['execution'],
atomic_groups: [ // 可参与多个最小单元
'full-planning-execution', // plan → execute
'verified-planning-execution', // plan → plan-verify → execute
'replanning-execution', // replan → execute
'test-generation-execution', // test-gen → execute
'tdd-planning-execution' // tdd-plan → execute
]
},
'test-cycle-execute': {
name: 'test-cycle-execute',
input: ['test-tasks'], // 输入端口:测试任务(需先test-fix-gen生成)
output: ['test-passed'], // 输出端口:测试通过
tags: ['testing'],
atomic_group: 'test-validation', // 最小单元:与 test-fix-gen 绑定
note: '需要先执行test-fix-gen生成测试任务再由此命令执行测试周期'
},
'tdd-plan': {
name: 'tdd-plan',
input: ['requirement'],
output: ['tdd-tasks'], // TDD 任务(供 execute 执行)
tags: ['planning', 'tdd'],
atomic_group: 'tdd-planning-execution' // 最小单元:与 execute 绑定
},
'tdd-verify': {
name: 'tdd-verify',
input: ['code'],
output: ['tdd-verified'],
tags: ['testing']
},
'lite-fix': {
name: 'lite-fix',
input: ['bug-report'], // 输入端口bug 报告
output: ['lite-fix'], // 输出端口:修复计划(供 lite-execute 执行)
tags: ['bugfix'],
atomic_group: 'bug-fix' // 最小单元:与 lite-execute 绑定
},
'debug': {
name: 'debug',
input: ['bug-report'],
output: ['debug-log'],
tags: ['bugfix']
},
'test-gen': {
name: 'test-gen',
input: ['code', 'session'], // 可接受代码或会话
output: ['test-tasks'], // 输出测试任务(IMPL-001,IMPL-002),供 execute 执行
tags: ['testing'],
atomic_group: 'test-generation-execution' // 最小单元:与 execute 绑定
},
'test-fix-gen': {
name: 'test-fix-gen',
input: ['failing-tests', 'session'],
output: ['test-tasks'], // 输出测试任务,针对特定问题生成测试并在测试中修正
tags: ['testing'],
atomic_group: 'test-validation', // 最小单元:与 test-cycle-execute 绑定
note: '生成测试任务供test-cycle-execute执行'
},
'review': {
name: 'review',
input: ['code', 'session'],
output: ['review-findings'],
tags: ['review']
},
'review-fix': {
name: 'review-fix',
input: ['review-findings', 'review-verified'], // Accept output from review-session-cycle or review-module-cycle
output: ['fixed-code'],
tags: ['review'],
atomic_group: 'code-review' // 最小单元:与 review-session-cycle/review-module-cycle 绑定
},
'brainstorm:auto-parallel': {
name: 'brainstorm:auto-parallel',
input: ['exploration-topic'], // 输入端口:探索主题
output: ['brainstorm-analysis'],
tags: ['brainstorm']
},
'multi-cli-plan': {
name: 'multi-cli-plan',
input: ['requirement'],
output: ['multi-cli-plan'], // 对比分析计划(供 lite-execute 执行)
tags: ['planning', 'multi-cli'],
atomic_group: 'multi-cli-planning' // 最小单元:与 lite-execute 绑定
},
'review-session-cycle': {
name: 'review-session-cycle',
input: ['code', 'session'], // 可接受代码或会话
output: ['review-verified'], // 输出端口:审查通过
tags: ['review'],
atomic_group: 'code-review' // 最小单元:与 review-fix 绑定
},
'review-module-cycle': {
name: 'review-module-cycle',
input: ['module-pattern'], // 输入端口:模块模式
output: ['review-verified'], // 输出端口:审查通过
tags: ['review'],
atomic_group: 'code-review' // 最小单元:与 review-fix 绑定
}
};
```
#### Recommendation Algorithm
```javascript
async function recommendCommandChain(analysis) {
// Step 1: 根据任务类型确定起始端口和目标端口
const { inputPort, outputPort } = determinePortFlow(analysis.task_type, analysis.constraints);
// Step 2: Claude 根据命令端口定义和任务特征,智能选择命令序列
// 优先级:简单任务 → lite-* 命令,复杂任务 → 完整命令,特殊约束 → 调整流程
const chain = selectChainByPorts(inputPort, outputPort, analysis);
return chain;
}
// 任务类型对应的端口流
function determinePortFlow(taskType, constraints) {
const flows = {
'bugfix': { inputPort: 'bug-report', outputPort: constraints?.includes('skip-tests') ? 'fixed-code' : 'test-passed' },
'tdd': { inputPort: 'requirement', outputPort: 'tdd-verified' },
'test-fix': { inputPort: 'failing-tests', outputPort: 'test-passed' },
'test-gen': { inputPort: 'code', outputPort: 'test-passed' },
'review': { inputPort: 'code', outputPort: 'review-verified' },
'brainstorm': { inputPort: 'exploration-topic', outputPort: 'test-passed' },
'multi-cli': { inputPort: 'requirement', outputPort: 'test-passed' },
'feature': { inputPort: 'requirement', outputPort: constraints?.includes('skip-tests') ? 'code' : 'test-passed' }
};
return flows[taskType] || flows['feature'];
}
// Claude 根据端口流选择命令链
function selectChainByPorts(inputPort, outputPort, analysis) {
// 参考下面的命令端口定义表和执行示例Claude 智能选择合适的命令序列
// 返回值示例: [lite-plan, lite-execute, test-cycle-execute]
}
```
#### Display to User
```
Recommended Command Chain:
Pipeline (管道视图):
需求 → lite-plan → 计划 → lite-execute → 代码 → test-cycle-execute → 测试通过
Commands (命令列表):
1. /workflow:lite-plan
2. /workflow:lite-execute
3. /workflow:test-cycle-execute
Proceed? [Confirm / Show Details / Adjust / Cancel]
```
### Phase 2b: Get User Confirmation
```javascript
async function getUserConfirmation(chain) {
const response = await AskUserQuestion({
questions: [{
question: 'Proceed with this command chain?',
header: 'Confirm',
options: [
{ label: 'Confirm and execute', description: 'Proceed with commands' },
{ label: 'Show details', description: 'View each command' },
{ label: 'Adjust chain', description: 'Remove or reorder' },
{ label: 'Cancel', description: 'Abort' }
]
}]
});
if (response.confirm === 'Cancel') throw new Error('Cancelled');
if (response.confirm === 'Show details') {
displayCommandDetails(chain);
return getUserConfirmation(chain);
}
if (response.confirm === 'Adjust chain') {
return await adjustChain(chain);
}
return chain;
}
```
### Phase 3: Execute Sequential Command Chain
```javascript
async function executeCommandChain(chain, analysis) {
const sessionId = `ccw-coord-${Date.now()}`;
const stateDir = `.workflow/.ccw-coordinator/${sessionId}`;
Bash(`mkdir -p "${stateDir}"`);
const state = {
session_id: sessionId,
status: 'running',
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
analysis: analysis,
command_chain: chain.map((cmd, idx) => ({ ...cmd, index: idx, status: 'pending' })),
execution_results: [],
prompts_used: []
};
// Save initial state immediately after confirmation
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
for (let i = 0; i < chain.length; i++) {
const cmd = chain[i];
console.log(`[${i+1}/${chain.length}] ${cmd.command}`);
// Update command_chain status to running
state.command_chain[i].status = 'running';
state.updated_at = new Date().toISOString();
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
// Assemble prompt: Command first, then context
let promptContent = formatCommand(cmd, state.execution_results, analysis);
// Build full prompt: Command → Task → Previous Results
let prompt = `${promptContent}\n\nTask: ${analysis.goal}`;
if (state.execution_results.length > 0) {
prompt += '\n\nPrevious results:\n';
state.execution_results.forEach(r => {
if (r.session_id) {
prompt += `- ${r.command}: ${r.session_id} (${r.artifacts?.join(', ') || 'completed'})\n`;
}
});
}
// Record prompt used
state.prompts_used.push({
index: i,
command: cmd.command,
prompt: prompt
});
// Execute CLI command in background and stop
// Format: ccw cli -p "PROMPT" --tool <tool> --mode <mode>
// Note: -y is a command parameter INSIDE the prompt, not a ccw cli parameter
// Example prompt: "/workflow:plan -y \"task description here\""
try {
const taskId = Bash(
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write`,
{ run_in_background: true }
).task_id;
// Save checkpoint
state.execution_results.push({
index: i,
command: cmd.command,
status: 'in-progress',
task_id: taskId,
session_id: null,
artifacts: [],
timestamp: new Date().toISOString()
});
state.command_chain[i].status = 'running';
state.updated_at = new Date().toISOString();
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
console.log(`[${i+1}/${chain.length}] ${cmd.command}\n`);
break; // Stop, wait for hook callback
} catch (error) {
state.command_chain[i].status = 'failed';
state.updated_at = new Date().toISOString();
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
const action = await AskUserQuestion({
questions: [{
question: `${cmd.command} failed to start: ${error.message}. What to do?`,
header: 'Error',
options: [
{ label: 'Retry', description: 'Try again' },
{ label: 'Skip', description: 'Continue next command' },
{ label: 'Abort', description: 'Stop execution' }
]
}]
});
if (action.error === 'Retry') {
state.command_chain[i].status = 'pending';
state.execution_results.pop();
i--;
} else if (action.error === 'Skip') {
state.execution_results[state.execution_results.length - 1].status = 'skipped';
} else if (action.error === 'Abort') {
state.status = 'failed';
break;
}
}
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
}
// Hook callbacks handle completion
if (state.status !== 'failed') state.status = 'waiting';
state.updated_at = new Date().toISOString();
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
console.log(`\n📋 Orchestrator paused: ${state.session_id}\n`);
return state;
}
// Smart parameter assembly
// Returns prompt content to be used with: ccw cli -p "RETURNED_VALUE" --tool claude --mode write
function formatCommand(cmd, previousResults, analysis) {
// Format: /workflow:<command> -y <parameters>
let prompt = `/workflow:${cmd.name} -y`;
const name = cmd.name;
// Planning commands - take task description
if (['lite-plan', 'plan', 'tdd-plan', 'multi-cli-plan'].includes(name)) {
prompt += ` "${analysis.goal}"`;
// Lite execution - use --in-memory if plan exists
} else if (name === 'lite-execute') {
const hasPlan = previousResults.some(r => r.command.includes('plan'));
prompt += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
// Standard execution - resume from planning session
} else if (name === 'execute') {
const plan = previousResults.find(r => r.command.includes('plan'));
if (plan?.session_id) prompt += ` --resume-session="${plan.session_id}"`;
// Bug fix commands - take bug description
} else if (['lite-fix', 'debug'].includes(name)) {
prompt += ` "${analysis.goal}"`;
// Brainstorm - take topic description
} else if (name === 'brainstorm:auto-parallel' || name === 'auto-parallel') {
prompt += ` "${analysis.goal}"`;
// Test generation from session - needs source session
} else if (name === 'test-gen') {
const impl = previousResults.find(r =>
r.command.includes('execute') || r.command.includes('lite-execute')
);
if (impl?.session_id) prompt += ` "${impl.session_id}"`;
else prompt += ` "${analysis.goal}"`;
// Test fix generation - session or description
} else if (name === 'test-fix-gen') {
const latest = previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) prompt += ` "${latest.session_id}"`;
else prompt += ` "${analysis.goal}"`;
// Review commands - take session or use latest
} else if (name === 'review') {
const latest = previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
// Review fix - takes session from review
} else if (name === 'review-fix') {
const review = previousResults.find(r => r.command.includes('review'));
const latest = review || previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
// TDD verify - takes execution session
} else if (name === 'tdd-verify') {
const exec = previousResults.find(r => r.command.includes('execute'));
if (exec?.session_id) prompt += ` --session="${exec.session_id}"`;
// Session-based commands (test-cycle, review-session, plan-verify)
} else if (name.includes('test') || name.includes('review') || name.includes('verify')) {
const latest = previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
}
return prompt;
}
// Hook callback: Called when background CLI completes
async function handleCliCompletion(sessionId, taskId, output) {
const stateDir = `.workflow/.ccw-coordinator/${sessionId}`;
const state = JSON.parse(Read(`${stateDir}/state.json`));
const pendingIdx = state.execution_results.findIndex(r => r.task_id === taskId);
if (pendingIdx === -1) {
console.error(`Unknown task_id: ${taskId}`);
return;
}
const parsed = parseOutput(output);
const cmdIdx = state.execution_results[pendingIdx].index;
// Update result
state.execution_results[pendingIdx] = {
...state.execution_results[pendingIdx],
status: parsed.sessionId ? 'completed' : 'failed',
session_id: parsed.sessionId,
artifacts: parsed.artifacts,
completed_at: new Date().toISOString()
};
state.command_chain[cmdIdx].status = parsed.sessionId ? 'completed' : 'failed';
state.updated_at = new Date().toISOString();
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
// Trigger next command or complete
const nextIdx = cmdIdx + 1;
if (nextIdx < state.command_chain.length) {
await resumeChainExecution(sessionId, nextIdx);
} else {
state.status = 'completed';
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
console.log(`✅ Completed: ${sessionId}\n`);
}
}
// Parse command output
function parseOutput(output) {
const sessionMatch = output.match(/WFS-[\w-]+/);
const artifacts = [];
output.matchAll(/\.workflow\/[^\s]+/g).forEach(m => artifacts.push(m[0]));
return { sessionId: sessionMatch?.[0] || null, artifacts };
}
```
## State File Structure
**Location**: `.workflow/.ccw-coordinator/{session_id}/state.json`
```json
{
"session_id": "ccw-coord-20250124-143025",
"status": "running|waiting|completed|failed",
"created_at": "2025-01-24T14:30:25Z",
"updated_at": "2025-01-24T14:35:45Z",
"analysis": {
"goal": "Implement user registration",
"scope": ["authentication", "user_management"],
"constraints": ["no breaking changes"],
"complexity": "medium"
},
"command_chain": [
{
"index": 0,
"command": "/workflow:plan",
"name": "plan",
"description": "Detailed planning",
"argumentHint": "[--explore] \"task\"",
"status": "completed"
},
{
"index": 1,
"command": "/workflow:execute",
"name": "execute",
"description": "Execute with state resume",
"argumentHint": "[--resume-session=\"WFS-xxx\"]",
"status": "completed"
},
{
"index": 2,
"command": "/workflow:test-cycle-execute",
"name": "test-cycle-execute",
"status": "pending"
}
],
"execution_results": [
{
"index": 0,
"command": "/workflow:plan",
"status": "completed",
"task_id": "task-001",
"session_id": "WFS-plan-20250124",
"artifacts": ["IMPL_PLAN.md", "exploration-architecture.json"],
"timestamp": "2025-01-24T14:30:25Z",
"completed_at": "2025-01-24T14:30:45Z"
},
{
"index": 1,
"command": "/workflow:execute",
"status": "in-progress",
"task_id": "task-002",
"session_id": null,
"artifacts": [],
"timestamp": "2025-01-24T14:32:00Z",
"completed_at": null
}
],
"prompts_used": [
{
"index": 0,
"command": "/workflow:plan",
"prompt": "/workflow:plan -y \"Implement user registration...\"\n\nTask: Implement user registration..."
},
{
"index": 1,
"command": "/workflow:execute",
"prompt": "/workflow:execute -y --resume-session=\"WFS-plan-20250124\"\n\nTask: Implement user registration\n\nPrevious results:\n- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)"
}
]
}
```
### Status Flow
```
running → waiting → [hook callback] → waiting → [hook callback] → completed
↓ ↑
failed ←────────────────────────────────────────────────────────────┘
```
**Status Values**:
- `running`: Orchestrator actively executing (launching CLI commands)
- `waiting`: Paused, waiting for hook callbacks to trigger continuation
- `completed`: All commands finished successfully
- `failed`: User aborted or unrecoverable error
### Field Descriptions
**execution_results[] fields**:
- `index`: Command position in chain (0-indexed)
- `command`: Full command string (e.g., `/workflow:plan`)
- `status`: `in-progress` | `completed` | `skipped` | `failed`
- `task_id`: Background task identifier (from Bash tool)
- `session_id`: Workflow session ID (e.g., `WFS-*`) or null if failed
- `artifacts`: Generated files/directories
- `timestamp`: Command start time (ISO 8601)
- `completed_at`: Command completion time or null if pending
**command_chain[] status values**:
- `pending`: Not started yet
- `running`: Currently executing
- `completed`: Successfully finished
- `failed`: Failed to execute
## CommandRegistry Integration
Sole CCW tool for command discovery:
```javascript
import { CommandRegistry } from 'ccw/tools/command-registry';
const registry = new CommandRegistry();
// Get all commands
const allCommands = registry.getAllCommandsSummary();
// Map<"/workflow:lite-plan" => {name, description}>
// Get categorized
const byCategory = registry.getAllCommandsByCategory();
// {planning, execution, testing, review, other}
// Get single command metadata
const cmd = registry.getCommand('lite-plan');
// {name, command, description, argumentHint, allowedTools, filePath}
```
## Universal Prompt Template
### Standard Format
```bash
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
```
### Prompt Content Template
```
/workflow:<command> -y <command_parameters>
Task: <task_description>
<optional_previous_results>
```
### Template Variables
| Variable | Description | Examples |
|----------|-------------|----------|
| `<command>` | Workflow command name | `plan`, `lite-execute`, `test-cycle-execute` |
| `-y` | Auto-confirm flag (inside prompt) | Always include for automation |
| `<command_parameters>` | Command-specific parameters | Task description, session ID, flags |
| `<task_description>` | Brief task description | "Implement user authentication", "Fix memory leak" |
| `<optional_previous_results>` | Context from previous commands | "Previous results:\n- /workflow:plan: WFS-xxx" |
### Command Parameter Patterns
| Command Type | Parameter Pattern | Example |
|--------------|------------------|---------|
| **Planning** | `"task description"` | `/workflow:plan -y "Implement OAuth2"` |
| **Execution (with plan)** | `--resume-session="WFS-xxx"` | `/workflow:execute -y --resume-session="WFS-plan-001"` |
| **Execution (standalone)** | `--in-memory` or `"task"` | `/workflow:lite-execute -y --in-memory` |
| **Session-based** | `--session="WFS-xxx"` | `/workflow:test-fix-gen -y --session="WFS-impl-001"` |
| **Fix/Debug** | `"problem description"` | `/workflow:lite-fix -y "Fix timeout bug"` |
### Complete Examples
**Planning Command**:
```bash
ccw cli -p '/workflow:plan -y "Implement user registration with email validation"
Task: Implement user registration' --tool claude --mode write
```
**Execution with Context**:
```bash
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
Task: Implement user registration
Previous results:
- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)' --tool claude --mode write
```
**Standalone Lite Execution**:
```bash
ccw cli -p '/workflow:lite-fix -y "Fix login timeout in auth module"
Task: Fix login timeout' --tool claude --mode write
```
## Execution Flow
```javascript
// Main entry point
async function ccwCoordinator(taskDescription) {
// Phase 1
const analysis = await analyzeRequirements(taskDescription);
// Phase 2
const chain = await recommendCommandChain(analysis);
const confirmedChain = await getUserConfirmation(chain);
// Phase 3
const state = await executeCommandChain(confirmedChain, analysis);
console.log(`✅ Complete! Session: ${state.session_id}`);
console.log(`State: .workflow/.ccw-coordinator/${state.session_id}/state.json`);
}
```
## Key Design Principles
1. **No Fixed Logic** - Claude intelligently decides based on analysis
2. **Dynamic Discovery** - CommandRegistry retrieves available commands
3. **Smart Parameters** - Command args assembled based on previous results
4. **Full State Tracking** - All execution recorded to state.json
5. **User Control** - Confirmation + error handling with user choice
6. **Context Passing** - Each prompt includes previous results
7. **Resumable** - Can load state.json to continue
8. **Serial Blocking** - Commands execute one-by-one with hook-based continuation
## CLI Execution Model
### CLI Invocation Format
**IMPORTANT**: The `ccw cli` command executes prompts through external tools. The format is:
```bash
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
```
**Parameters**:
- `-p "PROMPT_CONTENT"`: The prompt content to execute (required)
- `--tool <tool>`: CLI tool to use (e.g., `claude`, `gemini`, `qwen`)
- `--mode <mode>`: Execution mode (`analysis` or `write`)
**Note**: `-y` is a **command parameter inside the prompt**, NOT a `ccw cli` parameter.
### Prompt Assembly
The prompt content MUST start with the workflow command, followed by task context:
```
/workflow:<command> -y <parameters>
Task: <description>
<optional_context>
```
**Examples**:
```bash
# Planning command
ccw cli -p '/workflow:plan -y "Implement user registration feature"
Task: Implement user registration' --tool claude --mode write
# Execution command (with session reference)
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
Task: Implement user registration
Previous results:
- /workflow:plan: WFS-plan-20250124' --tool claude --mode write
# Lite execution (in-memory from previous plan)
ccw cli -p '/workflow:lite-execute -y --in-memory
Task: Implement user registration' --tool claude --mode write
```
### Serial Blocking
**CRITICAL**: Commands execute one-by-one. After launching CLI in background:
1. Orchestrator stops immediately (`break`)
2. Wait for hook callback - **DO NOT use TaskOutput polling**
3. Hook callback triggers next command
**Prompt Structure**: Command must be first in prompt content
```javascript
// Example: Execute command and stop
const prompt = '/workflow:plan -y "Implement user authentication"\n\nTask: Implement user auth system';
const taskId = Bash(`ccw cli -p "${prompt}" --tool claude --mode write`, { run_in_background: true }).task_id;
state.execution_results.push({ status: 'in-progress', task_id: taskId, ... });
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
break; // ⚠️ STOP HERE - DO NOT use TaskOutput polling
// Hook callback will call handleCliCompletion(sessionId, taskId, output) when done
// → Updates state → Triggers next command via resumeChainExecution()
```
## Available Commands
All from `~/.claude/commands/workflow/`:
**Planning**: lite-plan, plan, multi-cli-plan, plan-verify, tdd-plan
**Execution**: lite-execute, execute, develop-with-file
**Testing**: test-cycle-execute, test-gen, test-fix-gen, tdd-verify
**Review**: review, review-session-cycle, review-module-cycle, review-fix
**Bug Fixes**: lite-fix, debug, debug-with-file
**Brainstorming**: brainstorm:auto-parallel, brainstorm:artifacts, brainstorm:synthesis
**Design**: ui-design:*, animation-extract, layout-extract, style-extract, codify-style
**Session Management**: session:start, session:resume, session:complete, session:solidify, session:list
**Tools**: context-gather, test-context-gather, task-generate, conflict-resolution, action-plan-verify
**Utility**: clean, init, replan
### Testing Commands Distinction
| Command | Purpose | Output | Follow-up |
|---------|---------|--------|-----------|
| **test-gen** | 广泛测试示例生成并进行测试 | test-tasks (IMPL-001, IMPL-002) | `/workflow:execute` |
| **test-fix-gen** | 针对特定问题生成测试并在测试中修正 | test-tasks | `/workflow:test-cycle-execute` |
| **test-cycle-execute** | 执行测试周期(迭代测试和修复) | test-passed | N/A (终点) |
**流程说明**:
- **test-gen → execute**: 生成全面的测试套件execute 执行生成和测试
- **test-fix-gen → test-cycle-execute**: 针对特定问题生成修复任务test-cycle-execute 迭代测试和修复直到通过
### Task Type Routing (Pipeline Summary)
**Note**: `【 】` marks Minimum Execution Units (最小执行单元) - these commands must execute together.
| Task Type | Pipeline | Minimum Units |
|-----------|----------|---|
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Quick Implementation + Test Validation |
| **feature** (complex) | 需求 →【plan → plan-verify】→ validate → execute → 代码 → review → fix | Full Planning + Code Review + Testing |
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Bug Fix + Test Validation |
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify | TDD Planning + Execution |
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Test Validation |
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 | Test Generation + Execution |
| **review** | 代码 →【review-* → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Code Review + Testing |
| **brainstorm** | 探索主题 → brainstorm → 分析 →【plan → plan-verify】→ execute → test | Exploration + Planning + Execution |
| **multi-cli** | 需求 → multi-cli-plan → 对比分析 → lite-execute → test | Multi-Perspective + Testing |
Use `CommandRegistry.getAllCommandsSummary()` to discover all commands dynamically.

486
.claude/commands/ccw.md Normal file
View File

@@ -0,0 +1,486 @@
---
name: ccw
description: Main workflow orchestrator - analyze intent, select workflow, execute command chain in main process
argument-hint: "\"task description\""
allowed-tools: SlashCommand(*), TodoWrite(*), AskUserQuestion(*), Read(*), Grep(*), Glob(*)
---
# CCW Command - Main Workflow Orchestrator
Main process orchestrator: intent analysis → workflow selection → command chain execution.
## Core Concept: Minimum Execution Units (最小执行单元)
**Definition**: A set of commands that must execute together as an atomic group to achieve a meaningful workflow milestone.
**Why This Matters**:
- **Prevents Incomplete States**: Avoid stopping after task generation without execution
- **User Experience**: User gets complete results, not intermediate artifacts requiring manual follow-up
- **Workflow Integrity**: Maintains logical coherence of multi-step operations
**Key Units in CCW**:
| Unit Type | Pattern | Example |
|-----------|---------|---------|
| **Planning + Execution** | plan-cmd → execute-cmd | lite-plan → lite-execute |
| **Testing** | test-gen-cmd → test-exec-cmd | test-fix-gen → test-cycle-execute |
| **Review** | review-cmd → fix-cmd | review-session-cycle → review-fix |
**Atomic Rules**:
1. CCW automatically groups commands into minimum units - never splits them
2. Pipeline visualization shows units with `【 】` markers
3. Error handling preserves unit boundaries (retry/skip affects whole unit)
## Execution Model
**Synchronous (Main Process)**: Commands execute via SlashCommand in main process, blocking until complete.
```
User Input → Analyze Intent → Select Workflow → [Confirm] → Execute Chain
SlashCommand (blocking)
Update TodoWrite
Next Command...
```
**vs ccw-coordinator**: External CLI execution with background tasks and hook callbacks.
## 5-Phase Workflow
### Phase 1: Analyze Intent
```javascript
function analyzeIntent(input) {
return {
goal: extractGoal(input),
scope: extractScope(input),
constraints: extractConstraints(input),
task_type: detectTaskType(input), // bugfix|feature|tdd|review|exploration|...
complexity: assessComplexity(input), // low|medium|high
clarity_score: calculateClarity(input) // 0-3 (>=2 = clear)
};
}
// Task type detection (priority order)
function detectTaskType(text) {
const patterns = {
'bugfix-hotfix': /urgent|production|critical/ && /fix|bug/,
'bugfix': /fix|bug|error|crash|fail|debug/,
'issue-batch': /issues?|batch/ && /fix|resolve/,
'exploration': /uncertain|explore|research|what if/,
'multi-perspective': /multi-perspective|compare|cross-verify/,
'quick-task': /quick|simple|small/ && /feature|function/,
'ui-design': /ui|design|component|style/,
'tdd': /tdd|test-driven|test first/,
'test-fix': /test fail|fix test|failing test/,
'review': /review|code review/,
'documentation': /docs|documentation|readme/
};
for (const [type, pattern] of Object.entries(patterns)) {
if (pattern.test(text)) return type;
}
return 'feature';
}
```
**Output**: `Type: [task_type] | Goal: [goal] | Complexity: [complexity] | Clarity: [clarity_score]/3`
---
### Phase 1.5: Requirement Clarification (if clarity_score < 2)
```javascript
async function clarifyRequirements(analysis) {
if (analysis.clarity_score >= 2) return analysis;
const questions = generateClarificationQuestions(analysis); // Goal, Scope, Constraints
const answers = await AskUserQuestion({ questions });
return updateAnalysis(analysis, answers);
}
```
**Questions**: Goal (Create/Fix/Optimize/Analyze), Scope (Single file/Module/Cross-module/System), Constraints (Backward compat/Skip tests/Urgent hotfix)
---
### Phase 2: Select Workflow & Build Command Chain
```javascript
function selectWorkflow(analysis) {
const levelMap = {
'bugfix-hotfix': { level: 2, flow: 'bugfix.hotfix' },
'bugfix': { level: 2, flow: 'bugfix.standard' },
'issue-batch': { level: 'Issue', flow: 'issue' },
'exploration': { level: 4, flow: 'full' },
'quick-task': { level: 1, flow: 'lite-lite-lite' },
'ui-design': { level: analysis.complexity === 'high' ? 4 : 3, flow: 'ui' },
'tdd': { level: 3, flow: 'tdd' },
'test-fix': { level: 3, flow: 'test-fix-gen' },
'review': { level: 3, flow: 'review-fix' },
'documentation': { level: 2, flow: 'docs' },
'feature': { level: analysis.complexity === 'high' ? 3 : 2, flow: analysis.complexity === 'high' ? 'coupled' : 'rapid' }
};
const selected = levelMap[analysis.task_type] || levelMap['feature'];
return buildCommandChain(selected, analysis);
}
// Build command chain (port-based matching with Minimum Execution Units)
function buildCommandChain(workflow, analysis) {
const chains = {
// Level 1 - Rapid
'lite-lite-lite': [
{ cmd: '/workflow:lite-lite-lite', args: `"${analysis.goal}"` }
],
// Level 2 - Lightweight
'rapid': [
// Unit: Quick Implementation【lite-plan → lite-execute】
{ cmd: '/workflow:lite-plan', args: `"${analysis.goal}"`, unit: 'quick-impl' },
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'quick-impl' },
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
...(analysis.constraints?.includes('skip-tests') ? [] : [
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
])
],
'bugfix.standard': [
// Unit: Bug Fix【lite-fix → lite-execute】
{ cmd: '/workflow:lite-fix', args: `"${analysis.goal}"`, unit: 'bug-fix' },
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'bug-fix' },
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
...(analysis.constraints?.includes('skip-tests') ? [] : [
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
])
],
'bugfix.hotfix': [
{ cmd: '/workflow:lite-fix', args: `--hotfix "${analysis.goal}"` }
],
'multi-cli-plan': [
// Unit: Multi-CLI Planning【multi-cli-plan → lite-execute】
{ cmd: '/workflow:multi-cli-plan', args: `"${analysis.goal}"`, unit: 'multi-cli' },
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'multi-cli' },
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
...(analysis.constraints?.includes('skip-tests') ? [] : [
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
])
],
'docs': [
// Unit: Quick Implementation【lite-plan → lite-execute】
{ cmd: '/workflow:lite-plan', args: `"${analysis.goal}"`, unit: 'quick-impl' },
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'quick-impl' }
],
// Level 3 - Standard
'coupled': [
// Unit: Verified Planning【plan → plan-verify】
{ cmd: '/workflow:plan', args: `"${analysis.goal}"`, unit: 'verified-planning' },
{ cmd: '/workflow:plan-verify', args: '', unit: 'verified-planning' },
// Execution
{ cmd: '/workflow:execute', args: '' },
// Unit: Code Review【review-session-cycle → review-fix】
{ cmd: '/workflow:review-session-cycle', args: '', unit: 'code-review' },
{ cmd: '/workflow:review-fix', args: '', unit: 'code-review' },
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
...(analysis.constraints?.includes('skip-tests') ? [] : [
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
])
],
'tdd': [
// Unit: TDD Planning + Execution【tdd-plan → execute】
{ cmd: '/workflow:tdd-plan', args: `"${analysis.goal}"`, unit: 'tdd-planning' },
{ cmd: '/workflow:execute', args: '', unit: 'tdd-planning' },
// TDD Verification
{ cmd: '/workflow:tdd-verify', args: '' }
],
'test-fix-gen': [
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
{ cmd: '/workflow:test-fix-gen', args: `"${analysis.goal}"`, unit: 'test-validation' },
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
],
'review-fix': [
// Unit: Code Review【review-session-cycle → review-fix】
{ cmd: '/workflow:review-session-cycle', args: '', unit: 'code-review' },
{ cmd: '/workflow:review-fix', args: '', unit: 'code-review' },
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
],
'ui': [
{ cmd: '/workflow:ui-design:explore-auto', args: `"${analysis.goal}"` },
// Unit: Planning + Execution【plan → execute】
{ cmd: '/workflow:plan', args: '', unit: 'plan-execute' },
{ cmd: '/workflow:execute', args: '', unit: 'plan-execute' }
],
// Level 4 - Brainstorm
'full': [
{ cmd: '/workflow:brainstorm:auto-parallel', args: `"${analysis.goal}"` },
// Unit: Verified Planning【plan → plan-verify】
{ cmd: '/workflow:plan', args: '', unit: 'verified-planning' },
{ cmd: '/workflow:plan-verify', args: '', unit: 'verified-planning' },
// Execution
{ cmd: '/workflow:execute', args: '' },
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
],
// Issue Workflow
'issue': [
{ cmd: '/issue:discover', args: '' },
{ cmd: '/issue:plan', args: '--all-pending' },
{ cmd: '/issue:queue', args: '' },
{ cmd: '/issue:execute', args: '' }
]
};
return chains[workflow.flow] || chains['rapid'];
}
```
**Output**: `Level [X] - [flow] | Pipeline: [...] | Commands: [1. /cmd1 2. /cmd2 ...]`
---
### Phase 3: User Confirmation
```javascript
async function getUserConfirmation(chain) {
const response = await AskUserQuestion({
questions: [{
question: "Execute this command chain?",
header: "Confirm",
options: [
{ label: "Confirm", description: "Start" },
{ label: "Adjust", description: "Modify" },
{ label: "Cancel", description: "Abort" }
]
}]
});
if (response.error === "Cancel") throw new Error("Cancelled");
if (response.error === "Adjust") return await adjustChain(chain);
return chain;
}
```
---
### Phase 4: Setup TODO Tracking
```javascript
function setupTodoTracking(chain, workflow) {
const todos = chain.map((step, i) => ({
content: `CCW:${workflow}: [${i + 1}/${chain.length}] ${step.cmd}`,
status: i === 0 ? 'in_progress' : 'pending',
activeForm: `Executing ${step.cmd}`
}));
TodoWrite({ todos });
}
```
**Output**: `-> CCW:rapid: [1/3] /workflow:lite-plan | CCW:rapid: [2/3] /workflow:lite-execute | ...`
---
### Phase 5: Execute Command Chain
```javascript
async function executeCommandChain(chain, workflow) {
let previousResult = null;
for (let i = 0; i < chain.length; i++) {
try {
const fullCommand = assembleCommand(chain[i], previousResult);
const result = await SlashCommand({ command: fullCommand });
previousResult = { ...result, success: true };
updateTodoStatus(i, chain.length, workflow, 'completed');
} catch (error) {
const action = await handleError(chain[i], error, i);
if (action === 'retry') {
i--; // Retry
} else if (action === 'abort') {
return { success: false, error: error.message };
}
// 'skip' - continue
}
}
return { success: true, completed: chain.length };
}
// Assemble full command with session/plan parameters
function assembleCommand(step, previousResult) {
let command = step.cmd;
if (step.args) {
command += ` ${step.args}`;
} else if (previousResult?.session_id) {
command += ` --session="${previousResult.session_id}"`;
}
return command;
}
// Update TODO: mark current as complete, next as in-progress
function updateTodoStatus(index, total, workflow, status) {
const todos = getAllCurrentTodos();
const updated = todos.map(todo => {
if (todo.content.startsWith(`CCW:${workflow}:`)) {
const stepNum = extractStepIndex(todo.content);
if (stepNum === index + 1) return { ...todo, status };
if (stepNum === index + 2 && status === 'completed') return { ...todo, status: 'in_progress' };
}
return todo;
});
TodoWrite({ todos: updated });
}
// Error handling: Retry/Skip/Abort
async function handleError(step, error, index) {
const response = await AskUserQuestion({
questions: [{
question: `${step.cmd} failed: ${error.message}`,
header: "Error",
options: [
{ label: "Retry", description: "Re-execute" },
{ label: "Skip", description: "Continue next" },
{ label: "Abort", description: "Stop" }
]
}]
});
return { Retry: 'retry', Skip: 'skip', Abort: 'abort' }[response.Error] || 'abort';
}
```
---
## Execution Flow Summary
```
User Input
|
Phase 1: Analyze Intent
|-- Extract: goal, scope, constraints, task_type, complexity, clarity
+-- If clarity < 2 -> Phase 1.5: Clarify Requirements
|
Phase 2: Select Workflow & Build Chain
|-- Map task_type -> Level (1/2/3/4/Issue)
|-- Select flow based on complexity
+-- Build command chain (port-based)
|
Phase 3: User Confirmation (optional)
|-- Show pipeline visualization
+-- Allow adjustment
|
Phase 4: Setup TODO Tracking
+-- Create todos with CCW prefix
|
Phase 5: Execute Command Chain
|-- For each command:
| |-- Assemble full command
| |-- Execute via SlashCommand
| |-- Update TODO status
| +-- Handle errors (retry/skip/abort)
+-- Return workflow result
```
---
## Pipeline Examples (with Minimum Execution Units)
**Note**: `【 】` marks Minimum Execution Units - commands execute together as atomic groups.
| Input | Type | Level | Pipeline (with Units) |
|-------|------|-------|-----------------------|
| "Add API endpoint" | feature (low) | 2 |【lite-plan → lite-execute】→【test-fix-gen → test-cycle-execute】|
| "Fix login timeout" | bugfix | 2 |【lite-fix → lite-execute】→【test-fix-gen → test-cycle-execute】|
| "OAuth2 system" | feature (high) | 3 |【plan → plan-verify】→ execute →【review-session-cycle → review-fix】→【test-fix-gen → test-cycle-execute】|
| "Implement with TDD" | tdd | 3 |【tdd-plan → execute】→ tdd-verify |
| "Uncertain: real-time arch" | exploration | 4 | brainstorm:auto-parallel →【plan → plan-verify】→ execute →【test-fix-gen → test-cycle-execute】|
---
## Key Design Principles
1. **Main Process Execution** - Use SlashCommand in main process, no external CLI
2. **Intent-Driven** - Auto-select workflow based on task intent
3. **Port-Based Chaining** - Build command chain using port matching
4. **Minimum Execution Units** - Commands grouped into atomic units, never split (e.g., lite-plan → lite-execute)
5. **Progressive Clarification** - Low clarity triggers clarification phase
6. **TODO Tracking** - Use CCW prefix to isolate workflow todos
7. **Unit-Aware Error Handling** - Retry/skip/abort affects whole unit, not individual commands
8. **User Control** - Optional user confirmation at each phase
---
## State Management
**TodoWrite-Based Tracking**: All execution state tracked via TodoWrite with `CCW:` prefix.
```javascript
// Initial state
todos = [
{ content: "CCW:rapid: [1/3] /workflow:lite-plan", status: "in_progress" },
{ content: "CCW:rapid: [2/3] /workflow:lite-execute", status: "pending" },
{ content: "CCW:rapid: [3/3] /workflow:test-cycle-execute", status: "pending" }
];
// After command 1 completes
todos = [
{ content: "CCW:rapid: [1/3] /workflow:lite-plan", status: "completed" },
{ content: "CCW:rapid: [2/3] /workflow:lite-execute", status: "in_progress" },
{ content: "CCW:rapid: [3/3] /workflow:test-cycle-execute", status: "pending" }
];
```
**vs ccw-coordinator**: Extensive state.json with task_id, status transitions, hook callbacks.
---
## Type Comparison: ccw vs ccw-coordinator
| Aspect | ccw | ccw-coordinator |
|--------|-----|-----------------|
| **Type** | Main process (SlashCommand) | External CLI (ccw cli + hook callbacks) |
| **Execution** | Synchronous blocking | Async background with hook completion |
| **Workflow** | Auto intent-based selection | Manual chain building |
| **Intent Analysis** | 5-phase clarity check | 3-phase requirement analysis |
| **State** | TodoWrite only (in-memory) | state.json + checkpoint/resume |
| **Error Handling** | Retry/skip/abort (interactive) | Retry/skip/abort (via AskUser) |
| **Use Case** | Auto workflow for any task | Manual orchestration, large chains |
---
## Usage
```bash
# Auto-select workflow
ccw "Add user authentication"
# Complex requirement (triggers clarification)
ccw "Optimize system performance"
# Bug fix
ccw "Fix memory leak in WebSocket handler"
# TDD development
ccw "Implement user registration with TDD"
# Exploratory task
ccw "Uncertain about architecture for real-time notifications"
```

View File

@@ -1,45 +0,0 @@
# CCW Coordinator
交互式命令编排工具
## 使用
```
/ccw-coordinator
/coordinator
```
## 流程
1. 用户描述任务
2. Claude推荐命令链
3. 用户确认或调整
4. 执行命令链
5. 生成报告
## 示例
**Bug修复**
```
任务: 修复登录bug
推荐: lite-fix → test-cycle-execute
```
**新功能**
```
任务: 实现注册功能
推荐: plan → execute → test-cycle-execute
```
## 文件说明
| 文件 | 用途 |
|------|------|
| SKILL.md | Skill入口 |
| phases/orchestrator.md | 编排逻辑 |
| phases/state-schema.md | 状态定义 |
| phases/actions/*.md | 动作实现 |
| specs/specs.md | 命令库、验证规则、注册表 |
| tools/chain-validate.cjs | 验证工具 |
| tools/command-registry.cjs | 命令注册表工具 |

View File

@@ -1,320 +0,0 @@
---
name: ccw-coordinator
description: Interactive command orchestration tool for building and executing Claude CLI command chains. Triggers on "coordinator", "ccw-coordinator", "命令编排", "command chain", "orchestrate commands", "编排CLI命令".
allowed-tools: Task, AskUserQuestion, Read, Write, Bash, Glob, Grep
---
# CCW Coordinator
交互式命令编排工具允许用户依次选择命令形成命令串然后依次调用claude cli执行整个命令串。
支持灵活的工作流组合,提供交互式界面用于命令选择、编排和执行管理。
## Architecture Overview
```
┌─────────────────────────────────────────────────────────────────┐
│ Orchestrator (状态驱动决策) │
│ 根据用户选择编排命令和执行流程 │
└───────────────┬─────────────────────────────────────────────────┘
┌───────────┼───────────┬───────────────┐
↓ ↓ ↓ ↓
┌─────────┐ ┌──────────────┐ ┌────────────┐ ┌──────────┐
│ Init │ │ Command │ │ Command │ │ Execute │
│ │ │ Selection │ │ Build │ │ │
│ │ │ │ │ │ │ │
│ 初始化 │ │ 选择命令 │ │ 编排调整 │ │ 执行链 │
└─────────┘ └──────────────┘ └────────────┘ └──────────┘
│ │ │ │
└───────────────┼──────────────┴────────────┘
┌──────────────┐
│ Complete │
│ 生成报告 │
└──────────────┘
```
## Key Design Principles
1. **智能推荐**: Claude 根据用户任务描述,自动推荐最优命令链
2. **交互式编排**: 用户通过交互式界面选择和编排命令,实时反馈
3. **无状态动作**: 每个动作独立执行,通过共享状态进行通信
4. **灵活的命令库**: 支持ccw workflow命令和标准claude cli命令
5. **执行透明性**: 展示执行进度、结果和可能的错误
6. **会话持久化**: 保存编排会话,支持中途暂停和恢复
7. **智能提示词生成**: 根据任务上下文和前序产物自动生成 ccw cli 提示词
8. **自动确认**: 所有命令自动添加 `-y` 参数,跳过交互式确认,实现无人值守执行
## Intelligent Prompt Generation
执行命令时,系统根据以下信息智能生成 `ccw cli -p` 提示词:
### 提示词构成
```javascript
// 集成命令注册表 (~/.claude/tools/command-registry.js)
const registry = new CommandRegistry();
registry.buildRegistry();
function generatePrompt(cmd, state) {
const cmdMeta = registry.getCommand(cmd.command);
let prompt = `任务: ${state.task_description}\n`;
if (state.execution_results.length > 0) {
const previousOutputs = state.execution_results
.filter(r => r.status === 'success')
.map(r => {
if (r.summary?.session) {
return `- ${r.command}: ${r.summary.session} (${r.summary.files?.join(', ')})`;
}
return `- ${r.command}: 已完成`;
})
.join('\n');
prompt += `\n前序完成:\n${previousOutputs}\n`;
}
// 从 YAML 头提取命令元数据
if (cmdMeta) {
prompt += `\n命令: ${cmd.command}`;
if (cmdMeta.argumentHint) {
prompt += ` ${cmdMeta.argumentHint}`;
}
}
return prompt;
}
```
### 产物追踪
每个命令执行后自动提取关键产物:
```javascript
{
command: "/workflow:lite-plan",
status: "success",
output: "...",
summary: {
session: "WFS-plan-20250123", // 会话ID
files: [".workflow/IMPL_PLAN.md"], // 产物文件
timestamp: "2025-01-23T10:30:00Z"
}
}
```
### 命令调用示例
```bash
# 自动生成的智能提示词
ccw cli -p "任务: 实现用户认证功能
前序完成:
- /workflow:lite-plan: WFS-plan-20250123 (.workflow/IMPL_PLAN.md)
命令: /workflow:lite-execute [--resume-session=\"session-id\"]" /workflow:lite-execute
```
### 命令注册表集成
- **位置**: `tools/command-registry.js` (skill 内置)
- **工作模式**: 按需提取(只提取用户任务链中的命令)
- **功能**: 自动查找全局 `.claude/commands/workflow` 目录,解析命令 YAML 头元数据
- **作用**: 确保提示词包含准确的命令参数和上下文
详见 `tools/README.md`
---
## Execution Flow
### Orchestrator Execution Loop
```javascript
1. 初始化会话
2. 循环执行直到完成
读取当前状态
选择下一个动作根据状态和用户意图
执行动作更新状态
检查终止条件
3. 生成最终报告
```
### Action Sequence (Typical)
```
action-init
↓ (status: pending → running)
action-command-selection (可重复)
↓ (添加命令到链)
action-command-build (可选)
↓ (调整命令顺序)
action-command-execute
↓ (依次执行所有命令)
action-complete
↓ (status: running → completed)
```
## State Management
### Initial State
```json
{
"status": "pending",
"task_description": "",
"command_chain": [],
"confirmed": false,
"error_count": 0,
"execution_results": [],
"current_command_index": 0,
"started_at": null
}
```
### State Transitions
```
pending → running (init) → running → completed (execute)
aborted (error or user exit)
```
## Directory Setup
```javascript
const timestamp = new Date().toISOString().slice(0,19).replace(/[-:T]/g, '');
const workDir = `.workflow/.ccw-coordinator/${timestamp}`;
Bash(`mkdir -p "${workDir}"`);
Bash(`mkdir -p "${workDir}/commands"`);
Bash(`mkdir -p "${workDir}/logs"`);
```
## Output Structure
```
.workflow/.ccw-coordinator/{timestamp}/
├── state.json # 当前会话状态
├── command-chain.json # 编排的完整命令链
├── execution-log.md # 执行日志
├── final-summary.md # 最终报告
├── commands/ # 各命令执行详情
│ ├── 01-command.log
│ ├── 02-command.log
│ └── ...
└── logs/ # 错误和警告日志
├── errors.log
└── warnings.log
```
## Reference Documents
| Document | Purpose |
|----------|---------|
| [phases/orchestrator.md](phases/orchestrator.md) | 编排器实现 |
| [phases/state-schema.md](phases/state-schema.md) | 状态结构定义 |
| [phases/actions/action-init.md](phases/actions/action-init.md) | 初始化动作 |
| [phases/actions/action-command-selection.md](phases/actions/action-command-selection.md) | 命令选择动作 |
| [phases/actions/action-command-build.md](phases/actions/action-command-build.md) | 命令编排动作 |
| [phases/actions/action-command-execute.md](phases/actions/action-command-execute.md) | 命令执行动作 |
| [phases/actions/action-complete.md](phases/actions/action-complete.md) | 完成动作 |
| [phases/actions/action-abort.md](phases/actions/action-abort.md) | 中止动作 |
| [specs/specs.md](specs/specs.md) | 命令库、验证规则、注册表 |
| [tools/chain-validate.cjs](tools/chain-validate.cjs) | 验证工具 |
| [tools/command-registry.cjs](tools/command-registry.cjs) | 命令注册表工具 |
---
## Usage Examples
### 快速命令链
用户想要执行:规划 → 执行 → 测试
```
1. 触发 /ccw-coordinator
2. 描述任务:"实现用户注册功能"
3. Claude推荐: plan → execute → test-cycle-execute
4. 用户确认
5. 执行命令链
```
### 复杂工作流
用户想要执行:规划 → 执行 → 审查 → 修复
```
1. 触发 /ccw-coordinator
2. 描述任务:"重构认证模块"
3. Claude推荐: plan → execute → review-session-cycle → review-fix
4. 用户可调整命令顺序
5. 确认执行
6. 实时查看执行进度
```
### 紧急修复
用户想要快速修复bug
```
1. 触发 /ccw-coordinator
2. 描述任务:"修复生产环境登录bug"
3. Claude推荐: lite-fix --hotfix → test-cycle-execute
4. 用户确认
5. 快速执行修复
```
## Constraints and Rules
### 1. 命令推荐约束
- **智能推荐优先**: 必须先基于用户任务描述进行智能推荐,而非直接展示命令库
- **不使用静态映射**: 禁止使用查表或硬编码的推荐逻辑(如 `if task=bug则推荐lite-fix`
- **推荐必须说明理由**: Claude 推荐命令链时必须解释为什么这样推荐
- **用户保留选择权**: 推荐后,用户可选择:使用推荐/调整/手动选择
### 2. 验证约束
- **执行前必须验证**: 使用 `chain-validate.js` 验证命令链合法性
- **不合法必须提示**: 如果验证失败,必须明确告知用户错误原因和修复方法
- **允许用户覆盖**: 验证失败时,询问用户是否仍要执行(警告模式)
### 3. 执行约束
- **顺序执行**: 命令必须严格按照 command_chain 中的 order 顺序执行
- **错误处理**: 单个命令失败时,询问用户:重试/跳过/中止
- **错误上限**: 连续 3 次错误自动中止会话
- **实时反馈**: 每个命令执行时显示进度(如 `[2/5] 执行: lite-execute`
### 4. 状态管理约束
- **状态持久化**: 每次状态更新必须立即写入磁盘
- **单一数据源**: 状态只保存在 `state.json`,禁止多个状态文件
- **原子操作**: 状态更新必须使用 read-modify-write 模式,避免并发冲突
### 5. 用户体验约束
- **最小交互**: 默认使用智能推荐 + 一次确认,避免多次询问
- **清晰输出**: 每个步骤输出必须包含:当前状态、可用选项、建议操作
- **可恢复性**: 会话中断后,用户可从上次状态恢复
### 6. 禁止行为
-**禁止跳过推荐步骤**: 不能直接进入手动选择,必须先尝试推荐
-**禁止静态推荐**: 不能使用 recommended-chains.json 查表
-**禁止无验证执行**: 不能跳过链条验证直接执行
-**禁止静默失败**: 错误必须明确报告,不能静默跳过
## Notes
- 编排器使用状态机模式,确保执行流程的可靠性
- 所有命令链和执行结果都被持久化保存,支持后续查询和调试
- 支持用户中途修改命令链(在执行前)
- 执行错误会自动记录,支持重试机制
- Claude 智能推荐基于任务分析,非查表静态推荐

View File

@@ -1,9 +0,0 @@
# action-abort
中止会话,保存状态
```javascript
updateState({ status: 'aborted' });
console.log(`会话已中止: ${workDir}`);
```

View File

@@ -1,40 +0,0 @@
# action-command-build
调整命令链顺序或删除命令
## 流程
1. 显示当前命令链
2. 让用户调整(重新排序、删除)
3. 确认执行
## 伪代码
```javascript
// 显示链
console.log('命令链:');
state.command_chain.forEach((cmd, i) => {
console.log(`${i+1}. ${cmd.command}`);
});
// 询问用户
const action = await AskUserQuestion({
options: [
'继续执行',
'删除命令',
'重新排序',
'返回选择'
]
});
// 处理用户操作
if (action === '继续执行') {
updateState({confirmed: true, status: 'executing'});
}
// ... 其他操作
```
## 状态更新
- command_chain (可能修改)
- confirmed = true 时状态转为 executing

View File

@@ -1,124 +0,0 @@
# action-command-execute
依次执行命令链,智能生成 ccw cli 提示词
## 命令注册表集成
```javascript
// 从 ./tools/command-registry.cjs 按需提取命令元数据
const CommandRegistry = require('./tools/command-registry.cjs');
const registry = new CommandRegistry();
// 只提取当前任务链中的命令
const commandNames = command_chain.map(cmd => cmd.command);
const commandMeta = registry.getCommands(commandNames);
```
## 提示词生成策略
```javascript
function generatePrompt(cmd, state, commandMeta) {
const { task_description, execution_results } = state;
// 获取命令元数据(从已提取的 commandMeta
const cmdInfo = commandMeta[cmd.command];
// 提取前序产物信息
const previousOutputs = execution_results
.filter(r => r.status === 'success')
.map(r => {
const summary = r.summary;
if (summary?.session) {
return `- ${r.command}: ${summary.session} (${summary.files?.join(', ') || '完成'})`;
}
return `- ${r.command}: 已完成`;
})
.join('\n');
// 根据命令类型构建提示词
let prompt = `任务: ${task_description}\n`;
if (previousOutputs) {
prompt += `\n前序完成:\n${previousOutputs}\n`;
}
// 添加命令元数据上下文
if (cmdInfo) {
prompt += `\n命令: ${cmd.command}`;
if (cmdInfo.argumentHint) {
prompt += ` ${cmdInfo.argumentHint}`;
}
}
return prompt;
}
```
## 执行逻辑
```javascript
for (let i = current_command_index; i < command_chain.length; i++) {
const cmd = command_chain[i];
console.log(`[${i+1}/${command_chain.length}] 执行: ${cmd.command}`);
// 生成智能提示词
const prompt = generatePrompt(cmd, state, commandMeta);
try {
// 使用 ccw cli 执行(添加 -y 参数跳过确认)
const result = Bash(`ccw cli -p "${prompt.replace(/"/g, '\\"')}" ${cmd.command} -y`, {
run_in_background: true
});
execution_results.push({
command: cmd.command,
status: result.exit_code === 0 ? 'success' : 'failed',
exit_code: result.exit_code,
output: result.stdout,
summary: extractSummary(result.stdout) // 提取关键产物
});
command_chain[i].status = 'completed';
current_command_index = i + 1;
} catch (error) {
error_count++;
command_chain[i].status = 'failed';
if (error_count >= 3) break;
const action = await AskUserQuestion({
options: ['重试', '跳过', '中止']
});
if (action === '重试') i--;
if (action === '中止') break;
}
updateState({ command_chain, execution_results, current_command_index, error_count });
}
```
## 产物提取
```javascript
function extractSummary(output) {
// 从输出提取关键产物信息
// 例如: 会话ID, 文件路径, 任务完成状态等
const sessionMatch = output.match(/WFS-\w+-\d+/);
const fileMatch = output.match(/\.workflow\/[^\s]+/g);
return {
session: sessionMatch?.[0],
files: fileMatch || [],
timestamp: new Date().toISOString()
};
}
```
## 状态更新
- execution_results (包含 summary 产物信息)
- command_chain[].status
- current_command_index

View File

@@ -1,48 +0,0 @@
# action-command-selection
## 流程
1. 问用户任务
2. Claude推荐命令链
3. 用户确认/手动选择
4. 添加到command_chain
## 伪代码
```javascript
// 1. 获取用户任务描述
const taskInput = await AskUserQuestion({
question: '请描述您的任务',
options: [
{ label: '手动选择命令', value: 'manual' }
]
});
// 保存任务描述到状态
updateState({ task_description: taskInput.text || taskInput.value });
// 2. 若用户描述任务Claude推荐
if (taskInput.text) {
console.log('推荐: ', recommendChain(taskInput.text));
const confirm = await AskUserQuestion({
question: '是否使用推荐链?',
options: ['使用推荐', '调整', '手动选择']
});
if (confirm === '使用推荐') {
addCommandsToChain(recommendedChain);
updateState({ confirmed: true });
return;
}
}
// 3. 手动选择
const commands = loadCommandLibrary();
const selected = await AskUserQuestion(commands);
addToChain(selected);
```
## 状态更新
- task_description = 用户任务描述
- command_chain.push(newCommand)
- 如果用户确认: confirmed = true

View File

@@ -1,25 +0,0 @@
# action-complete
生成执行报告
```javascript
const success = execution_results.filter(r => r.status === 'success').length;
const failed = execution_results.filter(r => r.status === 'failed').length;
const duration = Date.now() - new Date(started_at).getTime();
const report = `
# 执行报告
- 会话: ${session_id}
- 耗时: ${Math.round(duration/1000)}s
- 成功: ${success}
- 失败: ${failed}
## 命令详情
${command_chain.map((c, i) => `${i+1}. ${c.command} - ${c.status}`).join('\n')}
`;
Write(`${workDir}/final-report.md`, report);
updateState({ status: 'completed' });
```

View File

@@ -1,26 +0,0 @@
# action-init
初始化编排会话
```javascript
const timestamp = Date.now();
const workDir = `.workflow/.ccw-coordinator/${timestamp}`;
Bash(`mkdir -p "${workDir}"`);
const state = {
session_id: `coord-${timestamp}`,
status: 'running',
started_at: new Date().toISOString(),
task_description: '',
command_chain: [],
current_command_index: 0,
execution_results: [],
confirmed: false,
error_count: 0
};
Write(`${workDir}/state.json`, JSON.stringify(state, null, 2));
console.log(`会话已初始化: ${workDir}`);
```

View File

@@ -1,59 +0,0 @@
# Orchestrator
状态驱动编排:读状态 → 选动作 → 执行 → 更新状态
## 决策逻辑
```javascript
function selectNextAction(state) {
if (['completed', 'aborted'].includes(state.status)) return null;
if (state.error_count >= 3) return 'action-abort';
switch (state.status) {
case 'pending':
return 'action-init';
case 'running':
return state.confirmed && state.command_chain.length > 0
? 'action-command-execute'
: 'action-command-selection';
case 'executing':
const pending = state.command_chain.filter(c => c.status === 'pending');
return pending.length === 0 ? 'action-complete' : 'action-command-execute';
default:
return 'action-abort';
}
}
```
## 执行循环
```javascript
const timestamp = Date.now();
const workDir = `.workflow/.ccw-coordinator/${timestamp}`;
Bash(`mkdir -p "${workDir}"`);
const state = {
session_id: `coord-${timestamp}`,
status: 'pending',
started_at: new Date().toISOString(),
task_description: '', // 从 action-command-selection 获取
command_chain: [],
current_command_index: 0,
execution_results: [],
confirmed: false,
error_count: 0
};
Write(`${workDir}/state.json`, JSON.stringify(state, null, 2));
let iterations = 0;
while (iterations < 50) {
const state = JSON.parse(Read(`${workDir}/state.json`));
const nextAction = selectNextAction(state);
if (!nextAction) break;
console.log(`[${nextAction}]`);
// 执行 phases/actions/{nextAction}.md
iterations++;
}
```

View File

@@ -1,66 +0,0 @@
# State Schema
```typescript
interface State {
session_id: string;
status: 'pending' | 'running' | 'executing' | 'completed' | 'aborted';
started_at: string;
task_description: string; // 用户任务描述
command_chain: Command[];
current_command_index: number;
execution_results: ExecutionResult[];
confirmed: boolean;
error_count: number;
}
interface Command {
id: string;
order: number;
command: string;
status: 'pending' | 'running' | 'completed' | 'failed';
result?: ExecutionResult;
}
interface ExecutionResult {
command: string;
status: 'success' | 'failed';
exit_code: number;
output?: string;
summary?: { // 提取的关键产物
session?: string;
files?: string[];
timestamp: string;
};
}
```
## 状态转移
```
pending → running → executing → completed
↓ ↓
(abort) (error → abort)
```
## 初始化
```javascript
{
session_id: generateId(),
status: 'pending',
started_at: new Date().toISOString(),
task_description: '', // 从用户输入获取
command_chain: [],
current_command_index: 0,
execution_results: [],
confirmed: false,
error_count: 0
}
```
## 更新
- 添加命令: `command_chain.push(cmd)`
- 确认执行: `confirmed = true, status = 'executing'`
- 记录执行: `execution_results.push(...), current_command_index++`
- 错误计数: `error_count++`

View File

@@ -1,66 +0,0 @@
{
"skill_name": "ccw-coordinator",
"display_name": "CCW Coordinator",
"description": "Interactive command orchestration - select, build, and execute workflow command chains",
"execution_mode": "autonomous",
"version": "1.0.0",
"triggers": [
"coordinator",
"ccw-coordinator",
"命令编排",
"command chain"
],
"allowed_tools": [
"Task",
"AskUserQuestion",
"Read",
"Write",
"Bash"
],
"actions": [
{
"id": "action-init",
"name": "Init",
"description": "Initialize orchestration session"
},
{
"id": "action-command-selection",
"name": "Select Commands",
"description": "Interactive command selection from library"
},
{
"id": "action-command-build",
"name": "Build Chain",
"description": "Adjust and confirm command chain"
},
{
"id": "action-command-execute",
"name": "Execute",
"description": "Execute command chain sequentially"
},
{
"id": "action-complete",
"name": "Complete",
"description": "Generate final report"
},
{
"id": "action-abort",
"name": "Abort",
"description": "Abort session and save state"
}
],
"termination_conditions": [
"user_exit",
"task_completed",
"error"
],
"output": {
"location": ".workflow/.ccw-coordinator/{timestamp}",
"artifacts": [
"state.json",
"command-chain.json",
"execution-log.md",
"final-report.md"
]
}
}

View File

@@ -1,169 +0,0 @@
# Command Library
CCW Coordinator 支持的命令库。基于 CCW workflow 命令系统。
## Command Categories
### Planning Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:lite-plan` | 轻量级规划 | L2 |
| `/workflow:plan` | 标准规划 | L3 |
| `/workflow:multi-cli-plan` | 多CLI协作规划 | L2 |
| `/workflow:brainstorm:auto-parallel` | 头脑风暴规划 | L4 |
| `/workflow:tdd-plan` | TDD规划 | L3 |
### Execution Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:lite-execute` | 轻量级执行 | L2 |
| `/workflow:execute` | 标准执行 | L3 |
| `/workflow:test-cycle-execute` | 测试循环执行 | L3 |
### BugFix Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:lite-fix` | 轻量级修复 | L2 |
| `/workflow:lite-fix --hotfix` | 紧急修复 | L2 |
### Testing Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:test-gen` | 测试生成 | L3 |
| `/workflow:test-fix-gen` | 测试修复生成 | L3 |
| `/workflow:tdd-verify` | TDD验证 | L3 |
### Review Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:review-session-cycle` | 会话审查 | L3 |
| `/workflow:review-module-cycle` | 模块审查 | L3 |
| `/workflow:review-fix` | 审查修复 | L3 |
| `/workflow:plan-verify` | 计划验证 | L3 |
### Documentation Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/memory:docs` | 生成文档 | L2 |
| `/memory:update-related` | 更新相关文档 | L2 |
| `/memory:update-full` | 全面更新文档 | L2 |
### Issue Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/issue:discover` | 发现Issue | Supplementary |
| `/issue:discover-by-prompt` | 基于提示发现Issue | Supplementary |
| `/issue:plan --all-pending` | 规划所有待处理Issue | Supplementary |
| `/issue:queue` | 排队Issue | Supplementary |
| `/issue:execute` | 执行Issue | Supplementary |
## Command Chains (Recommended)
### 标准开发流程
```
1. /workflow:lite-plan
2. /workflow:lite-execute
3. /workflow:test-cycle-execute
```
### 完整规划流程
```
1. /workflow:plan
2. /workflow:plan-verify
3. /workflow:execute
4. /workflow:review-session-cycle
```
### TDD 流程
```
1. /workflow:tdd-plan
2. /workflow:execute
3. /workflow:tdd-verify
```
### Issue 批处理流程
```
1. /issue:plan --all-pending
2. /issue:queue
3. /issue:execute
```
## JSON Format
```json
{
"workflow_commands": [
{
"category": "Planning",
"commands": [
{ "name": "/workflow:lite-plan", "description": "轻量级规划" },
{ "name": "/workflow:plan", "description": "标准规划" },
{ "name": "/workflow:multi-cli-plan", "description": "多CLI协作规划" },
{ "name": "/workflow:brainstorm:auto-parallel", "description": "头脑风暴" },
{ "name": "/workflow:tdd-plan", "description": "TDD规划" }
]
},
{
"category": "Execution",
"commands": [
{ "name": "/workflow:lite-execute", "description": "轻量级执行" },
{ "name": "/workflow:execute", "description": "标准执行" },
{ "name": "/workflow:test-cycle-execute", "description": "测试循环执行" }
]
},
{
"category": "BugFix",
"commands": [
{ "name": "/workflow:lite-fix", "description": "轻量级修复" },
{ "name": "/workflow:lite-fix --hotfix", "description": "紧急修复" }
]
},
{
"category": "Testing",
"commands": [
{ "name": "/workflow:test-gen", "description": "测试生成" },
{ "name": "/workflow:test-fix-gen", "description": "测试修复" },
{ "name": "/workflow:tdd-verify", "description": "TDD验证" }
]
},
{
"category": "Review",
"commands": [
{ "name": "/workflow:review-session-cycle", "description": "会话审查" },
{ "name": "/workflow:review-module-cycle", "description": "模块审查" },
{ "name": "/workflow:review-fix", "description": "审查修复" },
{ "name": "/workflow:plan-verify", "description": "计划验证" }
]
},
{
"category": "Documentation",
"commands": [
{ "name": "/memory:docs", "description": "生成文档" },
{ "name": "/memory:update-related", "description": "更新相关文档" },
{ "name": "/memory:update-full", "description": "全面更新文档" }
]
},
{
"category": "Issues",
"commands": [
{ "name": "/issue:discover", "description": "发现Issue" },
{ "name": "/issue:discover-by-prompt", "description": "基于提示发现Issue" },
{ "name": "/issue:plan --all-pending", "description": "规划所有待处理Issue" },
{ "name": "/issue:queue", "description": "排队Issue" },
{ "name": "/issue:execute", "description": "执行Issue" }
]
}
]
}
```

View File

@@ -1,362 +0,0 @@
# CCW Coordinator Specifications
命令库、验证规则和注册表一体化规范。
---
## 命令库
### Planning Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:lite-plan` | 轻量级规划 | L2 |
| `/workflow:plan` | 标准规划 | L3 |
| `/workflow:multi-cli-plan` | 多CLI协作规划 | L2 |
| `/workflow:brainstorm:auto-parallel` | 头脑风暴规划 | L4 |
| `/workflow:tdd-plan` | TDD规划 | L3 |
### Execution Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:lite-execute` | 轻量级执行 | L2 |
| `/workflow:execute` | 标准执行 | L3 |
| `/workflow:test-cycle-execute` | 测试循环执行 | L3 |
### BugFix Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:lite-fix` | 轻量级修复 | L2 |
| `/workflow:lite-fix --hotfix` | 紧急修复 | L2 |
### Testing Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:test-gen` | 测试生成 | L3 |
| `/workflow:test-fix-gen` | 测试修复生成 | L3 |
| `/workflow:tdd-verify` | TDD验证 | L3 |
### Review Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/workflow:review-session-cycle` | 会话审查 | L3 |
| `/workflow:review-module-cycle` | 模块审查 | L3 |
| `/workflow:review-fix` | 审查修复 | L3 |
| `/workflow:plan-verify` | 计划验证 | L3 |
### Documentation Commands
| Command | Description | Level |
|---------|-------------|-------|
| `/memory:docs` | 生成文档 | L2 |
| `/memory:update-related` | 更新相关文档 | L2 |
| `/memory:update-full` | 全面更新文档 | L2 |
### Issue Commands
| Command | Description |
|---------|-------------|
| `/issue:discover` | 发现Issue |
| `/issue:discover-by-prompt` | 基于提示发现Issue |
| `/issue:plan --all-pending` | 规划所有待处理Issue |
| `/issue:queue` | 排队Issue |
| `/issue:execute` | 执行Issue |
---
## 命令链推荐
### 标准开发流程
```
1. /workflow:lite-plan
2. /workflow:lite-execute
3. /workflow:test-cycle-execute
```
### 完整规划流程
```
1. /workflow:plan
2. /workflow:plan-verify
3. /workflow:execute
4. /workflow:review-session-cycle
```
### TDD 流程
```
1. /workflow:tdd-plan
2. /workflow:execute
3. /workflow:tdd-verify
```
### Issue 批处理流程
```
1. /issue:plan --all-pending
2. /issue:queue
3. /issue:execute
```
---
## 验证规则
### Rule 1: Single Planning Command
每条链最多包含一个规划命令。
| 有效 | 无效 |
|------|------|
| `plan → execute` | `plan → lite-plan → execute` |
### Rule 2: Compatible Pairs
规划和执行命令必须兼容。
| Planning | Execution | 兼容 |
|----------|-----------|------|
| lite-plan | lite-execute | ✓ |
| lite-plan | execute | ✗ |
| multi-cli-plan | lite-execute | ✓ |
| multi-cli-plan | execute | ✓ |
| plan | execute | ✓ |
| plan | lite-execute | ✗ |
| tdd-plan | execute | ✓ |
| tdd-plan | lite-execute | ✗ |
### Rule 3: Testing After Execution
测试命令必须在执行命令之后。
| 有效 | 无效 |
|------|------|
| `execute → test-cycle-execute` | `test-cycle-execute → execute` |
### Rule 4: Review After Execution
审查命令必须在执行命令之后。
| 有效 | 无效 |
|------|------|
| `execute → review-session-cycle` | `review-session-cycle → execute` |
### Rule 5: BugFix Standalone
`lite-fix` 必须单独执行,不能与其他命令组合。
| 有效 | 无效 |
|------|------|
| `lite-fix` | `plan → lite-fix → execute` |
| `lite-fix --hotfix` | `lite-fix → test-cycle-execute` |
### Rule 6: Dependency Satisfaction
每个命令的依赖必须在前面执行。
```javascript
test-fix-gen test-cycle-execute
test-cycle-execute
```
### Rule 7: No Redundancy
链条中不能有重复的命令。
| 有效 | 无效 |
|------|------|
| `plan → execute → test` | `plan → plan → execute` |
### Rule 8: Command Exists
所有命令必须在此规范中定义。
---
## 反模式(避免)
### ❌ Pattern 1: Multiple Planning
```
plan → lite-plan → execute
```
**问题**: 重复分析,浪费时间
**修复**: 选一个规划命令
### ❌ Pattern 2: Test Without Context
```
test-cycle-execute (独立执行)
```
**问题**: 没有执行上下文,无法工作
**修复**: 先执行 `execute``test-fix-gen`
### ❌ Pattern 3: BugFix with Planning
```
plan → execute → lite-fix
```
**问题**: lite-fix 是独立命令,不应与规划混合
**修复**: 用 `lite-fix` 单独修复,或用 `plan → execute` 做大改
### ❌ Pattern 4: Review Without Changes
```
review-session-cycle (独立执行)
```
**问题**: 没有 git 改动可审查
**修复**: 先执行 `execute` 生成改动
### ❌ Pattern 5: TDD Misuse
```
tdd-plan → lite-execute
```
**问题**: lite-execute 无法处理 TDD 任务结构
**修复**: 用 `tdd-plan → execute → tdd-verify`
---
## 命令注册表
### 命令元数据结构
```json
{
"command_name": {
"category": "Planning|Execution|Testing|Review|BugFix|Maintenance",
"level": "L0|L1|L2|L3",
"description": "命令描述",
"inputs": ["input1", "input2"],
"outputs": ["output1", "output2"],
"dependencies": ["依赖命令"],
"parameters": [
{"name": "--flag", "type": "string|boolean|number", "default": "value"}
],
"chain_position": "start|middle|middle_or_end|end|standalone",
"next_recommended": ["推荐的下一个命令"]
}
}
```
### 命令分组
| Group | Commands |
|-------|----------|
| planning | lite-plan, multi-cli-plan, plan, tdd-plan |
| execution | lite-execute, execute, develop-with-file |
| testing | test-gen, test-fix-gen, test-cycle-execute, tdd-verify |
| review | review-session-cycle, review-module-cycle, review-fix |
| bugfix | lite-fix, debug, debug-with-file |
| maintenance | clean, replan |
| verification | plan-verify, tdd-verify |
### 兼容性矩阵
| 组合 | 状态 |
|------|------|
| lite-plan + lite-execute | ✓ compatible |
| lite-plan + execute | ✗ incompatible - use plan |
| multi-cli-plan + lite-execute | ✓ compatible |
| plan + execute | ✓ compatible |
| plan + lite-execute | ✗ incompatible - use lite-plan |
| tdd-plan + execute | ✓ compatible |
| execute + test-cycle-execute | ✓ compatible |
| lite-execute + test-cycle-execute | ✓ compatible |
| test-fix-gen + test-cycle-execute | ✓ required |
| review-session-cycle + review-fix | ✓ compatible |
| lite-fix + test-cycle-execute | ✗ incompatible - lite-fix standalone |
---
## 验证工具
### chain-validate.cjs
位置: `tools/chain-validate.cjs`
验证命令链合法性:
```bash
node tools/chain-validate.cjs plan execute test-cycle-execute
```
输出:
```
{
"valid": true,
"errors": [],
"warnings": []
}
```
## 命令注册表
### 工具位置
位置: `tools/command-registry.cjs` (skill 内置)
### 工作模式
**按需提取**: 只提取用户确定的任务链中的命令,不是全量扫描。
```javascript
// 用户任务链: [lite-plan, lite-execute]
const commandNames = command_chain.map(cmd => cmd.command);
const commandMeta = registry.getCommands(commandNames);
// 只提取这 2 个命令的元数据
```
### 功能
- 自动查找全局 `.claude/commands/workflow` 目录(相对路径 > 用户 home
- 按需提取指定命令的 YAML 头元数据
- 缓存机制避免重复读取
- 提供批量查询接口
### 集成方式
在 action-command-execute 中自动集成:
```javascript
const CommandRegistry = require('./tools/command-registry.cjs');
const registry = new CommandRegistry();
// 只提取任务链中的命令
const commandNames = command_chain.map(cmd => cmd.command);
const commandMeta = registry.getCommands(commandNames);
// 使用元数据生成提示词
const cmdInfo = commandMeta[cmd.command];
// {
// name: 'lite-plan',
// description: '轻量级规划...',
// argumentHint: '[-e|--explore] "task description"',
// allowedTools: [...],
// filePath: '...'
// }
```
### 提示词生成
智能提示词自动包含:
1. **任务上下文**: 用户任务描述
2. **前序产物**: 已完成命令的产物信息
3. **命令元数据**: 命令的参数提示和描述
```
任务: 实现用户注册功能
前序完成:
- /workflow:lite-plan: WFS-plan-001 (IMPL_PLAN.md)
命令: /workflow:lite-execute [--resume-session="session-id"]
```
详见 `tools/README.md`

View File

@@ -1,95 +0,0 @@
# CCW Coordinator Tools
## command-registry.cjs
命令注册表工具:获取和提取命令元数据。
### 功能
- **按需提取**: 只提取指定命令的完整信息name, description, argumentHint, allowedTools 等)
- **全量获取**: 获取所有命令的名称和描述(快速查询)
- **自动查找**: 从全局 `.claude/commands/workflow` 目录读取(项目相对路径 > 用户 home
- **缓存机制**: 避免重复读取文件
### 编程接口
```javascript
const CommandRegistry = require('./tools/command-registry.cjs');
const registry = new CommandRegistry();
// 1. 获取所有命令的名称和描述(快速)
const allCommands = registry.getAllCommandsSummary();
// {
// "/workflow:lite-plan": {
// name: 'lite-plan',
// description: '轻量级规划...'
// },
// "/workflow:lite-execute": { ... }
// }
// 2. 按需提取指定命令的完整信息
const commands = registry.getCommands([
'/workflow:lite-plan',
'/workflow:lite-execute'
]);
// {
// "/workflow:lite-plan": {
// name: 'lite-plan',
// description: '...',
// argumentHint: '[-e|--explore] "task description"',
// allowedTools: [...],
// filePath: '...'
// },
// ...
// }
```
### 命令行接口
```bash
# 获取所有命令的名称和描述
node .claude/skills/ccw-coordinator/tools/command-registry.cjs
node .claude/skills/ccw-coordinator/tools/command-registry.cjs --all
# 输出: 23 个命令的简明列表 (name + description)
```
```bash
# 按需提取指定命令的完整信息
node .claude/skills/ccw-coordinator/tools/command-registry.cjs lite-plan lite-execute
# 输出: 完整信息 (name, description, argumentHint, allowedTools, filePath)
```
### 集成用途
`action-command-execute` 中使用:
```javascript
// 1. 初始化时只提取任务链中的命令(完整信息)
const commandNames = command_chain.map(cmd => cmd.command);
const commandMeta = registry.getCommands(commandNames);
// 2. 生成提示词时使用
function generatePrompt(cmd, state, commandMeta) {
const cmdInfo = commandMeta[cmd.command];
let prompt = `任务: ${state.task_description}\n`;
if (cmdInfo?.argumentHint) {
prompt += `命令: ${cmd.command} ${cmdInfo.argumentHint}`;
}
return prompt;
}
```
确保 `ccw cli -p "..."` 提示词包含准确的命令参数提示。
### 目录查找逻辑
自动查找顺序:
1. `.claude/commands/workflow` (相对于当前工作目录)
2. `~/.claude/commands/workflow` (用户 home 目录)

View File

@@ -1,320 +0,0 @@
#!/usr/bin/env node
/**
* Chain Validation Tool
*
* Validates workflow command chains against defined rules.
*
* Usage:
* node chain-validate.js plan execute test-cycle-execute
* node chain-validate.js --json "plan,execute,test-cycle-execute"
* node chain-validate.js --file custom-chain.json
*/
const fs = require('fs');
const path = require('path');
// Optional registry loading - gracefully degrade if not found
let registry = null;
try {
const registryPath = path.join(__dirname, '..', 'specs', 'chain-registry.json');
if (fs.existsSync(registryPath)) {
registry = JSON.parse(fs.readFileSync(registryPath, 'utf8'));
}
} catch (error) {
// Registry not available - dependency validation will be skipped
}
class ChainValidator {
constructor(registry) {
this.registry = registry;
this.errors = [];
this.warnings = [];
}
validate(chain) {
this.errors = [];
this.warnings = [];
this.validateSinglePlanning(chain);
this.validateCompatiblePairs(chain);
this.validateTestingPosition(chain);
this.validateReviewPosition(chain);
this.validateBugfixStandalone(chain);
this.validateDependencies(chain);
this.validateNoRedundancy(chain);
this.validateCommandExistence(chain);
return {
valid: this.errors.length === 0,
errors: this.errors,
warnings: this.warnings
};
}
validateSinglePlanning(chain) {
const planningCommands = chain.filter(cmd =>
['plan', 'lite-plan', 'multi-cli-plan', 'tdd-plan'].includes(cmd)
);
if (planningCommands.length > 1) {
this.errors.push({
rule: 'Single Planning Command',
message: `Too many planning commands: ${planningCommands.join(', ')}`,
severity: 'error'
});
}
}
validateCompatiblePairs(chain) {
const compatibility = {
'lite-plan': ['lite-execute'],
'multi-cli-plan': ['lite-execute', 'execute'],
'plan': ['execute'],
'tdd-plan': ['execute']
};
const planningCmd = chain.find(cmd =>
['plan', 'lite-plan', 'multi-cli-plan', 'tdd-plan'].includes(cmd)
);
const executionCmd = chain.find(cmd =>
['execute', 'lite-execute'].includes(cmd)
);
if (planningCmd && executionCmd) {
const compatible = compatibility[planningCmd] || [];
if (!compatible.includes(executionCmd)) {
this.errors.push({
rule: 'Compatible Pairs',
message: `${planningCmd} incompatible with ${executionCmd}`,
fix: `Use ${planningCmd} with ${compatible.join(' or ')}`,
severity: 'error'
});
}
}
}
validateTestingPosition(chain) {
const executionIdx = chain.findIndex(cmd =>
['execute', 'lite-execute', 'develop-with-file'].includes(cmd)
);
const testingIdx = chain.findIndex(cmd =>
['test-cycle-execute', 'tdd-verify', 'test-gen', 'test-fix-gen'].includes(cmd)
);
if (testingIdx !== -1 && executionIdx !== -1 && executionIdx > testingIdx) {
this.errors.push({
rule: 'Testing After Execution',
message: 'Testing commands must come after execution',
severity: 'error'
});
}
if (testingIdx !== -1 && executionIdx === -1) {
const hasTestGen = chain.some(cmd => ['test-gen', 'test-fix-gen'].includes(cmd));
if (!hasTestGen) {
this.warnings.push({
rule: 'Testing After Execution',
message: 'test-cycle-execute without execution context - needs test-gen or execute first',
severity: 'warning'
});
}
}
}
validateReviewPosition(chain) {
const executionIdx = chain.findIndex(cmd =>
['execute', 'lite-execute'].includes(cmd)
);
const reviewIdx = chain.findIndex(cmd =>
cmd.includes('review')
);
if (reviewIdx !== -1 && executionIdx !== -1 && executionIdx > reviewIdx) {
this.errors.push({
rule: 'Review After Changes',
message: 'Review commands must come after execution',
severity: 'error'
});
}
if (reviewIdx !== -1 && executionIdx === -1) {
const isModuleReview = chain[reviewIdx] === 'review-module-cycle';
if (!isModuleReview) {
this.warnings.push({
rule: 'Review After Changes',
message: 'Review without execution - needs git changes to review',
severity: 'warning'
});
}
}
}
validateBugfixStandalone(chain) {
if (chain.includes('lite-fix')) {
const others = chain.filter(cmd => cmd !== 'lite-fix');
if (others.length > 0) {
this.errors.push({
rule: 'BugFix Standalone',
message: 'lite-fix must be standalone, cannot combine with other commands',
fix: 'Use lite-fix alone OR use plan + execute for larger changes',
severity: 'error'
});
}
}
}
validateDependencies(chain) {
// Skip if registry not available
if (!this.registry || !this.registry.commands) {
return;
}
for (let i = 0; i < chain.length; i++) {
const cmd = chain[i];
const cmdMeta = this.registry.commands[cmd];
if (!cmdMeta) continue;
const deps = cmdMeta.dependencies || [];
const depsOptional = cmdMeta.dependencies_optional || false;
if (deps.length > 0 && !depsOptional) {
const hasDependency = deps.some(dep => {
const depIdx = chain.indexOf(dep);
return depIdx !== -1 && depIdx < i;
});
if (!hasDependency) {
this.errors.push({
rule: 'Dependency Satisfaction',
message: `${cmd} requires ${deps.join(' or ')} before it`,
severity: 'error'
});
}
}
}
}
validateNoRedundancy(chain) {
const seen = new Set();
const duplicates = [];
for (const cmd of chain) {
if (seen.has(cmd)) {
duplicates.push(cmd);
}
seen.add(cmd);
}
if (duplicates.length > 0) {
this.errors.push({
rule: 'No Redundant Commands',
message: `Duplicate commands: ${duplicates.join(', ')}`,
severity: 'error'
});
}
}
validateCommandExistence(chain) {
// Skip if registry not available
if (!this.registry || !this.registry.commands) {
return;
}
for (const cmd of chain) {
if (!this.registry.commands[cmd]) {
this.errors.push({
rule: 'Command Existence',
message: `Unknown command: ${cmd}`,
severity: 'error'
});
}
}
}
}
function main() {
const args = process.argv.slice(2);
if (args.length === 0) {
console.log('Usage:');
console.log(' chain-validate.js <command1> <command2> ...');
console.log(' chain-validate.js --json "cmd1,cmd2,cmd3"');
console.log(' chain-validate.js --file chain.json');
process.exit(1);
}
let chain;
if (args[0] === '--json') {
chain = args[1].split(',').map(s => s.trim());
} else if (args[0] === '--file') {
const filePath = args[1];
// SEC-001: 路径遍历验证 - 只允许访问工作目录下的文件
const resolvedPath = path.resolve(filePath);
const workDir = path.resolve('.');
if (!resolvedPath.startsWith(workDir)) {
console.error('Error: File path must be within current working directory');
process.exit(1);
}
// CORR-001: JSON 解析错误处理
let fileContent;
try {
fileContent = JSON.parse(fs.readFileSync(resolvedPath, 'utf8'));
} catch (error) {
console.error(`Error: Failed to parse JSON file ${filePath}: ${error.message}`);
process.exit(1);
}
// CORR-002: 嵌套属性 null 检查
chain = fileContent.chain || fileContent.steps?.map(s => s.command) || [];
if (chain.length === 0) {
console.error('Error: No valid chain found in file (expected "chain" array or "steps" with "command" fields)');
process.exit(1);
}
} else {
chain = args;
}
const validator = new ChainValidator(registry);
const result = validator.validate(chain);
console.log('\n=== Chain Validation Report ===\n');
console.log('Chain:', chain.join(' → '));
console.log('');
if (result.valid) {
console.log('✓ Chain is valid!\n');
} else {
console.log('✗ Chain has errors:\n');
result.errors.forEach(err => {
console.log(` [${err.rule}] ${err.message}`);
if (err.fix) {
console.log(` Fix: ${err.fix}`);
}
});
console.log('');
}
if (result.warnings.length > 0) {
console.log('⚠ Warnings:\n');
result.warnings.forEach(warn => {
console.log(` [${warn.rule}] ${warn.message}`);
});
console.log('');
}
process.exit(result.valid ? 0 : 1);
}
if (require.main === module) {
main();
}
module.exports = { ChainValidator };

View File

@@ -1,255 +0,0 @@
#!/usr/bin/env node
/**
* Command Registry Tool
*
* 功能:
* 1. 根据命令名称查找并提取 YAML 头
* 2. 从全局 .claude/commands/workflow 目录读取
* 3. 支持按需提取(不是全量扫描)
*/
const fs = require('fs');
const path = require('path');
const os = require('os');
class CommandRegistry {
constructor(commandDir = null) {
// 优先使用传入的目录
if (commandDir) {
this.commandDir = commandDir;
} else {
// 自动查找 .claude/commands/workflow
this.commandDir = this.findCommandDir();
}
this.cache = {};
}
/**
* 自动查找 .claude/commands/workflow 目录
* 支持: 项目相对路径、用户 home 目录
*/
findCommandDir() {
// 1. 尝试相对于当前工作目录
const relativePath = path.join('.claude', 'commands', 'workflow');
if (fs.existsSync(relativePath)) {
return path.resolve(relativePath);
}
// 2. 尝试用户 home 目录
const homeDir = os.homedir();
const homeCommandDir = path.join(homeDir, '.claude', 'commands', 'workflow');
if (fs.existsSync(homeCommandDir)) {
return homeCommandDir;
}
// 未找到时返回 null后续操作会失败并提示
return null;
}
/**
* 解析 YAML 头 (简化版本)
*
* 限制:
* - 只支持简单的 key: value 对 (单行值)
* - 不支持多行值、嵌套对象、复杂列表
* - allowed-tools 字段支持逗号分隔的字符串,自动转为数组
*
* 示例:
* ---
* name: lite-plan
* description: "Lightweight planning workflow"
* allowed-tools: Read, Write, Bash
* ---
*/
parseYamlHeader(content) {
// 处理 Windows 行结尾 (\r\n)
const match = content.match(/^---[\r\n]+([\s\S]*?)[\r\n]+---/);
if (!match) return null;
const yamlContent = match[1];
const result = {};
try {
const lines = yamlContent.split(/[\r\n]+/);
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed || trimmed.startsWith('#')) continue; // 跳过空行和注释
const colonIndex = trimmed.indexOf(':');
if (colonIndex === -1) continue;
const key = trimmed.substring(0, colonIndex).trim();
const value = trimmed.substring(colonIndex + 1).trim();
if (!key) continue; // 跳过无效行
// 去除引号 (单引号或双引号)
let cleanValue = value.replace(/^["']|["']$/g, '');
// allowed-tools 字段特殊处理:转为数组
// 支持格式: "Read, Write, Bash" 或 "Read,Write,Bash"
if (key === 'allowed-tools') {
cleanValue = Array.isArray(cleanValue)
? cleanValue
: cleanValue.split(',').map(t => t.trim()).filter(t => t);
}
result[key] = cleanValue;
}
} catch (error) {
console.error('YAML parsing error:', error.message);
return null;
}
return result;
}
/**
* 获取单个命令的元数据
* @param {string} commandName 命令名称 (e.g., "lite-plan" 或 "/workflow:lite-plan")
* @returns {object|null} 命令信息或 null
*/
getCommand(commandName) {
if (!this.commandDir) {
console.error('ERROR: .claude/commands/workflow 目录未找到');
return null;
}
// 标准化命令名称
const normalized = commandName.startsWith('/workflow:')
? commandName.substring('/workflow:'.length)
: commandName;
// 检查缓存
if (this.cache[normalized]) {
return this.cache[normalized];
}
// 读取命令文件
const filePath = path.join(this.commandDir, `${normalized}.md`);
if (!fs.existsSync(filePath)) {
return null;
}
try {
const content = fs.readFileSync(filePath, 'utf-8');
const header = this.parseYamlHeader(content);
if (header && header.name) {
const result = {
name: header.name,
command: `/workflow:${header.name}`,
description: header.description || '',
argumentHint: header['argument-hint'] || '',
allowedTools: Array.isArray(header['allowed-tools'])
? header['allowed-tools']
: (header['allowed-tools'] ? [header['allowed-tools']] : []),
filePath: filePath
};
// 缓存结果
this.cache[normalized] = result;
return result;
}
} catch (error) {
console.error(`读取命令失败 ${filePath}:`, error.message);
}
return null;
}
/**
* 批量获取多个命令的元数据
* @param {array} commandNames 命令名称数组
* @returns {object} 命令信息映射
*/
getCommands(commandNames) {
const result = {};
for (const name of commandNames) {
const cmd = this.getCommand(name);
if (cmd) {
result[cmd.command] = cmd;
}
}
return result;
}
/**
* 获取所有命令的名称和描述
* @returns {object} 命令名称和描述的映射
*/
getAllCommandsSummary() {
const result = {};
const commandDir = this.commandDir;
if (!commandDir) {
return result;
}
try {
const files = fs.readdirSync(commandDir);
for (const file of files) {
if (!file.endsWith('.md')) continue;
const filePath = path.join(commandDir, file);
const stat = fs.statSync(filePath);
if (stat.isDirectory()) continue;
try {
const content = fs.readFileSync(filePath, 'utf-8');
const header = this.parseYamlHeader(content);
if (header && header.name) {
const commandName = `/workflow:${header.name}`;
result[commandName] = {
name: header.name,
description: header.description || ''
};
}
} catch (error) {
// 跳过读取失败的文件
continue;
}
}
} catch (error) {
// 目录读取失败
return result;
}
return result;
}
/**
* 生成注册表 JSON
*/
toJSON(commands = null) {
const data = commands || this.cache;
return JSON.stringify(data, null, 2);
}
}
// CLI 模式
if (require.main === module) {
const args = process.argv.slice(2);
if (args.length === 0 || args[0] === '--all') {
// 获取所有命令的名称和描述
const registry = new CommandRegistry();
const commands = registry.getAllCommandsSummary();
console.log(JSON.stringify(commands, null, 2));
process.exit(0);
}
const registry = new CommandRegistry();
const commands = registry.getCommands(args);
console.log(JSON.stringify(commands, null, 2));
}
module.exports = CommandRegistry;

View File

@@ -1,522 +0,0 @@
---
name: ccw
description: Stateless workflow orchestrator. Auto-selects optimal workflow based on task intent. Triggers "ccw", "workflow".
allowed-tools: Task(*), SlashCommand(*), AskUserQuestion(*), Read(*), Bash(*), Grep(*), TodoWrite(*)
---
# CCW - Claude Code Workflow Orchestrator
无状态工作流协调器,根据任务意图自动选择最优工作流。
## Workflow System Overview
CCW 提供两个工作流系统:**Main Workflow** 和 **Issue Workflow**,协同覆盖完整的软件开发生命周期。
```
┌─────────────────────────────────────────────────────────────────────────────┐
│ Main Workflow │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Level 1 │ → │ Level 2 │ → │ Level 3 │ → │ Level 4 │ │
│ │ Rapid │ │ Lightweight │ │ Standard │ │ Brainstorm │ │
│ │ │ │ │ │ │ │ │ │
│ │ lite-lite- │ │ lite-plan │ │ plan │ │ brainstorm │ │
│ │ lite │ │ lite-fix │ │ tdd-plan │ │ :auto- │ │
│ │ │ │ multi-cli- │ │ test-fix- │ │ parallel │ │
│ │ │ │ plan │ │ gen │ │ ↓ │ │
│ │ │ │ │ │ │ │ plan │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ Complexity: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━▶ │
│ Low High │
└─────────────────────────────────────────────────────────────────────────────┘
│ After development
┌─────────────────────────────────────────────────────────────────────────────┐
│ Issue Workflow │
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Accumulate │ → │ Plan │ → │ Execute │ │
│ │ Discover & │ │ Batch │ │ Parallel │ │
│ │ Collect │ │ Planning │ │ Execution │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ │
│ Supplementary role: Maintain main branch stability, worktree isolation │
└─────────────────────────────────────────────────────────────────────────────┘
```
## Architecture
```
┌─────────────────────────────────────────────────────────────────┐
│ CCW Orchestrator (CLI-Enhanced + Requirement Analysis) │
├─────────────────────────────────────────────────────────────────┤
│ Phase 1 │ Input Analysis (rule-based, fast path) │
│ Phase 1.5 │ CLI Classification (semantic, smart path) │
│ Phase 1.75 │ Requirement Clarification (clarity < 2) │
│ Phase 2 │ Level Selection (intent → level → workflow) │
│ Phase 2.5 │ CLI Action Planning (high complexity) │
│ Phase 3 │ User Confirmation (optional) │
│ Phase 4 │ TODO Tracking Setup │
│ Phase 5 │ Execution Loop │
└─────────────────────────────────────────────────────────────────┘
```
## Level Quick Reference
| Level | Name | Workflows | Artifacts | Execution |
|-------|------|-----------|-----------|-----------|
| **1** | Rapid | `lite-lite-lite` | None | Direct execute |
| **2** | Lightweight | `lite-plan`, `lite-fix`, `multi-cli-plan` | Memory/Lightweight files | → `lite-execute` |
| **3** | Standard | `plan`, `tdd-plan`, `test-fix-gen` | Session persistence | → `execute` / `test-cycle-execute` |
| **4** | Brainstorm | `brainstorm:auto-parallel``plan` | Multi-role analysis + Session | → `execute` |
| **-** | Issue | `discover``plan``queue``execute` | Issue records | Worktree isolation (optional) |
## Workflow Selection Decision Tree
```
Start
├─ Is it post-development maintenance?
│ ├─ Yes → Issue Workflow
│ └─ No ↓
├─ Are requirements clear?
│ ├─ Uncertain → Level 4 (brainstorm:auto-parallel)
│ └─ Clear ↓
├─ Need persistent Session?
│ ├─ Yes → Level 3 (plan / tdd-plan / test-fix-gen)
│ └─ No ↓
├─ Need multi-perspective / solution comparison?
│ ├─ Yes → Level 2 (multi-cli-plan)
│ └─ No ↓
├─ Is it a bug fix?
│ ├─ Yes → Level 2 (lite-fix)
│ └─ No ↓
├─ Need planning?
│ ├─ Yes → Level 2 (lite-plan)
│ └─ No → Level 1 (lite-lite-lite)
```
## Intent Classification
### Priority Order (with Level Mapping)
| Priority | Intent | Patterns | Level | Flow |
|----------|--------|----------|-------|------|
| 1 | bugfix/hotfix | `urgent,production,critical` + bug | L2 | `bugfix.hotfix` |
| 1 | bugfix | `fix,bug,error,crash,fail` | L2 | `bugfix.standard` |
| 2 | issue batch | `issues,batch` + `fix,resolve` | Issue | `issue` |
| 3 | exploration | `不确定,explore,研究,what if` | L4 | `full` |
| 3 | multi-perspective | `多视角,权衡,比较方案,cross-verify` | L2 | `multi-cli-plan` |
| 4 | quick-task | `快速,简单,small,quick` + feature | L1 | `lite-lite-lite` |
| 5 | ui design | `ui,design,component,style` | L3/L4 | `ui` |
| 6 | tdd | `tdd,test-driven,先写测试` | L3 | `tdd` |
| 7 | test-fix | `测试失败,test fail,fix test` | L3 | `test-fix-gen` |
| 8 | review | `review,审查,code review` | L3 | `review-fix` |
| 9 | documentation | `文档,docs,readme` | L2 | `docs` |
| 99 | feature | complexity-based | L2/L3 | `rapid`/`coupled` |
### Quick Selection Guide
| Scenario | Recommended Workflow | Level |
|----------|---------------------|-------|
| Quick fixes, config adjustments | `lite-lite-lite` | 1 |
| Clear single-module features | `lite-plan → lite-execute` | 2 |
| Bug diagnosis and fix | `lite-fix` | 2 |
| Production emergencies | `lite-fix --hotfix` | 2 |
| Technology selection, solution comparison | `multi-cli-plan → lite-execute` | 2 |
| Multi-module changes, refactoring | `plan → verify → execute` | 3 |
| Test-driven development | `tdd-plan → execute → tdd-verify` | 3 |
| Test failure fixes | `test-fix-gen → test-cycle-execute` | 3 |
| New features, architecture design | `brainstorm:auto-parallel → plan → execute` | 4 |
| Post-development issue fixes | Issue Workflow | - |
### Complexity Assessment
```javascript
function assessComplexity(text) {
let score = 0
if (/refactor|重构|migrate|迁移|architect|架构|system|系统/.test(text)) score += 2
if (/multiple|多个|across|跨|all|所有|entire|整个/.test(text)) score += 2
if (/integrate|集成|api|database|数据库/.test(text)) score += 1
if (/security|安全|performance|性能|scale|扩展/.test(text)) score += 1
return score >= 4 ? 'high' : score >= 2 ? 'medium' : 'low'
}
```
| Complexity | Flow |
|------------|------|
| high | `coupled` (plan → verify → execute) |
| medium/low | `rapid` (lite-plan → lite-execute) |
### Dimension Extraction (WHAT/WHERE/WHY/HOW)
从用户输入提取四个维度,用于需求澄清和工作流选择:
| 维度 | 提取内容 | 示例模式 |
|------|----------|----------|
| **WHAT** | action + target | `创建/修复/重构/优化/分析` + 目标对象 |
| **WHERE** | scope + paths | `file/module/system` + 文件路径 |
| **WHY** | goal + motivation | `为了.../因为.../目的是...` |
| **HOW** | constraints + preferences | `必须.../不要.../应该...` |
**Clarity Score** (0-3):
- +0.5: 有明确 action
- +0.5: 有具体 target
- +0.5: 有文件路径
- +0.5: scope 不是 unknown
- +0.5: 有明确 goal
- +0.5: 有约束条件
- -0.5: 包含不确定词 (`不知道/maybe/怎么`)
### Requirement Clarification
`clarity_score < 2` 时触发需求澄清:
```javascript
if (dimensions.clarity_score < 2) {
const questions = generateClarificationQuestions(dimensions)
// 生成问题:目标是什么? 范围是什么? 有什么约束?
AskUserQuestion({ questions })
}
```
**澄清问题类型**:
- 目标不明确 → "你想要对什么进行操作?"
- 范围不明确 → "操作的范围是什么?"
- 目的不明确 → "这个操作的主要目标是什么?"
- 复杂操作 → "有什么特殊要求或限制?"
## TODO Tracking Protocol
### CRITICAL: Append-Only Rule
CCW 创建的 Todo **必须附加到现有列表**,不能覆盖用户的其他 Todo。
### Implementation
```javascript
// 1. 使用 CCW 前缀隔离工作流 todo
const prefix = `CCW:${flowName}`
// 2. 创建新 todo 时使用前缀格式
TodoWrite({
todos: [
...existingNonCCWTodos, // 保留用户的 todo
{ content: `${prefix}: [1/N] /command:step1`, status: "in_progress", activeForm: "..." },
{ content: `${prefix}: [2/N] /command:step2`, status: "pending", activeForm: "..." }
]
})
// 3. 更新状态时只修改匹配前缀的 todo
```
### Todo Format
```
CCW:{flow}: [{N}/{Total}] /command:name
```
### Visual Example
```
✓ CCW:rapid: [1/2] /workflow:lite-plan
→ CCW:rapid: [2/2] /workflow:lite-execute
用户自己的 todo保留不动
```
### Status Management
- 开始工作流:创建所有步骤 todo第一步 `in_progress`
- 完成步骤:当前步骤 `completed`,下一步 `in_progress`
- 工作流结束:所有 CCW todo 标记 `completed`
## Execution Flow
```javascript
// 1. Check explicit command
if (input.startsWith('/workflow:') || input.startsWith('/issue:')) {
SlashCommand(input)
return
}
// 2. Classify intent
const intent = classifyIntent(input) // See command.json intent_rules
// 3. Select flow
const flow = selectFlow(intent) // See command.json flows
// 4. Create todos with CCW prefix
createWorkflowTodos(flow)
// 5. Dispatch first command
SlashCommand(flow.steps[0].command, args: input)
```
## CLI Tool Integration
CCW 在特定条件下自动注入 CLI 调用:
| Condition | CLI Inject |
|-----------|------------|
| 大量代码上下文 (≥50k chars) | `gemini --mode analysis` |
| 高复杂度任务 | `gemini --mode analysis` |
| Bug 诊断 | `gemini --mode analysis` |
| 多任务执行 (≥3 tasks) | `codex --mode write` |
### CLI Enhancement Phases
**Phase 1.5: CLI-Assisted Classification**
当规则匹配不明确时,使用 CLI 辅助分类:
| 触发条件 | 说明 |
|----------|------|
| matchCount < 2 | 多个意图模式匹配 |
| complexity = high | 高复杂度任务 |
| input > 100 chars | 长输入需要语义理解 |
**Phase 2.5: CLI-Assisted Action Planning**
高复杂度任务的工作流优化:
| 触发条件 | 说明 |
|----------|------|
| complexity = high | 高复杂度任务 |
| steps >= 3 | 多步骤工作流 |
| input > 200 chars | 复杂需求描述 |
CLI 可返回建议:`use_default` | `modify` (调整步骤) | `upgrade` (升级工作流)
## Continuation Commands
工作流执行中的用户控制命令:
| 命令 | 作用 |
|------|------|
| `continue` | 继续执行下一步 |
| `skip` | 跳过当前步骤 |
| `abort` | 终止工作流 |
| `/workflow:*` | 切换到指定命令 |
| 自然语言 | 重新分析意图 |
## Workflow Flow Details
### Issue Workflow (Main Workflow 补充机制)
Issue Workflow 是 Main Workflow 的**补充机制**,专注于开发后的持续维护。
#### 设计理念
| 方面 | Main Workflow | Issue Workflow |
|------|---------------|----------------|
| **用途** | 主要开发周期 | 开发后维护 |
| **时机** | 功能开发阶段 | 主工作流完成后 |
| **范围** | 完整功能实现 | 针对性修复/增强 |
| **并行性** | 依赖分析 → Agent 并行 | Worktree 隔离 (可选) |
| **分支模型** | 当前分支工作 | 可使用隔离的 worktree |
#### 为什么 Main Workflow 不自动使用 Worktree
**依赖分析已解决并行性问题**
1. 规划阶段 (`/workflow:plan`) 执行依赖分析
2. 自动识别任务依赖和关键路径
3. 划分为**并行组**(独立任务)和**串行链**(依赖任务)
4. Agent 并行执行独立任务,无需文件系统隔离
#### 两阶段生命周期
```
┌─────────────────────────────────────────────────────────────────────┐
│ Phase 1: Accumulation (积累阶段) │
│ │
│ Triggers: 任务完成后的 review、代码审查发现、测试失败 │
│ │
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
│ │ discover │ │ discover- │ │ new │ │
│ │ Auto-find │ │ by-prompt │ │ Manual │ │
│ └────────────┘ └────────────┘ └────────────┘ │
│ │
│ 持续积累 issues 到待处理队列 │
└─────────────────────────────────────────────────────────────────────┘
│ 积累足够后
┌─────────────────────────────────────────────────────────────────────┐
│ Phase 2: Batch Resolution (批量解决阶段) │
│ │
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
│ │ plan │ ──→ │ queue │ ──→ │ execute │ │
│ │ --all- │ │ Optimize │ │ Parallel │ │
│ │ pending │ │ order │ │ execution │ │
│ └────────────┘ └────────────┘ └────────────┘ │
│ │
│ 支持 worktree 隔离,保持主分支稳定 │
└─────────────────────────────────────────────────────────────────────┘
```
#### 与 Main Workflow 的协作
```
开发迭代循环
┌─────────────────────────────────────────────────────────────────────┐
│ │
│ ┌─────────┐ ┌─────────┐ │
│ │ Feature │ ──→ Main Workflow ──→ Done ──→│ Review │ │
│ │ Request │ (Level 1-4) └────┬────┘ │
│ └─────────┘ │ │
│ ▲ │ 发现 Issues │
│ │ ▼ │
│ │ ┌─────────┐ │
│ 继续 │ │ Issue │ │
│ 新功能│ │ Workflow│ │
│ │ └────┬────┘ │
│ │ ┌──────────────────────────────┘ │
│ │ │ 修复完成 │
│ │ ▼ │
│ ┌────┴────┐◀────── │
│ │ Main │ Merge │
│ │ Branch │ back │
│ └─────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
```
#### 命令列表
**积累阶段:**
```bash
/issue:discover # 多视角自动发现
/issue:discover-by-prompt # 基于提示发现
/issue:new # 手动创建
```
**批量解决阶段:**
```bash
/issue:plan --all-pending # 批量规划所有待处理
/issue:queue # 生成优化执行队列
/issue:execute # 并行执行
```
### lite-lite-lite vs multi-cli-plan
| 维度 | lite-lite-lite | multi-cli-plan |
|------|---------------|----------------|
| **产物** | 无文件 | IMPL_PLAN.md + plan.json + synthesis.json |
| **状态** | 无状态 | 持久化 session |
| **CLI选择** | 自动分析任务类型选择 | 配置驱动 |
| **迭代** | 通过 AskUser | 多轮收敛 |
| **执行** | 直接执行 | 通过 lite-execute |
| **适用** | 快速修复、简单功能 | 复杂多步骤实现 |
**选择指南**
- 任务清晰、改动范围小 → `lite-lite-lite`
- 需要多视角分析、复杂架构 → `multi-cli-plan`
### multi-cli-plan vs lite-plan
| 维度 | multi-cli-plan | lite-plan |
|------|---------------|-----------|
| **上下文** | ACE 语义搜索 | 手动文件模式 |
| **分析** | 多 CLI 交叉验证 | 单次规划 |
| **迭代** | 多轮直到收敛 | 单轮 |
| **置信度** | 高 (共识驱动) | 中 (单一视角) |
| **适用** | 需要多视角的复杂任务 | 直接明确的实现 |
**选择指南**
- 需求明确、路径清晰 → `lite-plan`
- 需要权衡、多方案比较 → `multi-cli-plan`
## Artifact Flow Protocol
工作流产出的自动流转机制,支持不同格式产出间的意图提取和完成度判断。
### 产出格式
| 命令 | 产出位置 | 格式 | 关键字段 |
|------|----------|------|----------|
| `/workflow:lite-plan` | memory://plan | structured_plan | tasks, files, dependencies |
| `/workflow:plan` | .workflow/{session}/IMPL_PLAN.md | markdown_plan | phases, tasks, risks |
| `/workflow:execute` | execution_log.json | execution_report | completed_tasks, errors |
| `/workflow:test-cycle-execute` | test_results.json | test_report | pass_rate, failures, coverage |
| `/workflow:review-session-cycle` | review_report.md | review_report | findings, severity_counts |
### 意图提取 (Intent Extraction)
流转到下一步时,自动提取关键信息:
```
plan → execute:
提取: tasks (未完成), priority_order, files_to_modify, context_summary
execute → test:
提取: modified_files, test_scope (推断), pending_verification
test → fix:
条件: pass_rate < 0.95
提取: failures, error_messages, affected_files, suggested_fixes
review → fix:
条件: critical > 0 OR high > 3
提取: findings (critical/high), fix_priority, affected_files
```
### 完成度判断
**Test 完成度路由**:
```
pass_rate >= 0.95 AND coverage >= 0.80 → complete
pass_rate >= 0.95 AND coverage < 0.80 → add_more_tests
pass_rate >= 0.80 → fix_failures_then_continue
pass_rate < 0.80 → major_fix_required
```
**Review 完成度路由**:
```
critical == 0 AND high <= 3 → complete_or_optional_fix
critical > 0 → mandatory_fix
high > 3 → recommended_fix
```
### 流转决策模式
**plan_execute_test**:
```
plan → execute → test
↓ (if test fail)
extract_failures → fix → test (max 3 iterations)
↓ (if still fail)
manual_intervention
```
**iterative_improvement**:
```
execute → test → fix → test → ...
loop until: pass_rate >= 0.95 OR iterations >= 3
```
### 使用示例
```javascript
// 执行完成后,根据产出决定下一步
const result = await execute(plan)
// 提取意图流转到测试
const testContext = extractIntent('execute_to_test', result)
// testContext = { modified_files, test_scope, pending_verification }
// 测试完成后,根据完成度决定路由
const testResult = await test(testContext)
const nextStep = evaluateCompletion('test', testResult)
// nextStep = 'fix_failures_then_continue' if pass_rate = 0.85
```
## Reference
- [command.json](command.json) - 命令元数据、Flow 定义、意图规则、Artifact Flow

View File

@@ -1,641 +0,0 @@
{
"_metadata": {
"version": "2.0.0",
"description": "Unified CCW command index with capabilities, flows, and intent rules"
},
"capabilities": {
"explore": {
"description": "Codebase exploration and context gathering",
"commands": ["/workflow:init", "/workflow:tools:gather", "/memory:load"],
"agents": ["cli-explore-agent", "context-search-agent"]
},
"brainstorm": {
"description": "Multi-perspective analysis and ideation",
"commands": ["/workflow:brainstorm:auto-parallel", "/workflow:brainstorm:artifacts", "/workflow:brainstorm:synthesis"],
"roles": ["product-manager", "system-architect", "ux-expert", "data-architect", "api-designer"]
},
"plan": {
"description": "Task planning and decomposition",
"commands": ["/workflow:lite-plan", "/workflow:plan", "/workflow:tdd-plan", "/task:create", "/task:breakdown"],
"agents": ["cli-lite-planning-agent", "action-planning-agent"]
},
"verify": {
"description": "Plan and quality verification",
"commands": ["/workflow:plan-verify", "/workflow:tdd-verify"]
},
"execute": {
"description": "Task execution and implementation",
"commands": ["/workflow:lite-execute", "/workflow:execute", "/task:execute"],
"agents": ["code-developer", "cli-execution-agent", "universal-executor"]
},
"bugfix": {
"description": "Bug diagnosis and fixing",
"commands": ["/workflow:lite-fix"],
"agents": ["code-developer"]
},
"test": {
"description": "Test generation and execution",
"commands": ["/workflow:test-gen", "/workflow:test-fix-gen", "/workflow:test-cycle-execute"],
"agents": ["test-fix-agent"]
},
"review": {
"description": "Code review and quality analysis",
"commands": ["/workflow:review-session-cycle", "/workflow:review-module-cycle", "/workflow:review", "/workflow:review-fix"]
},
"issue": {
"description": "Issue lifecycle management - discover, accumulate, batch resolve",
"commands": ["/issue:new", "/issue:discover", "/issue:discover-by-prompt", "/issue:plan", "/issue:queue", "/issue:execute", "/issue:manage"],
"agents": ["issue-plan-agent", "issue-queue-agent", "cli-explore-agent"],
"lifecycle": {
"accumulation": {
"description": "任务完成后进行需求扩展、bug分析、测试发现",
"triggers": ["post-task review", "code review findings", "test failures"],
"commands": ["/issue:discover", "/issue:discover-by-prompt", "/issue:new"]
},
"batch_resolution": {
"description": "积累的issue集中规划和并行执行",
"flow": ["plan", "queue", "execute"],
"commands": ["/issue:plan --all-pending", "/issue:queue", "/issue:execute"]
}
}
},
"ui-design": {
"description": "UI design and prototyping",
"commands": ["/workflow:ui-design:explore-auto", "/workflow:ui-design:imitate-auto", "/workflow:ui-design:design-sync"],
"agents": ["ui-design-agent"]
},
"memory": {
"description": "Documentation and knowledge management",
"commands": ["/memory:docs", "/memory:update-related", "/memory:update-full", "/memory:skill-memory"],
"agents": ["doc-generator", "memory-bridge"]
}
},
"flows": {
"_level_guide": {
"L1": "Rapid - No artifacts, direct execution",
"L2": "Lightweight - Memory/lightweight files, → lite-execute",
"L3": "Standard - Session persistence, → execute/test-cycle-execute",
"L4": "Brainstorm - Multi-role analysis + Session, → execute"
},
"lite-lite-lite": {
"name": "Ultra-Rapid Execution",
"level": "L1",
"description": "零文件 + 自动CLI选择 + 语义描述 + 直接执行",
"complexity": ["low"],
"artifacts": "none",
"steps": [
{ "phase": "clarify", "description": "需求澄清 (AskUser if needed)" },
{ "phase": "auto-select", "description": "任务分析 → 自动选择CLI组合" },
{ "phase": "multi-cli", "description": "并行多CLI分析" },
{ "phase": "decision", "description": "展示结果 → AskUser决策" },
{ "phase": "execute", "description": "直接执行 (无中间文件)" }
],
"cli_hints": {
"analysis": { "tool": "auto", "mode": "analysis", "parallel": true },
"execution": { "tool": "auto", "mode": "write" }
},
"estimated_time": "10-30 min"
},
"rapid": {
"name": "Rapid Iteration",
"level": "L2",
"description": "内存规划 + 直接执行",
"complexity": ["low", "medium"],
"artifacts": "memory://plan",
"steps": [
{ "command": "/workflow:lite-plan", "optional": false, "auto_continue": true },
{ "command": "/workflow:lite-execute", "optional": false }
],
"cli_hints": {
"explore_phase": { "tool": "gemini", "mode": "analysis", "trigger": "needs_exploration" },
"execution": { "tool": "codex", "mode": "write", "trigger": "complexity >= medium" }
},
"estimated_time": "15-45 min"
},
"multi-cli-plan": {
"name": "Multi-CLI Collaborative Planning",
"level": "L2",
"description": "ACE上下文 + 多CLI协作分析 + 迭代收敛 + 计划生成",
"complexity": ["medium", "high"],
"artifacts": ".workflow/.multi-cli-plan/{session}/",
"steps": [
{ "command": "/workflow:multi-cli-plan", "optional": false, "phases": [
"context_gathering: ACE语义搜索",
"multi_cli_discussion: cli-discuss-agent多轮分析",
"present_options: 展示解决方案",
"user_decision: 用户选择",
"plan_generation: cli-lite-planning-agent生成计划"
]},
{ "command": "/workflow:lite-execute", "optional": false }
],
"vs_lite_plan": {
"context": "ACE semantic search vs Manual file patterns",
"analysis": "Multi-CLI cross-verification vs Single-pass planning",
"iteration": "Multiple rounds until convergence vs Single round",
"confidence": "High (consensus-based) vs Medium (single perspective)",
"best_for": "Complex tasks needing multiple perspectives vs Straightforward implementations"
},
"agents": ["cli-discuss-agent", "cli-lite-planning-agent"],
"cli_hints": {
"discussion": { "tools": ["gemini", "codex", "claude"], "mode": "analysis", "parallel": true },
"planning": { "tool": "gemini", "mode": "analysis" }
},
"estimated_time": "30-90 min"
},
"coupled": {
"name": "Standard Planning",
"level": "L3",
"description": "完整规划 + 验证 + 执行",
"complexity": ["medium", "high"],
"artifacts": ".workflow/active/{session}/",
"steps": [
{ "command": "/workflow:plan", "optional": false },
{ "command": "/workflow:plan-verify", "optional": false, "auto_continue": true },
{ "command": "/workflow:execute", "optional": false },
{ "command": "/workflow:review", "optional": true }
],
"cli_hints": {
"pre_analysis": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
"execution": { "tool": "codex", "mode": "write", "trigger": "always" }
},
"estimated_time": "2-4 hours"
},
"full": {
"name": "Full Exploration (Brainstorm)",
"level": "L4",
"description": "头脑风暴 + 规划 + 执行",
"complexity": ["high"],
"artifacts": ".workflow/active/{session}/.brainstorming/",
"steps": [
{ "command": "/workflow:brainstorm:auto-parallel", "optional": false, "confirm_before": true },
{ "command": "/workflow:plan", "optional": false },
{ "command": "/workflow:plan-verify", "optional": true, "auto_continue": true },
{ "command": "/workflow:execute", "optional": false }
],
"cli_hints": {
"role_analysis": { "tool": "gemini", "mode": "analysis", "trigger": "always", "parallel": true },
"execution": { "tool": "codex", "mode": "write", "trigger": "task_count >= 3" }
},
"estimated_time": "1-3 hours"
},
"bugfix": {
"name": "Bug Fix",
"level": "L2",
"description": "智能诊断 + 修复 (5 phases)",
"complexity": ["low", "medium"],
"artifacts": ".workflow/.lite-fix/{bug-slug}-{date}/",
"variants": {
"standard": [{ "command": "/workflow:lite-fix", "optional": false }],
"hotfix": [{ "command": "/workflow:lite-fix --hotfix", "optional": false }]
},
"phases": [
"Phase 1: Bug Analysis & Diagnosis (severity pre-assessment)",
"Phase 2: Clarification (optional, AskUserQuestion)",
"Phase 3: Fix Planning (Low/Medium → Claude, High/Critical → cli-lite-planning-agent)",
"Phase 4: Confirmation & Selection",
"Phase 5: Execute (→ lite-execute --mode bugfix)"
],
"cli_hints": {
"diagnosis": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
"fix": { "tool": "codex", "mode": "write", "trigger": "severity >= medium" }
},
"estimated_time": "10-30 min"
},
"issue": {
"name": "Issue Lifecycle",
"level": "Supplementary",
"description": "发现积累 → 批量规划 → 队列优化 → 并行执行 (Main Workflow 补充机制)",
"complexity": ["medium", "high"],
"artifacts": ".workflow/.issues/",
"purpose": "Post-development continuous maintenance, maintain main branch stability",
"phases": {
"accumulation": {
"description": "项目迭代中持续发现和积累issue",
"commands": ["/issue:discover", "/issue:discover-by-prompt", "/issue:new"],
"trigger": "post-task, code-review, test-failure"
},
"resolution": {
"description": "集中规划和执行积累的issue",
"steps": [
{ "command": "/issue:plan --all-pending", "optional": false },
{ "command": "/issue:queue", "optional": false },
{ "command": "/issue:execute", "optional": false }
]
}
},
"worktree_support": {
"description": "可选的 worktree 隔离,保持主分支稳定",
"use_case": "主开发完成后的 issue 修复"
},
"cli_hints": {
"discovery": { "tool": "gemini", "mode": "analysis", "trigger": "perspective_analysis", "parallel": true },
"solution_generation": { "tool": "gemini", "mode": "analysis", "trigger": "always", "parallel": true },
"batch_execution": { "tool": "codex", "mode": "write", "trigger": "always" }
},
"estimated_time": "1-4 hours"
},
"tdd": {
"name": "Test-Driven Development",
"level": "L3",
"description": "TDD规划 + 执行 + 验证 (6 phases)",
"complexity": ["medium", "high"],
"artifacts": ".workflow/active/{session}/",
"steps": [
{ "command": "/workflow:tdd-plan", "optional": false },
{ "command": "/workflow:plan-verify", "optional": true, "auto_continue": true },
{ "command": "/workflow:execute", "optional": false },
{ "command": "/workflow:tdd-verify", "optional": false }
],
"tdd_structure": {
"description": "Each IMPL task contains complete internal Red-Green-Refactor cycle",
"meta": "tdd_workflow: true",
"flow_control": "implementation_approach contains 3 steps (red/green/refactor)"
},
"cli_hints": {
"test_strategy": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
"red_green_refactor": { "tool": "codex", "mode": "write", "trigger": "always" }
},
"estimated_time": "1-3 hours"
},
"test-fix": {
"name": "Test Fix Generation",
"level": "L3",
"description": "测试修复生成 + 执行循环 (5 phases)",
"complexity": ["medium", "high"],
"artifacts": ".workflow/active/WFS-test-{session}/",
"dual_mode": {
"session_mode": { "input": "WFS-xxx", "context_source": "Source session summaries" },
"prompt_mode": { "input": "Text/file path", "context_source": "Direct codebase analysis" }
},
"steps": [
{ "command": "/workflow:test-fix-gen", "optional": false },
{ "command": "/workflow:test-cycle-execute", "optional": false }
],
"task_structure": [
"IMPL-001.json (test understanding & generation)",
"IMPL-001.5-review.json (quality gate)",
"IMPL-002.json (test execution & fix cycle)"
],
"cli_hints": {
"analysis": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
"fix_cycle": { "tool": "codex", "mode": "write", "trigger": "pass_rate < 0.95" }
},
"estimated_time": "1-2 hours"
},
"ui": {
"name": "UI-First Development",
"level": "L3/L4",
"description": "UI设计 + 规划 + 执行",
"complexity": ["medium", "high"],
"artifacts": ".workflow/active/{session}/",
"variants": {
"explore": [
{ "command": "/workflow:ui-design:explore-auto", "optional": false },
{ "command": "/workflow:ui-design:design-sync", "optional": false, "auto_continue": true },
{ "command": "/workflow:plan", "optional": false },
{ "command": "/workflow:execute", "optional": false }
],
"imitate": [
{ "command": "/workflow:ui-design:imitate-auto", "optional": false },
{ "command": "/workflow:ui-design:design-sync", "optional": false, "auto_continue": true },
{ "command": "/workflow:plan", "optional": false },
{ "command": "/workflow:execute", "optional": false }
]
},
"estimated_time": "2-4 hours"
},
"review-fix": {
"name": "Review and Fix",
"level": "L3",
"description": "多维审查 + 自动修复",
"complexity": ["medium"],
"artifacts": ".workflow/active/{session}/review_report.md",
"steps": [
{ "command": "/workflow:review-session-cycle", "optional": false },
{ "command": "/workflow:review-fix", "optional": true }
],
"cli_hints": {
"multi_dimension_review": { "tool": "gemini", "mode": "analysis", "trigger": "always", "parallel": true },
"auto_fix": { "tool": "codex", "mode": "write", "trigger": "findings_count >= 3" }
},
"estimated_time": "30-90 min"
},
"docs": {
"name": "Documentation",
"level": "L2",
"description": "批量文档生成",
"complexity": ["low", "medium"],
"variants": {
"incremental": [{ "command": "/memory:update-related", "optional": false }],
"full": [
{ "command": "/memory:docs", "optional": false },
{ "command": "/workflow:execute", "optional": false }
]
},
"estimated_time": "15-60 min"
}
},
"intent_rules": {
"_level_mapping": {
"description": "Intent → Level → Flow mapping guide",
"L1": ["lite-lite-lite"],
"L2": ["rapid", "bugfix", "multi-cli-plan", "docs"],
"L3": ["coupled", "tdd", "test-fix", "review-fix", "ui"],
"L4": ["full"],
"Supplementary": ["issue"]
},
"bugfix": {
"priority": 1,
"level": "L2",
"variants": {
"hotfix": {
"patterns": ["hotfix", "urgent", "production", "critical", "emergency", "紧急", "生产环境", "线上"],
"flow": "bugfix.hotfix"
},
"standard": {
"patterns": ["fix", "bug", "error", "issue", "crash", "broken", "fail", "wrong", "修复", "错误", "崩溃"],
"flow": "bugfix.standard"
}
}
},
"issue_batch": {
"priority": 2,
"level": "Supplementary",
"patterns": {
"batch": ["issues", "batch", "queue", "多个", "批量"],
"action": ["fix", "resolve", "处理", "解决"]
},
"require_both": true,
"flow": "issue"
},
"exploration": {
"priority": 3,
"level": "L4",
"patterns": ["不确定", "不知道", "explore", "研究", "分析一下", "怎么做", "what if", "探索"],
"flow": "full"
},
"multi_perspective": {
"priority": 3,
"level": "L2",
"patterns": ["多视角", "权衡", "比较方案", "cross-verify", "多CLI", "协作分析"],
"flow": "multi-cli-plan"
},
"quick_task": {
"priority": 4,
"level": "L1",
"patterns": ["快速", "简单", "small", "quick", "simple", "trivial", "小改动"],
"flow": "lite-lite-lite"
},
"ui_design": {
"priority": 5,
"level": "L3/L4",
"patterns": ["ui", "界面", "design", "设计", "component", "组件", "style", "样式", "layout", "布局"],
"variants": {
"imitate": { "triggers": ["参考", "模仿", "像", "类似"], "flow": "ui.imitate" },
"explore": { "triggers": [], "flow": "ui.explore" }
}
},
"tdd": {
"priority": 6,
"level": "L3",
"patterns": ["tdd", "test-driven", "测试驱动", "先写测试", "test first"],
"flow": "tdd"
},
"test_fix": {
"priority": 7,
"level": "L3",
"patterns": ["测试失败", "test fail", "fix test", "test error", "pass rate", "coverage gap"],
"flow": "test-fix"
},
"review": {
"priority": 8,
"level": "L3",
"patterns": ["review", "审查", "检查代码", "code review", "质量检查"],
"flow": "review-fix"
},
"documentation": {
"priority": 9,
"level": "L2",
"patterns": ["文档", "documentation", "docs", "readme"],
"variants": {
"incremental": { "triggers": ["更新", "增量"], "flow": "docs.incremental" },
"full": { "triggers": ["全部", "完整"], "flow": "docs.full" }
}
},
"feature": {
"priority": 99,
"complexity_map": {
"high": { "level": "L3", "flow": "coupled" },
"medium": { "level": "L2", "flow": "rapid" },
"low": { "level": "L1", "flow": "lite-lite-lite" }
}
}
},
"complexity_indicators": {
"high": {
"threshold": 4,
"patterns": {
"architecture": { "keywords": ["refactor", "重构", "migrate", "迁移", "architect", "架构", "system", "系统"], "weight": 2 },
"multi_module": { "keywords": ["multiple", "多个", "across", "跨", "all", "所有", "entire", "整个"], "weight": 2 },
"integration": { "keywords": ["integrate", "集成", "api", "database", "数据库"], "weight": 1 },
"quality": { "keywords": ["security", "安全", "performance", "性能", "scale", "扩展"], "weight": 1 }
}
},
"medium": { "threshold": 2 },
"low": { "threshold": 0 }
},
"cli_tools": {
"gemini": {
"strengths": ["超长上下文", "深度分析", "架构理解", "执行流追踪"],
"triggers": ["分析", "理解", "设计", "架构", "诊断"],
"mode": "analysis"
},
"qwen": {
"strengths": ["代码模式识别", "多维度分析"],
"triggers": ["评估", "对比", "验证"],
"mode": "analysis"
},
"codex": {
"strengths": ["精确代码生成", "自主执行"],
"triggers": ["实现", "重构", "修复", "生成"],
"mode": "write"
}
},
"cli_injection_rules": {
"context_gathering": { "trigger": "file_read >= 50k OR module_count >= 5", "inject": "gemini --mode analysis" },
"pre_planning_analysis": { "trigger": "complexity === high", "inject": "gemini --mode analysis" },
"debug_diagnosis": { "trigger": "intent === bugfix AND root_cause_unclear", "inject": "gemini --mode analysis" },
"code_review": { "trigger": "step === review", "inject": "gemini --mode analysis" },
"implementation": { "trigger": "step === execute AND task_count >= 3", "inject": "codex --mode write" }
},
"artifact_flow": {
"_description": "定义工作流产出的格式、意图提取和流转规则",
"outputs": {
"/workflow:lite-plan": {
"artifact": "memory://plan",
"format": "structured_plan",
"fields": ["tasks", "files", "dependencies", "approach"]
},
"/workflow:plan": {
"artifact": ".workflow/{session}/IMPL_PLAN.md",
"format": "markdown_plan",
"fields": ["phases", "tasks", "dependencies", "risks", "test_strategy"]
},
"/workflow:multi-cli-plan": {
"artifact": ".workflow/.multi-cli-plan/{session}/",
"format": "multi_file",
"files": ["IMPL_PLAN.md", "plan.json", "synthesis.json"],
"fields": ["consensus", "divergences", "recommended_approach", "tasks"]
},
"/workflow:lite-execute": {
"artifact": "git_changes",
"format": "code_diff",
"fields": ["modified_files", "added_files", "deleted_files", "build_status"]
},
"/workflow:execute": {
"artifact": ".workflow/{session}/execution_log.json",
"format": "execution_report",
"fields": ["completed_tasks", "pending_tasks", "errors", "warnings"]
},
"/workflow:test-cycle-execute": {
"artifact": ".workflow/{session}/test_results.json",
"format": "test_report",
"fields": ["pass_rate", "failures", "coverage", "duration"]
},
"/workflow:review-session-cycle": {
"artifact": ".workflow/{session}/review_report.md",
"format": "review_report",
"fields": ["findings", "severity_counts", "recommendations"]
},
"/workflow:lite-fix": {
"artifact": "git_changes",
"format": "fix_report",
"fields": ["root_cause", "fix_applied", "files_modified", "verification_status"]
}
},
"intent_extraction": {
"plan_to_execute": {
"from": ["lite-plan", "plan", "multi-cli-plan"],
"to": ["lite-execute", "execute"],
"extract": {
"tasks": "$.tasks[] | filter(status != 'completed')",
"priority_order": "$.tasks | sort_by(priority)",
"files_to_modify": "$.tasks[].files | flatten | unique",
"dependencies": "$.dependencies",
"context_summary": "$.approach OR $.recommended_approach"
}
},
"execute_to_test": {
"from": ["lite-execute", "execute"],
"to": ["test-cycle-execute", "test-fix-gen"],
"extract": {
"modified_files": "$.modified_files",
"test_scope": "infer_from($.modified_files)",
"build_status": "$.build_status",
"pending_verification": "$.completed_tasks | needs_test"
}
},
"test_to_fix": {
"from": ["test-cycle-execute"],
"to": ["lite-fix", "review-fix"],
"condition": "$.pass_rate < 0.95",
"extract": {
"failures": "$.failures",
"error_messages": "$.failures[].message",
"affected_files": "$.failures[].file",
"suggested_fixes": "$.failures[].suggested_fix"
}
},
"review_to_fix": {
"from": ["review-session-cycle", "review-module-cycle"],
"to": ["review-fix"],
"condition": "$.severity_counts.critical > 0 OR $.severity_counts.high > 3",
"extract": {
"findings": "$.findings | filter(severity in ['critical', 'high'])",
"fix_priority": "$.findings | group_by(category) | sort_by(severity)",
"affected_files": "$.findings[].file | unique"
}
}
},
"completion_criteria": {
"plan": {
"required": ["has_tasks", "has_files"],
"optional": ["has_tests", "no_blocking_risks"],
"threshold": 0.8,
"routing": {
"complete": "proceed_to_execute",
"incomplete": "clarify_requirements"
}
},
"execute": {
"required": ["all_tasks_attempted", "no_critical_errors"],
"optional": ["build_passes", "lint_passes"],
"threshold": 1.0,
"routing": {
"complete": "proceed_to_test_or_review",
"partial": "continue_execution",
"failed": "diagnose_and_retry"
}
},
"test": {
"metrics": {
"pass_rate": { "target": 0.95, "minimum": 0.80 },
"coverage": { "target": 0.80, "minimum": 0.60 }
},
"routing": {
"pass_rate >= 0.95 AND coverage >= 0.80": "complete",
"pass_rate >= 0.95 AND coverage < 0.80": "add_more_tests",
"pass_rate >= 0.80": "fix_failures_then_continue",
"pass_rate < 0.80": "major_fix_required"
}
},
"review": {
"metrics": {
"critical_findings": { "target": 0, "maximum": 0 },
"high_findings": { "target": 0, "maximum": 3 }
},
"routing": {
"critical == 0 AND high <= 3": "complete_or_optional_fix",
"critical > 0": "mandatory_fix",
"high > 3": "recommended_fix"
}
}
},
"flow_decisions": {
"_description": "根据产出完成度决定下一步",
"patterns": {
"plan_execute_test": {
"sequence": ["plan", "execute", "test"],
"on_test_fail": {
"action": "extract_failures_and_fix",
"max_iterations": 3,
"fallback": "manual_intervention"
}
},
"plan_execute_review": {
"sequence": ["plan", "execute", "review"],
"on_review_issues": {
"action": "prioritize_and_fix",
"auto_fix_threshold": "severity < high"
}
},
"iterative_improvement": {
"sequence": ["execute", "test", "fix"],
"loop_until": "pass_rate >= 0.95 OR iterations >= 3",
"on_loop_exit": "report_status"
}
}
}
}
}

View File

@@ -263,6 +263,49 @@ Open Dashboard via `ccw view`, manage indexes and execute searches in **CodexLen
## 💻 CCW CLI Commands
### 🌟 Recommended Commands (Main Features)
<div align="center">
<table>
<tr><th>Command</th><th>Description</th><th>When to Use</th></tr>
<tr>
<td><b>/ccw</b></td>
<td>Auto workflow orchestrator - analyzes intent, selects workflow level, executes command chain in main process</td>
<td>✅ General tasks, auto workflow selection, quick development</td>
</tr>
<tr>
<td><b>/ccw-coordinator</b></td>
<td>Manual orchestrator - recommends command chains, executes via external CLI with state persistence</td>
<td>🔧 Complex multi-step workflows, custom chains, resumable sessions</td>
</tr>
</table>
</div>
**Quick Examples**:
```bash
# /ccw - Auto workflow selection (Main Process)
/ccw "Add user authentication" # Auto-selects workflow based on intent
/ccw "Fix memory leak in WebSocket" # Detects bugfix workflow
/ccw "Implement with TDD" # Routes to TDD workflow
# /ccw-coordinator - Manual chain orchestration (External CLI)
/ccw-coordinator "Implement OAuth2 system" # Analyzes → Recommends chain → User confirms → Executes
```
**Key Differences**:
| Aspect | /ccw | /ccw-coordinator |
|--------|------|------------------|
| **Execution** | Main process (SlashCommand) | External CLI (background tasks) |
| **Selection** | Auto intent-based | Manual chain confirmation |
| **State** | TodoWrite tracking | Persistent state.json |
| **Use Case** | General tasks, quick dev | Complex chains, resumable |
---
### Other CLI Commands
```bash
ccw install # Install workflow files
ccw view # Open dashboard

View File

@@ -263,6 +263,49 @@ codexlens index /path/to/project
## 💻 CCW CLI 命令
### 🌟 推荐命令(核心功能)
<div align="center">
<table>
<tr><th>命令</th><th>说明</th><th>适用场景</th></tr>
<tr>
<td><b>/ccw</b></td>
<td>自动工作流编排器 - 分析意图、自动选择工作流级别、在主进程中执行命令链</td>
<td>✅ 通用任务、自动选择工作流、快速开发</td>
</tr>
<tr>
<td><b>/ccw-coordinator</b></td>
<td>手动编排器 - 推荐命令链、通过外部 CLI 执行、持久化状态</td>
<td>🔧 复杂多步骤工作流、自定义链、可恢复会话</td>
</tr>
</table>
</div>
**快速示例**
```bash
# /ccw - 自动工作流选择(主进程)
/ccw "添加用户认证" # 自动根据意图选择工作流
/ccw "修复 WebSocket 中的内存泄漏" # 识别为 bugfix 工作流
/ccw "使用 TDD 方式实现" # 路由到 TDD 工作流
# /ccw-coordinator - 手动链编排(外部 CLI
/ccw-coordinator "实现 OAuth2 系统" # 分析 → 推荐链 → 用户确认 → 执行
```
**主要区别**
| 方面 | /ccw | /ccw-coordinator |
|------|------|------------------|
| **执行方式** | 主进程SlashCommand | 外部 CLI后台任务 |
| **选择方式** | 自动基于意图识别 | 手动链确认 |
| **状态管理** | TodoWrite 跟踪 | 持久化 state.json |
| **适用场景** | 通用任务、快速开发 | 复杂链条、可恢复 |
---
### 其他 CLI 命令
```bash
ccw install # 安装工作流文件
ccw view # 打开 Dashboard

View File

@@ -0,0 +1,669 @@
/**
* CommandRegistry Tests
*
* Test coverage:
* - YAML header parsing
* - Command metadata extraction
* - Directory detection (relative and home)
* - Caching mechanism
* - Batch operations
* - Categorization
* - Error handling
*/
import { CommandRegistry, createCommandRegistry, getAllCommandsSync, getCommandSync } from './command-registry';
import * as fs from 'fs';
import * as path from 'path';
import * as os from 'os';
// Mock fs module
jest.mock('fs');
jest.mock('os');
describe('CommandRegistry', () => {
const mockReadFileSync = fs.readFileSync as jest.MockedFunction<typeof fs.readFileSync>;
const mockExistsSync = fs.existsSync as jest.MockedFunction<typeof fs.existsSync>;
const mockReaddirSync = fs.readdirSync as jest.MockedFunction<typeof fs.readdirSync>;
const mockStatSync = fs.statSync as jest.MockedFunction<typeof fs.statSync>;
const mockHomedir = os.homedir as jest.MockedFunction<typeof os.homedir>;
// Sample YAML headers
const sampleLitePlanYaml = `---
name: lite-plan
description: Quick planning for simple features
argument-hint: "\"feature description\""
allowed-tools: Task(*), Read(*), Write(*), Bash(*)
---
# Content here`;
const sampleExecuteYaml = `---
name: execute
description: Execute implementation from plan
argument-hint: "--resume-session=\"WFS-xxx\""
allowed-tools: Task(*), Bash(*)
---
# Content here`;
const sampleTestYaml = `---
name: test-cycle-execute
description: Run tests and fix failures
argument-hint: "--session=\"WFS-xxx\""
allowed-tools: Task(*), Bash(*)
---
# Content here`;
const sampleReviewYaml = `---
name: review
description: Code review workflow
argument-hint: "--session=\"WFS-xxx\""
allowed-tools: Task(*), Read(*)
---
# Content here`;
beforeEach(() => {
jest.clearAllMocks();
});
describe('constructor & directory detection', () => {
it('should use provided command directory', () => {
const customDir = '/custom/path';
const registry = new CommandRegistry(customDir);
expect((registry as any).commandDir).toBe(customDir);
});
it('should auto-detect relative .claude/commands/workflow directory', () => {
mockExistsSync.mockImplementation((path: string) => {
return path === '.claude/commands/workflow';
});
const registry = new CommandRegistry();
expect((registry as any).commandDir).toBe('.claude/commands/workflow');
expect(mockExistsSync).toHaveBeenCalledWith('.claude/commands/workflow');
});
it('should auto-detect home directory ~/.claude/commands/workflow', () => {
mockExistsSync.mockImplementation((checkPath: string) => {
return checkPath === path.join('/home/user', '.claude', 'commands', 'workflow');
});
mockHomedir.mockReturnValue('/home/user');
const registry = new CommandRegistry();
expect((registry as any).commandDir).toBe(
path.join('/home/user', '.claude', 'commands', 'workflow')
);
});
it('should return null if no command directory found', () => {
mockExistsSync.mockReturnValue(false);
mockHomedir.mockReturnValue('/home/user');
const registry = new CommandRegistry();
expect((registry as any).commandDir).toBeNull();
});
});
describe('parseYamlHeader', () => {
it('should parse simple YAML header with Unix line endings', () => {
const yaml = `---
name: test-command
description: Test description
argument-hint: "\"test\""
allowed-tools: Task(*), Read(*)
---
Content here`;
const registry = new CommandRegistry('/fake/path');
const result = (registry as any).parseYamlHeader(yaml);
expect(result).toEqual({
name: 'test-command',
description: 'Test description',
'argument-hint': '"test"',
'allowed-tools': 'Task(*), Read(*)'
});
});
it('should parse YAML header with Windows line endings (\\r\\n)', () => {
const yaml = `---\r\nname: test-command\r\ndescription: Test\r\n---`;
const registry = new CommandRegistry('/fake/path');
const result = (registry as any).parseYamlHeader(yaml);
expect(result).toEqual({
name: 'test-command',
description: 'Test'
});
});
it('should handle quoted values', () => {
const yaml = `---
name: "cmd"
description: 'double quoted'
---`;
const registry = new CommandRegistry('/fake/path');
const result = (registry as any).parseYamlHeader(yaml);
expect(result).toEqual({
name: 'cmd',
description: 'double quoted'
});
});
it('should parse allowed-tools and trim spaces', () => {
const yaml = `---
name: test
allowed-tools: Task(*), Read(*) , Write(*), Bash(*)
---`;
const registry = new CommandRegistry('/fake/path');
const result = (registry as any).parseYamlHeader(yaml);
expect(result['allowed-tools']).toBe('Task(*), Read(*), Write(*), Bash(*)');
});
it('should skip comments and empty lines', () => {
const yaml = `---
# This is a comment
name: test-command
# Another comment
description: Test
---`;
const registry = new CommandRegistry('/fake/path');
const result = (registry as any).parseYamlHeader(yaml);
expect(result).toEqual({
name: 'test-command',
description: 'Test'
});
});
it('should return null for missing YAML markers', () => {
const yaml = `name: test-command
description: Test`;
const registry = new CommandRegistry('/fake/path');
const result = (registry as any).parseYamlHeader(yaml);
expect(result).toBeNull();
});
it('should return null for malformed YAML', () => {
const yaml = `---
invalid yaml content without colons
---`;
const registry = new CommandRegistry('/fake/path');
const result = (registry as any).parseYamlHeader(yaml);
expect(result).toEqual({});
});
});
describe('getCommand', () => {
it('should get command metadata by name', () => {
const cmdDir = '/workflows';
mockExistsSync.mockImplementation((checkPath: string) => {
return checkPath === path.join(cmdDir, 'lite-plan.md');
});
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('lite-plan');
expect(result).toEqual({
name: 'lite-plan',
command: '/workflow:lite-plan',
description: 'Quick planning for simple features',
argumentHint: '"feature description"',
allowedTools: ['Task(*)', 'Read(*)', 'Write(*)', 'Bash(*)'],
filePath: path.join(cmdDir, 'lite-plan.md')
});
});
it('should normalize /workflow: prefix', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('/workflow:lite-plan');
expect(result?.name).toBe('lite-plan');
});
it('should use cache for repeated requests', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
const registry = new CommandRegistry(cmdDir);
registry.getCommand('lite-plan');
registry.getCommand('lite-plan');
// readFileSync should only be called once due to cache
expect(mockReadFileSync).toHaveBeenCalledTimes(1);
});
it('should return null if command file not found', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(false);
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('nonexistent');
expect(result).toBeNull();
});
it('should return null if no command directory', () => {
mockExistsSync.mockReturnValue(false);
mockHomedir.mockReturnValue('/home/user');
const registry = new CommandRegistry();
const result = registry.getCommand('lite-plan');
expect(result).toBeNull();
});
it('should return null if YAML header is invalid', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockReturnValue('No YAML header here');
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('lite-plan');
expect(result).toBeNull();
});
it('should parse allowedTools correctly', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockReturnValue(sampleExecuteYaml);
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('execute');
expect(result?.allowedTools).toEqual(['Task(*)', 'Bash(*)']);
});
it('should handle empty allowedTools', () => {
const yaml = `---
name: minimal-cmd
description: Minimal command
---`;
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockReturnValue(yaml);
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('minimal-cmd');
expect(result?.allowedTools).toEqual([]);
});
});
describe('getCommands', () => {
it('should get multiple commands', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockImplementation((filePath: string) => {
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
if (filePath.includes('execute')) return sampleExecuteYaml;
return '';
});
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommands(['lite-plan', 'execute', 'nonexistent']);
expect(result.size).toBe(2);
expect(result.has('/workflow:lite-plan')).toBe(true);
expect(result.has('/workflow:execute')).toBe(true);
});
it('should skip nonexistent commands', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(false);
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommands(['nonexistent1', 'nonexistent2']);
expect(result.size).toBe(0);
});
});
describe('getAllCommandsSummary', () => {
it('should get all commands summary', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md', 'test.md'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockImplementation((filePath: string) => {
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
if (filePath.includes('execute')) return sampleExecuteYaml;
if (filePath.includes('test')) return sampleTestYaml;
return '';
});
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsSummary();
expect(result.size).toBe(3);
expect(result.get('/workflow:lite-plan')).toEqual({
name: 'lite-plan',
description: 'Quick planning for simple features'
});
});
it('should skip directories', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['file.md', 'directory'] as any);
mockStatSync.mockImplementation((filePath: string) => ({
isDirectory: () => filePath.includes('directory')
} as any));
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsSummary();
// Only file.md should be processed
expect(mockReadFileSync).toHaveBeenCalledTimes(1);
});
it('should skip files with invalid YAML headers', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['valid.md', 'invalid.md'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockImplementation((filePath: string) => {
if (filePath.includes('valid')) return sampleLitePlanYaml;
return 'No YAML header';
});
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsSummary();
expect(result.size).toBe(1);
});
it('should return empty map if no command directory', () => {
mockExistsSync.mockReturnValue(false);
mockHomedir.mockReturnValue('/home/user');
const registry = new CommandRegistry();
const result = registry.getAllCommandsSummary();
expect(result.size).toBe(0);
});
it('should handle directory read errors gracefully', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockImplementation(() => {
throw new Error('Permission denied');
});
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsSummary();
expect(result.size).toBe(0);
});
});
describe('getAllCommandsByCategory', () => {
it('should categorize commands by name patterns', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md', 'test-cycle-execute.md', 'review.md'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockImplementation((filePath: string) => {
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
if (filePath.includes('execute')) return sampleExecuteYaml;
if (filePath.includes('test')) return sampleTestYaml;
if (filePath.includes('review')) return sampleReviewYaml;
return '';
});
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsByCategory();
expect(result.planning.length).toBe(1);
expect(result.execution.length).toBe(1);
expect(result.testing.length).toBe(1);
expect(result.review.length).toBe(1);
expect(result.other.length).toBe(0);
expect(result.planning[0].name).toBe('lite-plan');
expect(result.execution[0].name).toBe('execute');
});
it('should handle commands matching multiple patterns', () => {
const yamlMultiMatch = `---
name: test-plan
description: TDD planning
allowed-tools: Task(*)
---`;
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['test-plan.md'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockReturnValue(yamlMultiMatch);
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsByCategory();
// Should match 'plan' pattern (planning)
expect(result.planning.length).toBe(1);
});
});
describe('toJSON', () => {
it('should serialize cached commands to JSON', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
const registry = new CommandRegistry(cmdDir);
registry.getCommand('lite-plan');
const json = registry.toJSON();
expect(json['/workflow:lite-plan']).toEqual({
name: 'lite-plan',
command: '/workflow:lite-plan',
description: 'Quick planning for simple features',
argumentHint: '"feature description"',
allowedTools: ['Task(*)', 'Read(*)', 'Write(*)', 'Bash(*)'],
filePath: path.join(cmdDir, 'lite-plan.md')
});
});
it('should only include cached commands', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockImplementation((filePath: string) => {
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
return sampleExecuteYaml;
});
const registry = new CommandRegistry(cmdDir);
registry.getCommand('lite-plan');
// Don't call getCommand for 'execute'
const json = registry.toJSON();
expect(Object.keys(json).length).toBe(1);
expect(json['/workflow:lite-plan']).toBeDefined();
expect(json['/workflow:execute']).toBeUndefined();
});
});
describe('exported functions', () => {
it('createCommandRegistry should create new instance', () => {
mockExistsSync.mockReturnValue(true);
const registry = createCommandRegistry('/custom/path');
expect((registry as any).commandDir).toBe('/custom/path');
});
it('getAllCommandsSync should return all commands', () => {
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['lite-plan.md'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
mockHomedir.mockReturnValue('/home/user');
const result = getAllCommandsSync();
expect(result.size).toBeGreaterThanOrEqual(1);
});
it('getCommandSync should return specific command', () => {
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
mockHomedir.mockReturnValue('/home/user');
const result = getCommandSync('lite-plan');
expect(result?.name).toBe('lite-plan');
});
});
describe('edge cases', () => {
it('should handle file read errors', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReadFileSync.mockImplementation(() => {
throw new Error('File read error');
});
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('lite-plan');
expect(result).toBeNull();
});
it('should handle YAML parsing errors', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
// Return something that will cause parsing to fail
mockReadFileSync.mockReturnValue('---\ninvalid: : : yaml\n---');
const registry = new CommandRegistry(cmdDir);
const result = registry.getCommand('lite-plan');
// Should return null since name is not in result
expect(result).toBeNull();
});
it('should handle empty command directory', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue([] as any);
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsSummary();
expect(result.size).toBe(0);
});
it('should handle non-md files in command directory', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['lite-plan.md', 'readme.txt', '.gitignore'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
const registry = new CommandRegistry(cmdDir);
const result = registry.getAllCommandsSummary();
expect(result.size).toBe(1);
});
});
describe('integration tests', () => {
it('should work with full workflow', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md', 'test-cycle-execute.md'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockImplementation((filePath: string) => {
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
if (filePath.includes('execute')) return sampleExecuteYaml;
if (filePath.includes('test')) return sampleTestYaml;
return '';
});
const registry = new CommandRegistry(cmdDir);
// Get all summary
const summary = registry.getAllCommandsSummary();
expect(summary.size).toBe(3);
// Get by category
const byCategory = registry.getAllCommandsByCategory();
expect(byCategory.planning.length).toBe(1);
expect(byCategory.execution.length).toBe(1);
expect(byCategory.testing.length).toBe(1);
// Get specific command
const cmd = registry.getCommand('lite-plan');
expect(cmd?.name).toBe('lite-plan');
// Get multiple commands
const multiple = registry.getCommands(['lite-plan', 'execute']);
expect(multiple.size).toBe(2);
// Convert to JSON
const json = registry.toJSON();
expect(Object.keys(json).length).toBeGreaterThan(0);
});
it('should maintain cache across operations', () => {
const cmdDir = '/workflows';
mockExistsSync.mockReturnValue(true);
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md'] as any);
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
mockReadFileSync.mockImplementation((filePath: string) => {
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
return sampleExecuteYaml;
});
const registry = new CommandRegistry(cmdDir);
// First call
registry.getCommand('lite-plan');
const initialCallCount = mockReadFileSync.mock.calls.length;
// getAllCommandsSummary will read all files
registry.getAllCommandsSummary();
const afterSummaryCallCount = mockReadFileSync.mock.calls.length;
// Second getCommand should use cache
registry.getCommand('lite-plan');
const finalCallCount = mockReadFileSync.mock.calls.length;
// lite-plan.md should only be read twice:
// 1. Initial getCommand
// 2. getAllCommandsSummary (must read all files)
// Not again in second getCommand due to cache
expect(finalCallCount).toBe(afterSummaryCallCount);
});
});
});

View File

@@ -0,0 +1,308 @@
/**
* Command Registry Tool
*
* Features:
* 1. Scan and parse YAML headers from command files
* 2. Read from global ~/.claude/commands/workflow directory
* 3. Support on-demand extraction (not full scan)
* 4. Cache parsed metadata for performance
*/
import { existsSync, readdirSync, readFileSync, statSync } from 'fs';
import { join } from 'path';
import { homedir } from 'os';
export interface CommandMetadata {
name: string;
command: string;
description: string;
argumentHint: string;
allowedTools: string[];
filePath: string;
}
export interface CommandSummary {
name: string;
description: string;
}
export class CommandRegistry {
private commandDir: string | null;
private cache: Map<string, CommandMetadata>;
constructor(commandDir?: string) {
this.cache = new Map();
if (commandDir) {
this.commandDir = commandDir;
} else {
this.commandDir = this.findCommandDir();
}
}
/**
* Auto-detect ~/.claude/commands/workflow directory
*/
private findCommandDir(): string | null {
// Try relative to current working directory
const relativePath = join('.claude', 'commands', 'workflow');
if (existsSync(relativePath)) {
return relativePath;
}
// Try user home directory
const homeDir = homedir();
const homeCommandDir = join(homeDir, '.claude', 'commands', 'workflow');
if (existsSync(homeCommandDir)) {
return homeCommandDir;
}
return null;
}
/**
* Parse YAML header (simplified version)
*
* Limitations:
* - Only supports simple key: value pairs (single-line values)
* - No support for multi-line values, nested objects, complex lists
* - allowed-tools field converts comma-separated strings to arrays
*/
private parseYamlHeader(content: string): Record<string, any> | null {
// Handle Windows line endings (\r\n)
const match = content.match(/^---[\r\n]+([\s\S]*?)[\r\n]+---/);
if (!match) return null;
const yamlContent = match[1];
const result: Record<string, any> = {};
try {
const lines = yamlContent.split(/[\r\n]+/);
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed || trimmed.startsWith('#')) continue; // Skip empty lines and comments
const colonIndex = trimmed.indexOf(':');
if (colonIndex === -1) continue;
const key = trimmed.substring(0, colonIndex).trim();
let value = trimmed.substring(colonIndex + 1).trim();
if (!key) continue; // Skip invalid lines
// Remove quotes (single or double)
let cleanValue = value.replace(/^["']|["']$/g, '');
// Special handling for allowed-tools field: convert to array
// Supports format: "Read, Write, Bash" or "Read,Write,Bash"
if (key === 'allowed-tools') {
cleanValue = cleanValue
.split(',')
.map(t => t.trim())
.filter(t => t)
.join(','); // Keep as comma-separated for now, will convert in getCommand
}
result[key] = cleanValue;
}
} catch (error) {
const err = error as Error;
console.error('YAML parsing error:', err.message);
return null;
}
return result;
}
/**
* Get single command metadata
* @param commandName Command name (e.g., "lite-plan" or "/workflow:lite-plan")
* @returns Command metadata or null
*/
public getCommand(commandName: string): CommandMetadata | null {
if (!this.commandDir) {
console.error('ERROR: ~/.claude/commands/workflow directory not found');
return null;
}
// Normalize command name
const normalized = commandName.startsWith('/workflow:')
? commandName.substring('/workflow:'.length)
: commandName;
// Check cache
const cached = this.cache.get(normalized);
if (cached) {
return cached;
}
// Read command file
const filePath = join(this.commandDir, `${normalized}.md`);
if (!existsSync(filePath)) {
return null;
}
try {
const content = readFileSync(filePath, 'utf-8');
const header = this.parseYamlHeader(content);
if (header && header.name) {
const toolsStr = header['allowed-tools'] || '';
const allowedTools = toolsStr
.split(',')
.map((t: string) => t.trim())
.filter((t: string) => t);
const result: CommandMetadata = {
name: header.name,
command: `/workflow:${header.name}`,
description: header.description || '',
argumentHint: header['argument-hint'] || '',
allowedTools: allowedTools,
filePath: filePath
};
// Cache result
this.cache.set(normalized, result);
return result;
}
} catch (error) {
const err = error as Error;
console.error(`Failed to read command ${filePath}:`, err.message);
}
return null;
}
/**
* Get multiple commands metadata
* @param commandNames Array of command names
* @returns Map of command metadata
*/
public getCommands(commandNames: string[]): Map<string, CommandMetadata> {
const result = new Map<string, CommandMetadata>();
for (const name of commandNames) {
const cmd = this.getCommand(name);
if (cmd) {
result.set(cmd.command, cmd);
}
}
return result;
}
/**
* Get all commands' names and descriptions
* @returns Map of command names to summaries
*/
public getAllCommandsSummary(): Map<string, CommandSummary> {
const result = new Map<string, CommandSummary>();
if (!this.commandDir) {
return result;
}
try {
const files = readdirSync(this.commandDir);
for (const file of files) {
if (!file.endsWith('.md')) continue;
const filePath = join(this.commandDir, file);
const stat = statSync(filePath);
if (stat.isDirectory()) continue;
try {
const content = readFileSync(filePath, 'utf-8');
const header = this.parseYamlHeader(content);
if (header && header.name) {
const commandName = `/workflow:${header.name}`;
result.set(commandName, {
name: header.name,
description: header.description || ''
});
}
} catch (error) {
// Skip files that fail to read
continue;
}
}
} catch (error) {
// Return empty map if directory read fails
return result;
}
return result;
}
/**
* Get all commands organized by category/tags
*/
public getAllCommandsByCategory(): Record<string, CommandMetadata[]> {
const summary = this.getAllCommandsSummary();
const result: Record<string, CommandMetadata[]> = {
planning: [],
execution: [],
testing: [],
review: [],
other: []
};
for (const [cmdName] of summary) {
const cmd = this.getCommand(cmdName);
if (cmd) {
// Categorize based on command name patterns
if (cmd.name.includes('plan')) {
result.planning.push(cmd);
} else if (cmd.name.includes('execute')) {
result.execution.push(cmd);
} else if (cmd.name.includes('test')) {
result.testing.push(cmd);
} else if (cmd.name.includes('review')) {
result.review.push(cmd);
} else {
result.other.push(cmd);
}
}
}
return result;
}
/**
* Convert to JSON for serialization
*/
public toJSON(): Record<string, any> {
const result: Record<string, CommandMetadata> = {};
for (const [key, value] of this.cache) {
result[`/workflow:${key}`] = value;
}
return result;
}
}
/**
* Export function for direct usage
*/
export function createCommandRegistry(commandDir?: string): CommandRegistry {
return new CommandRegistry(commandDir);
}
/**
* Export function to get all commands
*/
export function getAllCommandsSync(): Map<string, CommandSummary> {
const registry = new CommandRegistry();
return registry.getAllCommandsSummary();
}
/**
* Export function to get specific command
*/
export function getCommandSync(name: string): CommandMetadata | null {
const registry = new CommandRegistry();
return registry.getCommand(name);
}

View File

@@ -378,3 +378,7 @@ export { registerTool };
// Export ToolSchema type
export type { ToolSchema };
// Export CommandRegistry for direct import
export { CommandRegistry, createCommandRegistry, getAllCommandsSync, getCommandSync } from './command-registry.js';
export type { CommandMetadata, CommandSummary } from './command-registry.js';

View File

@@ -19,5 +19,5 @@
"noEmit": false
},
"include": ["src/**/*"],
"exclude": ["src/templates/**/*", "node_modules", "dist"]
"exclude": ["src/templates/**/*", "src/**/*.test.ts", "node_modules", "dist"]
}

21
codex-lens/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 CodexLens Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

59
codex-lens/README.md Normal file
View File

@@ -0,0 +1,59 @@
# CodexLens
CodexLens is a multi-modal code analysis platform designed to provide comprehensive code understanding and analysis capabilities.
## Features
- **Multi-language Support**: Analyze code in Python, JavaScript, TypeScript and more using Tree-sitter parsers
- **Semantic Search**: Find relevant code snippets using semantic understanding with fastembed and HNSWLIB
- **Code Parsing**: Advanced code structure parsing with tree-sitter
- **Flexible Architecture**: Modular design for easy extension and customization
## Installation
### Basic Installation
```bash
pip install codex-lens
```
### With Semantic Search
```bash
pip install codex-lens[semantic]
```
### With GPU Acceleration (NVIDIA CUDA)
```bash
pip install codex-lens[semantic-gpu]
```
### With DirectML (Windows - NVIDIA/AMD/Intel)
```bash
pip install codex-lens[semantic-directml]
```
### With All Optional Features
```bash
pip install codex-lens[full]
```
## Requirements
- Python >= 3.10
- See `pyproject.toml` for detailed dependency list
## Development
This project uses setuptools for building and packaging.
## License
MIT License
## Authors
CodexLens Contributors

View File

@@ -0,0 +1,28 @@
"""CodexLens package."""
from __future__ import annotations
from . import config, entities, errors
from .config import Config
from .entities import IndexedFile, SearchResult, SemanticChunk, Symbol
from .errors import CodexLensError, ConfigError, ParseError, SearchError, StorageError
__version__ = "0.1.0"
__all__ = [
"__version__",
"config",
"entities",
"errors",
"Config",
"IndexedFile",
"SearchResult",
"SemanticChunk",
"Symbol",
"CodexLensError",
"ConfigError",
"ParseError",
"StorageError",
"SearchError",
]

View File

@@ -0,0 +1,14 @@
"""Module entrypoint for `python -m codexlens`."""
from __future__ import annotations
from codexlens.cli import app
def main() -> None:
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
"""Codexlens Public API Layer.
This module exports all public API functions and dataclasses for the
codexlens LSP-like functionality.
Dataclasses (from models.py):
- CallInfo: Call relationship information
- MethodContext: Method context with call relationships
- FileContextResult: File context result with method summaries
- DefinitionResult: Definition lookup result
- ReferenceResult: Reference lookup result
- GroupedReferences: References grouped by definition
- SymbolInfo: Symbol information for workspace search
- HoverInfo: Hover information for a symbol
- SemanticResult: Semantic search result
Utility functions (from utils.py):
- resolve_project: Resolve and validate project root path
- normalize_relationship_type: Normalize relationship type to canonical form
- rank_by_proximity: Rank results by file path proximity
Example:
>>> from codexlens.api import (
... DefinitionResult,
... resolve_project,
... normalize_relationship_type
... )
>>> project = resolve_project("/path/to/project")
>>> rel_type = normalize_relationship_type("calls")
>>> print(rel_type)
'call'
"""
from __future__ import annotations
# Dataclasses
from .models import (
CallInfo,
MethodContext,
FileContextResult,
DefinitionResult,
ReferenceResult,
GroupedReferences,
SymbolInfo,
HoverInfo,
SemanticResult,
)
# Utility functions
from .utils import (
resolve_project,
normalize_relationship_type,
rank_by_proximity,
rank_by_score,
)
# API functions
from .definition import find_definition
from .symbols import workspace_symbols
from .hover import get_hover
from .file_context import file_context
from .references import find_references
from .semantic import semantic_search
__all__ = [
# Dataclasses
"CallInfo",
"MethodContext",
"FileContextResult",
"DefinitionResult",
"ReferenceResult",
"GroupedReferences",
"SymbolInfo",
"HoverInfo",
"SemanticResult",
# Utility functions
"resolve_project",
"normalize_relationship_type",
"rank_by_proximity",
"rank_by_score",
# API functions
"find_definition",
"workspace_symbols",
"get_hover",
"file_context",
"find_references",
"semantic_search",
]

View File

@@ -0,0 +1,126 @@
"""find_definition API implementation.
This module provides the find_definition() function for looking up
symbol definitions with a 3-stage fallback strategy.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import DefinitionResult
from .utils import resolve_project, rank_by_proximity
logger = logging.getLogger(__name__)
def find_definition(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
file_context: Optional[str] = None,
limit: int = 10
) -> List[DefinitionResult]:
"""Find definition locations for a symbol.
Uses a 3-stage fallback strategy:
1. Exact match with kind filter
2. Exact match without kind filter
3. Prefix match
Args:
project_root: Project root directory (for index location)
symbol_name: Name of the symbol to find
symbol_kind: Optional symbol kind filter (class, function, etc.)
file_context: Optional file path for proximity ranking
limit: Maximum number of results to return
Returns:
List of DefinitionResult sorted by proximity if file_context provided
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Stage 1: Exact match with kind filter
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
if results:
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
# Stage 2: Exact match without kind (if kind was specified)
if symbol_kind:
results = _search_with_kind(global_index, symbol_name, None, limit)
if results:
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
# Stage 3: Prefix match
results = global_index.search(
name=symbol_name,
kind=None,
limit=limit,
prefix_mode=True
)
if results:
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
logger.debug(f"No definitions found for {symbol_name}")
return []
def _search_with_kind(
global_index: GlobalSymbolIndex,
symbol_name: str,
symbol_kind: Optional[str],
limit: int
) -> List[Symbol]:
"""Search for symbols with optional kind filter."""
return global_index.search(
name=symbol_name,
kind=symbol_kind,
limit=limit,
prefix_mode=False
)
def _rank_and_convert(
symbols: List[Symbol],
file_context: Optional[str]
) -> List[DefinitionResult]:
"""Convert symbols to DefinitionResult and rank by proximity."""
results = [
DefinitionResult(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
end_line=sym.range[1] if sym.range else 1,
signature=None, # Could extract from file if needed
container=None, # Could extract from parent symbol
score=1.0
)
for sym in symbols
]
return rank_by_proximity(results, file_context)

View File

@@ -0,0 +1,271 @@
"""file_context API implementation.
This module provides the file_context() function for retrieving
method call graphs from a source file.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.dir_index import DirIndexStore
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import (
FileContextResult,
MethodContext,
CallInfo,
)
from .utils import resolve_project, normalize_relationship_type
logger = logging.getLogger(__name__)
def file_context(
project_root: str,
file_path: str,
include_calls: bool = True,
include_callers: bool = True,
max_depth: int = 1,
format: str = "brief"
) -> FileContextResult:
"""Get method call context for a code file.
Retrieves all methods/functions in the file along with their
outgoing calls and incoming callers.
Args:
project_root: Project root directory (for index location)
file_path: Path to the code file to analyze
include_calls: Whether to include outgoing calls
include_callers: Whether to include incoming callers
max_depth: Call chain depth (V1 only supports 1)
format: Output format (brief | detailed | tree)
Returns:
FileContextResult with method contexts and summary
Raises:
IndexNotFoundError: If project is not indexed
FileNotFoundError: If file does not exist
ValueError: If max_depth > 1 (V1 limitation)
"""
# V1 limitation: only depth=1 supported
if max_depth > 1:
raise ValueError(
f"max_depth > 1 not supported in V1. "
f"Requested: {max_depth}, supported: 1"
)
project_path = resolve_project(project_root)
file_path_resolved = Path(file_path).resolve()
# Validate file exists
if not file_path_resolved.exists():
raise FileNotFoundError(f"File not found: {file_path_resolved}")
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Get all symbols in the file
symbols = global_index.get_file_symbols(str(file_path_resolved))
# Filter to functions, methods, and classes
method_symbols = [
s for s in symbols
if s.kind in ("function", "method", "class")
]
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
# Try to find dir_index for relationship queries
dir_index = _find_dir_index(project_info, file_path_resolved)
# Build method contexts
methods: List[MethodContext] = []
outgoing_resolved = True
incoming_resolved = True
targets_resolved = True
for symbol in method_symbols:
calls: List[CallInfo] = []
callers: List[CallInfo] = []
if include_calls and dir_index:
try:
outgoing = dir_index.get_outgoing_calls(
str(file_path_resolved),
symbol.name
)
for target_name, rel_type, line, target_file in outgoing:
calls.append(CallInfo(
symbol_name=target_name,
file_path=target_file,
line=line,
relationship=normalize_relationship_type(rel_type)
))
if target_file is None:
targets_resolved = False
except Exception as e:
logger.debug(f"Failed to get outgoing calls: {e}")
outgoing_resolved = False
if include_callers and dir_index:
try:
incoming = dir_index.get_incoming_calls(symbol.name)
for source_name, rel_type, line, source_file in incoming:
callers.append(CallInfo(
symbol_name=source_name,
file_path=source_file,
line=line,
relationship=normalize_relationship_type(rel_type)
))
except Exception as e:
logger.debug(f"Failed to get incoming calls: {e}")
incoming_resolved = False
methods.append(MethodContext(
name=symbol.name,
kind=symbol.kind,
line_range=symbol.range if symbol.range else (1, 1),
signature=None, # Could extract from source
calls=calls,
callers=callers
))
# Detect language from file extension
language = _detect_language(file_path_resolved)
# Generate summary
summary = _generate_summary(file_path_resolved, methods, format)
return FileContextResult(
file_path=str(file_path_resolved),
language=language,
methods=methods,
summary=summary,
discovery_status={
"outgoing_resolved": outgoing_resolved,
"incoming_resolved": incoming_resolved,
"targets_resolved": targets_resolved
}
)
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
"""Find the dir_index that contains the file.
Args:
project_info: Project information from registry
file_path: Path to the file
Returns:
DirIndexStore if found, None otherwise
"""
try:
# Look for _index.db in file's directory or parent directories
current = file_path.parent
while current != current.parent:
index_db = current / "_index.db"
if index_db.exists():
return DirIndexStore(str(index_db))
# Also check in project's index_root
relative = current.relative_to(project_info.source_root)
index_in_cache = project_info.index_root / relative / "_index.db"
if index_in_cache.exists():
return DirIndexStore(str(index_in_cache))
current = current.parent
except Exception as e:
logger.debug(f"Failed to find dir_index: {e}")
return None
def _detect_language(file_path: Path) -> str:
"""Detect programming language from file extension.
Args:
file_path: Path to the file
Returns:
Language name
"""
ext_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".jsx": "javascript",
".tsx": "typescript",
".go": "go",
".rs": "rust",
".java": "java",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
}
return ext_map.get(file_path.suffix.lower(), "unknown")
def _generate_summary(
file_path: Path,
methods: List[MethodContext],
format: str
) -> str:
"""Generate human-readable summary of file context.
Args:
file_path: Path to the file
methods: List of method contexts
format: Output format (brief | detailed | tree)
Returns:
Markdown-formatted summary
"""
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
for method in methods:
start, end = method.line_range
lines.append(f"### {method.name} (line {start}-{end})")
if method.calls:
calls_str = ", ".join(
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
if format == "detailed"
else c.symbol_name
for c in method.calls
)
lines.append(f"- Calls: {calls_str}")
if method.callers:
callers_str = ", ".join(
f"{c.symbol_name} ({c.file_path}:{c.line})"
if format == "detailed"
else c.symbol_name
for c in method.callers
)
lines.append(f"- Called by: {callers_str}")
if not method.calls and not method.callers:
lines.append("- (no call relationships)")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,148 @@
"""get_hover API implementation.
This module provides the get_hover() function for retrieving
detailed hover information for symbols.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import HoverInfo
from .utils import resolve_project
logger = logging.getLogger(__name__)
def get_hover(
project_root: str,
symbol_name: str,
file_path: Optional[str] = None
) -> Optional[HoverInfo]:
"""Get detailed hover information for a symbol.
Args:
project_root: Project root directory (for index location)
symbol_name: Name of the symbol to look up
file_path: Optional file path to disambiguate when symbol
appears in multiple files
Returns:
HoverInfo if symbol found, None otherwise
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Search for the symbol
results = global_index.search(
name=symbol_name,
kind=None,
limit=50,
prefix_mode=False
)
if not results:
logger.debug(f"No hover info found for {symbol_name}")
return None
# If file_path provided, filter to that file
if file_path:
file_path_resolved = str(Path(file_path).resolve())
matching = [s for s in results if s.file == file_path_resolved]
if matching:
results = matching
# Take the first result
symbol = results[0]
# Build hover info
return HoverInfo(
name=symbol.name,
kind=symbol.kind,
signature=_extract_signature(symbol),
documentation=_extract_documentation(symbol),
file_path=symbol.file or "",
line_range=symbol.range if symbol.range else (1, 1),
type_info=_extract_type_info(symbol)
)
def _extract_signature(symbol: Symbol) -> str:
"""Extract signature from symbol.
For now, generates a basic signature based on kind and name.
In a full implementation, this would parse the actual source code.
Args:
symbol: The symbol to extract signature from
Returns:
Signature string
"""
if symbol.kind == "function":
return f"def {symbol.name}(...)"
elif symbol.kind == "method":
return f"def {symbol.name}(self, ...)"
elif symbol.kind == "class":
return f"class {symbol.name}"
elif symbol.kind == "variable":
return symbol.name
elif symbol.kind == "constant":
return f"{symbol.name} = ..."
else:
return f"{symbol.kind} {symbol.name}"
def _extract_documentation(symbol: Symbol) -> Optional[str]:
"""Extract documentation from symbol.
In a full implementation, this would parse docstrings from source.
For now, returns None.
Args:
symbol: The symbol to extract documentation from
Returns:
Documentation string if available, None otherwise
"""
# Would need to read source file and parse docstring
# For V1, return None
return None
def _extract_type_info(symbol: Symbol) -> Optional[str]:
"""Extract type information from symbol.
In a full implementation, this would parse type annotations.
For now, returns None.
Args:
symbol: The symbol to extract type info from
Returns:
Type info string if available, None otherwise
"""
# Would need to parse type annotations from source
# For V1, return None
return None

View File

@@ -0,0 +1,281 @@
"""API dataclass definitions for codexlens LSP API.
This module defines all result dataclasses used by the public API layer,
following the patterns established in mcp/schema.py.
"""
from __future__ import annotations
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Dict, Tuple
# =============================================================================
# Section 4.2: file_context dataclasses
# =============================================================================
@dataclass
class CallInfo:
"""Call relationship information.
Attributes:
symbol_name: Name of the called/calling symbol
file_path: Target file path (may be None if unresolved)
line: Line number of the call
relationship: Type of relationship (call | import | inheritance)
"""
symbol_name: str
file_path: Optional[str]
line: int
relationship: str # call | import | inheritance
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class MethodContext:
"""Method context with call relationships.
Attributes:
name: Method/function name
kind: Symbol kind (function | method | class)
line_range: Start and end line numbers
signature: Function signature (if available)
calls: List of outgoing calls
callers: List of incoming calls
"""
name: str
kind: str # function | method | class
line_range: Tuple[int, int]
signature: Optional[str]
calls: List[CallInfo] = field(default_factory=list)
callers: List[CallInfo] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
result = {
"name": self.name,
"kind": self.kind,
"line_range": list(self.line_range),
"calls": [c.to_dict() for c in self.calls],
"callers": [c.to_dict() for c in self.callers],
}
if self.signature is not None:
result["signature"] = self.signature
return result
@dataclass
class FileContextResult:
"""File context result with method summaries.
Attributes:
file_path: Path to the analyzed file
language: Programming language
methods: List of method contexts
summary: Human-readable summary
discovery_status: Status flags for call resolution
"""
file_path: str
language: str
methods: List[MethodContext]
summary: str
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
"outgoing_resolved": False,
"incoming_resolved": True,
"targets_resolved": False
})
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"file_path": self.file_path,
"language": self.language,
"methods": [m.to_dict() for m in self.methods],
"summary": self.summary,
"discovery_status": self.discovery_status,
}
# =============================================================================
# Section 4.3: find_definition dataclasses
# =============================================================================
@dataclass
class DefinitionResult:
"""Definition lookup result.
Attributes:
name: Symbol name
kind: Symbol kind (class, function, method, etc.)
file_path: File where symbol is defined
line: Start line number
end_line: End line number
signature: Symbol signature (if available)
container: Containing class/module (if any)
score: Match score for ranking
"""
name: str
kind: str
file_path: str
line: int
end_line: int
signature: Optional[str] = None
container: Optional[str] = None
score: float = 1.0
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
# =============================================================================
# Section 4.4: find_references dataclasses
# =============================================================================
@dataclass
class ReferenceResult:
"""Reference lookup result.
Attributes:
file_path: File containing the reference
line: Line number
column: Column number
context_line: The line of code containing the reference
relationship: Type of reference (call | import | type_annotation | inheritance)
"""
file_path: str
line: int
column: int
context_line: str
relationship: str # call | import | type_annotation | inheritance
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class GroupedReferences:
"""References grouped by definition.
Used when a symbol has multiple definitions (e.g., overloads).
Attributes:
definition: The definition this group refers to
references: List of references to this definition
"""
definition: DefinitionResult
references: List[ReferenceResult] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"definition": self.definition.to_dict(),
"references": [r.to_dict() for r in self.references],
}
# =============================================================================
# Section 4.5: workspace_symbols dataclasses
# =============================================================================
@dataclass
class SymbolInfo:
"""Symbol information for workspace search.
Attributes:
name: Symbol name
kind: Symbol kind
file_path: File where symbol is defined
line: Line number
container: Containing class/module (if any)
score: Match score for ranking
"""
name: str
kind: str
file_path: str
line: int
container: Optional[str] = None
score: float = 1.0
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
# =============================================================================
# Section 4.6: get_hover dataclasses
# =============================================================================
@dataclass
class HoverInfo:
"""Hover information for a symbol.
Attributes:
name: Symbol name
kind: Symbol kind
signature: Symbol signature
documentation: Documentation string (if available)
file_path: File where symbol is defined
line_range: Start and end line numbers
type_info: Type information (if available)
"""
name: str
kind: str
signature: str
documentation: Optional[str]
file_path: str
line_range: Tuple[int, int]
type_info: Optional[str] = None
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
result = {
"name": self.name,
"kind": self.kind,
"signature": self.signature,
"file_path": self.file_path,
"line_range": list(self.line_range),
}
if self.documentation is not None:
result["documentation"] = self.documentation
if self.type_info is not None:
result["type_info"] = self.type_info
return result
# =============================================================================
# Section 4.7: semantic_search dataclasses
# =============================================================================
@dataclass
class SemanticResult:
"""Semantic search result.
Attributes:
symbol_name: Name of the matched symbol
kind: Symbol kind
file_path: File where symbol is defined
line: Line number
vector_score: Vector similarity score (None if not available)
structural_score: Structural match score (None if not available)
fusion_score: Combined fusion score
snippet: Code snippet
match_reason: Explanation of why this matched (optional)
"""
symbol_name: str
kind: str
file_path: str
line: int
vector_score: Optional[float]
structural_score: Optional[float]
fusion_score: float
snippet: str
match_reason: Optional[str] = None
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}

View File

@@ -0,0 +1,345 @@
"""Find references API for codexlens.
This module implements the find_references() function that wraps
ChainSearchEngine.search_references() with grouped result structure
for multi-definition symbols.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional, Dict
from .models import (
DefinitionResult,
ReferenceResult,
GroupedReferences,
)
from .utils import (
resolve_project,
normalize_relationship_type,
)
logger = logging.getLogger(__name__)
def _read_line_from_file(file_path: str, line: int) -> str:
"""Read a specific line from a file.
Args:
file_path: Path to the file
line: Line number (1-based)
Returns:
The line content, stripped of trailing whitespace.
Returns empty string if file cannot be read or line doesn't exist.
"""
try:
path = Path(file_path)
if not path.exists():
return ""
with path.open("r", encoding="utf-8", errors="replace") as f:
for i, content in enumerate(f, 1):
if i == line:
return content.rstrip()
return ""
except Exception as exc:
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
return ""
def _transform_to_reference_result(
raw_ref: "RawReferenceResult",
) -> ReferenceResult:
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
Args:
raw_ref: Raw reference result from ChainSearchEngine
Returns:
API ReferenceResult with context_line and normalized relationship
"""
# Read the actual line from the file
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
# Normalize relationship type
relationship = normalize_relationship_type(raw_ref.relationship_type)
return ReferenceResult(
file_path=raw_ref.file_path,
line=raw_ref.line,
column=raw_ref.column,
context_line=context_line,
relationship=relationship,
)
def find_references(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
include_definition: bool = True,
group_by_definition: bool = True,
limit: int = 100,
) -> List[GroupedReferences]:
"""Find all reference locations for a symbol.
Multi-definition case returns grouped results to resolve ambiguity.
This function wraps ChainSearchEngine.search_references() and groups
the results by definition location. Each GroupedReferences contains
a definition and all references that point to it.
Args:
project_root: Project root directory path
symbol_name: Name of the symbol to find references for
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
include_definition: Whether to include the definition location
in the result (default True)
group_by_definition: Whether to group references by definition.
If False, returns a single group with all references.
(default True)
limit: Maximum number of references to return (default 100)
Returns:
List of GroupedReferences. Each group contains:
- definition: The DefinitionResult for this symbol definition
- references: List of ReferenceResult pointing to this definition
Raises:
ValueError: If project_root does not exist or is not a directory
Examples:
>>> refs = find_references("/path/to/project", "authenticate")
>>> for group in refs:
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
... for ref in group.references:
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
Note:
Reference relationship types are normalized:
- 'calls' -> 'call'
- 'imports' -> 'import'
- 'inherits' -> 'inheritance'
"""
# Validate and resolve project root
project_path = resolve_project(project_root)
# Import here to avoid circular imports
from codexlens.config import Config
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.search.chain_search import ChainSearchEngine
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
from codexlens.entities import Symbol
# Initialize infrastructure
config = Config()
registry = RegistryStore()
mapper = PathMapper(config.index_dir)
# Create chain search engine
engine = ChainSearchEngine(registry, mapper, config=config)
try:
# Step 1: Find definitions for the symbol
definitions: List[DefinitionResult] = []
if include_definition or group_by_definition:
# Search for symbol definitions
symbols = engine.search_symbols(
name=symbol_name,
source_path=project_path,
kind=symbol_kind,
)
# Convert Symbol to DefinitionResult
for sym in symbols:
# Only include exact name matches for definitions
if sym.name != symbol_name:
continue
# Optionally filter by kind
if symbol_kind and sym.kind != symbol_kind:
continue
definitions.append(DefinitionResult(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
end_line=sym.range[1] if sym.range else 1,
signature=None, # Not available from Symbol
container=None, # Not available from Symbol
score=1.0,
))
# Step 2: Get all references using ChainSearchEngine
raw_references = engine.search_references(
symbol_name=symbol_name,
source_path=project_path,
depth=-1,
limit=limit,
)
# Step 3: Transform raw references to API ReferenceResult
api_references: List[ReferenceResult] = []
for raw_ref in raw_references:
api_ref = _transform_to_reference_result(raw_ref)
api_references.append(api_ref)
# Step 4: Group references by definition
if group_by_definition and definitions:
return _group_references_by_definition(
definitions=definitions,
references=api_references,
include_definition=include_definition,
)
else:
# Return single group with placeholder definition or first definition
if definitions:
definition = definitions[0]
else:
# Create placeholder definition when no definition found
definition = DefinitionResult(
name=symbol_name,
kind=symbol_kind or "unknown",
file_path="",
line=0,
end_line=0,
signature=None,
container=None,
score=0.0,
)
return [GroupedReferences(
definition=definition,
references=api_references,
)]
finally:
engine.close()
def _group_references_by_definition(
definitions: List[DefinitionResult],
references: List[ReferenceResult],
include_definition: bool = True,
) -> List[GroupedReferences]:
"""Group references by their likely definition.
Uses file proximity heuristic to assign references to definitions.
References in the same file or directory as a definition are
assigned to that definition.
Args:
definitions: List of definition locations
references: List of reference locations
include_definition: Whether to include definition in results
Returns:
List of GroupedReferences with references assigned to definitions
"""
import os
if not definitions:
return []
if len(definitions) == 1:
# Single definition - all references belong to it
return [GroupedReferences(
definition=definitions[0],
references=references,
)]
# Multiple definitions - group by proximity
groups: Dict[int, List[ReferenceResult]] = {
i: [] for i in range(len(definitions))
}
for ref in references:
# Find the closest definition by file proximity
best_def_idx = 0
best_score = -1
for i, defn in enumerate(definitions):
score = _proximity_score(ref.file_path, defn.file_path)
if score > best_score:
best_score = score
best_def_idx = i
groups[best_def_idx].append(ref)
# Build result groups
result: List[GroupedReferences] = []
for i, defn in enumerate(definitions):
# Skip definitions with no references if not including definition itself
if not include_definition and not groups[i]:
continue
result.append(GroupedReferences(
definition=defn,
references=groups[i],
))
return result
def _proximity_score(ref_path: str, def_path: str) -> int:
"""Calculate proximity score between two file paths.
Args:
ref_path: Reference file path
def_path: Definition file path
Returns:
Proximity score (higher = closer):
- Same file: 1000
- Same directory: 100
- Otherwise: common path prefix length
"""
import os
if not ref_path or not def_path:
return 0
# Normalize paths
ref_path = os.path.normpath(ref_path)
def_path = os.path.normpath(def_path)
# Same file
if ref_path == def_path:
return 1000
ref_dir = os.path.dirname(ref_path)
def_dir = os.path.dirname(def_path)
# Same directory
if ref_dir == def_dir:
return 100
# Common path prefix
try:
common = os.path.commonpath([ref_path, def_path])
return len(common)
except ValueError:
# No common path (different drives on Windows)
return 0
# Type alias for the raw reference from ChainSearchEngine
class RawReferenceResult:
"""Type stub for ChainSearchEngine.ReferenceResult.
This is only used for type hints and is replaced at runtime
by the actual import.
"""
file_path: str
line: int
column: int
context: str
relationship_type: str

View File

@@ -0,0 +1,471 @@
"""Semantic search API with RRF fusion.
This module provides the semantic_search() function for combining
vector, structural, and keyword search with configurable fusion strategies.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional
from .models import SemanticResult
from .utils import resolve_project
logger = logging.getLogger(__name__)
def semantic_search(
project_root: str,
query: str,
mode: str = "fusion",
vector_weight: float = 0.5,
structural_weight: float = 0.3,
keyword_weight: float = 0.2,
fusion_strategy: str = "rrf",
kind_filter: Optional[List[str]] = None,
limit: int = 20,
include_match_reason: bool = False,
) -> List[SemanticResult]:
"""Semantic search - combining vector and structural search.
This function provides a high-level API for semantic code search,
combining vector similarity, structural (symbol + relationships),
and keyword-based search methods with configurable fusion.
Args:
project_root: Project root directory
query: Natural language query
mode: Search mode
- vector: Vector search only
- structural: Structural search only (symbol + relationships)
- fusion: Fusion search (default)
vector_weight: Vector search weight [0, 1] (default 0.5)
structural_weight: Structural search weight [0, 1] (default 0.3)
keyword_weight: Keyword search weight [0, 1] (default 0.2)
fusion_strategy: Fusion strategy (maps to chain_search.py)
- rrf: Reciprocal Rank Fusion (recommended, default)
- staged: Staged cascade -> staged_cascade_search
- binary: Binary rerank cascade -> binary_cascade_search
- hybrid: Hybrid cascade -> hybrid_cascade_search
kind_filter: Symbol type filter (e.g., ["function", "class"])
limit: Max return count (default 20)
include_match_reason: Generate match reason (heuristic, not LLM)
Returns:
Results sorted by fusion_score
Degradation:
- No vector index: vector_score=None, uses FTS + structural search
- No relationship data: structural_score=None, vector search only
Examples:
>>> results = semantic_search(
... "/path/to/project",
... "authentication handler",
... mode="fusion",
... fusion_strategy="rrf"
... )
>>> for r in results:
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
"""
# Validate and resolve project path
project_path = resolve_project(project_root)
# Normalize weights to sum to 1.0
total_weight = vector_weight + structural_weight + keyword_weight
if total_weight > 0:
vector_weight = vector_weight / total_weight
structural_weight = structural_weight / total_weight
keyword_weight = keyword_weight / total_weight
else:
# Default to equal weights if all zero
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
# Initialize search infrastructure
try:
from codexlens.config import Config
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
except ImportError as exc:
logger.error("Failed to import search dependencies: %s", exc)
return []
# Load config
config = Config.load()
# Get or create registry and mapper
try:
registry = RegistryStore.default()
mapper = PathMapper(registry)
except Exception as exc:
logger.error("Failed to initialize search infrastructure: %s", exc)
return []
# Build search options based on mode
search_options = _build_search_options(
mode=mode,
vector_weight=vector_weight,
structural_weight=structural_weight,
keyword_weight=keyword_weight,
limit=limit,
)
# Execute search based on fusion_strategy
try:
with ChainSearchEngine(registry, mapper, config=config) as engine:
chain_result = _execute_search(
engine=engine,
query=query,
source_path=project_path,
fusion_strategy=fusion_strategy,
options=search_options,
limit=limit,
)
except Exception as exc:
logger.error("Search execution failed: %s", exc)
return []
# Transform results to SemanticResult
semantic_results = _transform_results(
results=chain_result.results,
mode=mode,
vector_weight=vector_weight,
structural_weight=structural_weight,
keyword_weight=keyword_weight,
kind_filter=kind_filter,
include_match_reason=include_match_reason,
query=query,
)
return semantic_results[:limit]
def _build_search_options(
mode: str,
vector_weight: float,
structural_weight: float,
keyword_weight: float,
limit: int,
) -> "SearchOptions":
"""Build SearchOptions based on mode and weights.
Args:
mode: Search mode (vector, structural, fusion)
vector_weight: Vector search weight
structural_weight: Structural search weight
keyword_weight: Keyword search weight
limit: Result limit
Returns:
Configured SearchOptions
"""
from codexlens.search.chain_search import SearchOptions
# Default options
options = SearchOptions(
total_limit=limit * 2, # Fetch extra for filtering
limit_per_dir=limit,
include_symbols=True, # Always include symbols for structural
)
if mode == "vector":
# Pure vector mode
options.hybrid_mode = True
options.enable_vector = True
options.pure_vector = True
options.enable_fuzzy = False
elif mode == "structural":
# Structural only - use FTS + symbols
options.hybrid_mode = True
options.enable_vector = False
options.enable_fuzzy = True
options.include_symbols = True
else:
# Fusion mode (default)
options.hybrid_mode = True
options.enable_vector = vector_weight > 0
options.enable_fuzzy = keyword_weight > 0
options.include_symbols = structural_weight > 0
# Set custom weights for RRF
if options.enable_vector and keyword_weight > 0:
options.hybrid_weights = {
"vector": vector_weight,
"exact": keyword_weight * 0.7,
"fuzzy": keyword_weight * 0.3,
}
return options
def _execute_search(
engine: "ChainSearchEngine",
query: str,
source_path: Path,
fusion_strategy: str,
options: "SearchOptions",
limit: int,
) -> "ChainSearchResult":
"""Execute search using appropriate strategy.
Maps fusion_strategy to ChainSearchEngine methods:
- rrf: Standard hybrid search with RRF fusion
- staged: staged_cascade_search
- binary: binary_cascade_search
- hybrid: hybrid_cascade_search
Args:
engine: ChainSearchEngine instance
query: Search query
source_path: Project root path
fusion_strategy: Strategy name
options: Search options
limit: Result limit
Returns:
ChainSearchResult from the search
"""
from codexlens.search.chain_search import ChainSearchResult
if fusion_strategy == "staged":
# Use staged cascade search (4-stage pipeline)
return engine.staged_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
elif fusion_strategy == "binary":
# Use binary cascade search (binary coarse + dense fine)
return engine.binary_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
elif fusion_strategy == "hybrid":
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
return engine.hybrid_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
else:
# Default: rrf - Standard search with RRF fusion
return engine.search(
query=query,
source_path=source_path,
options=options,
)
def _transform_results(
results: List,
mode: str,
vector_weight: float,
structural_weight: float,
keyword_weight: float,
kind_filter: Optional[List[str]],
include_match_reason: bool,
query: str,
) -> List[SemanticResult]:
"""Transform ChainSearchEngine results to SemanticResult.
Args:
results: List of SearchResult objects
mode: Search mode
vector_weight: Vector weight used
structural_weight: Structural weight used
keyword_weight: Keyword weight used
kind_filter: Optional symbol kind filter
include_match_reason: Whether to generate match reasons
query: Original query (for match reason generation)
Returns:
List of SemanticResult objects
"""
semantic_results = []
for result in results:
# Extract symbol info
symbol_name = getattr(result, "symbol_name", None)
symbol_kind = getattr(result, "symbol_kind", None)
start_line = getattr(result, "start_line", None)
# Use symbol object if available
if hasattr(result, "symbol") and result.symbol:
symbol_name = symbol_name or result.symbol.name
symbol_kind = symbol_kind or result.symbol.kind
if hasattr(result.symbol, "range") and result.symbol.range:
start_line = start_line or result.symbol.range[0]
# Filter by kind if specified
if kind_filter and symbol_kind:
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
continue
# Determine scores based on mode and metadata
metadata = getattr(result, "metadata", {}) or {}
fusion_score = result.score
# Try to extract source scores from metadata
source_scores = metadata.get("source_scores", {})
vector_score: Optional[float] = None
structural_score: Optional[float] = None
if mode == "vector":
# In pure vector mode, the main score is the vector score
vector_score = result.score
structural_score = None
elif mode == "structural":
# In structural mode, no vector score
vector_score = None
structural_score = result.score
else:
# Fusion mode - try to extract individual scores
if "vector" in source_scores:
vector_score = source_scores["vector"]
elif metadata.get("fusion_method") == "simple_weighted":
# From weighted fusion
vector_score = source_scores.get("vector")
# Structural score approximation (from exact/fuzzy FTS)
fts_scores = []
if "exact" in source_scores:
fts_scores.append(source_scores["exact"])
if "fuzzy" in source_scores:
fts_scores.append(source_scores["fuzzy"])
if "splade" in source_scores:
fts_scores.append(source_scores["splade"])
if fts_scores:
structural_score = max(fts_scores)
# Build snippet
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
if len(snippet) > 500:
snippet = snippet[:500] + "..."
# Generate match reason if requested
match_reason = None
if include_match_reason:
match_reason = _generate_match_reason(
query=query,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
snippet=snippet,
vector_score=vector_score,
structural_score=structural_score,
)
semantic_result = SemanticResult(
symbol_name=symbol_name or Path(result.path).stem,
kind=symbol_kind or "unknown",
file_path=result.path,
line=start_line or 1,
vector_score=vector_score,
structural_score=structural_score,
fusion_score=fusion_score,
snippet=snippet,
match_reason=match_reason,
)
semantic_results.append(semantic_result)
# Sort by fusion_score descending
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
return semantic_results
def _generate_match_reason(
query: str,
symbol_name: Optional[str],
symbol_kind: Optional[str],
snippet: str,
vector_score: Optional[float],
structural_score: Optional[float],
) -> str:
"""Generate human-readable match reason heuristically.
This is a simple heuristic-based approach, not LLM-powered.
Args:
query: Original search query
symbol_name: Symbol name if available
symbol_kind: Symbol kind if available
snippet: Code snippet
vector_score: Vector similarity score
structural_score: Structural match score
Returns:
Human-readable explanation string
"""
reasons = []
# Check for direct name match
query_lower = query.lower()
query_words = set(query_lower.split())
if symbol_name:
name_lower = symbol_name.lower()
# Direct substring match
if query_lower in name_lower or name_lower in query_lower:
reasons.append(f"Symbol name '{symbol_name}' matches query")
# Word overlap
name_words = set(_split_camel_case(symbol_name).lower().split())
overlap = query_words & name_words
if overlap and not reasons:
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
# Check snippet for keyword matches
snippet_lower = snippet.lower()
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
if matching_words and len(reasons) < 2:
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
# Add score-based reasoning
if vector_score is not None and vector_score > 0.7:
reasons.append("High semantic similarity")
elif vector_score is not None and vector_score > 0.5:
reasons.append("Moderate semantic similarity")
if structural_score is not None and structural_score > 0.8:
reasons.append("Strong structural match")
# Symbol kind context
if symbol_kind and len(reasons) < 3:
reasons.append(f"Matched {symbol_kind}")
if not reasons:
reasons.append("Partial relevance based on content analysis")
return "; ".join(reasons[:3])
def _split_camel_case(name: str) -> str:
"""Split camelCase and PascalCase to words.
Args:
name: Symbol name in camelCase or PascalCase
Returns:
Space-separated words
"""
import re
# Insert space before uppercase letters
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
# Insert space before uppercase followed by lowercase
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
# Replace underscores with spaces
result = result.replace("_", " ")
return result

View File

@@ -0,0 +1,146 @@
"""workspace_symbols API implementation.
This module provides the workspace_symbols() function for searching
symbols across the entire workspace with prefix matching.
"""
from __future__ import annotations
import fnmatch
import logging
from pathlib import Path
from typing import List, Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import SymbolInfo
from .utils import resolve_project
logger = logging.getLogger(__name__)
def workspace_symbols(
project_root: str,
query: str,
kind_filter: Optional[List[str]] = None,
file_pattern: Optional[str] = None,
limit: int = 50
) -> List[SymbolInfo]:
"""Search for symbols across the entire workspace.
Uses prefix matching for efficient searching.
Args:
project_root: Project root directory (for index location)
query: Search query (prefix match)
kind_filter: Optional list of symbol kinds to include
(e.g., ["class", "function"])
file_pattern: Optional glob pattern to filter by file path
(e.g., "*.py", "src/**/*.ts")
limit: Maximum number of results to return
Returns:
List of SymbolInfo sorted by score
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Search with prefix matching
# If kind_filter has multiple kinds, we need to search for each
all_results: List[Symbol] = []
if kind_filter and len(kind_filter) > 0:
# Search for each kind separately
for kind in kind_filter:
results = global_index.search(
name=query,
kind=kind,
limit=limit,
prefix_mode=True
)
all_results.extend(results)
else:
# Search without kind filter
all_results = global_index.search(
name=query,
kind=None,
limit=limit,
prefix_mode=True
)
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
# Apply file pattern filter if specified
if file_pattern:
all_results = [
sym for sym in all_results
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
]
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
# Convert to SymbolInfo and sort by relevance
symbols = [
SymbolInfo(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
container=None, # Could extract from parent
score=_calculate_score(sym.name, query)
)
for sym in all_results
]
# Sort by score (exact matches first)
symbols.sort(key=lambda s: s.score, reverse=True)
return symbols[:limit]
def _calculate_score(symbol_name: str, query: str) -> float:
"""Calculate relevance score for a symbol match.
Scoring:
- Exact match: 1.0
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
- Case-insensitive match: 0.6
Args:
symbol_name: The matched symbol name
query: The search query
Returns:
Score between 0.0 and 1.0
"""
if symbol_name == query:
return 1.0
if symbol_name.lower() == query.lower():
return 0.9
if symbol_name.startswith(query):
ratio = len(query) / len(symbol_name)
return 0.8 + 0.2 * ratio
if symbol_name.lower().startswith(query.lower()):
ratio = len(query) / len(symbol_name)
return 0.6 + 0.2 * ratio
return 0.5

View File

@@ -0,0 +1,153 @@
"""Utility functions for the codexlens API.
This module provides helper functions for:
- Project resolution
- Relationship type normalization
- Result ranking by proximity
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import List, Optional, TypeVar, Callable
from .models import DefinitionResult
# Type variable for generic ranking
T = TypeVar('T')
def resolve_project(project_root: str) -> Path:
"""Resolve and validate project root path.
Args:
project_root: Path to project root (relative or absolute)
Returns:
Resolved absolute Path
Raises:
ValueError: If path does not exist or is not a directory
"""
path = Path(project_root).resolve()
if not path.exists():
raise ValueError(f"Project root does not exist: {path}")
if not path.is_dir():
raise ValueError(f"Project root is not a directory: {path}")
return path
# Relationship type normalization mapping
_RELATIONSHIP_NORMALIZATION = {
# Plural to singular
"calls": "call",
"imports": "import",
"inherits": "inheritance",
"uses": "use",
# Already normalized (passthrough)
"call": "call",
"import": "import",
"inheritance": "inheritance",
"use": "use",
"type_annotation": "type_annotation",
}
def normalize_relationship_type(relationship: str) -> str:
"""Normalize relationship type to canonical form.
Converts plural forms and variations to standard singular forms:
- 'calls' -> 'call'
- 'imports' -> 'import'
- 'inherits' -> 'inheritance'
- 'uses' -> 'use'
Args:
relationship: Raw relationship type string
Returns:
Normalized relationship type
Examples:
>>> normalize_relationship_type('calls')
'call'
>>> normalize_relationship_type('inherits')
'inheritance'
>>> normalize_relationship_type('call')
'call'
"""
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
def rank_by_proximity(
results: List[DefinitionResult],
file_context: Optional[str] = None
) -> List[DefinitionResult]:
"""Rank results by file path proximity to context.
V1 Implementation: Uses path-based proximity scoring.
Scoring algorithm:
1. Same directory: highest score (100)
2. Otherwise: length of common path prefix
Args:
results: List of definition results to rank
file_context: Reference file path for proximity calculation.
If None, returns results unchanged.
Returns:
Results sorted by proximity score (highest first)
Examples:
>>> results = [
... DefinitionResult(name="foo", kind="function",
... file_path="/a/b/c.py", line=1, end_line=10),
... DefinitionResult(name="foo", kind="function",
... file_path="/a/x/y.py", line=1, end_line=10),
... ]
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
>>> ranked[0].file_path
'/a/b/c.py'
"""
if not file_context or not results:
return results
def proximity_score(result: DefinitionResult) -> int:
"""Calculate proximity score for a result."""
result_dir = os.path.dirname(result.file_path)
context_dir = os.path.dirname(file_context)
# Same directory gets highest score
if result_dir == context_dir:
return 100
# Otherwise, score by common path prefix length
try:
common = os.path.commonpath([result.file_path, file_context])
return len(common)
except ValueError:
# No common path (different drives on Windows)
return 0
return sorted(results, key=proximity_score, reverse=True)
def rank_by_score(
results: List[T],
score_fn: Callable[[T], float],
reverse: bool = True
) -> List[T]:
"""Generic ranking function by custom score.
Args:
results: List of items to rank
score_fn: Function to extract score from item
reverse: If True, highest scores first (default)
Returns:
Sorted list
"""
return sorted(results, key=score_fn, reverse=reverse)

View File

@@ -0,0 +1,27 @@
"""CLI package for CodexLens."""
from __future__ import annotations
import sys
import os
# Force UTF-8 encoding for Windows console
# This ensures Chinese characters display correctly instead of GBK garbled text
if sys.platform == "win32":
# Set environment variable for Python I/O encoding
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
# Reconfigure stdout/stderr to use UTF-8 if possible
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
# Fallback: some environments don't support reconfigure
pass
from .commands import app
__all__ = ["app"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,135 @@
"""Rich and JSON output helpers for CodexLens CLI."""
from __future__ import annotations
import json
import sys
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any, Iterable, Mapping, Sequence
from rich.console import Console
from rich.table import Table
from rich.text import Text
from codexlens.entities import SearchResult, Symbol
# Force UTF-8 encoding for Windows console to properly display Chinese text
# Use force_terminal=True and legacy_windows=False to avoid GBK encoding issues
console = Console(force_terminal=True, legacy_windows=False)
def _to_jsonable(value: Any) -> Any:
if value is None:
return None
if hasattr(value, "model_dump"):
return value.model_dump()
if is_dataclass(value):
return asdict(value)
if isinstance(value, Path):
return str(value)
if isinstance(value, Mapping):
return {k: _to_jsonable(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return [_to_jsonable(v) for v in value]
return value
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
"""Print JSON output with optional additional fields.
Args:
success: Whether the operation succeeded
result: Result data (used when success=True)
error: Error message (used when success=False)
**kwargs: Additional fields to include in the payload (e.g., code, details)
"""
payload: dict[str, Any] = {"success": success}
if success:
payload["result"] = _to_jsonable(result)
else:
payload["error"] = error or "Unknown error"
# Include additional error details if provided
for key, value in kwargs.items():
payload[key] = _to_jsonable(value)
console.print_json(json.dumps(payload, ensure_ascii=False))
def render_search_results(
results: Sequence[SearchResult], *, title: str = "Search Results", verbose: bool = False
) -> None:
"""Render search results with optional source tags in verbose mode.
Args:
results: Search results to display
title: Table title
verbose: If True, show search source tags ([E], [F], [V]) and fusion scores
"""
table = Table(title=title, show_lines=False)
if verbose:
# Verbose mode: show source tags
table.add_column("Source", style="dim", width=6, justify="center")
table.add_column("Path", style="cyan", no_wrap=True)
table.add_column("Score", style="magenta", justify="right")
table.add_column("Excerpt", style="white")
for res in results:
excerpt = res.excerpt or ""
score_str = f"{res.score:.3f}"
if verbose:
# Extract search source tag if available
source = getattr(res, "search_source", None)
source_tag = ""
if source == "exact":
source_tag = "[E]"
elif source == "fuzzy":
source_tag = "[F]"
elif source == "vector":
source_tag = "[V]"
elif source == "fusion":
source_tag = "[RRF]"
table.add_row(source_tag, res.path, score_str, excerpt)
else:
table.add_row(res.path, score_str, excerpt)
console.print(table)
def render_symbols(symbols: Sequence[Symbol], *, title: str = "Symbols") -> None:
table = Table(title=title)
table.add_column("Name", style="green")
table.add_column("Kind", style="yellow")
table.add_column("Range", style="white", justify="right")
for sym in symbols:
start, end = sym.range
table.add_row(sym.name, sym.kind, f"{start}-{end}")
console.print(table)
def render_status(stats: Mapping[str, Any]) -> None:
table = Table(title="Index Status")
table.add_column("Metric", style="cyan")
table.add_column("Value", style="white")
for key, value in stats.items():
if isinstance(value, Mapping):
value_text = ", ".join(f"{k}:{v}" for k, v in value.items())
elif isinstance(value, (list, tuple)):
value_text = ", ".join(str(v) for v in value)
else:
value_text = str(value)
table.add_row(str(key), value_text)
console.print(table)
def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) -> None:
header = Text.assemble(("File: ", "bold"), (path, "cyan"), (" Language: ", "bold"), (language, "green"))
console.print(header)
render_symbols(list(symbols), title="Discovered Symbols")

View File

@@ -0,0 +1,692 @@
"""Configuration system for CodexLens."""
from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Any, Dict, List, Optional
from .errors import ConfigError
# Workspace-local directory name
WORKSPACE_DIR_NAME = ".codexlens"
# Settings file name
SETTINGS_FILE_NAME = "settings.json"
# SPLADE index database name (centralized storage)
SPLADE_DB_NAME = "_splade.db"
# Dense vector storage names (centralized storage)
VECTORS_HNSW_NAME = "_vectors.hnsw"
VECTORS_META_DB_NAME = "_vectors_meta.db"
BINARY_VECTORS_MMAP_NAME = "_binary_vectors.mmap"
log = logging.getLogger(__name__)
def _default_global_dir() -> Path:
"""Get global CodexLens data directory."""
env_override = os.getenv("CODEXLENS_DATA_DIR")
if env_override:
return Path(env_override).expanduser().resolve()
return (Path.home() / ".codexlens").resolve()
def find_workspace_root(start_path: Path) -> Optional[Path]:
"""Find the workspace root by looking for .codexlens directory.
Searches from start_path upward to find an existing .codexlens directory.
Returns None if not found.
"""
current = start_path.resolve()
# Search up to filesystem root
while current != current.parent:
workspace_dir = current / WORKSPACE_DIR_NAME
if workspace_dir.is_dir():
return current
current = current.parent
# Check root as well
workspace_dir = current / WORKSPACE_DIR_NAME
if workspace_dir.is_dir():
return current
return None
@dataclass
class Config:
"""Runtime configuration for CodexLens.
- data_dir: Base directory for all persistent CodexLens data.
- venv_path: Optional virtualenv used for language tooling.
- supported_languages: Language IDs and their associated file extensions.
- parsing_rules: Per-language parsing and chunking hints.
"""
data_dir: Path = field(default_factory=_default_global_dir)
venv_path: Path = field(default_factory=lambda: _default_global_dir() / "venv")
supported_languages: Dict[str, Dict[str, Any]] = field(
default_factory=lambda: {
# Source code languages (category: "code")
"python": {"extensions": [".py"], "tree_sitter_language": "python", "category": "code"},
"javascript": {"extensions": [".js", ".jsx"], "tree_sitter_language": "javascript", "category": "code"},
"typescript": {"extensions": [".ts", ".tsx"], "tree_sitter_language": "typescript", "category": "code"},
"java": {"extensions": [".java"], "tree_sitter_language": "java", "category": "code"},
"go": {"extensions": [".go"], "tree_sitter_language": "go", "category": "code"},
"zig": {"extensions": [".zig"], "tree_sitter_language": "zig", "category": "code"},
"objective-c": {"extensions": [".m", ".mm"], "tree_sitter_language": "objc", "category": "code"},
"c": {"extensions": [".c", ".h"], "tree_sitter_language": "c", "category": "code"},
"cpp": {"extensions": [".cc", ".cpp", ".hpp", ".cxx"], "tree_sitter_language": "cpp", "category": "code"},
"rust": {"extensions": [".rs"], "tree_sitter_language": "rust", "category": "code"},
}
)
parsing_rules: Dict[str, Dict[str, Any]] = field(
default_factory=lambda: {
"default": {
"max_chunk_chars": 4000,
"max_chunk_lines": 200,
"overlap_lines": 20,
}
}
)
llm_enabled: bool = False
llm_tool: str = "gemini"
llm_timeout_ms: int = 300000
llm_batch_size: int = 5
# Hybrid chunker configuration
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
# Embedding configuration
embedding_backend: str = "fastembed" # "fastembed" (local) or "litellm" (API)
embedding_model: str = "code" # For fastembed: profile (fast/code/multilingual/balanced)
# For litellm: model name from config (e.g., "qwen3-embedding")
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
# SPLADE sparse retrieval configuration
enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead)
splade_model: str = "naver/splade-cocondenser-ensembledistil"
splade_threshold: float = 0.01 # Min weight to store in index
splade_onnx_path: Optional[str] = None # Custom ONNX model path
# FTS fallback (disabled by default, available via --use-fts)
use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled)
# Indexing/search optimizations
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
# Graph expansion (search-time, uses precomputed neighbors)
enable_graph_expansion: bool = False
graph_expansion_depth: int = 2
# Optional search reranking (disabled by default)
enable_reranking: bool = False
reranking_top_k: int = 50
symbol_boost_factor: float = 1.5
# Optional cross-encoder reranking (second stage; requires optional reranker deps)
enable_cross_encoder_rerank: bool = False
reranker_backend: str = "onnx"
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker_top_k: int = 50
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
reranker_chunk_type_weights: Optional[Dict[str, float]] = None # Weights for chunk types: {"code": 1.0, "docstring": 0.7}
reranker_test_file_penalty: float = 0.0 # Penalty for test files (0.0-1.0, e.g., 0.2 = 20% reduction)
# Chunk stripping configuration (for semantic embedding)
chunk_strip_comments: bool = True # Strip comments from code chunks
chunk_strip_docstrings: bool = True # Strip docstrings from code chunks
# Cascade search configuration (two-stage retrieval)
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
cascade_coarse_k: int = 100 # Number of coarse candidates from first stage
cascade_fine_k: int = 10 # Number of final results after reranking
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
# Staged cascade search configuration (4-stage pipeline)
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
# RRF fusion configuration
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
rrf_k: int = 60 # RRF constant (default 60)
# Category-based filtering to separate code/doc results
enable_category_filter: bool = True # Enable code/doc result separation
# Multi-endpoint configuration for litellm backend
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
embedding_pool_enabled: bool = False # Enable high availability pool for embeddings
embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
# Reranker multi-endpoint configuration
reranker_pool_enabled: bool = False # Enable high availability pool for reranker
reranker_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
reranker_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
# API concurrency settings
api_max_workers: int = 4 # Max concurrent API calls for embedding/reranking
api_batch_size: int = 8 # Batch size for API requests
api_batch_size_dynamic: bool = False # Enable dynamic batch size calculation
api_batch_size_utilization_factor: float = 0.8 # Use 80% of model token capacity
api_batch_size_max: int = 2048 # Absolute upper limit for batch size
chars_per_token_estimate: int = 4 # Characters per token estimation ratio
def __post_init__(self) -> None:
try:
self.data_dir = self.data_dir.expanduser().resolve()
self.venv_path = self.venv_path.expanduser().resolve()
self.data_dir.mkdir(parents=True, exist_ok=True)
except PermissionError as exc:
raise ConfigError(
f"Permission denied initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
f"[{type(exc).__name__}]: {exc}"
) from exc
except OSError as exc:
raise ConfigError(
f"Filesystem error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
f"[{type(exc).__name__}]: {exc}"
) from exc
except Exception as exc:
raise ConfigError(
f"Unexpected error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
f"[{type(exc).__name__}]: {exc}"
) from exc
@cached_property
def cache_dir(self) -> Path:
"""Directory for transient caches."""
return self.data_dir / "cache"
@cached_property
def index_dir(self) -> Path:
"""Directory where index artifacts are stored."""
return self.data_dir / "index"
@cached_property
def db_path(self) -> Path:
"""Default SQLite index path."""
return self.index_dir / "codexlens.db"
def ensure_runtime_dirs(self) -> None:
"""Create standard runtime directories if missing."""
for directory in (self.cache_dir, self.index_dir):
try:
directory.mkdir(parents=True, exist_ok=True)
except PermissionError as exc:
raise ConfigError(
f"Permission denied creating directory {directory} [{type(exc).__name__}]: {exc}"
) from exc
except OSError as exc:
raise ConfigError(
f"Filesystem error creating directory {directory} [{type(exc).__name__}]: {exc}"
) from exc
except Exception as exc:
raise ConfigError(
f"Unexpected error creating directory {directory} [{type(exc).__name__}]: {exc}"
) from exc
def language_for_path(self, path: str | Path) -> str | None:
"""Infer a supported language ID from a file path."""
extension = Path(path).suffix.lower()
for language_id, spec in self.supported_languages.items():
extensions: List[str] = spec.get("extensions", [])
if extension in extensions:
return language_id
return None
def category_for_path(self, path: str | Path) -> str | None:
"""Get file category ('code' or 'doc') from a file path."""
language = self.language_for_path(path)
if language is None:
return None
spec = self.supported_languages.get(language, {})
return spec.get("category")
def rules_for_language(self, language_id: str) -> Dict[str, Any]:
"""Get parsing rules for a specific language, falling back to defaults."""
return {**self.parsing_rules.get("default", {}), **self.parsing_rules.get(language_id, {})}
@cached_property
def settings_path(self) -> Path:
"""Path to the settings file."""
return self.data_dir / SETTINGS_FILE_NAME
def save_settings(self) -> None:
"""Save embedding and other settings to file."""
embedding_config = {
"backend": self.embedding_backend,
"model": self.embedding_model,
"use_gpu": self.embedding_use_gpu,
"pool_enabled": self.embedding_pool_enabled,
"strategy": self.embedding_strategy,
"cooldown": self.embedding_cooldown,
}
# Include multi-endpoint config if present
if self.embedding_endpoints:
embedding_config["endpoints"] = self.embedding_endpoints
settings = {
"embedding": embedding_config,
"llm": {
"enabled": self.llm_enabled,
"tool": self.llm_tool,
"timeout_ms": self.llm_timeout_ms,
"batch_size": self.llm_batch_size,
},
"reranker": {
"enabled": self.enable_cross_encoder_rerank,
"backend": self.reranker_backend,
"model": self.reranker_model,
"top_k": self.reranker_top_k,
"max_input_tokens": self.reranker_max_input_tokens,
"pool_enabled": self.reranker_pool_enabled,
"strategy": self.reranker_strategy,
"cooldown": self.reranker_cooldown,
},
"cascade": {
"strategy": self.cascade_strategy,
"coarse_k": self.cascade_coarse_k,
"fine_k": self.cascade_fine_k,
},
"api": {
"max_workers": self.api_max_workers,
"batch_size": self.api_batch_size,
"batch_size_dynamic": self.api_batch_size_dynamic,
"batch_size_utilization_factor": self.api_batch_size_utilization_factor,
"batch_size_max": self.api_batch_size_max,
"chars_per_token_estimate": self.chars_per_token_estimate,
},
}
with open(self.settings_path, "w", encoding="utf-8") as f:
json.dump(settings, f, indent=2)
def load_settings(self) -> None:
"""Load settings from file if exists."""
if not self.settings_path.exists():
return
try:
with open(self.settings_path, "r", encoding="utf-8") as f:
settings = json.load(f)
# Load embedding settings
embedding = settings.get("embedding", {})
if "backend" in embedding:
backend = embedding["backend"]
# Support 'api' as alias for 'litellm'
if backend == "api":
backend = "litellm"
if backend in {"fastembed", "litellm"}:
self.embedding_backend = backend
else:
log.warning(
"Invalid embedding backend in %s: %r (expected 'fastembed' or 'litellm')",
self.settings_path,
embedding["backend"],
)
if "model" in embedding:
self.embedding_model = embedding["model"]
if "use_gpu" in embedding:
self.embedding_use_gpu = embedding["use_gpu"]
# Load multi-endpoint configuration
if "endpoints" in embedding:
self.embedding_endpoints = embedding["endpoints"]
if "pool_enabled" in embedding:
self.embedding_pool_enabled = embedding["pool_enabled"]
if "strategy" in embedding:
self.embedding_strategy = embedding["strategy"]
if "cooldown" in embedding:
self.embedding_cooldown = embedding["cooldown"]
# Load LLM settings
llm = settings.get("llm", {})
if "enabled" in llm:
self.llm_enabled = llm["enabled"]
if "tool" in llm:
self.llm_tool = llm["tool"]
if "timeout_ms" in llm:
self.llm_timeout_ms = llm["timeout_ms"]
if "batch_size" in llm:
self.llm_batch_size = llm["batch_size"]
# Load reranker settings
reranker = settings.get("reranker", {})
if "enabled" in reranker:
self.enable_cross_encoder_rerank = reranker["enabled"]
if "backend" in reranker:
backend = reranker["backend"]
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
self.reranker_backend = backend
else:
log.warning(
"Invalid reranker backend in %s: %r (expected 'fastembed', 'onnx', 'api', 'litellm', or 'legacy')",
self.settings_path,
backend,
)
if "model" in reranker:
self.reranker_model = reranker["model"]
if "top_k" in reranker:
self.reranker_top_k = reranker["top_k"]
if "max_input_tokens" in reranker:
self.reranker_max_input_tokens = reranker["max_input_tokens"]
if "pool_enabled" in reranker:
self.reranker_pool_enabled = reranker["pool_enabled"]
if "strategy" in reranker:
self.reranker_strategy = reranker["strategy"]
if "cooldown" in reranker:
self.reranker_cooldown = reranker["cooldown"]
# Load cascade settings
cascade = settings.get("cascade", {})
if "strategy" in cascade:
strategy = cascade["strategy"]
if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}:
self.cascade_strategy = strategy
else:
log.warning(
"Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')",
self.settings_path,
strategy,
)
if "coarse_k" in cascade:
self.cascade_coarse_k = cascade["coarse_k"]
if "fine_k" in cascade:
self.cascade_fine_k = cascade["fine_k"]
# Load API settings
api = settings.get("api", {})
if "max_workers" in api:
self.api_max_workers = api["max_workers"]
if "batch_size" in api:
self.api_batch_size = api["batch_size"]
if "batch_size_dynamic" in api:
self.api_batch_size_dynamic = api["batch_size_dynamic"]
if "batch_size_utilization_factor" in api:
self.api_batch_size_utilization_factor = api["batch_size_utilization_factor"]
if "batch_size_max" in api:
self.api_batch_size_max = api["batch_size_max"]
if "chars_per_token_estimate" in api:
self.chars_per_token_estimate = api["chars_per_token_estimate"]
except Exception as exc:
log.warning(
"Failed to load settings from %s (%s): %s",
self.settings_path,
type(exc).__name__,
exc,
)
# Apply .env overrides (highest priority)
self._apply_env_overrides()
def _apply_env_overrides(self) -> None:
"""Apply environment variable overrides from .env file.
Priority: default → settings.json → .env (highest)
Supported variables (with or without CODEXLENS_ prefix):
EMBEDDING_MODEL: Override embedding model/profile
EMBEDDING_BACKEND: Override embedding backend (fastembed/litellm)
EMBEDDING_POOL_ENABLED: Enable embedding high availability pool
EMBEDDING_STRATEGY: Load balance strategy for embedding
EMBEDDING_COOLDOWN: Rate limit cooldown for embedding
RERANKER_MODEL: Override reranker model
RERANKER_BACKEND: Override reranker backend
RERANKER_ENABLED: Override reranker enabled state (true/false)
RERANKER_POOL_ENABLED: Enable reranker high availability pool
RERANKER_STRATEGY: Load balance strategy for reranker
RERANKER_COOLDOWN: Rate limit cooldown for reranker
"""
from .env_config import load_global_env
env_vars = load_global_env()
if not env_vars:
return
def get_env(key: str) -> str | None:
"""Get env var with or without CODEXLENS_ prefix."""
# Check prefixed version first (Dashboard format), then unprefixed
return env_vars.get(f"CODEXLENS_{key}") or env_vars.get(key)
# Embedding overrides
embedding_model = get_env("EMBEDDING_MODEL")
if embedding_model:
self.embedding_model = embedding_model
log.debug("Overriding embedding_model from .env: %s", self.embedding_model)
embedding_backend = get_env("EMBEDDING_BACKEND")
if embedding_backend:
backend = embedding_backend.lower()
# Support 'api' as alias for 'litellm'
if backend == "api":
backend = "litellm"
if backend in {"fastembed", "litellm"}:
self.embedding_backend = backend
log.debug("Overriding embedding_backend from .env: %s", backend)
else:
log.warning("Invalid EMBEDDING_BACKEND in .env: %r", embedding_backend)
embedding_pool = get_env("EMBEDDING_POOL_ENABLED")
if embedding_pool:
value = embedding_pool.lower()
self.embedding_pool_enabled = value in {"true", "1", "yes", "on"}
log.debug("Overriding embedding_pool_enabled from .env: %s", self.embedding_pool_enabled)
embedding_strategy = get_env("EMBEDDING_STRATEGY")
if embedding_strategy:
strategy = embedding_strategy.lower()
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
self.embedding_strategy = strategy
log.debug("Overriding embedding_strategy from .env: %s", strategy)
else:
log.warning("Invalid EMBEDDING_STRATEGY in .env: %r", embedding_strategy)
embedding_cooldown = get_env("EMBEDDING_COOLDOWN")
if embedding_cooldown:
try:
self.embedding_cooldown = float(embedding_cooldown)
log.debug("Overriding embedding_cooldown from .env: %s", self.embedding_cooldown)
except ValueError:
log.warning("Invalid EMBEDDING_COOLDOWN in .env: %r", embedding_cooldown)
# Reranker overrides
reranker_model = get_env("RERANKER_MODEL")
if reranker_model:
self.reranker_model = reranker_model
log.debug("Overriding reranker_model from .env: %s", self.reranker_model)
reranker_backend = get_env("RERANKER_BACKEND")
if reranker_backend:
backend = reranker_backend.lower()
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
self.reranker_backend = backend
log.debug("Overriding reranker_backend from .env: %s", backend)
else:
log.warning("Invalid RERANKER_BACKEND in .env: %r", reranker_backend)
reranker_enabled = get_env("RERANKER_ENABLED")
if reranker_enabled:
value = reranker_enabled.lower()
self.enable_cross_encoder_rerank = value in {"true", "1", "yes", "on"}
log.debug("Overriding reranker_enabled from .env: %s", self.enable_cross_encoder_rerank)
reranker_pool = get_env("RERANKER_POOL_ENABLED")
if reranker_pool:
value = reranker_pool.lower()
self.reranker_pool_enabled = value in {"true", "1", "yes", "on"}
log.debug("Overriding reranker_pool_enabled from .env: %s", self.reranker_pool_enabled)
reranker_strategy = get_env("RERANKER_STRATEGY")
if reranker_strategy:
strategy = reranker_strategy.lower()
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
self.reranker_strategy = strategy
log.debug("Overriding reranker_strategy from .env: %s", strategy)
else:
log.warning("Invalid RERANKER_STRATEGY in .env: %r", reranker_strategy)
reranker_cooldown = get_env("RERANKER_COOLDOWN")
if reranker_cooldown:
try:
self.reranker_cooldown = float(reranker_cooldown)
log.debug("Overriding reranker_cooldown from .env: %s", self.reranker_cooldown)
except ValueError:
log.warning("Invalid RERANKER_COOLDOWN in .env: %r", reranker_cooldown)
reranker_max_tokens = get_env("RERANKER_MAX_INPUT_TOKENS")
if reranker_max_tokens:
try:
self.reranker_max_input_tokens = int(reranker_max_tokens)
log.debug("Overriding reranker_max_input_tokens from .env: %s", self.reranker_max_input_tokens)
except ValueError:
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
# Reranker tuning from environment
test_penalty = get_env("RERANKER_TEST_FILE_PENALTY")
if test_penalty:
try:
self.reranker_test_file_penalty = float(test_penalty)
log.debug("Overriding reranker_test_file_penalty from .env: %s", self.reranker_test_file_penalty)
except ValueError:
log.warning("Invalid RERANKER_TEST_FILE_PENALTY in .env: %r", test_penalty)
docstring_weight = get_env("RERANKER_DOCSTRING_WEIGHT")
if docstring_weight:
try:
weight = float(docstring_weight)
self.reranker_chunk_type_weights = {"code": 1.0, "docstring": weight}
log.debug("Overriding reranker docstring weight from .env: %s", weight)
except ValueError:
log.warning("Invalid RERANKER_DOCSTRING_WEIGHT in .env: %r", docstring_weight)
# Chunk stripping from environment
strip_comments = get_env("CHUNK_STRIP_COMMENTS")
if strip_comments:
self.chunk_strip_comments = strip_comments.lower() in ("true", "1", "yes")
log.debug("Overriding chunk_strip_comments from .env: %s", self.chunk_strip_comments)
strip_docstrings = get_env("CHUNK_STRIP_DOCSTRINGS")
if strip_docstrings:
self.chunk_strip_docstrings = strip_docstrings.lower() in ("true", "1", "yes")
log.debug("Overriding chunk_strip_docstrings from .env: %s", self.chunk_strip_docstrings)
@classmethod
def load(cls) -> "Config":
"""Load config with settings from file."""
config = cls()
config.load_settings()
return config
@dataclass
class WorkspaceConfig:
"""Workspace-local configuration for CodexLens.
Stores index data in project/.codexlens/ directory.
"""
workspace_root: Path
def __post_init__(self) -> None:
self.workspace_root = Path(self.workspace_root).resolve()
@property
def codexlens_dir(self) -> Path:
"""The .codexlens directory in workspace root."""
return self.workspace_root / WORKSPACE_DIR_NAME
@property
def db_path(self) -> Path:
"""SQLite index path for this workspace."""
return self.codexlens_dir / "index.db"
@property
def cache_dir(self) -> Path:
"""Cache directory for this workspace."""
return self.codexlens_dir / "cache"
@property
def env_path(self) -> Path:
"""Path to workspace .env file."""
return self.codexlens_dir / ".env"
def load_env(self, *, override: bool = False) -> int:
"""Load .env file and apply to os.environ.
Args:
override: If True, override existing environment variables
Returns:
Number of variables applied
"""
from .env_config import apply_workspace_env
return apply_workspace_env(self.workspace_root, override=override)
def get_api_config(self, prefix: str) -> dict:
"""Get API configuration from environment.
Args:
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
Returns:
Dictionary with api_key, api_base, model, etc.
"""
from .env_config import get_api_config
return get_api_config(prefix, workspace_root=self.workspace_root)
def initialize(self) -> None:
"""Create the .codexlens directory structure."""
try:
self.codexlens_dir.mkdir(parents=True, exist_ok=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Create .gitignore to exclude cache but keep index
gitignore_path = self.codexlens_dir / ".gitignore"
if not gitignore_path.exists():
gitignore_path.write_text(
"# CodexLens workspace data\n"
"cache/\n"
"*.log\n"
".env\n" # Exclude .env from git
)
except Exception as exc:
raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc
def exists(self) -> bool:
"""Check if workspace is already initialized."""
return self.codexlens_dir.is_dir() and self.db_path.exists()
@classmethod
def from_path(cls, path: Path) -> Optional["WorkspaceConfig"]:
"""Create WorkspaceConfig from a path by finding workspace root.
Returns None if no workspace found.
"""
root = find_workspace_root(path)
if root is None:
return None
return cls(workspace_root=root)
@classmethod
def create_at(cls, path: Path) -> "WorkspaceConfig":
"""Create a new workspace at the given path."""
config = cls(workspace_root=path)
config.initialize()
return config

View File

@@ -0,0 +1,128 @@
"""Pydantic entity models for CodexLens."""
from __future__ import annotations
import math
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field, field_validator
class Symbol(BaseModel):
"""A code symbol discovered in a file."""
name: str = Field(..., min_length=1)
kind: str = Field(..., min_length=1)
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
@field_validator("range")
@classmethod
def validate_range(cls, value: Tuple[int, int]) -> Tuple[int, int]:
if len(value) != 2:
raise ValueError("range must be a (start_line, end_line) tuple")
start_line, end_line = value
if start_line < 1 or end_line < 1:
raise ValueError("range lines must be >= 1")
if end_line < start_line:
raise ValueError("end_line must be >= start_line")
return value
class SemanticChunk(BaseModel):
"""A semantically meaningful chunk of content, optionally embedded."""
content: str = Field(..., min_length=1)
embedding: Optional[List[float]] = Field(default=None, description="Vector embedding for semantic search")
metadata: Dict[str, Any] = Field(default_factory=dict)
id: Optional[int] = Field(default=None, description="Database row ID")
file_path: Optional[str] = Field(default=None, description="Source file path")
@field_validator("embedding")
@classmethod
def validate_embedding(cls, value: Optional[List[float]]) -> Optional[List[float]]:
if value is None:
return value
if not value:
raise ValueError("embedding cannot be empty when provided")
norm = math.sqrt(sum(x * x for x in value))
epsilon = 1e-10
if norm < epsilon:
raise ValueError("embedding cannot be a zero vector")
return value
class IndexedFile(BaseModel):
"""An indexed source file with symbols and optional semantic chunks."""
path: str = Field(..., min_length=1)
language: str = Field(..., min_length=1)
symbols: List[Symbol] = Field(default_factory=list)
chunks: List[SemanticChunk] = Field(default_factory=list)
relationships: List["CodeRelationship"] = Field(default_factory=list)
@field_validator("path", "language")
@classmethod
def strip_and_validate_nonempty(cls, value: str) -> str:
cleaned = value.strip()
if not cleaned:
raise ValueError("value cannot be blank")
return cleaned
class RelationshipType(str, Enum):
"""Types of code relationships."""
CALL = "calls"
INHERITS = "inherits"
IMPORTS = "imports"
class CodeRelationship(BaseModel):
"""A relationship between code symbols (e.g., function calls, inheritance)."""
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
relationship_type: RelationshipType = Field(..., description="Type of relationship (call, inherits, etc.)")
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
class AdditionalLocation(BaseModel):
"""A pointer to another location where a similar result was found.
Used for grouping search results with similar scores and content,
where the primary result is stored in SearchResult and secondary
locations are stored in this model.
"""
path: str = Field(..., min_length=1)
score: float = Field(..., ge=0.0)
start_line: Optional[int] = Field(default=None, description="Start line of the result (1-based)")
end_line: Optional[int] = Field(default=None, description="End line of the result (1-based)")
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol")
class SearchResult(BaseModel):
"""A unified search result for lexical or semantic search."""
path: str = Field(..., min_length=1)
score: float = Field(..., ge=0.0)
excerpt: Optional[str] = None
content: Optional[str] = Field(default=None, description="Full content of matched code block")
symbol: Optional[Symbol] = None
chunk: Optional[SemanticChunk] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
# Additional context for complete code blocks
start_line: Optional[int] = Field(default=None, description="Start line of code block (1-based)")
end_line: Optional[int] = Field(default=None, description="End line of code block (1-based)")
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol/function/class")
symbol_kind: Optional[str] = Field(default=None, description="Kind of symbol (function/class/method)")
# Field for grouping similar results
additional_locations: List["AdditionalLocation"] = Field(
default_factory=list,
description="Other locations for grouped results with similar scores and content."
)

View File

@@ -0,0 +1,304 @@
"""Environment configuration loader for CodexLens.
Loads .env files from workspace .codexlens directory with fallback to project root.
Provides unified access to API configurations.
Priority order:
1. Environment variables (already set)
2. .codexlens/.env (workspace-local)
3. .env (project root)
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional
log = logging.getLogger(__name__)
# Supported environment variables with descriptions
ENV_VARS = {
# Reranker configuration (overrides settings.json)
"RERANKER_MODEL": "Reranker model name (overrides settings.json)",
"RERANKER_BACKEND": "Reranker backend: fastembed, onnx, api, litellm, legacy",
"RERANKER_ENABLED": "Enable reranker: true/false",
"RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)",
"RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)",
"RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina",
"RERANKER_POOL_ENABLED": "Enable reranker high availability pool: true/false",
"RERANKER_STRATEGY": "Reranker load balance strategy: round_robin, latency_aware, weighted_random",
"RERANKER_COOLDOWN": "Reranker rate limit cooldown in seconds",
# Embedding configuration (overrides settings.json)
"EMBEDDING_MODEL": "Embedding model/profile name (overrides settings.json)",
"EMBEDDING_BACKEND": "Embedding backend: fastembed, litellm",
"EMBEDDING_API_KEY": "API key for embedding service",
"EMBEDDING_API_BASE": "Base URL for embedding API",
"EMBEDDING_POOL_ENABLED": "Enable embedding high availability pool: true/false",
"EMBEDDING_STRATEGY": "Embedding load balance strategy: round_robin, latency_aware, weighted_random",
"EMBEDDING_COOLDOWN": "Embedding rate limit cooldown in seconds",
# LiteLLM configuration
"LITELLM_API_KEY": "API key for LiteLLM",
"LITELLM_API_BASE": "Base URL for LiteLLM",
"LITELLM_MODEL": "LiteLLM model name",
# General configuration
"CODEXLENS_DATA_DIR": "Custom data directory path",
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
# Chunking configuration
"CHUNK_STRIP_COMMENTS": "Strip comments from code chunks for embedding: true/false (default: true)",
"CHUNK_STRIP_DOCSTRINGS": "Strip docstrings from code chunks for embedding: true/false (default: true)",
# Reranker tuning
"RERANKER_TEST_FILE_PENALTY": "Penalty for test files in reranking: 0.0-1.0 (default: 0.0)",
"RERANKER_DOCSTRING_WEIGHT": "Weight for docstring chunks in reranking: 0.0-1.0 (default: 1.0)",
}
def _parse_env_line(line: str) -> tuple[str, str] | None:
"""Parse a single .env line, returning (key, value) or None."""
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
return None
# Handle export prefix
if line.startswith("export "):
line = line[7:].strip()
# Split on first =
if "=" not in line:
return None
key, _, value = line.partition("=")
key = key.strip()
value = value.strip()
# Remove surrounding quotes
if len(value) >= 2:
if (value.startswith('"') and value.endswith('"')) or \
(value.startswith("'") and value.endswith("'")):
value = value[1:-1]
return key, value
def load_env_file(env_path: Path) -> Dict[str, str]:
"""Load environment variables from a .env file.
Args:
env_path: Path to .env file
Returns:
Dictionary of environment variables
"""
if not env_path.is_file():
return {}
env_vars: Dict[str, str] = {}
try:
content = env_path.read_text(encoding="utf-8")
for line in content.splitlines():
result = _parse_env_line(line)
if result:
key, value = result
env_vars[key] = value
except Exception as exc:
log.warning("Failed to load .env file %s: %s", env_path, exc)
return env_vars
def _get_global_data_dir() -> Path:
"""Get global CodexLens data directory."""
env_override = os.environ.get("CODEXLENS_DATA_DIR")
if env_override:
return Path(env_override).expanduser().resolve()
return (Path.home() / ".codexlens").resolve()
def load_global_env() -> Dict[str, str]:
"""Load environment variables from global ~/.codexlens/.env file.
Returns:
Dictionary of environment variables from global config
"""
global_env_path = _get_global_data_dir() / ".env"
if global_env_path.is_file():
env_vars = load_env_file(global_env_path)
log.debug("Loaded %d vars from global %s", len(env_vars), global_env_path)
return env_vars
return {}
def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]:
"""Load environment variables from workspace .env files.
Priority (later overrides earlier):
1. Global ~/.codexlens/.env (lowest priority)
2. Project root .env
3. .codexlens/.env (highest priority)
Args:
workspace_root: Workspace root directory. If None, uses current directory.
Returns:
Merged dictionary of environment variables
"""
if workspace_root is None:
workspace_root = Path.cwd()
workspace_root = Path(workspace_root).resolve()
env_vars: Dict[str, str] = {}
# Load from global ~/.codexlens/.env (lowest priority)
global_vars = load_global_env()
if global_vars:
env_vars.update(global_vars)
# Load from project root .env (medium priority)
root_env = workspace_root / ".env"
if root_env.is_file():
loaded = load_env_file(root_env)
env_vars.update(loaded)
log.debug("Loaded %d vars from %s", len(loaded), root_env)
# Load from .codexlens/.env (highest priority)
codexlens_env = workspace_root / ".codexlens" / ".env"
if codexlens_env.is_file():
loaded = load_env_file(codexlens_env)
env_vars.update(loaded)
log.debug("Loaded %d vars from %s", len(loaded), codexlens_env)
return env_vars
def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int:
"""Load .env files and apply to os.environ.
Args:
workspace_root: Workspace root directory
override: If True, override existing environment variables
Returns:
Number of variables applied
"""
env_vars = load_workspace_env(workspace_root)
applied = 0
for key, value in env_vars.items():
if override or key not in os.environ:
os.environ[key] = value
applied += 1
log.debug("Applied env var: %s", key)
return applied
def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback.
Priority:
1. os.environ (already set)
2. .codexlens/.env
3. .env
4. default value
Args:
key: Environment variable name
default: Default value if not found
workspace_root: Workspace root for .env file lookup
Returns:
Value or default
"""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Load from .env files
env_vars = load_workspace_env(workspace_root)
if key in env_vars:
return env_vars[key]
return default
def get_api_config(
prefix: str,
*,
workspace_root: Path | None = None,
defaults: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
"""Get API configuration from environment.
Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc.
Args:
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
workspace_root: Workspace root for .env file lookup
defaults: Default values
Returns:
Dictionary with api_key, api_base, model, etc.
"""
defaults = defaults or {}
config: Dict[str, Any] = {}
# Standard API config fields
field_mapping = {
"api_key": f"{prefix}_API_KEY",
"api_base": f"{prefix}_API_BASE",
"model": f"{prefix}_MODEL",
"provider": f"{prefix}_PROVIDER",
"timeout": f"{prefix}_TIMEOUT",
}
for field, env_key in field_mapping.items():
value = get_env(env_key, workspace_root=workspace_root)
if value is not None:
# Type conversion for specific fields
if field == "timeout":
try:
config[field] = float(value)
except ValueError:
pass
else:
config[field] = value
elif field in defaults:
config[field] = defaults[field]
return config
def generate_env_example() -> str:
"""Generate .env.example content with all supported variables.
Returns:
String content for .env.example file
"""
lines = [
"# CodexLens Environment Configuration",
"# Copy this file to .codexlens/.env and fill in your values",
"",
]
# Group by prefix
groups: Dict[str, list] = {}
for key, desc in ENV_VARS.items():
prefix = key.split("_")[0]
if prefix not in groups:
groups[prefix] = []
groups[prefix].append((key, desc))
for prefix, items in groups.items():
lines.append(f"# {prefix} Configuration")
for key, desc in items:
lines.append(f"# {desc}")
lines.append(f"# {key}=")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,59 @@
"""CodexLens exception hierarchy."""
from __future__ import annotations
class CodexLensError(Exception):
"""Base class for all CodexLens errors."""
class ConfigError(CodexLensError):
"""Raised when configuration is invalid or cannot be loaded."""
class ParseError(CodexLensError):
"""Raised when parsing or indexing a file fails."""
class StorageError(CodexLensError):
"""Raised when reading/writing index storage fails.
Attributes:
message: Human-readable error description
db_path: Path to the database file (if applicable)
operation: The operation that failed (e.g., 'query', 'initialize', 'migrate')
details: Additional context for debugging
"""
def __init__(
self,
message: str,
db_path: str | None = None,
operation: str | None = None,
details: dict | None = None
) -> None:
super().__init__(message)
self.message = message
self.db_path = db_path
self.operation = operation
self.details = details or {}
def __str__(self) -> str:
parts = [self.message]
if self.db_path:
parts.append(f"[db: {self.db_path}]")
if self.operation:
parts.append(f"[op: {self.operation}]")
if self.details:
detail_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
parts.append(f"[{detail_str}]")
return " ".join(parts)
class SearchError(CodexLensError):
"""Raised when a search operation fails."""
class IndexNotFoundError(CodexLensError):
"""Raised when a project's index cannot be found."""

View File

@@ -0,0 +1,28 @@
"""Hybrid Search data structures for CodexLens.
This module provides core data structures for hybrid search:
- CodeSymbolNode: Graph node representing a code symbol
- CodeAssociationGraph: Graph of code relationships
- SearchResultCluster: Clustered search results
- Range: Position range in source files
- CallHierarchyItem: LSP call hierarchy item
Note: The search engine is in codexlens.search.hybrid_search
LSP-based expansion is in codexlens.lsp module
"""
from codexlens.hybrid_search.data_structures import (
CallHierarchyItem,
CodeAssociationGraph,
CodeSymbolNode,
Range,
SearchResultCluster,
)
__all__ = [
"CallHierarchyItem",
"CodeAssociationGraph",
"CodeSymbolNode",
"Range",
"SearchResultCluster",
]

View File

@@ -0,0 +1,602 @@
"""Core data structures for the hybrid search system.
This module defines the fundamental data structures used throughout the
hybrid search pipeline, including code symbol representations, association
graphs, and clustered search results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
import networkx as nx
@dataclass
class Range:
"""Position range within a source file.
Attributes:
start_line: Starting line number (0-based).
start_character: Starting character offset within the line.
end_line: Ending line number (0-based).
end_character: Ending character offset within the line.
"""
start_line: int
start_character: int
end_line: int
end_character: int
def __post_init__(self) -> None:
"""Validate range values."""
if self.start_line < 0:
raise ValueError("start_line must be >= 0")
if self.start_character < 0:
raise ValueError("start_character must be >= 0")
if self.end_line < 0:
raise ValueError("end_line must be >= 0")
if self.end_character < 0:
raise ValueError("end_character must be >= 0")
if self.end_line < self.start_line:
raise ValueError("end_line must be >= start_line")
if self.end_line == self.start_line and self.end_character < self.start_character:
raise ValueError("end_character must be >= start_character on the same line")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
"start": {"line": self.start_line, "character": self.start_character},
"end": {"line": self.end_line, "character": self.end_character},
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Range:
"""Create Range from dictionary representation."""
return cls(
start_line=data["start"]["line"],
start_character=data["start"]["character"],
end_line=data["end"]["line"],
end_character=data["end"]["character"],
)
@classmethod
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
"""Create Range from LSP Range object.
LSP Range format:
{"start": {"line": int, "character": int},
"end": {"line": int, "character": int}}
"""
return cls(
start_line=lsp_range["start"]["line"],
start_character=lsp_range["start"]["character"],
end_line=lsp_range["end"]["line"],
end_character=lsp_range["end"]["character"],
)
@dataclass
class CallHierarchyItem:
"""LSP CallHierarchyItem for representing callers/callees.
Attributes:
name: Symbol name (function, method, class name).
kind: Symbol kind (function, method, class, etc.).
file_path: Absolute file path where the symbol is defined.
range: Position range in the source file.
detail: Optional additional detail about the symbol.
"""
name: str
kind: str
file_path: str
range: Range
detail: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
}
if self.detail:
result["detail"] = self.detail
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from dictionary representation."""
return cls(
name=data["name"],
kind=data["kind"],
file_path=data["file_path"],
range=Range.from_dict(data["range"]),
detail=data.get("detail"),
)
@dataclass
class CodeSymbolNode:
"""Graph node representing a code symbol.
Attributes:
id: Unique identifier in format 'file_path:name:line'.
name: Symbol name (function, class, variable name).
kind: Symbol kind (function, class, method, variable, etc.).
file_path: Absolute file path where symbol is defined.
range: Start/end position in the source file.
embedding: Optional vector embedding for semantic search.
raw_code: Raw source code of the symbol.
docstring: Documentation string (if available).
score: Ranking score (used during reranking).
"""
id: str
name: str
kind: str
file_path: str
range: Range
embedding: Optional[List[float]] = None
raw_code: str = ""
docstring: str = ""
score: float = 0.0
def __post_init__(self) -> None:
"""Validate required fields."""
if not self.id:
raise ValueError("id cannot be empty")
if not self.name:
raise ValueError("name cannot be empty")
if not self.kind:
raise ValueError("kind cannot be empty")
if not self.file_path:
raise ValueError("file_path cannot be empty")
def __hash__(self) -> int:
"""Hash based on unique ID."""
return hash(self.id)
def __eq__(self, other: object) -> bool:
"""Equality based on unique ID."""
if not isinstance(other, CodeSymbolNode):
return False
return self.id == other.id
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"id": self.id,
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
"score": self.score,
}
if self.raw_code:
result["raw_code"] = self.raw_code
if self.docstring:
result["docstring"] = self.docstring
# Exclude embedding from serialization (too large for JSON responses)
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
"""Create CodeSymbolNode from dictionary representation."""
return cls(
id=data["id"],
name=data["name"],
kind=data["kind"],
file_path=data["file_path"],
range=Range.from_dict(data["range"]),
embedding=data.get("embedding"),
raw_code=data.get("raw_code", ""),
docstring=data.get("docstring", ""),
score=data.get("score", 0.0),
)
@classmethod
def from_lsp_location(
cls,
uri: str,
name: str,
kind: str,
lsp_range: Dict[str, Any],
raw_code: str = "",
docstring: str = "",
) -> CodeSymbolNode:
"""Create CodeSymbolNode from LSP location data.
Args:
uri: File URI (file:// prefix will be stripped).
name: Symbol name.
kind: Symbol kind.
lsp_range: LSP Range object.
raw_code: Optional raw source code.
docstring: Optional documentation string.
Returns:
New CodeSymbolNode instance.
"""
# Strip file:// prefix if present
file_path = uri
if file_path.startswith("file://"):
file_path = file_path[7:]
# Handle Windows paths (file:///C:/...)
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
range_obj = Range.from_lsp_range(lsp_range)
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
return cls(
id=symbol_id,
name=name,
kind=kind,
file_path=file_path,
range=range_obj,
raw_code=raw_code,
docstring=docstring,
)
@classmethod
def create_id(cls, file_path: str, name: str, line: int) -> str:
"""Generate a unique symbol ID.
Args:
file_path: Absolute file path.
name: Symbol name.
line: Start line number.
Returns:
Unique ID string in format 'file_path:name:line'.
"""
return f"{file_path}:{name}:{line}"
@dataclass
class CodeAssociationGraph:
"""Graph of code relationships between symbols.
This graph represents the association between code symbols discovered
through LSP queries (references, call hierarchy, etc.).
Attributes:
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
edges: List of (from_id, to_id, relationship_type) tuples.
relationship_type: 'calls', 'references', 'inherits', 'imports'.
"""
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
edges: List[Tuple[str, str, str]] = field(default_factory=list)
def add_node(self, node: CodeSymbolNode) -> None:
"""Add a node to the graph.
Args:
node: CodeSymbolNode to add. If a node with the same ID exists,
it will be replaced.
"""
self.nodes[node.id] = node
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
"""Add an edge to the graph.
Args:
from_id: Source node ID.
to_id: Target node ID.
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
Raises:
ValueError: If from_id or to_id not in graph nodes.
"""
if from_id not in self.nodes:
raise ValueError(f"Source node '{from_id}' not found in graph")
if to_id not in self.nodes:
raise ValueError(f"Target node '{to_id}' not found in graph")
edge = (from_id, to_id, rel_type)
if edge not in self.edges:
self.edges.append(edge)
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
"""Add an edge without validating node existence.
Use this method during bulk graph construction where nodes may be
added after edges, or when performance is critical.
Args:
from_id: Source node ID.
to_id: Target node ID.
rel_type: Relationship type.
"""
edge = (from_id, to_id, rel_type)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
"""Get a node by ID.
Args:
node_id: Node ID to look up.
Returns:
CodeSymbolNode if found, None otherwise.
"""
return self.nodes.get(node_id)
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
"""Get neighboring nodes connected by outgoing edges.
Args:
node_id: Node ID to find neighbors for.
rel_type: Optional filter by relationship type.
Returns:
List of neighboring CodeSymbolNode objects.
"""
neighbors = []
for from_id, to_id, edge_rel in self.edges:
if from_id == node_id:
if rel_type is None or edge_rel == rel_type:
node = self.nodes.get(to_id)
if node:
neighbors.append(node)
return neighbors
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
"""Get nodes connected by incoming edges.
Args:
node_id: Node ID to find incoming connections for.
rel_type: Optional filter by relationship type.
Returns:
List of CodeSymbolNode objects with edges pointing to node_id.
"""
incoming = []
for from_id, to_id, edge_rel in self.edges:
if to_id == node_id:
if rel_type is None or edge_rel == rel_type:
node = self.nodes.get(from_id)
if node:
incoming.append(node)
return incoming
def to_networkx(self) -> "nx.DiGraph":
"""Convert to NetworkX DiGraph for graph algorithms.
Returns:
NetworkX directed graph with nodes and edges.
Raises:
ImportError: If networkx is not installed.
"""
try:
import networkx as nx
except ImportError:
raise ImportError(
"networkx is required for graph algorithms. "
"Install with: pip install networkx"
)
graph = nx.DiGraph()
# Add nodes with attributes
for node_id, node in self.nodes.items():
graph.add_node(
node_id,
name=node.name,
kind=node.kind,
file_path=node.file_path,
score=node.score,
)
# Add edges with relationship type
for from_id, to_id, rel_type in self.edges:
graph.add_edge(from_id, to_id, relationship=rel_type)
return graph
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization.
Returns:
Dictionary with 'nodes' and 'edges' keys.
"""
return {
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"edges": [
{"from": from_id, "to": to_id, "relationship": rel_type}
for from_id, to_id, rel_type in self.edges
],
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
"""Create CodeAssociationGraph from dictionary representation.
Args:
data: Dictionary with 'nodes' and 'edges' keys.
Returns:
New CodeAssociationGraph instance.
"""
graph = cls()
# Load nodes
for node_id, node_data in data.get("nodes", {}).items():
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
# Load edges
for edge_data in data.get("edges", []):
graph.edges.append((
edge_data["from"],
edge_data["to"],
edge_data["relationship"],
))
return graph
def __len__(self) -> int:
"""Return the number of nodes in the graph."""
return len(self.nodes)
@dataclass
class SearchResultCluster:
"""Clustered search result containing related code symbols.
Search results are grouped into clusters based on graph community
detection or embedding similarity. Each cluster represents a
conceptually related group of code symbols.
Attributes:
cluster_id: Unique cluster identifier.
score: Cluster relevance score (max of symbol scores).
title: Human-readable cluster title/summary.
symbols: List of CodeSymbolNode in this cluster.
metadata: Additional cluster metadata.
"""
cluster_id: str
score: float
title: str
symbols: List[CodeSymbolNode] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate cluster fields."""
if not self.cluster_id:
raise ValueError("cluster_id cannot be empty")
if self.score < 0:
raise ValueError("score must be >= 0")
def add_symbol(self, symbol: CodeSymbolNode) -> None:
"""Add a symbol to the cluster.
Args:
symbol: CodeSymbolNode to add.
"""
self.symbols.append(symbol)
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
"""Get top N symbols by score.
Args:
n: Number of symbols to return.
Returns:
List of top N CodeSymbolNode objects sorted by score descending.
"""
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
return sorted_symbols[:n]
def update_score(self) -> None:
"""Update cluster score to max of symbol scores."""
if self.symbols:
self.score = max(s.score for s in self.symbols)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization.
Returns:
Dictionary representation of the cluster.
"""
return {
"cluster_id": self.cluster_id,
"score": self.score,
"title": self.title,
"symbols": [s.to_dict() for s in self.symbols],
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
"""Create SearchResultCluster from dictionary representation.
Args:
data: Dictionary with cluster data.
Returns:
New SearchResultCluster instance.
"""
return cls(
cluster_id=data["cluster_id"],
score=data["score"],
title=data["title"],
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
metadata=data.get("metadata", {}),
)
def __len__(self) -> int:
"""Return the number of symbols in the cluster."""
return len(self.symbols)
@dataclass
class CallHierarchyItem:
"""LSP CallHierarchyItem for representing callers/callees.
Attributes:
name: Symbol name (function, method, etc.).
kind: Symbol kind (function, method, etc.).
file_path: Absolute file path.
range: Position range in the file.
detail: Optional additional detail (e.g., signature).
"""
name: str
kind: str
file_path: str
range: Range
detail: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
}
if self.detail:
result["detail"] = self.detail
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from dictionary representation."""
return cls(
name=data.get("name", "unknown"),
kind=data.get("kind", "unknown"),
file_path=data.get("file_path", data.get("uri", "")),
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
detail=data.get("detail"),
)
@classmethod
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from LSP response format.
LSP uses 0-based line numbers and 'character' instead of 'char'.
"""
uri = data.get("uri", data.get("file_path", ""))
# Strip file:// prefix
file_path = uri
if file_path.startswith("file://"):
file_path = file_path[7:]
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
return cls(
name=data.get("name", "unknown"),
kind=str(data.get("kind", "unknown")),
file_path=file_path,
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
detail=data.get("detail"),
)

View File

@@ -0,0 +1,26 @@
"""Code indexing and symbol extraction."""
from codexlens.indexing.symbol_extractor import SymbolExtractor
from codexlens.indexing.embedding import (
BinaryEmbeddingBackend,
DenseEmbeddingBackend,
CascadeEmbeddingBackend,
get_cascade_embedder,
binarize_embedding,
pack_binary_embedding,
unpack_binary_embedding,
hamming_distance,
)
__all__ = [
"SymbolExtractor",
# Cascade embedding backends
"BinaryEmbeddingBackend",
"DenseEmbeddingBackend",
"CascadeEmbeddingBackend",
"get_cascade_embedder",
# Utility functions
"binarize_embedding",
"pack_binary_embedding",
"unpack_binary_embedding",
"hamming_distance",
]

View File

@@ -0,0 +1,582 @@
"""Multi-type embedding backends for cascade retrieval.
This module provides embedding backends optimized for cascade retrieval:
1. BinaryEmbeddingBackend - Fast coarse filtering with binary vectors
2. DenseEmbeddingBackend - High-precision dense vectors for reranking
3. CascadeEmbeddingBackend - Combined binary + dense for two-stage retrieval
Cascade retrieval workflow:
1. Binary search (fast, ~32 bytes/vector) -> top-K candidates
2. Dense rerank (precise, ~8KB/vector) -> final results
"""
from __future__ import annotations
import logging
from typing import Iterable, List, Optional, Tuple
import numpy as np
from codexlens.semantic.base import BaseEmbedder
logger = logging.getLogger(__name__)
# =============================================================================
# Utility Functions
# =============================================================================
def binarize_embedding(embedding: np.ndarray) -> np.ndarray:
"""Convert float embedding to binary vector.
Applies sign-based quantization: values > 0 become 1, values <= 0 become 0.
Args:
embedding: Float32 embedding of any dimension
Returns:
Binary vector (uint8 with values 0 or 1) of same dimension
"""
return (embedding > 0).astype(np.uint8)
def pack_binary_embedding(binary_vector: np.ndarray) -> bytes:
"""Pack binary vector into compact bytes format.
Packs 8 binary values into each byte for storage efficiency.
For a 256-dim binary vector, output is 32 bytes.
Args:
binary_vector: Binary vector (uint8 with values 0 or 1)
Returns:
Packed bytes (length = ceil(dim / 8))
"""
# Ensure vector length is multiple of 8 by padding if needed
dim = len(binary_vector)
padded_dim = ((dim + 7) // 8) * 8
if padded_dim > dim:
padded = np.zeros(padded_dim, dtype=np.uint8)
padded[:dim] = binary_vector
binary_vector = padded
# Pack 8 bits per byte
packed = np.packbits(binary_vector)
return packed.tobytes()
def unpack_binary_embedding(packed_bytes: bytes, dim: int = 256) -> np.ndarray:
"""Unpack bytes back to binary vector.
Args:
packed_bytes: Packed binary data
dim: Original vector dimension (default: 256)
Returns:
Binary vector (uint8 with values 0 or 1)
"""
unpacked = np.unpackbits(np.frombuffer(packed_bytes, dtype=np.uint8))
return unpacked[:dim]
def hamming_distance(a: bytes, b: bytes) -> int:
"""Compute Hamming distance between two packed binary vectors.
Uses XOR and popcount for efficient distance computation.
Args:
a: First packed binary vector
b: Second packed binary vector
Returns:
Hamming distance (number of differing bits)
"""
a_arr = np.frombuffer(a, dtype=np.uint8)
b_arr = np.frombuffer(b, dtype=np.uint8)
xor = np.bitwise_xor(a_arr, b_arr)
return int(np.unpackbits(xor).sum())
# =============================================================================
# Binary Embedding Backend
# =============================================================================
class BinaryEmbeddingBackend(BaseEmbedder):
"""Generate 256-dimensional binary embeddings for fast coarse retrieval.
Uses a lightweight embedding model and applies sign-based quantization
to produce compact binary vectors (32 bytes per embedding).
Suitable for:
- First-stage candidate retrieval
- Hamming distance-based similarity search
- Memory-constrained environments
Model: sentence-transformers/all-MiniLM-L6-v2 (384 dim) -> quantized to 256 bits
"""
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, fast
BINARY_DIM = 256
def __init__(
self,
model_name: Optional[str] = None,
use_gpu: bool = True,
) -> None:
"""Initialize binary embedding backend.
Args:
model_name: Base embedding model name. Defaults to BAAI/bge-small-en-v1.5
use_gpu: Whether to use GPU acceleration
"""
from codexlens.semantic import SEMANTIC_AVAILABLE
if not SEMANTIC_AVAILABLE:
raise ImportError(
"Semantic search dependencies not available. "
"Install with: pip install codexlens[semantic]"
)
self._model_name = model_name or self.DEFAULT_MODEL
self._use_gpu = use_gpu
self._model = None
# Projection matrix for dimension reduction (lazily initialized)
self._projection_matrix: Optional[np.ndarray] = None
@property
def model_name(self) -> str:
"""Return model name."""
return self._model_name
@property
def embedding_dim(self) -> int:
"""Return binary embedding dimension (256)."""
return self.BINARY_DIM
@property
def packed_bytes(self) -> int:
"""Return packed bytes size (32 bytes for 256 bits)."""
return self.BINARY_DIM // 8
def _load_model(self) -> None:
"""Lazy load the embedding model."""
if self._model is not None:
return
from fastembed import TextEmbedding
from codexlens.semantic.gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
try:
self._model = TextEmbedding(
model_name=self._model_name,
providers=providers,
)
except TypeError:
# Fallback for older fastembed versions
self._model = TextEmbedding(model_name=self._model_name)
logger.debug(f"BinaryEmbeddingBackend loaded model: {self._model_name}")
def _get_projection_matrix(self, input_dim: int) -> np.ndarray:
"""Get or create projection matrix for dimension reduction.
Uses random projection with fixed seed for reproducibility.
Args:
input_dim: Input embedding dimension from base model
Returns:
Projection matrix of shape (input_dim, BINARY_DIM)
"""
if self._projection_matrix is not None:
return self._projection_matrix
# Fixed seed for reproducibility across sessions
rng = np.random.RandomState(42)
# Gaussian random projection
self._projection_matrix = rng.randn(input_dim, self.BINARY_DIM).astype(np.float32)
# Normalize columns for consistent scale
norms = np.linalg.norm(self._projection_matrix, axis=0, keepdims=True)
self._projection_matrix /= (norms + 1e-8)
return self._projection_matrix
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate binary embeddings as numpy array.
Args:
texts: Single text or iterable of texts
Returns:
Binary embeddings of shape (n_texts, 256) with values 0 or 1
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Get base float embeddings
float_embeddings = np.array(list(self._model.embed(texts)))
input_dim = float_embeddings.shape[1]
# Project to target dimension if needed
if input_dim != self.BINARY_DIM:
projection = self._get_projection_matrix(input_dim)
float_embeddings = float_embeddings @ projection
# Binarize
return binarize_embedding(float_embeddings)
def embed_packed(self, texts: str | Iterable[str]) -> List[bytes]:
"""Generate packed binary embeddings.
Args:
texts: Single text or iterable of texts
Returns:
List of packed bytes (32 bytes each for 256-dim)
"""
binary = self.embed_to_numpy(texts)
return [pack_binary_embedding(vec) for vec in binary]
# =============================================================================
# Dense Embedding Backend
# =============================================================================
class DenseEmbeddingBackend(BaseEmbedder):
"""Generate high-dimensional dense embeddings for precise reranking.
Uses large embedding models to produce 2048-dimensional float32 vectors
for maximum retrieval quality.
Suitable for:
- Second-stage reranking
- High-precision similarity search
- Quality-critical applications
Model: BAAI/bge-large-en-v1.5 (1024 dim) with optional expansion
"""
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, use small for testing
TARGET_DIM = 768 # Reduced target for faster testing
def __init__(
self,
model_name: Optional[str] = None,
use_gpu: bool = True,
expand_dim: bool = True,
) -> None:
"""Initialize dense embedding backend.
Args:
model_name: Dense embedding model name. Defaults to BAAI/bge-large-en-v1.5
use_gpu: Whether to use GPU acceleration
expand_dim: If True, expand embeddings to TARGET_DIM using learned expansion
"""
from codexlens.semantic import SEMANTIC_AVAILABLE
if not SEMANTIC_AVAILABLE:
raise ImportError(
"Semantic search dependencies not available. "
"Install with: pip install codexlens[semantic]"
)
self._model_name = model_name or self.DEFAULT_MODEL
self._use_gpu = use_gpu
self._expand_dim = expand_dim
self._model = None
self._native_dim: Optional[int] = None
# Expansion matrix for dimension expansion (lazily initialized)
self._expansion_matrix: Optional[np.ndarray] = None
@property
def model_name(self) -> str:
"""Return model name."""
return self._model_name
@property
def embedding_dim(self) -> int:
"""Return embedding dimension.
Returns TARGET_DIM if expand_dim is True, otherwise native model dimension.
"""
if self._expand_dim:
return self.TARGET_DIM
# Return cached native dim or estimate based on model
if self._native_dim is not None:
return self._native_dim
# Model dimension estimates
model_dims = {
"BAAI/bge-large-en-v1.5": 1024,
"BAAI/bge-base-en-v1.5": 768,
"BAAI/bge-small-en-v1.5": 384,
"intfloat/multilingual-e5-large": 1024,
}
return model_dims.get(self._model_name, 1024)
@property
def max_tokens(self) -> int:
"""Return maximum token limit."""
return 512 # Conservative default for large models
def _load_model(self) -> None:
"""Lazy load the embedding model."""
if self._model is not None:
return
from fastembed import TextEmbedding
from codexlens.semantic.gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
try:
self._model = TextEmbedding(
model_name=self._model_name,
providers=providers,
)
except TypeError:
self._model = TextEmbedding(model_name=self._model_name)
logger.debug(f"DenseEmbeddingBackend loaded model: {self._model_name}")
def _get_expansion_matrix(self, input_dim: int) -> np.ndarray:
"""Get or create expansion matrix for dimension expansion.
Uses random orthogonal projection for information-preserving expansion.
Args:
input_dim: Input embedding dimension from base model
Returns:
Expansion matrix of shape (input_dim, TARGET_DIM)
"""
if self._expansion_matrix is not None:
return self._expansion_matrix
# Fixed seed for reproducibility
rng = np.random.RandomState(123)
# Create semi-orthogonal expansion matrix
# First input_dim columns form identity-like structure
self._expansion_matrix = np.zeros((input_dim, self.TARGET_DIM), dtype=np.float32)
# Copy original dimensions
copy_dim = min(input_dim, self.TARGET_DIM)
self._expansion_matrix[:copy_dim, :copy_dim] = np.eye(copy_dim, dtype=np.float32)
# Fill remaining with random projections
if self.TARGET_DIM > input_dim:
random_part = rng.randn(input_dim, self.TARGET_DIM - input_dim).astype(np.float32)
# Normalize
norms = np.linalg.norm(random_part, axis=0, keepdims=True)
random_part /= (norms + 1e-8)
self._expansion_matrix[:, input_dim:] = random_part
return self._expansion_matrix
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate dense embeddings as numpy array.
Args:
texts: Single text or iterable of texts
Returns:
Dense embeddings of shape (n_texts, TARGET_DIM) as float32
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Get base float embeddings
float_embeddings = np.array(list(self._model.embed(texts)), dtype=np.float32)
self._native_dim = float_embeddings.shape[1]
# Expand to target dimension if needed
if self._expand_dim and self._native_dim < self.TARGET_DIM:
expansion = self._get_expansion_matrix(self._native_dim)
float_embeddings = float_embeddings @ expansion
return float_embeddings
# =============================================================================
# Cascade Embedding Backend
# =============================================================================
class CascadeEmbeddingBackend(BaseEmbedder):
"""Combined binary + dense embedding backend for cascade retrieval.
Generates both binary (for fast coarse filtering) and dense (for precise
reranking) embeddings in a single pass, optimized for two-stage retrieval.
Cascade workflow:
1. encode_cascade() returns (binary_embeddings, dense_embeddings)
2. Binary search: Use Hamming distance on binary vectors -> top-K candidates
3. Dense rerank: Use cosine similarity on dense vectors -> final results
Memory efficiency:
- Binary: 32 bytes per vector (256 bits)
- Dense: 8192 bytes per vector (2048 x float32)
- Total: ~8KB per document for full cascade support
"""
def __init__(
self,
binary_model: Optional[str] = None,
dense_model: Optional[str] = None,
use_gpu: bool = True,
) -> None:
"""Initialize cascade embedding backend.
Args:
binary_model: Model for binary embeddings. Defaults to BAAI/bge-small-en-v1.5
dense_model: Model for dense embeddings. Defaults to BAAI/bge-large-en-v1.5
use_gpu: Whether to use GPU acceleration
"""
self._binary_backend = BinaryEmbeddingBackend(
model_name=binary_model,
use_gpu=use_gpu,
)
self._dense_backend = DenseEmbeddingBackend(
model_name=dense_model,
use_gpu=use_gpu,
expand_dim=True,
)
self._use_gpu = use_gpu
@property
def model_name(self) -> str:
"""Return model names for both backends."""
return f"cascade({self._binary_backend.model_name}, {self._dense_backend.model_name})"
@property
def embedding_dim(self) -> int:
"""Return dense embedding dimension (for compatibility)."""
return self._dense_backend.embedding_dim
@property
def binary_dim(self) -> int:
"""Return binary embedding dimension."""
return self._binary_backend.embedding_dim
@property
def dense_dim(self) -> int:
"""Return dense embedding dimension."""
return self._dense_backend.embedding_dim
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate dense embeddings (for BaseEmbedder compatibility).
For cascade embeddings, use encode_cascade() instead.
Args:
texts: Single text or iterable of texts
Returns:
Dense embeddings of shape (n_texts, dense_dim)
"""
return self._dense_backend.embed_to_numpy(texts)
def encode_cascade(
self,
texts: str | Iterable[str],
batch_size: int = 32,
) -> Tuple[np.ndarray, np.ndarray]:
"""Generate both binary and dense embeddings.
Args:
texts: Single text or iterable of texts
batch_size: Batch size for processing
Returns:
Tuple of:
- binary_embeddings: Shape (n_texts, 256), uint8 values 0/1
- dense_embeddings: Shape (n_texts, 2048), float32
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
binary_embeddings = self._binary_backend.embed_to_numpy(texts)
dense_embeddings = self._dense_backend.embed_to_numpy(texts)
return binary_embeddings, dense_embeddings
def encode_binary(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate only binary embeddings.
Args:
texts: Single text or iterable of texts
Returns:
Binary embeddings of shape (n_texts, 256)
"""
return self._binary_backend.embed_to_numpy(texts)
def encode_dense(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate only dense embeddings.
Args:
texts: Single text or iterable of texts
Returns:
Dense embeddings of shape (n_texts, 2048)
"""
return self._dense_backend.embed_to_numpy(texts)
def encode_binary_packed(self, texts: str | Iterable[str]) -> List[bytes]:
"""Generate packed binary embeddings.
Args:
texts: Single text or iterable of texts
Returns:
List of packed bytes (32 bytes each)
"""
return self._binary_backend.embed_packed(texts)
# =============================================================================
# Factory Function
# =============================================================================
def get_cascade_embedder(
binary_model: Optional[str] = None,
dense_model: Optional[str] = None,
use_gpu: bool = True,
) -> CascadeEmbeddingBackend:
"""Factory function to create a cascade embedder.
Args:
binary_model: Model for binary embeddings (default: BAAI/bge-small-en-v1.5)
dense_model: Model for dense embeddings (default: BAAI/bge-large-en-v1.5)
use_gpu: Whether to use GPU acceleration
Returns:
Configured CascadeEmbeddingBackend instance
Example:
>>> embedder = get_cascade_embedder()
>>> binary, dense = embedder.encode_cascade(["hello world"])
>>> binary.shape # (1, 256)
>>> dense.shape # (1, 2048)
"""
return CascadeEmbeddingBackend(
binary_model=binary_model,
dense_model=dense_model,
use_gpu=use_gpu,
)

View File

@@ -0,0 +1,277 @@
"""Symbol and relationship extraction from source code."""
import re
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
try:
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
except Exception: # pragma: no cover - optional dependency / platform variance
TreeSitterSymbolParser = None # type: ignore[assignment]
class SymbolExtractor:
"""Extract symbols and relationships from source code using regex patterns."""
# Pattern definitions for different languages
PATTERNS = {
'python': {
'function': r'^(?:async\s+)?def\s+(\w+)\s*\(',
'class': r'^class\s+(\w+)\s*[:\(]',
'import': r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)',
'call': r'(?<![.\w])(\w+)\s*\(',
},
'typescript': {
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*[<\(]',
'class': r'(?:export\s+)?class\s+(\w+)',
'import': r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
'call': r'(?<![.\w])(\w+)\s*[<\(]',
},
'javascript': {
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(',
'class': r'(?:export\s+)?class\s+(\w+)',
'import': r"(?:import|require)\s*\(?['\"]([^'\"]+)['\"]",
'call': r'(?<![.\w])(\w+)\s*\(',
}
}
LANGUAGE_MAP = {
'.py': 'python',
'.ts': 'typescript',
'.tsx': 'typescript',
'.js': 'javascript',
'.jsx': 'javascript',
}
def __init__(self, db_path: Path):
self.db_path = db_path
self.db_conn: Optional[sqlite3.Connection] = None
def connect(self) -> None:
"""Connect to database and ensure schema exists."""
self.db_conn = sqlite3.connect(str(self.db_path))
self._ensure_tables()
def __enter__(self) -> "SymbolExtractor":
"""Context manager entry: connect to database."""
self.connect()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit: close database connection."""
self.close()
def _ensure_tables(self) -> None:
"""Create symbols and relationships tables if they don't exist."""
if not self.db_conn:
return
cursor = self.db_conn.cursor()
# Create symbols table with qualified_name
cursor.execute('''
CREATE TABLE IF NOT EXISTS symbols (
id INTEGER PRIMARY KEY AUTOINCREMENT,
qualified_name TEXT NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
file_path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
UNIQUE(file_path, name, start_line)
)
''')
# Create relationships table with target_symbol_fqn
cursor.execute('''
CREATE TABLE IF NOT EXISTS symbol_relationships (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_symbol_id INTEGER NOT NULL,
target_symbol_fqn TEXT NOT NULL,
relationship_type TEXT NOT NULL,
file_path TEXT NOT NULL,
line INTEGER,
FOREIGN KEY (source_symbol_id) REFERENCES symbols(id) ON DELETE CASCADE
)
''')
# Create performance indexes
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_path)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_source ON symbol_relationships(source_symbol_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_target ON symbol_relationships(target_symbol_fqn)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_type ON symbol_relationships(relationship_type)')
self.db_conn.commit()
def extract_from_file(self, file_path: Path, content: str) -> Tuple[List[Dict], List[Dict]]:
"""Extract symbols and relationships from file content.
Args:
file_path: Path to the source file
content: File content as string
Returns:
Tuple of (symbols, relationships) where:
- symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
- relationships: List of relationship dicts with source_scope, target, type, file_path, line
"""
ext = file_path.suffix.lower()
lang = self.LANGUAGE_MAP.get(ext)
if not lang or lang not in self.PATTERNS:
return [], []
patterns = self.PATTERNS[lang]
symbols = []
relationships: List[Dict] = []
lines = content.split('\n')
current_scope = None
for line_num, line in enumerate(lines, 1):
# Extract function/class definitions
for kind in ['function', 'class']:
if kind in patterns:
match = re.search(patterns[kind], line)
if match:
name = match.group(1)
qualified_name = f"{file_path.stem}.{name}"
symbols.append({
'qualified_name': qualified_name,
'name': name,
'kind': kind,
'file_path': str(file_path),
'start_line': line_num,
'end_line': line_num, # Simplified - would need proper parsing for actual end
})
current_scope = name
if TreeSitterSymbolParser is not None:
try:
ts_parser = TreeSitterSymbolParser(lang, file_path)
if ts_parser.is_available():
indexed = ts_parser.parse(content, file_path)
if indexed is not None and indexed.relationships:
relationships = [
{
"source_scope": r.source_symbol,
"target": r.target_symbol,
"type": r.relationship_type.value,
"file_path": str(file_path),
"line": r.source_line,
}
for r in indexed.relationships
]
except Exception:
relationships = []
# Regex fallback for relationships (when tree-sitter is unavailable)
if not relationships:
current_scope = None
for line_num, line in enumerate(lines, 1):
for kind in ['function', 'class']:
if kind in patterns:
match = re.search(patterns[kind], line)
if match:
current_scope = match.group(1)
# Extract imports
if 'import' in patterns:
match = re.search(patterns['import'], line)
if match:
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
if import_target and current_scope:
relationships.append({
'source_scope': current_scope,
'target': import_target.strip(),
'type': 'imports',
'file_path': str(file_path),
'line': line_num,
})
# Extract function calls (simplified)
if 'call' in patterns and current_scope:
for match in re.finditer(patterns['call'], line):
call_name = match.group(1)
# Skip common keywords and the current function
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
relationships.append({
'source_scope': current_scope,
'target': call_name,
'type': 'calls',
'file_path': str(file_path),
'line': line_num,
})
return symbols, relationships
def save_symbols(self, symbols: List[Dict]) -> Dict[str, int]:
"""Save symbols to database and return name->id mapping.
Args:
symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
Returns:
Dictionary mapping symbol name to database id
"""
if not self.db_conn or not symbols:
return {}
cursor = self.db_conn.cursor()
name_to_id = {}
for sym in symbols:
try:
cursor.execute('''
INSERT OR IGNORE INTO symbols
(qualified_name, name, kind, file_path, start_line, end_line)
VALUES (?, ?, ?, ?, ?, ?)
''', (sym['qualified_name'], sym['name'], sym['kind'],
sym['file_path'], sym['start_line'], sym['end_line']))
# Get the id
cursor.execute('''
SELECT id FROM symbols
WHERE file_path = ? AND name = ? AND start_line = ?
''', (sym['file_path'], sym['name'], sym['start_line']))
row = cursor.fetchone()
if row:
name_to_id[sym['name']] = row[0]
except sqlite3.Error:
continue
self.db_conn.commit()
return name_to_id
def save_relationships(self, relationships: List[Dict], name_to_id: Dict[str, int]) -> None:
"""Save relationships to database.
Args:
relationships: List of relationship dicts with source_scope, target, type, file_path, line
name_to_id: Dictionary mapping symbol names to database ids
"""
if not self.db_conn or not relationships:
return
cursor = self.db_conn.cursor()
for rel in relationships:
source_id = name_to_id.get(rel['source_scope'])
if source_id:
try:
cursor.execute('''
INSERT INTO symbol_relationships
(source_symbol_id, target_symbol_fqn, relationship_type, file_path, line)
VALUES (?, ?, ?, ?, ?)
''', (source_id, rel['target'], rel['type'], rel['file_path'], rel['line']))
except sqlite3.Error:
continue
self.db_conn.commit()
def close(self) -> None:
"""Close database connection."""
if self.db_conn:
self.db_conn.close()
self.db_conn = None

View File

@@ -0,0 +1,34 @@
"""LSP module for real-time language server integration.
This module provides:
- LspBridge: HTTP bridge to VSCode language servers
- LspGraphBuilder: Build code association graphs via LSP
- Location: Position in a source file
Example:
>>> from codexlens.lsp import LspBridge, LspGraphBuilder
>>>
>>> async with LspBridge() as bridge:
... refs = await bridge.get_references(symbol)
... graph = await LspGraphBuilder().build_from_seeds(seeds, bridge)
"""
from codexlens.lsp.lsp_bridge import (
CacheEntry,
Location,
LspBridge,
)
from codexlens.lsp.lsp_graph_builder import (
LspGraphBuilder,
)
# Alias for backward compatibility
GraphBuilder = LspGraphBuilder
__all__ = [
"CacheEntry",
"GraphBuilder",
"Location",
"LspBridge",
"LspGraphBuilder",
]

View File

@@ -0,0 +1,551 @@
"""LSP request handlers for codex-lens.
This module contains handlers for LSP requests:
- textDocument/definition
- textDocument/completion
- workspace/symbol
- textDocument/didSave
- textDocument/hover
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import List, Optional, Union
from urllib.parse import quote, unquote
try:
from lsprotocol import types as lsp
except ImportError as exc:
raise ImportError(
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
) from exc
from codexlens.entities import Symbol
from codexlens.lsp.server import server
logger = logging.getLogger(__name__)
# Symbol kind mapping from codex-lens to LSP
SYMBOL_KIND_MAP = {
"class": lsp.SymbolKind.Class,
"function": lsp.SymbolKind.Function,
"method": lsp.SymbolKind.Method,
"variable": lsp.SymbolKind.Variable,
"constant": lsp.SymbolKind.Constant,
"property": lsp.SymbolKind.Property,
"field": lsp.SymbolKind.Field,
"interface": lsp.SymbolKind.Interface,
"module": lsp.SymbolKind.Module,
"namespace": lsp.SymbolKind.Namespace,
"package": lsp.SymbolKind.Package,
"enum": lsp.SymbolKind.Enum,
"enum_member": lsp.SymbolKind.EnumMember,
"struct": lsp.SymbolKind.Struct,
"type": lsp.SymbolKind.TypeParameter,
"type_alias": lsp.SymbolKind.TypeParameter,
}
# Completion kind mapping from codex-lens to LSP
COMPLETION_KIND_MAP = {
"class": lsp.CompletionItemKind.Class,
"function": lsp.CompletionItemKind.Function,
"method": lsp.CompletionItemKind.Method,
"variable": lsp.CompletionItemKind.Variable,
"constant": lsp.CompletionItemKind.Constant,
"property": lsp.CompletionItemKind.Property,
"field": lsp.CompletionItemKind.Field,
"interface": lsp.CompletionItemKind.Interface,
"module": lsp.CompletionItemKind.Module,
"enum": lsp.CompletionItemKind.Enum,
"enum_member": lsp.CompletionItemKind.EnumMember,
"struct": lsp.CompletionItemKind.Struct,
"type": lsp.CompletionItemKind.TypeParameter,
"type_alias": lsp.CompletionItemKind.TypeParameter,
}
def _path_to_uri(path: Union[str, Path]) -> str:
"""Convert a file path to a URI.
Args:
path: File path (string or Path object)
Returns:
File URI string
"""
path_str = str(Path(path).resolve())
# Handle Windows paths
if path_str.startswith("/"):
return f"file://{quote(path_str)}"
else:
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
def _uri_to_path(uri: str) -> Path:
"""Convert a URI to a file path.
Args:
uri: File URI string
Returns:
Path object
"""
path = uri.replace("file:///", "").replace("file://", "")
return Path(unquote(path))
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
"""Extract the word at the given position in the document.
Args:
document_text: Full document text
line: 0-based line number
character: 0-based character position
Returns:
Word at position, or None if no word found
"""
lines = document_text.splitlines()
if line >= len(lines):
return None
line_text = lines[line]
if character > len(line_text):
return None
# Find word boundaries
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
for match in word_pattern.finditer(line_text):
if match.start() <= character <= match.end():
return match.group()
return None
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
"""Extract the incomplete word prefix at the given position.
Args:
document_text: Full document text
line: 0-based line number
character: 0-based character position
Returns:
Prefix string (may be empty)
"""
lines = document_text.splitlines()
if line >= len(lines):
return ""
line_text = lines[line]
if character > len(line_text):
character = len(line_text)
# Extract text before cursor
before_cursor = line_text[:character]
# Find the start of the current word
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
if match:
return match.group()
return ""
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
"""Convert a codex-lens Symbol to an LSP Location.
Args:
symbol: codex-lens Symbol object
Returns:
LSP Location, or None if symbol has no file
"""
if not symbol.file:
return None
# LSP uses 0-based lines, codex-lens uses 1-based
start_line = max(0, symbol.range[0] - 1)
end_line = max(0, symbol.range[1] - 1)
return lsp.Location(
uri=_path_to_uri(symbol.file),
range=lsp.Range(
start=lsp.Position(line=start_line, character=0),
end=lsp.Position(line=end_line, character=0),
),
)
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
"""Map codex-lens symbol kind to LSP SymbolKind.
Args:
kind: codex-lens symbol kind string
Returns:
LSP SymbolKind
"""
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
"""Map codex-lens symbol kind to LSP CompletionItemKind.
Args:
kind: codex-lens symbol kind string
Returns:
LSP CompletionItemKind
"""
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
# -----------------------------------------------------------------------------
# LSP Request Handlers
# -----------------------------------------------------------------------------
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
def lsp_definition(
params: lsp.DefinitionParams,
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
"""Handle textDocument/definition request.
Finds the definition of the symbol at the cursor position.
"""
if not server.global_index:
logger.debug("No global index available for definition lookup")
return None
# Get document
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
# Get word at position
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
logger.debug("No word found at position")
return None
logger.debug("Looking up definition for: %s", word)
# Search for exact symbol match
try:
symbols = server.global_index.search(
name=word,
limit=10,
prefix_mode=False, # Exact match preferred
)
# Filter for exact name match
exact_matches = [s for s in symbols if s.name == word]
if not exact_matches:
# Fall back to prefix search
symbols = server.global_index.search(
name=word,
limit=10,
prefix_mode=True,
)
exact_matches = [s for s in symbols if s.name == word]
if not exact_matches:
logger.debug("No definition found for: %s", word)
return None
# Convert to LSP locations
locations = []
for sym in exact_matches:
loc = symbol_to_location(sym)
if loc:
locations.append(loc)
if len(locations) == 1:
return locations[0]
elif locations:
return locations
else:
return None
except Exception as exc:
logger.error("Error looking up definition: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
"""Handle textDocument/references request.
Finds all references to the symbol at the cursor position using
the code_relationships table for accurate call-site tracking.
Falls back to same-name symbol search if search_engine is unavailable.
"""
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
return None
logger.debug("Finding references for: %s", word)
try:
# Try using search_engine.search_references() for accurate reference tracking
if server.search_engine and server.workspace_root:
references = server.search_engine.search_references(
symbol_name=word,
source_path=server.workspace_root,
limit=200,
)
if references:
locations = []
for ref in references:
locations.append(
lsp.Location(
uri=_path_to_uri(ref.file_path),
range=lsp.Range(
start=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column,
),
end=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column + len(word),
),
),
)
)
return locations if locations else None
# Fallback: search for symbols with same name using global_index
if server.global_index:
symbols = server.global_index.search(
name=word,
limit=100,
prefix_mode=False,
)
# Filter for exact matches
exact_matches = [s for s in symbols if s.name == word]
locations = []
for sym in exact_matches:
loc = symbol_to_location(sym)
if loc:
locations.append(loc)
return locations if locations else None
return None
except Exception as exc:
logger.error("Error finding references: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
"""Handle textDocument/completion request.
Provides code completion suggestions based on indexed symbols.
"""
if not server.global_index:
return None
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
prefix = _get_prefix_at_position(
document.source,
params.position.line,
params.position.character,
)
if not prefix or len(prefix) < 2:
# Require at least 2 characters for completion
return None
logger.debug("Completing prefix: %s", prefix)
try:
symbols = server.global_index.search(
name=prefix,
limit=50,
prefix_mode=True,
)
if not symbols:
return None
# Convert to completion items
items = []
seen_names = set()
for sym in symbols:
if sym.name in seen_names:
continue
seen_names.add(sym.name)
items.append(
lsp.CompletionItem(
label=sym.name,
kind=_symbol_kind_to_completion_kind(sym.kind),
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
sort_text=sym.name.lower(),
)
)
return lsp.CompletionList(
is_incomplete=len(symbols) >= 50,
items=items,
)
except Exception as exc:
logger.error("Error getting completions: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
"""Handle textDocument/hover request.
Provides hover information for the symbol at the cursor position
using HoverProvider for rich symbol information including
signature, documentation, and location.
"""
if not server.global_index:
return None
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
return None
logger.debug("Hover for: %s", word)
try:
# Use HoverProvider for rich symbol information
from codexlens.lsp.providers import HoverProvider
provider = HoverProvider(server.global_index, server.registry)
info = provider.get_hover_info(word)
if not info:
return None
# Format as markdown with signature and location
content = provider.format_hover_markdown(info)
return lsp.Hover(
contents=lsp.MarkupContent(
kind=lsp.MarkupKind.Markdown,
value=content,
),
)
except Exception as exc:
logger.error("Error getting hover info: %s", exc)
return None
@server.feature(lsp.WORKSPACE_SYMBOL)
def lsp_workspace_symbol(
params: lsp.WorkspaceSymbolParams,
) -> Optional[List[lsp.SymbolInformation]]:
"""Handle workspace/symbol request.
Searches for symbols across the workspace.
"""
if not server.global_index:
return None
query = params.query
if not query or len(query) < 2:
return None
logger.debug("Workspace symbol search: %s", query)
try:
symbols = server.global_index.search(
name=query,
limit=100,
prefix_mode=True,
)
if not symbols:
return None
result = []
for sym in symbols:
loc = symbol_to_location(sym)
if loc:
result.append(
lsp.SymbolInformation(
name=sym.name,
kind=_symbol_kind_to_lsp(sym.kind),
location=loc,
container_name=Path(sym.file).parent.name if sym.file else None,
)
)
return result if result else None
except Exception as exc:
logger.error("Error searching workspace symbols: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
"""Handle textDocument/didSave notification.
Triggers incremental re-indexing of the saved file.
Note: Full incremental indexing requires WatcherManager integration,
which is planned for Phase 2.
"""
file_path = _uri_to_path(params.text_document.uri)
logger.info("File saved: %s", file_path)
# Phase 1: Just log the save event
# Phase 2 will integrate with WatcherManager for incremental indexing
# if server.watcher_manager:
# server.watcher_manager.trigger_reindex(file_path)
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
"""Handle textDocument/didOpen notification."""
file_path = _uri_to_path(params.text_document.uri)
logger.debug("File opened: %s", file_path)
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
"""Handle textDocument/didClose notification."""
file_path = _uri_to_path(params.text_document.uri)
logger.debug("File closed: %s", file_path)

View File

@@ -0,0 +1,834 @@
"""LspBridge service for real-time LSP communication with caching.
This module provides a bridge to communicate with language servers either via:
1. Standalone LSP Manager (direct subprocess communication - default)
2. VSCode Bridge extension (HTTP-based, legacy mode)
Features:
- Direct communication with language servers (no VSCode dependency)
- Cache with TTL and file modification time invalidation
- Graceful error handling with empty results on failure
- Support for definition, references, hover, and call hierarchy
"""
from __future__ import annotations
import asyncio
import os
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from codexlens.lsp.standalone_manager import StandaloneLspManager
# Check for optional dependencies
try:
import aiohttp
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False
from codexlens.hybrid_search.data_structures import (
CallHierarchyItem,
CodeSymbolNode,
Range,
)
@dataclass
class Location:
"""A location in a source file (LSP response format)."""
file_path: str
line: int
character: int
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"file_path": self.file_path,
"line": self.line,
"character": self.character,
}
@classmethod
def from_lsp_response(cls, data: Dict[str, Any]) -> "Location":
"""Create Location from LSP response format.
Handles both direct format and VSCode URI format.
"""
# Handle VSCode URI format (file:///path/to/file)
uri = data.get("uri", data.get("file_path", ""))
if uri.startswith("file:///"):
# Windows: file:///C:/path -> C:/path
# Unix: file:///path -> /path
file_path = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
elif uri.startswith("file://"):
file_path = uri[7:]
else:
file_path = uri
# Get position from range or direct fields
if "range" in data:
range_data = data["range"]
start = range_data.get("start", {})
line = start.get("line", 0) + 1 # LSP is 0-based, convert to 1-based
character = start.get("character", 0) + 1
else:
line = data.get("line", 1)
character = data.get("character", 1)
return cls(file_path=file_path, line=line, character=character)
@dataclass
class CacheEntry:
"""A cached LSP response with expiration metadata.
Attributes:
data: The cached response data
file_mtime: File modification time when cached (for invalidation)
cached_at: Unix timestamp when entry was cached
"""
data: Any
file_mtime: float
cached_at: float
class LspBridge:
"""Bridge for real-time LSP communication with language servers.
By default, uses StandaloneLspManager to directly spawn and communicate
with language servers via JSON-RPC over stdio. No VSCode dependency required.
For legacy mode, can use VSCode Bridge HTTP server (set use_vscode_bridge=True).
Features:
- Direct language server communication (default)
- Response caching with TTL and file modification invalidation
- Timeout handling
- Graceful error handling returning empty results
Example:
# Default: standalone mode (no VSCode needed)
async with LspBridge() as bridge:
refs = await bridge.get_references(symbol)
definition = await bridge.get_definition(symbol)
# Legacy: VSCode Bridge mode
async with LspBridge(use_vscode_bridge=True) as bridge:
refs = await bridge.get_references(symbol)
"""
DEFAULT_BRIDGE_URL = "http://127.0.0.1:3457"
DEFAULT_TIMEOUT = 30.0 # seconds (increased for standalone mode)
DEFAULT_CACHE_TTL = 300 # 5 minutes
DEFAULT_MAX_CACHE_SIZE = 1000 # Maximum cache entries
def __init__(
self,
bridge_url: str = DEFAULT_BRIDGE_URL,
timeout: float = DEFAULT_TIMEOUT,
cache_ttl: int = DEFAULT_CACHE_TTL,
max_cache_size: int = DEFAULT_MAX_CACHE_SIZE,
use_vscode_bridge: bool = False,
workspace_root: Optional[str] = None,
config_file: Optional[str] = None,
):
"""Initialize LspBridge.
Args:
bridge_url: URL of the VSCode Bridge HTTP server (legacy mode only)
timeout: Request timeout in seconds
cache_ttl: Cache time-to-live in seconds
max_cache_size: Maximum number of cache entries (LRU eviction)
use_vscode_bridge: If True, use VSCode Bridge HTTP mode (requires aiohttp)
workspace_root: Root directory for standalone LSP manager
config_file: Path to lsp-servers.json configuration file
"""
self.bridge_url = bridge_url
self.timeout = timeout
self.cache_ttl = cache_ttl
self.max_cache_size = max_cache_size
self.use_vscode_bridge = use_vscode_bridge
self.workspace_root = workspace_root
self.config_file = config_file
self.cache: OrderedDict[str, CacheEntry] = OrderedDict()
# VSCode Bridge mode (legacy)
self._session: Optional["aiohttp.ClientSession"] = None
# Standalone mode (default)
self._manager: Optional["StandaloneLspManager"] = None
self._manager_started = False
# Validate dependencies
if use_vscode_bridge and not HAS_AIOHTTP:
raise ImportError(
"aiohttp is required for VSCode Bridge mode: pip install aiohttp"
)
async def _ensure_manager(self) -> "StandaloneLspManager":
"""Ensure standalone LSP manager is started."""
if self._manager is None:
from codexlens.lsp.standalone_manager import StandaloneLspManager
self._manager = StandaloneLspManager(
workspace_root=self.workspace_root,
config_file=self.config_file,
timeout=self.timeout,
)
if not self._manager_started:
await self._manager.start()
self._manager_started = True
return self._manager
async def _get_session(self) -> "aiohttp.ClientSession":
"""Get or create the aiohttp session (VSCode Bridge mode only)."""
if not HAS_AIOHTTP:
raise ImportError("aiohttp required for VSCode Bridge mode")
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
async def close(self) -> None:
"""Close connections and cleanup resources."""
# Close VSCode Bridge session
if self._session and not self._session.closed:
await self._session.close()
self._session = None
# Stop standalone manager
if self._manager and self._manager_started:
await self._manager.stop()
self._manager_started = False
def _get_file_mtime(self, file_path: str) -> float:
"""Get file modification time, or 0 if file doesn't exist."""
try:
return os.path.getmtime(file_path)
except OSError:
return 0.0
def _is_cached(self, cache_key: str, file_path: str) -> bool:
"""Check if cache entry is valid.
Cache is invalid if:
- Entry doesn't exist
- TTL has expired
- File has been modified since caching
Args:
cache_key: The cache key to check
file_path: Path to source file for mtime check
Returns:
True if cache is valid and can be used
"""
if cache_key not in self.cache:
return False
entry = self.cache[cache_key]
now = time.time()
# Check TTL
if now - entry.cached_at > self.cache_ttl:
del self.cache[cache_key]
return False
# Check file modification time
current_mtime = self._get_file_mtime(file_path)
if current_mtime != entry.file_mtime:
del self.cache[cache_key]
return False
# Move to end on access (LRU behavior)
self.cache.move_to_end(cache_key)
return True
def _cache(self, key: str, file_path: str, data: Any) -> None:
"""Store data in cache with LRU eviction.
Args:
key: Cache key
file_path: Path to source file (for mtime tracking)
data: Data to cache
"""
# Remove oldest entries if at capacity
while len(self.cache) >= self.max_cache_size:
self.cache.popitem(last=False) # Remove oldest (FIFO order)
# Move to end if key exists (update access order)
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = CacheEntry(
data=data,
file_mtime=self._get_file_mtime(file_path),
cached_at=time.time(),
)
def clear_cache(self) -> None:
"""Clear all cached entries."""
self.cache.clear()
async def _request_vscode_bridge(self, action: str, params: Dict[str, Any]) -> Any:
"""Make HTTP request to VSCode Bridge (legacy mode).
Args:
action: The endpoint/action name (e.g., "get_definition")
params: Request parameters
Returns:
Response data on success, None on failure
"""
url = f"{self.bridge_url}/{action}"
try:
session = await self._get_session()
async with session.post(url, json=params) as response:
if response.status != 200:
return None
data = await response.json()
if data.get("success") is False:
return None
return data.get("result")
except asyncio.TimeoutError:
return None
except Exception:
return None
async def get_references(self, symbol: CodeSymbolNode) -> List[Location]:
"""Get all references to a symbol via real-time LSP.
Args:
symbol: The code symbol to find references for
Returns:
List of Location objects where the symbol is referenced.
Returns empty list on error or timeout.
"""
cache_key = f"refs:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
locations: List[Location] = []
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_references", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
# Don't cache on connection error (result is None)
if result is None:
return locations
if isinstance(result, list):
for item in result:
try:
locations.append(Location.from_lsp_response(item))
except (KeyError, TypeError):
continue
else:
# Default: Standalone mode
manager = await self._ensure_manager()
result = await manager.get_references(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
for item in result:
try:
locations.append(Location.from_lsp_response(item))
except (KeyError, TypeError):
continue
self._cache(cache_key, symbol.file_path, locations)
return locations
async def get_definition(self, symbol: CodeSymbolNode) -> Optional[Location]:
"""Get symbol definition location.
Args:
symbol: The code symbol to find definition for
Returns:
Location of the definition, or None if not found
"""
cache_key = f"def:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
location: Optional[Location] = None
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_definition", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
if result:
if isinstance(result, list) and len(result) > 0:
try:
location = Location.from_lsp_response(result[0])
except (KeyError, TypeError):
pass
elif isinstance(result, dict):
try:
location = Location.from_lsp_response(result)
except (KeyError, TypeError):
pass
else:
# Default: Standalone mode
manager = await self._ensure_manager()
result = await manager.get_definition(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
if result:
try:
location = Location.from_lsp_response(result)
except (KeyError, TypeError):
pass
self._cache(cache_key, symbol.file_path, location)
return location
async def get_call_hierarchy(self, symbol: CodeSymbolNode) -> List[CallHierarchyItem]:
"""Get incoming/outgoing calls for a symbol.
If call hierarchy is not supported by the language server,
falls back to using references.
Args:
symbol: The code symbol to get call hierarchy for
Returns:
List of CallHierarchyItem representing callers/callees.
Returns empty list on error or if not supported.
"""
cache_key = f"calls:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
items: List[CallHierarchyItem] = []
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_call_hierarchy", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
if result is None:
# Fallback: use references
refs = await self.get_references(symbol)
for ref in refs:
items.append(CallHierarchyItem(
name=f"caller@{ref.line}",
kind="reference",
file_path=ref.file_path,
range=Range(
start_line=ref.line,
start_character=ref.character,
end_line=ref.line,
end_character=ref.character,
),
detail="Inferred from reference",
))
elif isinstance(result, list):
for item in result:
try:
range_data = item.get("range", {})
start = range_data.get("start", {})
end = range_data.get("end", {})
items.append(CallHierarchyItem(
name=item.get("name", "unknown"),
kind=item.get("kind", "unknown"),
file_path=item.get("file_path", item.get("uri", "")),
range=Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
),
detail=item.get("detail"),
))
except (KeyError, TypeError):
continue
else:
# Default: Standalone mode
manager = await self._ensure_manager()
# Try to get call hierarchy items
hierarchy_items = await manager.get_call_hierarchy_items(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
if hierarchy_items:
# Get incoming calls for each item
for h_item in hierarchy_items:
incoming = await manager.get_incoming_calls(h_item)
for call in incoming:
from_item = call.get("from", {})
range_data = from_item.get("range", {})
start = range_data.get("start", {})
end = range_data.get("end", {})
# Parse URI
uri = from_item.get("uri", "")
if uri.startswith("file:///"):
fp = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
elif uri.startswith("file://"):
fp = uri[7:]
else:
fp = uri
items.append(CallHierarchyItem(
name=from_item.get("name", "unknown"),
kind=str(from_item.get("kind", "unknown")),
file_path=fp,
range=Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
),
detail=from_item.get("detail"),
))
else:
# Fallback: use references
refs = await self.get_references(symbol)
for ref in refs:
items.append(CallHierarchyItem(
name=f"caller@{ref.line}",
kind="reference",
file_path=ref.file_path,
range=Range(
start_line=ref.line,
start_character=ref.character,
end_line=ref.line,
end_character=ref.character,
),
detail="Inferred from reference",
))
self._cache(cache_key, symbol.file_path, items)
return items
async def get_document_symbols(self, file_path: str) -> List[Dict[str, Any]]:
"""Get all symbols in a document (batch operation).
This is more efficient than individual hover queries when processing
multiple locations in the same file.
Args:
file_path: Path to the source file
Returns:
List of symbol dictionaries with name, kind, range, etc.
Returns empty list on error or timeout.
"""
cache_key = f"symbols:{file_path}"
if self._is_cached(cache_key, file_path):
return self.cache[cache_key].data
symbols: List[Dict[str, Any]] = []
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_document_symbols", {
"file_path": file_path,
})
if isinstance(result, list):
symbols = self._flatten_document_symbols(result)
else:
# Default: Standalone mode
manager = await self._ensure_manager()
result = await manager.get_document_symbols(file_path)
if result:
symbols = self._flatten_document_symbols(result)
self._cache(cache_key, file_path, symbols)
return symbols
def _flatten_document_symbols(
self, symbols: List[Dict[str, Any]], parent_name: str = ""
) -> List[Dict[str, Any]]:
"""Flatten nested document symbols into a flat list.
Document symbols can be nested (e.g., methods inside classes).
This flattens them for easier lookup by line number.
Args:
symbols: List of symbol dictionaries (may be nested)
parent_name: Name of parent symbol for qualification
Returns:
Flat list of all symbols with their ranges
"""
flat: List[Dict[str, Any]] = []
for sym in symbols:
# Add the symbol itself
symbol_entry = {
"name": sym.get("name", "unknown"),
"kind": self._symbol_kind_to_string(sym.get("kind", 0)),
"range": sym.get("range", sym.get("location", {}).get("range", {})),
"selection_range": sym.get("selectionRange", {}),
"detail": sym.get("detail", ""),
"parent": parent_name,
}
flat.append(symbol_entry)
# Recursively process children
children = sym.get("children", [])
if children:
qualified_name = sym.get("name", "")
if parent_name:
qualified_name = f"{parent_name}.{qualified_name}"
flat.extend(self._flatten_document_symbols(children, qualified_name))
return flat
def _symbol_kind_to_string(self, kind: int) -> str:
"""Convert LSP SymbolKind integer to string.
Args:
kind: LSP SymbolKind enum value
Returns:
Human-readable string representation
"""
# LSP SymbolKind enum (1-indexed)
kinds = {
1: "file",
2: "module",
3: "namespace",
4: "package",
5: "class",
6: "method",
7: "property",
8: "field",
9: "constructor",
10: "enum",
11: "interface",
12: "function",
13: "variable",
14: "constant",
15: "string",
16: "number",
17: "boolean",
18: "array",
19: "object",
20: "key",
21: "null",
22: "enum_member",
23: "struct",
24: "event",
25: "operator",
26: "type_parameter",
}
return kinds.get(kind, "unknown")
async def get_hover(self, symbol: CodeSymbolNode) -> Optional[str]:
"""Get hover documentation for a symbol.
Args:
symbol: The code symbol to get hover info for
Returns:
Hover documentation as string, or None if not available
"""
cache_key = f"hover:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
hover_text: Optional[str] = None
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_hover", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
if result:
hover_text = self._parse_hover_result(result)
else:
# Default: Standalone mode
manager = await self._ensure_manager()
hover_text = await manager.get_hover(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
self._cache(cache_key, symbol.file_path, hover_text)
return hover_text
def _parse_hover_result(self, result: Any) -> Optional[str]:
"""Parse hover result into string."""
if isinstance(result, str):
return result
elif isinstance(result, list):
parts = []
for item in result:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
value = item.get("value", item.get("contents", ""))
if value:
parts.append(str(value))
return "\n\n".join(parts) if parts else None
elif isinstance(result, dict):
contents = result.get("contents", result.get("value", ""))
if isinstance(contents, str):
return contents
elif isinstance(contents, list):
parts = []
for c in contents:
if isinstance(c, str):
parts.append(c)
elif isinstance(c, dict):
parts.append(str(c.get("value", "")))
return "\n\n".join(parts) if parts else None
return None
async def __aenter__(self) -> "LspBridge":
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit - close connections."""
await self.close()
# Simple test
if __name__ == "__main__":
import sys
async def test_lsp_bridge():
"""Simple test of LspBridge functionality."""
print("Testing LspBridge (Standalone Mode)...")
print(f"Timeout: {LspBridge.DEFAULT_TIMEOUT}s")
print(f"Cache TTL: {LspBridge.DEFAULT_CACHE_TTL}s")
print()
# Create a test symbol pointing to this file
test_file = os.path.abspath(__file__)
test_symbol = CodeSymbolNode(
id=f"{test_file}:LspBridge:96",
name="LspBridge",
kind="class",
file_path=test_file,
range=Range(
start_line=96,
start_character=1,
end_line=200,
end_character=1,
),
)
print(f"Test symbol: {test_symbol.name} in {os.path.basename(test_symbol.file_path)}")
print()
# Use standalone mode (default)
async with LspBridge(
workspace_root=str(Path(__file__).parent.parent.parent.parent),
) as bridge:
print("1. Testing get_document_symbols...")
try:
symbols = await bridge.get_document_symbols(test_file)
print(f" Found {len(symbols)} symbols")
for sym in symbols[:5]:
print(f" - {sym.get('name')} ({sym.get('kind')})")
except Exception as e:
print(f" Error: {e}")
print()
print("2. Testing get_definition...")
try:
definition = await bridge.get_definition(test_symbol)
if definition:
print(f" Definition: {os.path.basename(definition.file_path)}:{definition.line}")
else:
print(" No definition found")
except Exception as e:
print(f" Error: {e}")
print()
print("3. Testing get_references...")
try:
refs = await bridge.get_references(test_symbol)
print(f" Found {len(refs)} references")
for ref in refs[:3]:
print(f" - {os.path.basename(ref.file_path)}:{ref.line}")
except Exception as e:
print(f" Error: {e}")
print()
print("4. Testing get_hover...")
try:
hover = await bridge.get_hover(test_symbol)
if hover:
print(f" Hover: {hover[:100]}...")
else:
print(" No hover info found")
except Exception as e:
print(f" Error: {e}")
print()
print("5. Testing get_call_hierarchy...")
try:
calls = await bridge.get_call_hierarchy(test_symbol)
print(f" Found {len(calls)} call hierarchy items")
for call in calls[:3]:
print(f" - {call.name} in {os.path.basename(call.file_path)}")
except Exception as e:
print(f" Error: {e}")
print()
print("6. Testing cache...")
print(f" Cache entries: {len(bridge.cache)}")
for key in list(bridge.cache.keys())[:5]:
print(f" - {key}")
print()
print("Test complete!")
# Run the test
# Note: On Windows, use default ProactorEventLoop (supports subprocess creation)
asyncio.run(test_lsp_bridge())

View File

@@ -0,0 +1,375 @@
"""Graph builder for code association graphs via LSP."""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Dict, List, Optional, Set, Tuple
from codexlens.hybrid_search.data_structures import (
CallHierarchyItem,
CodeAssociationGraph,
CodeSymbolNode,
Range,
)
from codexlens.lsp.lsp_bridge import (
Location,
LspBridge,
)
logger = logging.getLogger(__name__)
class LspGraphBuilder:
"""Builds code association graph by expanding from seed symbols using LSP."""
def __init__(
self,
max_depth: int = 2,
max_nodes: int = 100,
max_concurrent: int = 10,
):
"""Initialize GraphBuilder.
Args:
max_depth: Maximum depth for BFS expansion from seeds.
max_nodes: Maximum number of nodes in the graph.
max_concurrent: Maximum concurrent LSP requests.
"""
self.max_depth = max_depth
self.max_nodes = max_nodes
self.max_concurrent = max_concurrent
# Cache for document symbols per file (avoids per-location hover queries)
self._document_symbols_cache: Dict[str, List[Dict[str, Any]]] = {}
async def build_from_seeds(
self,
seeds: List[CodeSymbolNode],
lsp_bridge: LspBridge,
) -> CodeAssociationGraph:
"""Build association graph by BFS expansion from seeds.
For each seed:
1. Get references via LSP
2. Get call hierarchy via LSP
3. Add nodes and edges to graph
4. Continue expanding until max_depth or max_nodes reached
Args:
seeds: Initial seed symbols to expand from.
lsp_bridge: LSP bridge for querying language servers.
Returns:
CodeAssociationGraph with expanded nodes and relationships.
"""
graph = CodeAssociationGraph()
visited: Set[str] = set()
semaphore = asyncio.Semaphore(self.max_concurrent)
# Initialize queue with seeds at depth 0
queue: List[Tuple[CodeSymbolNode, int]] = [(s, 0) for s in seeds]
# Add seed nodes to graph
for seed in seeds:
graph.add_node(seed)
# BFS expansion
while queue and len(graph.nodes) < self.max_nodes:
# Take a batch of nodes from queue
batch_size = min(self.max_concurrent, len(queue))
batch = queue[:batch_size]
queue = queue[batch_size:]
# Expand nodes in parallel
tasks = [
self._expand_node(
node, depth, graph, lsp_bridge, visited, semaphore
)
for node, depth in batch
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results and add new nodes to queue
for result in results:
if isinstance(result, Exception):
logger.warning("Error expanding node: %s", result)
continue
if result:
# Add new nodes to queue if not at max depth
for new_node, new_depth in result:
if (
new_depth <= self.max_depth
and len(graph.nodes) < self.max_nodes
):
queue.append((new_node, new_depth))
return graph
async def _expand_node(
self,
node: CodeSymbolNode,
depth: int,
graph: CodeAssociationGraph,
lsp_bridge: LspBridge,
visited: Set[str],
semaphore: asyncio.Semaphore,
) -> List[Tuple[CodeSymbolNode, int]]:
"""Expand a single node, return new nodes to process.
Args:
node: Node to expand.
depth: Current depth in BFS.
graph: Graph to add nodes and edges to.
lsp_bridge: LSP bridge for queries.
visited: Set of visited node IDs.
semaphore: Semaphore for concurrency control.
Returns:
List of (new_node, new_depth) tuples to add to queue.
"""
# Skip if already visited or at max depth
if node.id in visited:
return []
if depth > self.max_depth:
return []
if len(graph.nodes) >= self.max_nodes:
return []
visited.add(node.id)
new_nodes: List[Tuple[CodeSymbolNode, int]] = []
async with semaphore:
# Get relationships in parallel
try:
refs_task = lsp_bridge.get_references(node)
calls_task = lsp_bridge.get_call_hierarchy(node)
refs, calls = await asyncio.gather(
refs_task, calls_task, return_exceptions=True
)
# Handle reference results
if isinstance(refs, Exception):
logger.debug(
"Failed to get references for %s: %s", node.id, refs
)
refs = []
# Handle call hierarchy results
if isinstance(calls, Exception):
logger.debug(
"Failed to get call hierarchy for %s: %s",
node.id,
calls,
)
calls = []
# Process references
for ref in refs:
if len(graph.nodes) >= self.max_nodes:
break
ref_node = await self._location_to_node(ref, lsp_bridge)
if ref_node and ref_node.id != node.id:
if ref_node.id not in graph.nodes:
graph.add_node(ref_node)
new_nodes.append((ref_node, depth + 1))
# Use add_edge since both nodes should exist now
graph.add_edge(node.id, ref_node.id, "references")
# Process call hierarchy (incoming calls)
for call in calls:
if len(graph.nodes) >= self.max_nodes:
break
call_node = await self._call_hierarchy_to_node(
call, lsp_bridge
)
if call_node and call_node.id != node.id:
if call_node.id not in graph.nodes:
graph.add_node(call_node)
new_nodes.append((call_node, depth + 1))
# Incoming call: call_node calls node
graph.add_edge(call_node.id, node.id, "calls")
except Exception as e:
logger.warning(
"Error during node expansion for %s: %s", node.id, e
)
return new_nodes
def clear_cache(self) -> None:
"""Clear the document symbols cache.
Call this between searches to free memory and ensure fresh data.
"""
self._document_symbols_cache.clear()
async def _get_symbol_at_location(
self,
file_path: str,
line: int,
lsp_bridge: LspBridge,
) -> Optional[Dict[str, Any]]:
"""Find symbol at location using cached document symbols.
This is much more efficient than individual hover queries because
document symbols are fetched once per file and cached.
Args:
file_path: Path to the source file.
line: Line number (1-based).
lsp_bridge: LSP bridge for fetching document symbols.
Returns:
Symbol dictionary with name, kind, range, etc., or None if not found.
"""
# Get or fetch document symbols for this file
if file_path not in self._document_symbols_cache:
symbols = await lsp_bridge.get_document_symbols(file_path)
self._document_symbols_cache[file_path] = symbols
symbols = self._document_symbols_cache[file_path]
# Find symbol containing this line (best match = smallest range)
best_match: Optional[Dict[str, Any]] = None
best_range_size = float("inf")
for symbol in symbols:
sym_range = symbol.get("range", {})
start = sym_range.get("start", {})
end = sym_range.get("end", {})
# LSP ranges are 0-based, our line is 1-based
start_line = start.get("line", 0) + 1
end_line = end.get("line", 0) + 1
if start_line <= line <= end_line:
range_size = end_line - start_line
if range_size < best_range_size:
best_match = symbol
best_range_size = range_size
return best_match
async def _location_to_node(
self,
location: Location,
lsp_bridge: LspBridge,
) -> Optional[CodeSymbolNode]:
"""Convert LSP location to CodeSymbolNode.
Uses cached document symbols instead of individual hover queries
for better performance.
Args:
location: LSP location to convert.
lsp_bridge: LSP bridge for additional queries.
Returns:
CodeSymbolNode or None if conversion fails.
"""
try:
file_path = location.file_path
start_line = location.line
# Try to find symbol info from cached document symbols (fast)
symbol_info = await self._get_symbol_at_location(
file_path, start_line, lsp_bridge
)
if symbol_info:
name = symbol_info.get("name", f"symbol_L{start_line}")
kind = symbol_info.get("kind", "unknown")
# Extract range from symbol if available
sym_range = symbol_info.get("range", {})
start = sym_range.get("start", {})
end = sym_range.get("end", {})
location_range = Range(
start_line=start.get("line", start_line - 1) + 1,
start_character=start.get("character", location.character - 1) + 1,
end_line=end.get("line", start_line - 1) + 1,
end_character=end.get("character", location.character - 1) + 1,
)
else:
# Fallback to basic node without symbol info
name = f"symbol_L{start_line}"
kind = "unknown"
location_range = Range(
start_line=location.line,
start_character=location.character,
end_line=location.line,
end_character=location.character,
)
node_id = self._create_node_id(file_path, name, start_line)
return CodeSymbolNode(
id=node_id,
name=name,
kind=kind,
file_path=file_path,
range=location_range,
docstring="", # Skip hover for performance
)
except Exception as e:
logger.debug("Failed to convert location to node: %s", e)
return None
async def _call_hierarchy_to_node(
self,
call_item: CallHierarchyItem,
lsp_bridge: LspBridge,
) -> Optional[CodeSymbolNode]:
"""Convert CallHierarchyItem to CodeSymbolNode.
Args:
call_item: Call hierarchy item to convert.
lsp_bridge: LSP bridge (unused, kept for API consistency).
Returns:
CodeSymbolNode or None if conversion fails.
"""
try:
file_path = call_item.file_path
name = call_item.name
start_line = call_item.range.start_line
# CallHierarchyItem.kind is already a string
kind = call_item.kind
node_id = self._create_node_id(file_path, name, start_line)
return CodeSymbolNode(
id=node_id,
name=name,
kind=kind,
file_path=file_path,
range=call_item.range,
docstring=call_item.detail or "",
)
except Exception as e:
logger.debug(
"Failed to convert call hierarchy item to node: %s", e
)
return None
def _create_node_id(
self, file_path: str, name: str, line: int
) -> str:
"""Create unique node ID.
Args:
file_path: Path to the file.
name: Symbol name.
line: Line number (0-based).
Returns:
Unique node ID string.
"""
return f"{file_path}:{name}:{line}"

View File

@@ -0,0 +1,177 @@
"""LSP feature providers."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.registry import RegistryStore
logger = logging.getLogger(__name__)
@dataclass
class HoverInfo:
"""Hover information for a symbol."""
name: str
kind: str
signature: str
documentation: Optional[str]
file_path: str
line_range: tuple # (start_line, end_line)
class HoverProvider:
"""Provides hover information for symbols."""
def __init__(
self,
global_index: "GlobalSymbolIndex",
registry: Optional["RegistryStore"] = None,
) -> None:
"""Initialize hover provider.
Args:
global_index: Global symbol index for lookups
registry: Optional registry store for index path resolution
"""
self.global_index = global_index
self.registry = registry
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
"""Get hover information for a symbol.
Args:
symbol_name: Name of the symbol to look up
Returns:
HoverInfo or None if symbol not found
"""
# Look up symbol in global index using exact match
symbols = self.global_index.search(
name=symbol_name,
limit=1,
prefix_mode=False,
)
# Filter for exact name match
exact_matches = [s for s in symbols if s.name == symbol_name]
if not exact_matches:
return None
symbol = exact_matches[0]
# Extract signature from source file
signature = self._extract_signature(symbol)
# Symbol uses 'file' attribute and 'range' tuple
file_path = symbol.file or ""
start_line, end_line = symbol.range
return HoverInfo(
name=symbol.name,
kind=symbol.kind,
signature=signature,
documentation=None, # Symbol doesn't have docstring field
file_path=file_path,
line_range=(start_line, end_line),
)
def _extract_signature(self, symbol) -> str:
"""Extract function/class signature from source file.
Args:
symbol: Symbol object with file and range information
Returns:
Extracted signature string or fallback kind + name
"""
try:
file_path = Path(symbol.file) if symbol.file else None
if not file_path or not file_path.exists():
return f"{symbol.kind} {symbol.name}"
content = file_path.read_text(encoding="utf-8", errors="ignore")
lines = content.split("\n")
# Extract signature lines (first line of definition + continuation)
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
if start_line >= len(lines) or start_line < 0:
return f"{symbol.kind} {symbol.name}"
signature_lines = []
first_line = lines[start_line]
signature_lines.append(first_line)
# Continue if multiline signature (no closing paren + colon yet)
# Look for patterns like "def func(", "class Foo(", etc.
i = start_line + 1
max_lines = min(start_line + 5, len(lines))
while i < max_lines:
line = signature_lines[-1]
# Stop if we see closing pattern
if "):" in line or line.rstrip().endswith(":"):
break
signature_lines.append(lines[i])
i += 1
return "\n".join(signature_lines)
except Exception as e:
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
return f"{symbol.kind} {symbol.name}"
def format_hover_markdown(self, info: HoverInfo) -> str:
"""Format hover info as Markdown.
Args:
info: HoverInfo object to format
Returns:
Markdown-formatted hover content
"""
parts = []
# Detect language for code fence based on file extension
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
lang_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".tsx": "typescript",
".jsx": "javascript",
".java": "java",
".go": "go",
".rs": "rust",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
".cs": "csharp",
".rb": "ruby",
".php": "php",
}
lang = lang_map.get(ext, "")
# Code block with signature
parts.append(f"```{lang}\n{info.signature}\n```")
# Documentation if available
if info.documentation:
parts.append(f"\n---\n\n{info.documentation}")
# Location info
file_name = Path(info.file_path).name if info.file_path else "unknown"
parts.append(
f"\n---\n\n*{info.kind}* defined in "
f"`{file_name}` "
f"(line {info.line_range[0]})"
)
return "\n".join(parts)

View File

@@ -0,0 +1,263 @@
"""codex-lens LSP Server implementation using pygls.
This module provides the main Language Server class and entry point.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from typing import Optional
try:
from lsprotocol import types as lsp
from pygls.lsp.server import LanguageServer
except ImportError as exc:
raise ImportError(
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
) from exc
from codexlens.config import Config
from codexlens.search.chain_search import ChainSearchEngine
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
logger = logging.getLogger(__name__)
class CodexLensLanguageServer(LanguageServer):
"""Language Server for codex-lens code indexing.
Provides IDE features using codex-lens symbol index:
- Go to Definition
- Find References
- Code Completion
- Hover Information
- Workspace Symbol Search
Attributes:
registry: Global project registry for path lookups
mapper: Path mapper for source/index conversions
global_index: Project-wide symbol index
search_engine: Chain search engine for symbol search
workspace_root: Current workspace root path
"""
def __init__(self) -> None:
super().__init__(name="codexlens-lsp", version="0.1.0")
self.registry: Optional[RegistryStore] = None
self.mapper: Optional[PathMapper] = None
self.global_index: Optional[GlobalSymbolIndex] = None
self.search_engine: Optional[ChainSearchEngine] = None
self.workspace_root: Optional[Path] = None
self._config: Optional[Config] = None
def initialize_components(self, workspace_root: Path) -> bool:
"""Initialize codex-lens components for the workspace.
Args:
workspace_root: Root path of the workspace
Returns:
True if initialization succeeded, False otherwise
"""
self.workspace_root = workspace_root.resolve()
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
try:
# Initialize registry
self.registry = RegistryStore()
self.registry.initialize()
# Initialize path mapper
self.mapper = PathMapper()
# Try to find project in registry
project_info = self.registry.find_by_source_path(str(self.workspace_root))
if project_info:
project_id = int(project_info["id"])
index_root = Path(project_info["index_root"])
# Initialize global symbol index
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
self.global_index = GlobalSymbolIndex(global_db, project_id)
self.global_index.initialize()
# Initialize search engine
self._config = Config()
self.search_engine = ChainSearchEngine(
registry=self.registry,
mapper=self.mapper,
config=self._config,
)
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
return True
else:
logger.warning(
"Workspace not indexed by codex-lens: %s. "
"Run 'codexlens index %s' to index first.",
self.workspace_root,
self.workspace_root,
)
return False
except Exception as exc:
logger.error("Failed to initialize codex-lens: %s", exc)
return False
def shutdown_components(self) -> None:
"""Clean up codex-lens components."""
if self.global_index:
try:
self.global_index.close()
except Exception as exc:
logger.debug("Error closing global index: %s", exc)
self.global_index = None
if self.search_engine:
try:
self.search_engine.close()
except Exception as exc:
logger.debug("Error closing search engine: %s", exc)
self.search_engine = None
if self.registry:
try:
self.registry.close()
except Exception as exc:
logger.debug("Error closing registry: %s", exc)
self.registry = None
# Create server instance
server = CodexLensLanguageServer()
@server.feature(lsp.INITIALIZE)
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
"""Handle LSP initialize request."""
logger.info("LSP initialize request received")
# Get workspace root
workspace_root: Optional[Path] = None
if params.root_uri:
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
elif params.root_path:
workspace_root = Path(params.root_path)
if workspace_root:
server.initialize_components(workspace_root)
# Declare server capabilities
return lsp.InitializeResult(
capabilities=lsp.ServerCapabilities(
text_document_sync=lsp.TextDocumentSyncOptions(
open_close=True,
change=lsp.TextDocumentSyncKind.Incremental,
save=lsp.SaveOptions(include_text=False),
),
definition_provider=True,
references_provider=True,
completion_provider=lsp.CompletionOptions(
trigger_characters=[".", ":"],
resolve_provider=False,
),
hover_provider=True,
workspace_symbol_provider=True,
),
server_info=lsp.ServerInfo(
name="codexlens-lsp",
version="0.1.0",
),
)
@server.feature(lsp.SHUTDOWN)
def lsp_shutdown(params: None) -> None:
"""Handle LSP shutdown request."""
logger.info("LSP shutdown request received")
server.shutdown_components()
def main() -> int:
"""Entry point for codexlens-lsp command.
Returns:
Exit code (0 for success)
"""
# Import handlers to register them with the server
# This must be done before starting the server
import codexlens.lsp.handlers # noqa: F401
parser = argparse.ArgumentParser(
description="codex-lens Language Server",
prog="codexlens-lsp",
)
parser.add_argument(
"--stdio",
action="store_true",
default=True,
help="Use stdio for communication (default)",
)
parser.add_argument(
"--tcp",
action="store_true",
help="Use TCP for communication",
)
parser.add_argument(
"--host",
default="127.0.0.1",
help="TCP host (default: 127.0.0.1)",
)
parser.add_argument(
"--port",
type=int,
default=2087,
help="TCP port (default: 2087)",
)
parser.add_argument(
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Log level (default: INFO)",
)
parser.add_argument(
"--log-file",
help="Log file path (optional)",
)
args = parser.parse_args()
# Configure logging
log_handlers = []
if args.log_file:
log_handlers.append(logging.FileHandler(args.log_file))
else:
log_handlers.append(logging.StreamHandler(sys.stderr))
logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=log_handlers,
)
logger.info("Starting codexlens-lsp server")
if args.tcp:
logger.info("Starting TCP server on %s:%d", args.host, args.port)
server.start_tcp(args.host, args.port)
else:
logger.info("Starting stdio server")
server.start_io()
return 0
if __name__ == "__main__":
sys.exit(main())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
"""Model Context Protocol implementation for Claude Code integration."""
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
from codexlens.mcp.provider import MCPProvider
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
__all__ = [
"MCPContext",
"SymbolInfo",
"ReferenceInfo",
"RelatedSymbol",
"MCPProvider",
"HookManager",
"create_context_for_prompt",
]

View File

@@ -0,0 +1,170 @@
"""Hook interfaces for Claude Code integration."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
from codexlens.mcp.schema import MCPContext
if TYPE_CHECKING:
from codexlens.mcp.provider import MCPProvider
logger = logging.getLogger(__name__)
class HookManager:
"""Manages hook registration and execution."""
def __init__(self, mcp_provider: "MCPProvider") -> None:
self.mcp_provider = mcp_provider
self._pre_hooks: Dict[str, Callable] = {}
self._post_hooks: Dict[str, Callable] = {}
# Register default hooks
self._register_default_hooks()
def _register_default_hooks(self) -> None:
"""Register built-in hooks."""
self._pre_hooks["explain"] = self._pre_explain_hook
self._pre_hooks["refactor"] = self._pre_refactor_hook
self._pre_hooks["document"] = self._pre_document_hook
def execute_pre_hook(
self,
action: str,
params: Dict[str, Any],
) -> Optional[MCPContext]:
"""Execute pre-tool hook to gather context.
Args:
action: The action being performed (e.g., "explain", "refactor")
params: Parameters for the action
Returns:
MCPContext to inject into prompt, or None
"""
hook = self._pre_hooks.get(action)
if not hook:
logger.debug(f"No pre-hook for action: {action}")
return None
try:
return hook(params)
except Exception as e:
logger.error(f"Pre-hook failed for {action}: {e}")
return None
def execute_post_hook(
self,
action: str,
result: Any,
) -> None:
"""Execute post-tool hook for proactive caching.
Args:
action: The action that was performed
result: Result of the action
"""
hook = self._post_hooks.get(action)
if not hook:
return
try:
hook(result)
except Exception as e:
logger.error(f"Post-hook failed for {action}: {e}")
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'explain' action."""
symbol_name = params.get("symbol")
if not symbol_name:
return None
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="symbol_explanation",
include_references=True,
include_related=True,
)
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'refactor' action."""
symbol_name = params.get("symbol")
if not symbol_name:
return None
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="refactor_context",
include_references=True,
include_related=True,
max_references=20,
)
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'document' action."""
symbol_name = params.get("symbol")
file_path = params.get("file_path")
if symbol_name:
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="documentation_context",
include_references=False,
include_related=True,
)
elif file_path:
return self.mcp_provider.build_context_for_file(
Path(file_path),
context_type="file_documentation",
)
return None
def register_pre_hook(
self,
action: str,
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
) -> None:
"""Register a custom pre-tool hook."""
self._pre_hooks[action] = hook
def register_post_hook(
self,
action: str,
hook: Callable[[Any], None],
) -> None:
"""Register a custom post-tool hook."""
self._post_hooks[action] = hook
def create_context_for_prompt(
mcp_provider: "MCPProvider",
action: str,
params: Dict[str, Any],
) -> str:
"""Create context string for prompt injection.
This is the main entry point for Claude Code hook integration.
Args:
mcp_provider: The MCP provider instance
action: Action being performed
params: Action parameters
Returns:
Formatted context string for prompt injection
"""
manager = HookManager(mcp_provider)
context = manager.execute_pre_hook(action, params)
if context:
return context.to_prompt_injection()
return ""

View File

@@ -0,0 +1,202 @@
"""MCP context provider."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional, List, TYPE_CHECKING
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
if TYPE_CHECKING:
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.registry import RegistryStore
from codexlens.search.chain_search import ChainSearchEngine
logger = logging.getLogger(__name__)
class MCPProvider:
"""Builds MCP context objects from codex-lens data."""
def __init__(
self,
global_index: "GlobalSymbolIndex",
search_engine: "ChainSearchEngine",
registry: "RegistryStore",
) -> None:
self.global_index = global_index
self.search_engine = search_engine
self.registry = registry
def build_context(
self,
symbol_name: str,
context_type: str = "symbol_explanation",
include_references: bool = True,
include_related: bool = True,
max_references: int = 10,
) -> Optional[MCPContext]:
"""Build comprehensive context for a symbol.
Args:
symbol_name: Name of the symbol to contextualize
context_type: Type of context being requested
include_references: Whether to include reference locations
include_related: Whether to include related symbols
max_references: Maximum number of references to include
Returns:
MCPContext object or None if symbol not found
"""
# Look up symbol
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
if not symbols:
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
return None
symbol = symbols[0]
# Build SymbolInfo
symbol_info = SymbolInfo(
name=symbol.name,
kind=symbol.kind,
file_path=symbol.file or "",
line_start=symbol.range[0],
line_end=symbol.range[1],
signature=None, # Symbol entity doesn't have signature
documentation=None, # Symbol entity doesn't have docstring
)
# Extract definition source code
definition = self._extract_definition(symbol)
# Get references
references = []
if include_references:
refs = self.search_engine.search_references(
symbol_name,
limit=max_references,
)
references = [
ReferenceInfo(
file_path=r.file_path,
line=r.line,
column=r.column,
context=r.context,
relationship_type=r.relationship_type,
)
for r in refs
]
# Get related symbols
related_symbols = []
if include_related:
related_symbols = self._get_related_symbols(symbol)
return MCPContext(
context_type=context_type,
symbol=symbol_info,
definition=definition,
references=references,
related_symbols=related_symbols,
metadata={
"source": "codex-lens",
},
)
def _extract_definition(self, symbol) -> Optional[str]:
"""Extract source code for symbol definition."""
try:
file_path = Path(symbol.file) if symbol.file else None
if not file_path or not file_path.exists():
return None
content = file_path.read_text(encoding='utf-8', errors='ignore')
lines = content.split("\n")
start = symbol.range[0] - 1
end = symbol.range[1]
if start >= len(lines):
return None
return "\n".join(lines[start:end])
except Exception as e:
logger.debug(f"Failed to extract definition: {e}")
return None
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
"""Get symbols related to the given symbol."""
related = []
try:
# Search for symbols that might be related by name patterns
# This is a simplified implementation - could be enhanced with relationship data
# Look for imports/callers via reference search
refs = self.search_engine.search_references(symbol.name, limit=20)
seen_names = set()
for ref in refs:
# Extract potential symbol name from context
if ref.relationship_type and ref.relationship_type not in seen_names:
related.append(RelatedSymbol(
name=f"{Path(ref.file_path).stem}",
kind="module",
relationship=ref.relationship_type,
file_path=ref.file_path,
))
seen_names.add(ref.relationship_type)
if len(related) >= 10:
break
except Exception as e:
logger.debug(f"Failed to get related symbols: {e}")
return related
def build_context_for_file(
self,
file_path: Path,
context_type: str = "file_overview",
) -> MCPContext:
"""Build context for an entire file."""
# Try to get symbols by searching with file path
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
symbols = []
# Search for common symbols that might be in this file
# This is a simplified approach - a full implementation would query by file path
try:
# Use the global index to search for symbols from this file
file_str = str(file_path.resolve())
# Get all symbols and filter by file path (not efficient but works)
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
except Exception as e:
logger.debug(f"Failed to get file symbols: {e}")
related = [
RelatedSymbol(
name=s.name,
kind=s.kind,
relationship="defines",
)
for s in symbols
]
return MCPContext(
context_type=context_type,
related_symbols=related,
metadata={
"file_path": str(file_path),
"symbol_count": len(symbols),
},
)

View File

@@ -0,0 +1,113 @@
"""MCP data models."""
from __future__ import annotations
import json
from dataclasses import dataclass, field, asdict
from typing import List, Optional
@dataclass
class SymbolInfo:
"""Information about a code symbol."""
name: str
kind: str
file_path: str
line_start: int
line_end: int
signature: Optional[str] = None
documentation: Optional[str] = None
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class ReferenceInfo:
"""Information about a symbol reference."""
file_path: str
line: int
column: int
context: str
relationship_type: str
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class RelatedSymbol:
"""Related symbol (import, call target, etc.)."""
name: str
kind: str
relationship: str # "imports", "calls", "inherits", "uses"
file_path: Optional[str] = None
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class MCPContext:
"""Model Context Protocol context object.
This is the structured context that gets injected into
LLM prompts to provide code understanding.
"""
version: str = "1.0"
context_type: str = "code_context"
symbol: Optional[SymbolInfo] = None
definition: Optional[str] = None
references: List[ReferenceInfo] = field(default_factory=list)
related_symbols: List[RelatedSymbol] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
result = {
"version": self.version,
"context_type": self.context_type,
"metadata": self.metadata,
}
if self.symbol:
result["symbol"] = self.symbol.to_dict()
if self.definition:
result["definition"] = self.definition
if self.references:
result["references"] = [r.to_dict() for r in self.references]
if self.related_symbols:
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
return result
def to_json(self, indent: int = 2) -> str:
"""Serialize to JSON string."""
return json.dumps(self.to_dict(), indent=indent)
def to_prompt_injection(self) -> str:
"""Format for injection into LLM prompt."""
parts = ["<code_context>"]
if self.symbol:
parts.append(f"## Symbol: {self.symbol.name}")
parts.append(f"Type: {self.symbol.kind}")
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
if self.definition:
parts.append("\n## Definition")
parts.append(f"```\n{self.definition}\n```")
if self.references:
parts.append(f"\n## References ({len(self.references)} found)")
for ref in self.references[:5]: # Limit to 5
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
parts.append(f" ```\n {ref.context}\n ```")
if self.related_symbols:
parts.append("\n## Related Symbols")
for sym in self.related_symbols[:10]: # Limit to 10
parts.append(f"- {sym.name} ({sym.relationship})")
parts.append("</code_context>")
return "\n".join(parts)

View File

@@ -0,0 +1,8 @@
"""Parsers for CodexLens."""
from __future__ import annotations
from .factory import ParserFactory
__all__ = ["ParserFactory"]

View File

@@ -0,0 +1,202 @@
"""Optional encoding detection module for CodexLens.
Provides automatic encoding detection with graceful fallback to UTF-8.
Install with: pip install codexlens[encoding]
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Tuple, Optional
log = logging.getLogger(__name__)
# Feature flag for encoding detection availability
ENCODING_DETECTION_AVAILABLE = False
_import_error: Optional[str] = None
def _detect_chardet_backend() -> Tuple[bool, Optional[str]]:
"""Detect if chardet or charset-normalizer is available."""
try:
import chardet
return True, None
except ImportError:
pass
try:
from charset_normalizer import from_bytes
return True, None
except ImportError:
pass
return False, "chardet not available. Install with: pip install codexlens[encoding]"
# Initialize on module load
ENCODING_DETECTION_AVAILABLE, _import_error = _detect_chardet_backend()
def check_encoding_available() -> Tuple[bool, Optional[str]]:
"""Check if encoding detection dependencies are available.
Returns:
Tuple of (available, error_message)
"""
return ENCODING_DETECTION_AVAILABLE, _import_error
def detect_encoding(content_bytes: bytes, confidence_threshold: float = 0.7) -> str:
"""Detect encoding from file content bytes.
Uses chardet or charset-normalizer with configurable confidence threshold.
Falls back to UTF-8 if confidence is too low or detection unavailable.
Args:
content_bytes: Raw file content as bytes
confidence_threshold: Minimum confidence (0.0-1.0) to accept detection
Returns:
Detected encoding name (e.g., 'utf-8', 'iso-8859-1', 'gbk')
Returns 'utf-8' as fallback if detection fails or confidence too low
"""
if not ENCODING_DETECTION_AVAILABLE:
log.debug("Encoding detection not available, using UTF-8 fallback")
return "utf-8"
if not content_bytes:
return "utf-8"
try:
# Try chardet first
try:
import chardet
result = chardet.detect(content_bytes)
encoding = result.get("encoding")
confidence = result.get("confidence", 0.0)
if encoding and confidence >= confidence_threshold:
log.debug(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
# Normalize encoding name: replace underscores with hyphens
return encoding.lower().replace('_', '-')
else:
log.debug(
f"Low confidence encoding detection: {encoding} "
f"(confidence: {confidence:.2f}), using UTF-8 fallback"
)
return "utf-8"
except ImportError:
pass
# Fallback to charset-normalizer
try:
from charset_normalizer import from_bytes
results = from_bytes(content_bytes)
if results:
best = results.best()
if best and best.encoding:
log.debug(f"Detected encoding via charset-normalizer: {best.encoding}")
# Normalize encoding name: replace underscores with hyphens
return best.encoding.lower().replace('_', '-')
except ImportError:
pass
except Exception as e:
log.warning(f"Encoding detection failed: {e}, using UTF-8 fallback")
return "utf-8"
def read_file_safe(
path: Path | str,
confidence_threshold: float = 0.7,
max_detection_bytes: int = 100_000
) -> Tuple[str, str]:
"""Read file with automatic encoding detection and safe decoding.
Reads file bytes, detects encoding, and decodes with error replacement
to preserve file structure even with encoding issues.
Args:
path: Path to file to read
confidence_threshold: Minimum confidence for encoding detection
max_detection_bytes: Maximum bytes to use for encoding detection (default 100KB)
Returns:
Tuple of (content, detected_encoding)
- content: Decoded file content (with <20> for unmappable bytes)
- detected_encoding: Detected encoding name
Raises:
OSError: If file cannot be read
IsADirectoryError: If path is a directory
"""
file_path = Path(path) if isinstance(path, str) else path
# Read file bytes
try:
content_bytes = file_path.read_bytes()
except Exception as e:
log.error(f"Failed to read file {file_path}: {e}")
raise
# Detect encoding from first N bytes for performance
detection_sample = content_bytes[:max_detection_bytes] if len(content_bytes) > max_detection_bytes else content_bytes
encoding = detect_encoding(detection_sample, confidence_threshold)
# Decode with error replacement to preserve structure
try:
content = content_bytes.decode(encoding, errors='replace')
log.debug(f"Successfully decoded {file_path} using {encoding}")
return content, encoding
except Exception as e:
# Final fallback to UTF-8 with replacement
log.warning(f"Failed to decode {file_path} with {encoding}, using UTF-8: {e}")
content = content_bytes.decode('utf-8', errors='replace')
return content, 'utf-8'
def is_binary_file(path: Path | str, sample_size: int = 8192) -> bool:
"""Check if file is likely binary by sampling first bytes.
Uses heuristic: if >30% of sample bytes are null or non-text, consider binary.
Args:
path: Path to file to check
sample_size: Number of bytes to sample (default 8KB)
Returns:
True if file appears to be binary, False otherwise
"""
file_path = Path(path) if isinstance(path, str) else path
try:
with file_path.open('rb') as f:
sample = f.read(sample_size)
if not sample:
return False
# Count null bytes and non-printable characters
null_count = sample.count(b'\x00')
non_text_count = sum(1 for byte in sample if byte < 0x20 and byte not in (0x09, 0x0a, 0x0d))
# If >30% null bytes or >50% non-text, consider binary
null_ratio = null_count / len(sample)
non_text_ratio = non_text_count / len(sample)
return null_ratio > 0.3 or non_text_ratio > 0.5
except Exception as e:
log.debug(f"Binary check failed for {file_path}: {e}, assuming text")
return False
__all__ = [
"ENCODING_DETECTION_AVAILABLE",
"check_encoding_available",
"detect_encoding",
"read_file_safe",
"is_binary_file",
]

View File

@@ -0,0 +1,385 @@
"""Parser factory for CodexLens.
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
available. Regex fallbacks are retained to preserve the existing parser
interface and behavior in minimal environments.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Protocol
from codexlens.config import Config
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class Parser(Protocol):
def parse(self, text: str, path: Path) -> IndexedFile: ...
@dataclass
class SimpleRegexParser:
language_id: str
def parse(self, text: str, path: Path) -> IndexedFile:
# Try tree-sitter first for supported languages
if self.language_id in {"python", "javascript", "typescript"}:
ts_parser = TreeSitterSymbolParser(self.language_id, path)
if ts_parser.is_available():
indexed = ts_parser.parse(text, path)
if indexed is not None:
return indexed
# Fallback to regex parsing
if self.language_id == "python":
symbols = _parse_python_symbols_regex(text)
relationships = _parse_python_relationships_regex(text, path)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols_regex(text)
relationships = _parse_js_ts_relationships_regex(text, path)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
relationships = []
elif self.language_id == "go":
symbols = _parse_go_symbols(text)
relationships = []
elif self.language_id == "markdown":
symbols = _parse_markdown_symbols(text)
relationships = []
elif self.language_id == "text":
symbols = _parse_text_symbols(text)
relationships = []
else:
symbols = _parse_generic_symbols(text)
relationships = []
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
class ParserFactory:
def __init__(self, config: Config) -> None:
self.config = config
self._parsers: Dict[str, Parser] = {}
def get_parser(self, language_id: str) -> Parser:
if language_id not in self._parsers:
self._parsers[language_id] = SimpleRegexParser(language_id)
return self._parsers[language_id]
# Regex-based fallback parsers
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_python_symbols(text: str) -> List[Symbol]:
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser("python")
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Optional[Path] = None,
) -> List[Symbol]:
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser(language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
current_class_indent: Optional[int] = None
for i, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_class_indent = len(line) - len(line.lstrip(" "))
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
indent = len(line) - len(line.lstrip(" "))
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
continue
if current_class_indent is not None:
indent = len(line) - len(line.lstrip(" "))
if line.strip() and indent <= current_class_indent:
current_class_indent = None
return symbols
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
current_scope = def_match.group(1)
continue
if current_scope is None:
continue
import_match = _PY_IMPORT_RE.search(line)
if import_match:
import_target = import_match.group(1) or import_match.group(2)
if import_target:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_target.strip(),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _PY_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {
"if",
"for",
"while",
"return",
"print",
"len",
"str",
"int",
"float",
"list",
"dict",
"set",
"tuple",
current_scope,
}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
)
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
in_class = False
class_brace_depth = 0
brace_depth = 0
for i, line in enumerate(text.splitlines(), start=1):
brace_depth += line.count("{") - line.count("}")
class_match = _JS_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
in_class = True
class_brace_depth = brace_depth
continue
if in_class and brace_depth < class_brace_depth:
in_class = False
func_match = _JS_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
continue
if in_class:
method_match = _JS_METHOD_RE.match(line)
if method_match:
name = method_match.group(1)
if name != "constructor":
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
return symbols
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _JS_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
func_match = _JS_FUNC_RE.match(line)
if func_match:
current_scope = func_match.group(1)
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
current_scope = arrow_match.group(1)
continue
if current_scope is None:
continue
import_match = _JS_IMPORT_RE.search(line)
if import_match:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_match.group(1),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _JS_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {current_scope}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
)
def _parse_java_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _JAVA_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
method_match = _JAVA_METHOD_RE.match(line)
if method_match:
symbols.append(Symbol(name=method_match.group(1), kind="method", range=(i, i)))
return symbols
_GO_FUNC_RE = re.compile(r"^\s*func\s+(?:\([^)]+\)\s+)?([A-Za-z_]\w*)\s*\(")
_GO_TYPE_RE = re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+(?:struct|interface)\b")
def _parse_go_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
type_match = _GO_TYPE_RE.match(line)
if type_match:
symbols.append(Symbol(name=type_match.group(1), kind="class", range=(i, i)))
continue
func_match = _GO_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
return symbols
_GENERIC_DEF_RE = re.compile(r"^\s*(?:def|function|func)\s+([A-Za-z_]\w*)\b")
_GENERIC_CLASS_RE = re.compile(r"^\s*(?:class|struct|interface)\s+([A-Za-z_]\w*)\b")
def _parse_generic_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _GENERIC_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _GENERIC_DEF_RE.match(line)
if def_match:
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
return symbols
# Markdown heading regex: # Heading, ## Heading, etc.
_MD_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$")
def _parse_markdown_symbols(text: str) -> List[Symbol]:
"""Parse Markdown headings as symbols.
Extracts # headings as 'section' symbols with heading level as kind suffix.
"""
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
heading_match = _MD_HEADING_RE.match(line)
if heading_match:
level = len(heading_match.group(1))
title = heading_match.group(2).strip()
# Use 'section' kind with level indicator
kind = f"h{level}"
symbols.append(Symbol(name=title, kind=kind, range=(i, i)))
return symbols
def _parse_text_symbols(text: str) -> List[Symbol]:
"""Parse plain text files - no symbols, just index content."""
# Text files don't have structured symbols, return empty list
# The file content will still be indexed for FTS search
return []

View File

@@ -0,0 +1,98 @@
"""Token counting utilities for CodexLens.
Provides accurate token counting using tiktoken with character count fallback.
"""
from __future__ import annotations
from typing import Optional
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
class Tokenizer:
"""Token counter with tiktoken primary and character count fallback."""
def __init__(self, encoding_name: str = "cl100k_base") -> None:
"""Initialize tokenizer.
Args:
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
"""
self._encoding: Optional[object] = None
self._encoding_name = encoding_name
if TIKTOKEN_AVAILABLE:
try:
self._encoding = tiktoken.get_encoding(encoding_name)
except Exception:
# Fallback to character counting if encoding fails
self._encoding = None
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Uses tiktoken if available, otherwise falls back to character count / 4.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if not text:
return 0
if self._encoding is not None:
try:
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
except Exception:
# Fall through to character count fallback
pass
# Fallback: rough estimate using character count
# Average of ~4 characters per token for English text
return max(1, len(text) // 4)
def is_using_tiktoken(self) -> bool:
"""Check if tiktoken is being used.
Returns:
True if tiktoken is available and initialized
"""
return self._encoding is not None
# Global default tokenizer instance
_default_tokenizer: Optional[Tokenizer] = None
def get_default_tokenizer() -> Tokenizer:
"""Get the global default tokenizer instance.
Returns:
Shared Tokenizer instance
"""
global _default_tokenizer
if _default_tokenizer is None:
_default_tokenizer = Tokenizer()
return _default_tokenizer
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
"""Count tokens in text using default or provided tokenizer.
Args:
text: Text to count tokens for
tokenizer: Optional tokenizer instance (uses default if None)
Returns:
Estimated token count
"""
if tokenizer is None:
tokenizer = get_default_tokenizer()
return tokenizer.count_tokens(text)

View File

@@ -0,0 +1,809 @@
"""Tree-sitter based parser for CodexLens.
Provides precise AST-level parsing via tree-sitter.
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
return `None`; callers should use a regex-based fallback such as
`codexlens.parsers.factory.SimpleRegexParser`.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
class TreeSitterSymbolParser:
"""Parser using tree-sitter for AST-level symbol extraction."""
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize tree-sitter parser for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
path: Optional file path for language variant detection (e.g., .tsx)
"""
self.language_id = language_id
self.path = path
self._parser: Optional[object] = None
self._language: Optional[TreeSitterLanguage] = None
self._tokenizer = get_default_tokenizer()
if TREE_SITTER_AVAILABLE:
self._initialize_parser()
def _initialize_parser(self) -> None:
"""Initialize tree-sitter parser and language."""
if TreeSitterParser is None or TreeSitterLanguage is None:
return
try:
# Load language grammar
if self.language_id == "python":
import tree_sitter_python
self._language = TreeSitterLanguage(tree_sitter_python.language())
elif self.language_id == "javascript":
import tree_sitter_javascript
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
elif self.language_id == "typescript":
import tree_sitter_typescript
# Detect TSX files by extension
if self.path is not None and self.path.suffix.lower() == ".tsx":
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
else:
return
# Create parser
self._parser = TreeSitterParser()
if hasattr(self._parser, "set_language"):
self._parser.set_language(self._language) # type: ignore[attr-defined]
else:
self._parser.language = self._language # type: ignore[assignment]
except Exception:
# Gracefully handle missing language bindings
self._parser = None
self._language = None
def is_available(self) -> bool:
"""Check if tree-sitter parser is available.
Returns:
True if parser is initialized and ready
"""
return self._parser is not None and self._language is not None
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
if not self.is_available() or self._parser is None:
return None
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
return source_bytes, tree.root_node
except Exception:
return None
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
"""Parse source code and extract symbols without creating IndexedFile.
Args:
text: Source code text
Returns:
List of symbols if parsing succeeds, None if tree-sitter unavailable
"""
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
return self._extract_symbols(source_bytes, root)
except Exception:
# Gracefully handle extraction errors
return None
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse source code and extract symbols.
Args:
text: Source code text
path: File path
Returns:
IndexedFile if parsing succeeds, None if tree-sitter unavailable
"""
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
symbols = self._extract_symbols(source_bytes, root)
relationships = self._extract_relationships(source_bytes, root, path)
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
except Exception:
# Gracefully handle parsing errors
return None
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of extracted symbols
"""
if self.language_id == "python":
return self._extract_python_symbols(source_bytes, root)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_symbols(source_bytes, root)
else:
return []
def _extract_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, path)
if self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, path)
return []
def _extract_python_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"self", "cls"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "class_definition" and pushed_scope:
superclasses = node.child_by_field_name("superclasses")
if superclasses is not None:
for child in superclasses.children:
dotted = self._python_expression_to_dotted(source_bytes, child)
if not dotted:
continue
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type in {"import_statement", "import_from_statement"}:
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
if node.type == "call":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _extract_js_ts_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"this", "super"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if pushed_scope:
superclass = node.child_by_field_name("superclass")
if superclass is not None:
dotted = self._js_expression_to_dotted(source_bytes, superclass)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type in {"identifier", "property_identifier"}
and value_node.type == "arrow_function"
):
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name and scope_name != "constructor":
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"import_declaration", "import_statement"}:
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
# Best-effort support for CommonJS require() imports:
# const fs = require("fs")
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type == "identifier"
and value_node.type == "call_expression"
):
callee = value_node.child_by_field_name("function")
args = value_node.child_by_field_name("arguments")
if (
callee is not None
and self._node_text(source_bytes, callee).strip() == "require"
and args is not None
):
module_name = self._js_first_string_argument(source_bytes, args)
if module_name:
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
record_import(module_name, self._node_start_line(node))
if node.type == "call_expression":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _node_start_line(self, node: TreeSitterNode) -> int:
return node.start_point[0] + 1
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
dotted = (dotted or "").strip()
if not dotted:
return ""
base, sep, rest = dotted.partition(".")
resolved_base = aliases.get(base, base)
if not rest:
return resolved_base
if resolved_base and rest:
return f"{resolved_base}.{rest}"
return resolved_base
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"identifier", "dotted_name"}:
return self._node_text(source_bytes, node).strip()
if node.type == "attribute":
obj = node.child_by_field_name("object")
attr = node.child_by_field_name("attribute")
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
if obj_text and attr_text:
return f"{obj_text}.{attr_text}"
return obj_text or attr_text
return ""
def _python_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
if node.type == "import_statement":
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
module_name = self._node_text(source_bytes, name_node).strip()
if not module_name:
continue
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else module_name.split(".", 1)[0]
)
if bound_name:
aliases[bound_name] = module_name
targets.append(module_name)
elif child.type == "dotted_name":
module_name = self._node_text(source_bytes, child).strip()
if not module_name:
continue
bound_name = module_name.split(".", 1)[0]
if bound_name:
aliases[bound_name] = bound_name
targets.append(module_name)
if node.type == "import_from_statement":
module_name = ""
module_node = node.child_by_field_name("module_name")
if module_node is None:
for child in node.children:
if child.type == "dotted_name":
module_node = child
break
if module_node is not None:
module_name = self._node_text(source_bytes, module_node).strip()
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
imported_name = self._node_text(source_bytes, name_node).strip()
if not imported_name or imported_name == "*":
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported_name
)
if bound_name:
aliases[bound_name] = target
targets.append(target)
elif child.type == "identifier":
imported_name = self._node_text(source_bytes, child).strip()
if not imported_name or imported_name in {"from", "import", "*"}:
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
aliases[imported_name] = target
targets.append(target)
return aliases, targets
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"this", "super"}:
return node.type
if node.type in {"identifier", "property_identifier"}:
return self._node_text(source_bytes, node).strip()
if node.type == "member_expression":
obj = node.child_by_field_name("object")
prop = node.child_by_field_name("property")
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
if obj_text and prop_text:
return f"{obj_text}.{prop_text}"
return obj_text or prop_text
return ""
def _js_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
module_name = ""
source_node = node.child_by_field_name("source")
if source_node is not None:
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
if module_name:
targets.append(module_name)
for child in node.children:
if child.type == "import_clause":
for clause_child in child.children:
if clause_child.type == "identifier":
# Default import: import React from "react"
local = self._node_text(source_bytes, clause_child).strip()
if local and module_name:
aliases[local] = module_name
if clause_child.type == "namespace_import":
# Namespace import: import * as fs from "fs"
name_node = clause_child.child_by_field_name("name")
if name_node is not None and module_name:
local = self._node_text(source_bytes, name_node).strip()
if local:
aliases[local] = module_name
if clause_child.type == "named_imports":
for spec in clause_child.children:
if spec.type != "import_specifier":
continue
name_node = spec.child_by_field_name("name")
alias_node = spec.child_by_field_name("alias")
if name_node is None:
continue
imported = self._node_text(source_bytes, name_node).strip()
if not imported:
continue
local = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported
)
if local and module_name:
aliases[local] = f"{module_name}.{imported}"
targets.append(f"{module_name}.{imported}")
return aliases, targets
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
for child in args_node.children:
if child.type == "string":
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
return ""
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract Python symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of Python symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind=self._python_function_kind(node),
range=self._node_range(node),
))
return symbols
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract JavaScript/TypeScript symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of JS/TS symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = self._node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=self._node_range(node),
))
return symbols
def _python_function_kind(self, node: TreeSitterNode) -> str:
"""Determine if Python function is a method or standalone function.
Args:
node: Function definition node
Returns:
'method' if inside a class, 'function' otherwise
"""
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
"""Check if node has a class ancestor.
Args:
node: AST node to check
Returns:
True if node is inside a class
"""
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
"""Get line range for a node.
Args:
node: AST node
Returns:
(start_line, end_line) tuple, 1-based inclusive
"""
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Args:
text: Text to count tokens for
Returns:
Token count
"""
return self._tokenizer.count_tokens(text)

View File

@@ -0,0 +1,53 @@
from .chain_search import (
ChainSearchEngine,
SearchOptions,
SearchStats,
ChainSearchResult,
quick_search,
)
# Clustering availability flag (lazy import pattern)
CLUSTERING_AVAILABLE = False
_clustering_import_error: str | None = None
try:
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
from .clustering import check_clustering_available
CLUSTERING_AVAILABLE = _clustering_flag
except ImportError as e:
_clustering_import_error = str(e)
def check_clustering_available() -> tuple[bool, str | None]:
"""Fallback when clustering module not loadable."""
return False, _clustering_import_error
# Clustering module exports (conditional)
try:
from .clustering import (
BaseClusteringStrategy,
ClusteringConfig,
ClusteringStrategyFactory,
get_strategy,
)
_clustering_exports = [
"BaseClusteringStrategy",
"ClusteringConfig",
"ClusteringStrategyFactory",
"get_strategy",
]
except ImportError:
_clustering_exports = []
__all__ = [
"ChainSearchEngine",
"SearchOptions",
"SearchStats",
"ChainSearchResult",
"quick_search",
# Clustering
"CLUSTERING_AVAILABLE",
"check_clustering_available",
*_clustering_exports,
]

View File

@@ -0,0 +1,21 @@
"""Association tree module for LSP-based code relationship discovery.
This module provides components for building and processing call association trees
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from .builder import AssociationTreeBuilder
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
from .deduplicator import ResultDeduplicator
__all__ = [
"AssociationTreeBuilder",
"CallTree",
"TreeNode",
"UniqueNode",
"ResultDeduplicator",
]

View File

@@ -0,0 +1,450 @@
"""Association tree builder using LSP call hierarchy.
Builds call relationship trees by recursively expanding from seed locations
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Dict, List, Optional, Set
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
from codexlens.lsp.standalone_manager import StandaloneLspManager
from .data_structures import CallTree, TreeNode
logger = logging.getLogger(__name__)
class AssociationTreeBuilder:
"""Builds association trees from seed locations using LSP call hierarchy.
Uses depth-first recursive expansion to build a tree of code relationships
starting from seed locations (typically from vector search results).
Strategy:
- Start from seed locations (vector search results)
- For each seed, get call hierarchy items via LSP
- Recursively expand incoming calls (callers) if expand_callers=True
- Recursively expand outgoing calls (callees) if expand_callees=True
- Track visited nodes to prevent cycles
- Stop at max_depth or when no more relations found
Attributes:
lsp_manager: StandaloneLspManager for LSP communication
visited: Set of visited node IDs to prevent cycles
timeout: Timeout for individual LSP requests (seconds)
"""
def __init__(
self,
lsp_manager: StandaloneLspManager,
timeout: float = 5.0,
analysis_wait: float = 2.0,
):
"""Initialize AssociationTreeBuilder.
Args:
lsp_manager: StandaloneLspManager instance for LSP communication
timeout: Timeout for individual LSP requests in seconds
analysis_wait: Time to wait for LSP analysis on first file (seconds)
"""
self.lsp_manager = lsp_manager
self.timeout = timeout
self.analysis_wait = analysis_wait
self.visited: Set[str] = set()
self._analyzed_files: Set[str] = set() # Track files already analyzed
async def build_tree(
self,
seed_file_path: str,
seed_line: int,
seed_character: int = 1,
max_depth: int = 5,
expand_callers: bool = True,
expand_callees: bool = True,
) -> CallTree:
"""Build call tree from a single seed location.
Args:
seed_file_path: Path to the seed file
seed_line: Line number of the seed symbol (1-based)
seed_character: Character position (1-based, default 1)
max_depth: Maximum recursion depth (default 5)
expand_callers: Whether to expand incoming calls (callers)
expand_callees: Whether to expand outgoing calls (callees)
Returns:
CallTree containing all discovered nodes and relationships
"""
tree = CallTree()
self.visited.clear()
# Determine wait time - only wait for analysis on first encounter of file
wait_time = 0.0
if seed_file_path not in self._analyzed_files:
wait_time = self.analysis_wait
self._analyzed_files.add(seed_file_path)
# Get call hierarchy items for the seed position
try:
hierarchy_items = await asyncio.wait_for(
self.lsp_manager.get_call_hierarchy_items(
file_path=seed_file_path,
line=seed_line,
character=seed_character,
wait_for_analysis=wait_time,
),
timeout=self.timeout + wait_time,
)
except asyncio.TimeoutError:
logger.warning(
"Timeout getting call hierarchy items for %s:%d",
seed_file_path,
seed_line,
)
return tree
except Exception as e:
logger.error(
"Error getting call hierarchy items for %s:%d: %s",
seed_file_path,
seed_line,
e,
)
return tree
if not hierarchy_items:
logger.debug(
"No call hierarchy items found for %s:%d",
seed_file_path,
seed_line,
)
return tree
# Create root nodes from hierarchy items
for item_dict in hierarchy_items:
# Convert LSP dict to CallHierarchyItem
item = self._dict_to_call_hierarchy_item(item_dict)
if not item:
continue
root_node = TreeNode(
item=item,
depth=0,
path_from_root=[self._create_node_id(item)],
)
tree.roots.append(root_node)
tree.add_node(root_node)
# Mark as visited
self.visited.add(root_node.node_id)
# Recursively expand the tree
await self._expand_node(
node=root_node,
node_dict=item_dict,
tree=tree,
current_depth=0,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
tree.depth_reached = max_depth
return tree
async def _expand_node(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Recursively expand a node by fetching its callers and callees.
Args:
node: TreeNode to expand
node_dict: LSP CallHierarchyItem dict (for LSP requests)
tree: CallTree to add discovered nodes to
current_depth: Current recursion depth
max_depth: Maximum allowed depth
expand_callers: Whether to expand incoming calls
expand_callees: Whether to expand outgoing calls
"""
# Stop if max depth reached
if current_depth >= max_depth:
return
# Prepare tasks for parallel expansion
tasks = []
if expand_callers:
tasks.append(
self._expand_incoming_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
if expand_callees:
tasks.append(
self._expand_outgoing_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
# Execute expansions in parallel
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _expand_incoming_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand incoming calls (callers) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to continue expanding callers
expand_callees: Whether to expand callees
"""
try:
incoming_calls = await asyncio.wait_for(
self.lsp_manager.get_incoming_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting incoming calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting incoming calls for %s: %s", node.node_id, e)
return
if not incoming_calls:
return
# Process each incoming call
for call_dict in incoming_calls:
caller_dict = call_dict.get("from")
if not caller_dict:
continue
# Convert to CallHierarchyItem
caller_item = self._dict_to_call_hierarchy_item(caller_dict)
if not caller_item:
continue
caller_id = self._create_node_id(caller_item)
# Check for cycles
if caller_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [caller_id],
)
node.parents.append(cycle_node)
continue
# Create new caller node
caller_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [caller_id],
)
# Add to tree
tree.add_node(caller_node)
tree.add_edge(caller_node, node)
# Update relationships
node.parents.append(caller_node)
caller_node.children.append(node)
# Mark as visited
self.visited.add(caller_id)
# Recursively expand the caller
await self._expand_node(
node=caller_node,
node_dict=caller_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
async def _expand_outgoing_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand outgoing calls (callees) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to expand callers
expand_callees: Whether to continue expanding callees
"""
try:
outgoing_calls = await asyncio.wait_for(
self.lsp_manager.get_outgoing_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting outgoing calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting outgoing calls for %s: %s", node.node_id, e)
return
if not outgoing_calls:
return
# Process each outgoing call
for call_dict in outgoing_calls:
callee_dict = call_dict.get("to")
if not callee_dict:
continue
# Convert to CallHierarchyItem
callee_item = self._dict_to_call_hierarchy_item(callee_dict)
if not callee_item:
continue
callee_id = self._create_node_id(callee_item)
# Check for cycles
if callee_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [callee_id],
)
node.children.append(cycle_node)
continue
# Create new callee node
callee_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [callee_id],
)
# Add to tree
tree.add_node(callee_node)
tree.add_edge(node, callee_node)
# Update relationships
node.children.append(callee_node)
callee_node.parents.append(node)
# Mark as visited
self.visited.add(callee_id)
# Recursively expand the callee
await self._expand_node(
node=callee_node,
node_dict=callee_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
def _dict_to_call_hierarchy_item(
self, item_dict: Dict
) -> Optional[CallHierarchyItem]:
"""Convert LSP dict to CallHierarchyItem.
Args:
item_dict: LSP CallHierarchyItem dictionary
Returns:
CallHierarchyItem or None if conversion fails
"""
try:
# Extract URI and convert to file path
uri = item_dict.get("uri", "")
file_path = uri.replace("file:///", "").replace("file://", "")
# Handle Windows paths (file:///C:/...)
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
# Extract range
range_dict = item_dict.get("range", {})
start = range_dict.get("start", {})
end = range_dict.get("end", {})
# Create Range (convert from 0-based to 1-based)
item_range = Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
)
return CallHierarchyItem(
name=item_dict.get("name", "unknown"),
kind=str(item_dict.get("kind", "unknown")),
file_path=file_path,
range=item_range,
detail=item_dict.get("detail"),
)
except Exception as e:
logger.debug("Failed to convert dict to CallHierarchyItem: %s", e)
return None
def _create_node_id(self, item: CallHierarchyItem) -> str:
"""Create unique node ID from CallHierarchyItem.
Args:
item: CallHierarchyItem
Returns:
Unique node ID string
"""
return f"{item.file_path}:{item.name}:{item.range.start_line}"

View File

@@ -0,0 +1,191 @@
"""Data structures for association tree building.
Defines the core data classes for representing call hierarchy trees and
deduplicated results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
@dataclass
class TreeNode:
"""Node in the call association tree.
Represents a single function/method in the tree, including its position
in the hierarchy and relationships.
Attributes:
item: LSP CallHierarchyItem containing symbol information
depth: Distance from the root node (seed) - 0 for roots
children: List of child nodes (functions called by this node)
parents: List of parent nodes (functions that call this node)
is_cycle: Whether this node creates a circular reference
path_from_root: Path (list of node IDs) from root to this node
"""
item: CallHierarchyItem
depth: int = 0
children: List[TreeNode] = field(default_factory=list)
parents: List[TreeNode] = field(default_factory=list)
is_cycle: bool = False
path_from_root: List[str] = field(default_factory=list)
@property
def node_id(self) -> str:
"""Unique identifier for this node."""
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
def __hash__(self) -> int:
"""Hash based on node ID."""
return hash(self.node_id)
def __eq__(self, other: object) -> bool:
"""Equality based on node ID."""
if not isinstance(other, TreeNode):
return False
return self.node_id == other.node_id
def __repr__(self) -> str:
"""String representation of the node."""
cycle_marker = " [CYCLE]" if self.is_cycle else ""
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
@dataclass
class CallTree:
"""Complete call tree structure built from seeds.
Contains all nodes discovered through recursive expansion and
the relationships between them.
Attributes:
roots: List of root nodes (seed symbols)
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
node_list: Flat list of all nodes in tree order
edges: List of (from_node_id, to_node_id) tuples representing calls
depth_reached: Maximum depth achieved in expansion
"""
roots: List[TreeNode] = field(default_factory=list)
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
node_list: List[TreeNode] = field(default_factory=list)
edges: List[tuple[str, str]] = field(default_factory=list)
depth_reached: int = 0
def add_node(self, node: TreeNode) -> None:
"""Add a node to the tree.
Args:
node: TreeNode to add
"""
if node.node_id not in self.all_nodes:
self.all_nodes[node.node_id] = node
self.node_list.append(node)
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
"""Add an edge between two nodes.
Args:
from_node: Source node
to_node: Target node
"""
edge = (from_node.node_id, to_node.node_id)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[TreeNode]:
"""Get a node by ID.
Args:
node_id: Node identifier
Returns:
TreeNode if found, None otherwise
"""
return self.all_nodes.get(node_id)
def __len__(self) -> int:
"""Return total number of nodes in tree."""
return len(self.all_nodes)
def __repr__(self) -> str:
"""String representation of the tree."""
return (
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
f"depth={self.depth_reached})"
)
@dataclass
class UniqueNode:
"""Deduplicated unique code symbol from the tree.
Represents a single unique code location that may appear multiple times
in the tree under different contexts. Contains aggregated information
about all occurrences.
Attributes:
file_path: Absolute path to the file
name: Symbol name (function, method, class, etc.)
kind: Symbol kind (function, method, class, etc.)
range: Code range in the file
min_depth: Minimum depth at which this node appears in the tree
occurrences: Number of times this node appears in the tree
paths: List of paths from roots to this node
context_nodes: Related nodes from the tree
score: Composite relevance score (higher is better)
"""
file_path: str
name: str
kind: str
range: Range
min_depth: int = 0
occurrences: int = 1
paths: List[List[str]] = field(default_factory=list)
context_nodes: List[str] = field(default_factory=list)
score: float = 0.0
@property
def node_key(self) -> tuple[str, int, int]:
"""Unique key for deduplication.
Uses (file_path, start_line, end_line) as the unique identifier
for this symbol across all occurrences.
"""
return (
self.file_path,
self.range.start_line,
self.range.end_line,
)
def add_path(self, path: List[str]) -> None:
"""Add a path from root to this node.
Args:
path: List of node IDs from root to this node
"""
if path not in self.paths:
self.paths.append(path)
def __hash__(self) -> int:
"""Hash based on node key."""
return hash(self.node_key)
def __eq__(self, other: object) -> bool:
"""Equality based on node key."""
if not isinstance(other, UniqueNode):
return False
return self.node_key == other.node_key
def __repr__(self) -> str:
"""String representation of the unique node."""
return (
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
)

View File

@@ -0,0 +1,301 @@
"""Result deduplication for association tree nodes.
Provides functionality to extract unique nodes from a call tree and assign
relevance scores based on various factors.
"""
from __future__ import annotations
import logging
from typing import Dict, List, Optional
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
logger = logging.getLogger(__name__)
# Symbol kind weights for scoring (higher = more relevant)
KIND_WEIGHTS: Dict[str, float] = {
# Functions and methods are primary targets
"function": 1.0,
"method": 1.0,
"12": 1.0, # LSP SymbolKind.Function
"6": 1.0, # LSP SymbolKind.Method
# Classes are important but secondary
"class": 0.8,
"5": 0.8, # LSP SymbolKind.Class
# Interfaces and types
"interface": 0.7,
"11": 0.7, # LSP SymbolKind.Interface
"type": 0.6,
# Constructors
"constructor": 0.9,
"9": 0.9, # LSP SymbolKind.Constructor
# Variables and constants
"variable": 0.4,
"13": 0.4, # LSP SymbolKind.Variable
"constant": 0.5,
"14": 0.5, # LSP SymbolKind.Constant
# Default for unknown kinds
"unknown": 0.3,
}
class ResultDeduplicator:
"""Extracts and scores unique nodes from call trees.
Processes a CallTree to extract unique code locations, merging duplicates
and assigning relevance scores based on:
- Depth: Shallower nodes (closer to seeds) score higher
- Frequency: Nodes appearing multiple times score higher
- Kind: Function/method > class > variable
Attributes:
depth_weight: Weight for depth factor in scoring (default 0.4)
frequency_weight: Weight for frequency factor (default 0.3)
kind_weight: Weight for symbol kind factor (default 0.3)
max_depth_penalty: Maximum depth before full penalty applied
"""
def __init__(
self,
depth_weight: float = 0.4,
frequency_weight: float = 0.3,
kind_weight: float = 0.3,
max_depth_penalty: int = 10,
):
"""Initialize ResultDeduplicator.
Args:
depth_weight: Weight for depth factor (0.0-1.0)
frequency_weight: Weight for frequency factor (0.0-1.0)
kind_weight: Weight for symbol kind factor (0.0-1.0)
max_depth_penalty: Depth at which score becomes 0 for depth factor
"""
self.depth_weight = depth_weight
self.frequency_weight = frequency_weight
self.kind_weight = kind_weight
self.max_depth_penalty = max_depth_penalty
def deduplicate(
self,
tree: CallTree,
max_results: Optional[int] = None,
) -> List[UniqueNode]:
"""Extract unique nodes from the call tree.
Traverses the tree, groups nodes by their unique key (file_path,
start_line, end_line), and merges duplicate occurrences.
Args:
tree: CallTree to process
max_results: Maximum number of results to return (None = all)
Returns:
List of UniqueNode objects, sorted by score descending
"""
if not tree.node_list:
return []
# Group nodes by unique key
unique_map: Dict[tuple, UniqueNode] = {}
for node in tree.node_list:
if node.is_cycle:
# Skip cycle markers - they point to already-counted nodes
continue
key = self._get_node_key(node)
if key in unique_map:
# Update existing unique node
unique_node = unique_map[key]
unique_node.occurrences += 1
unique_node.min_depth = min(unique_node.min_depth, node.depth)
unique_node.add_path(node.path_from_root)
# Collect context from relationships
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
else:
# Create new unique node
unique_node = UniqueNode(
file_path=node.item.file_path,
name=node.item.name,
kind=node.item.kind,
range=node.item.range,
min_depth=node.depth,
occurrences=1,
paths=[node.path_from_root.copy()],
context_nodes=[],
score=0.0,
)
# Collect initial context
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
unique_map[key] = unique_node
# Calculate scores for all unique nodes
unique_nodes = list(unique_map.values())
# Find max frequency for normalization
max_frequency = max((n.occurrences for n in unique_nodes), default=1)
for node in unique_nodes:
node.score = self._score_node(node, max_frequency)
# Sort by score descending
unique_nodes.sort(key=lambda n: n.score, reverse=True)
# Apply max_results limit
if max_results is not None and max_results > 0:
unique_nodes = unique_nodes[:max_results]
logger.debug(
"Deduplicated %d tree nodes to %d unique nodes",
len(tree.node_list),
len(unique_nodes),
)
return unique_nodes
def _score_node(
self,
node: UniqueNode,
max_frequency: int,
) -> float:
"""Calculate composite score for a unique node.
Score = depth_weight * depth_score +
frequency_weight * frequency_score +
kind_weight * kind_score
Args:
node: UniqueNode to score
max_frequency: Maximum occurrence count for normalization
Returns:
Composite score between 0.0 and 1.0
"""
# Depth score: closer to root = higher score
# Score of 1.0 at depth 0, decreasing to 0.0 at max_depth_penalty
depth_score = max(
0.0,
1.0 - (node.min_depth / self.max_depth_penalty),
)
# Frequency score: more occurrences = higher score
frequency_score = node.occurrences / max_frequency if max_frequency > 0 else 0.0
# Kind score: function/method > class > variable
kind_str = str(node.kind).lower()
kind_score = KIND_WEIGHTS.get(kind_str, KIND_WEIGHTS["unknown"])
# Composite score
score = (
self.depth_weight * depth_score
+ self.frequency_weight * frequency_score
+ self.kind_weight * kind_score
)
return score
def _get_node_key(self, node: TreeNode) -> tuple:
"""Get unique key for a tree node.
Uses (file_path, start_line, end_line) as the unique identifier.
Args:
node: TreeNode
Returns:
Tuple key for deduplication
"""
return (
node.item.file_path,
node.item.range.start_line,
node.item.range.end_line,
)
def filter_by_kind(
self,
nodes: List[UniqueNode],
kinds: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by symbol kind.
Args:
nodes: List of UniqueNode to filter
kinds: List of allowed kinds (e.g., ["function", "method"])
Returns:
Filtered list of UniqueNode
"""
kinds_lower = [k.lower() for k in kinds]
return [
node
for node in nodes
if str(node.kind).lower() in kinds_lower
]
def filter_by_file(
self,
nodes: List[UniqueNode],
file_patterns: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by file path patterns.
Args:
nodes: List of UniqueNode to filter
file_patterns: List of path substrings to match
Returns:
Filtered list of UniqueNode
"""
return [
node
for node in nodes
if any(pattern in node.file_path for pattern in file_patterns)
]
def to_dict_list(self, nodes: List[UniqueNode]) -> List[Dict]:
"""Convert list of UniqueNode to JSON-serializable dicts.
Args:
nodes: List of UniqueNode
Returns:
List of dictionaries
"""
return [
{
"file_path": node.file_path,
"name": node.name,
"kind": node.kind,
"range": {
"start_line": node.range.start_line,
"start_character": node.range.start_character,
"end_line": node.range.end_line,
"end_character": node.range.end_character,
},
"min_depth": node.min_depth,
"occurrences": node.occurrences,
"path_count": len(node.paths),
"score": round(node.score, 4),
}
for node in nodes
]

View File

@@ -0,0 +1,277 @@
"""Binary vector searcher for cascade search.
This module provides fast binary vector search using Hamming distance
for the first stage of cascade search (coarse filtering).
Supports two loading modes:
1. Memory-mapped file (preferred): Low memory footprint, OS-managed paging
2. Database loading (fallback): Loads all vectors into RAM
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
# Pre-computed popcount lookup table for vectorized Hamming distance
# Each byte value (0-255) maps to its bit count
_POPCOUNT_TABLE = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
class BinarySearcher:
"""Fast binary vector search using Hamming distance.
This class implements the first stage of cascade search:
fast, approximate retrieval using binary vectors and Hamming distance.
The binary vectors are derived from dense embeddings by thresholding:
binary[i] = 1 if dense[i] > 0 else 0
Hamming distance between two binary vectors counts the number of
differing bits, which can be computed very efficiently using XOR
and population count.
Supports two loading modes:
- Memory-mapped file (preferred): Uses np.memmap for minimal RAM usage
- Database (fallback): Loads all vectors into memory from SQLite
"""
def __init__(self, index_root_or_meta_path: Path) -> None:
"""Initialize BinarySearcher.
Args:
index_root_or_meta_path: Either:
- Path to index root directory (containing _binary_vectors.mmap)
- Path to _vectors_meta.db (legacy mode, loads from DB)
"""
path = Path(index_root_or_meta_path)
# Determine if this is an index root or a specific DB path
if path.suffix == '.db':
# Legacy mode: specific DB path
self.index_root = path.parent
self.meta_store_path = path
else:
# New mode: index root directory
self.index_root = path
self.meta_store_path = path / "_vectors_meta.db"
self._chunk_ids: Optional[np.ndarray] = None
self._binary_matrix: Optional[np.ndarray] = None
self._is_memmap = False
self._loaded = False
def load(self) -> bool:
"""Load binary vectors using memory-mapped file or database fallback.
Tries to load from memory-mapped file first (preferred for large indexes),
falls back to database loading if mmap file doesn't exist.
Returns:
True if vectors were loaded successfully.
"""
if self._loaded:
return True
# Try memory-mapped file first (preferred)
mmap_path = self.index_root / "_binary_vectors.mmap"
meta_path = mmap_path.with_suffix('.meta.json')
if mmap_path.exists() and meta_path.exists():
try:
with open(meta_path, 'r') as f:
meta = json.load(f)
shape = tuple(meta['shape'])
self._chunk_ids = np.array(meta['chunk_ids'], dtype=np.int64)
# Memory-map the binary matrix (read-only)
self._binary_matrix = np.memmap(
str(mmap_path),
dtype=np.uint8,
mode='r',
shape=shape
)
self._is_memmap = True
self._loaded = True
logger.info(
"Memory-mapped %d binary vectors (%d bytes each)",
len(self._chunk_ids), shape[1]
)
return True
except Exception as e:
logger.warning("Failed to load mmap binary vectors, falling back to DB: %s", e)
# Fallback: load from database
return self._load_from_db()
def _load_from_db(self) -> bool:
"""Load binary vectors from database (legacy/fallback mode).
Returns:
True if vectors were loaded successfully.
"""
try:
from codexlens.storage.vector_meta_store import VectorMetadataStore
with VectorMetadataStore(self.meta_store_path) as store:
rows = store.get_all_binary_vectors()
if not rows:
logger.warning("No binary vectors found in %s", self.meta_store_path)
return False
# Convert to numpy arrays for fast computation
self._chunk_ids = np.array([r[0] for r in rows], dtype=np.int64)
# Unpack bytes to numpy array
binary_arrays = []
for _, vec_bytes in rows:
arr = np.frombuffer(vec_bytes, dtype=np.uint8)
binary_arrays.append(arr)
self._binary_matrix = np.vstack(binary_arrays)
self._is_memmap = False
self._loaded = True
logger.info(
"Loaded %d binary vectors from DB (%d bytes each)",
len(self._chunk_ids), self._binary_matrix.shape[1]
)
return True
except Exception as e:
logger.error("Failed to load binary vectors: %s", e)
return False
def search(
self,
query_vector: np.ndarray,
top_k: int = 100
) -> List[Tuple[int, int]]:
"""Search for similar vectors using Hamming distance.
Args:
query_vector: Dense query vector (will be binarized).
top_k: Number of top results to return.
Returns:
List of (chunk_id, hamming_distance) tuples sorted by distance.
"""
if not self._loaded and not self.load():
return []
# Binarize query vector
query_binary = (query_vector > 0).astype(np.uint8)
query_packed = np.packbits(query_binary)
# Compute Hamming distances using XOR and popcount
# XOR gives 1 for differing bits
xor_result = np.bitwise_xor(self._binary_matrix, query_packed)
# Vectorized popcount using lookup table (orders of magnitude faster)
# Sum the bit counts for each byte across all columns
distances = np.sum(_POPCOUNT_TABLE[xor_result], axis=1, dtype=np.int32)
# Get top-k with smallest distances
if top_k >= len(distances):
top_indices = np.argsort(distances)
else:
# Partial sort for efficiency
top_indices = np.argpartition(distances, top_k)[:top_k]
top_indices = top_indices[np.argsort(distances[top_indices])]
results = [
(int(self._chunk_ids[i]), int(distances[i]))
for i in top_indices
]
return results
def search_with_rerank(
self,
query_dense: np.ndarray,
dense_vectors: np.ndarray,
dense_chunk_ids: np.ndarray,
top_k: int = 10,
candidates: int = 100
) -> List[Tuple[int, float]]:
"""Two-stage cascade search: binary filter + dense rerank.
Args:
query_dense: Dense query vector.
dense_vectors: Dense vectors for reranking (from HNSW or stored).
dense_chunk_ids: Chunk IDs corresponding to dense_vectors.
top_k: Final number of results.
candidates: Number of candidates from binary search.
Returns:
List of (chunk_id, cosine_similarity) tuples.
"""
# Stage 1: Binary filtering
binary_results = self.search(query_dense, top_k=candidates)
if not binary_results:
return []
candidate_ids = {r[0] for r in binary_results}
# Stage 2: Dense reranking
# Find indices of candidates in dense_vectors
candidate_mask = np.isin(dense_chunk_ids, list(candidate_ids))
candidate_indices = np.where(candidate_mask)[0]
if len(candidate_indices) == 0:
# Fallback: return binary results with normalized distance
max_dist = max(r[1] for r in binary_results) if binary_results else 1
return [(r[0], 1.0 - r[1] / max_dist) for r in binary_results[:top_k]]
# Compute cosine similarities for candidates
candidate_vectors = dense_vectors[candidate_indices]
candidate_ids_array = dense_chunk_ids[candidate_indices]
# Normalize vectors
query_norm = query_dense / (np.linalg.norm(query_dense) + 1e-8)
cand_norms = candidate_vectors / (
np.linalg.norm(candidate_vectors, axis=1, keepdims=True) + 1e-8
)
# Cosine similarities
similarities = np.dot(cand_norms, query_norm)
# Sort by similarity (descending)
sorted_indices = np.argsort(-similarities)[:top_k]
results = [
(int(candidate_ids_array[i]), float(similarities[i]))
for i in sorted_indices
]
return results
@property
def vector_count(self) -> int:
"""Get number of loaded binary vectors."""
return len(self._chunk_ids) if self._chunk_ids is not None else 0
@property
def is_memmap(self) -> bool:
"""Check if using memory-mapped file (vs in-memory array)."""
return self._is_memmap
def clear(self) -> None:
"""Clear loaded vectors from memory."""
# For memmap, just delete the reference (OS will handle cleanup)
if self._is_memmap and self._binary_matrix is not None:
del self._binary_matrix
self._chunk_ids = None
self._binary_matrix = None
self._is_memmap = False
self._loaded = False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,124 @@
"""Clustering strategies for the staged hybrid search pipeline.
This module provides extensible clustering infrastructure for grouping
similar search results and selecting representative results.
Install with: pip install codexlens[clustering]
Example:
>>> from codexlens.search.clustering import (
... CLUSTERING_AVAILABLE,
... ClusteringConfig,
... get_strategy,
... )
>>> config = ClusteringConfig(min_cluster_size=3)
>>> # Auto-select best available strategy with fallback
>>> strategy = get_strategy("auto", config)
>>> representatives = strategy.fit_predict(embeddings, results)
>>>
>>> # Or explicitly use a specific strategy
>>> if CLUSTERING_AVAILABLE:
... from codexlens.search.clustering import HDBSCANStrategy
... strategy = HDBSCANStrategy(config)
... representatives = strategy.fit_predict(embeddings, results)
"""
from __future__ import annotations
# Always export base classes and factory (no heavy dependencies)
from .base import BaseClusteringStrategy, ClusteringConfig
from .factory import (
ClusteringStrategyFactory,
check_clustering_strategy_available,
get_strategy,
)
from .noop_strategy import NoOpStrategy
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
# Feature flag for clustering availability (hdbscan + sklearn)
CLUSTERING_AVAILABLE = False
HDBSCAN_AVAILABLE = False
DBSCAN_AVAILABLE = False
_import_error: str | None = None
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
"""Detect if clustering dependencies are available.
Returns:
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
"""
hdbscan_ok = False
dbscan_ok = False
try:
import hdbscan # noqa: F401
hdbscan_ok = True
except ImportError:
pass
try:
from sklearn.cluster import DBSCAN # noqa: F401
dbscan_ok = True
except ImportError:
pass
all_ok = hdbscan_ok and dbscan_ok
error = None
if not all_ok:
missing = []
if not hdbscan_ok:
missing.append("hdbscan")
if not dbscan_ok:
missing.append("scikit-learn")
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
return all_ok, hdbscan_ok, dbscan_ok, error
# Initialize on module load
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
_detect_clustering_available()
)
def check_clustering_available() -> tuple[bool, str | None]:
"""Check if all clustering dependencies are available.
Returns:
Tuple of (is_available, error_message).
error_message is None if available, otherwise contains install instructions.
"""
return CLUSTERING_AVAILABLE, _import_error
# Conditionally export strategy implementations
__all__ = [
# Feature flags
"CLUSTERING_AVAILABLE",
"HDBSCAN_AVAILABLE",
"DBSCAN_AVAILABLE",
"check_clustering_available",
# Base classes
"BaseClusteringStrategy",
"ClusteringConfig",
# Factory
"ClusteringStrategyFactory",
"get_strategy",
"check_clustering_strategy_available",
# Always-available strategies
"NoOpStrategy",
"FrequencyStrategy",
"FrequencyConfig",
]
# Conditionally add strategy classes to __all__ and module namespace
if HDBSCAN_AVAILABLE:
from .hdbscan_strategy import HDBSCANStrategy
__all__.append("HDBSCANStrategy")
if DBSCAN_AVAILABLE:
from .dbscan_strategy import DBSCANStrategy
__all__.append("DBSCANStrategy")

View File

@@ -0,0 +1,153 @@
"""Base classes for clustering strategies in the hybrid search pipeline.
This module defines the abstract base class for clustering strategies used
in the staged hybrid search pipeline. Strategies cluster search results
based on their embeddings and select representative results from each cluster.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
@dataclass
class ClusteringConfig:
"""Configuration parameters for clustering strategies.
Attributes:
min_cluster_size: Minimum number of results to form a cluster.
HDBSCAN default is 5, but for search results 2-3 is often better.
min_samples: Number of samples in a neighborhood for a point to be
considered a core point. Lower values allow more clusters.
metric: Distance metric for clustering. Common options:
- 'euclidean': Standard L2 distance
- 'cosine': Cosine distance (1 - cosine_similarity)
- 'manhattan': L1 distance
cluster_selection_epsilon: Distance threshold for cluster selection.
Results within this distance may be merged into the same cluster.
allow_single_cluster: If True, allow all results to form one cluster.
Useful when results are very similar.
prediction_data: If True, generate prediction data for new points.
"""
min_cluster_size: int = 3
min_samples: int = 2
metric: str = "cosine"
cluster_selection_epsilon: float = 0.0
allow_single_cluster: bool = True
prediction_data: bool = False
def __post_init__(self) -> None:
"""Validate configuration parameters."""
if self.min_cluster_size < 2:
raise ValueError("min_cluster_size must be >= 2")
if self.min_samples < 1:
raise ValueError("min_samples must be >= 1")
if self.metric not in ("euclidean", "cosine", "manhattan"):
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
if self.cluster_selection_epsilon < 0:
raise ValueError("cluster_selection_epsilon must be >= 0")
class BaseClusteringStrategy(ABC):
"""Abstract base class for clustering strategies.
Clustering strategies are used in the staged hybrid search pipeline to
group similar search results and select representative results from each
cluster, reducing redundancy while maintaining diversity.
Subclasses must implement:
- cluster(): Group results into clusters based on embeddings
- select_representatives(): Choose best result(s) from each cluster
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize the clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
"""
self.config = config or ClusteringConfig()
@abstractmethod
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results based on their embeddings.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Used for additional metadata during clustering.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Results not assigned to any cluster
(noise points) should be returned as single-element clusters.
Example:
>>> strategy = HDBSCANStrategy()
>>> clusters = strategy.cluster(embeddings, results)
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
>>> # Result indices 0, 2, 5 are in cluster 0
>>> # Result indices 1, 3 are in cluster 1
>>> # Result index 4 is a noise point (singleton cluster)
>>> # Result indices 6, 7, 8 are in cluster 2
"""
...
@abstractmethod
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
This method chooses the best result(s) from each cluster to include
in the final search results. The selection can be based on:
- Highest score within cluster
- Closest to cluster centroid
- Custom selection logic
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings array for centroid-based selection.
Returns:
List of representative SearchResult objects, one or more per cluster,
ordered by relevance (highest score first).
Example:
>>> representatives = strategy.select_representatives(clusters, results)
>>> # Returns best result from each cluster
"""
...
def fit_predict(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List["SearchResult"]:
"""Convenience method to cluster and select representatives in one call.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
results: List of SearchResult objects.
Returns:
List of representative SearchResult objects.
"""
clusters = self.cluster(embeddings, results)
return self.select_representatives(clusters, results, embeddings)

View File

@@ -0,0 +1,197 @@
"""DBSCAN-based clustering strategy for search results.
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
is the fallback clustering strategy when HDBSCAN is not available.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class DBSCANStrategy(BaseClusteringStrategy):
"""DBSCAN-based clustering strategy.
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
DBSCAN requires an explicit eps parameter, which is auto-computed from the
distance distribution if not provided.
Example:
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
>>> strategy = DBSCANStrategy(config)
>>> clusters = strategy.cluster(embeddings, results)
>>> representatives = strategy.select_representatives(clusters, results)
"""
# Default eps percentile for auto-computation
DEFAULT_EPS_PERCENTILE: float = 15.0
def __init__(
self,
config: Optional[ClusteringConfig] = None,
eps: Optional[float] = None,
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
) -> None:
"""Initialize DBSCAN clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
from the distance distribution.
eps_percentile: Percentile of pairwise distances to use for
auto-computing eps. Default is 15th percentile.
Raises:
ImportError: If sklearn is not installed.
"""
super().__init__(config)
self.eps = eps
self.eps_percentile = eps_percentile
# Validate sklearn is available
try:
from sklearn.cluster import DBSCAN # noqa: F401
except ImportError as exc:
raise ImportError(
"scikit-learn package is required for DBSCANStrategy. "
"Install with: pip install codexlens[clustering]"
) from exc
def _compute_eps(self, embeddings: "np.ndarray") -> float:
"""Auto-compute eps from pairwise distance distribution.
Uses the specified percentile of pairwise distances as eps,
which typically captures local density well.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
Returns:
Computed eps value.
"""
import numpy as np
from sklearn.metrics import pairwise_distances
# Compute pairwise distances
distances = pairwise_distances(embeddings, metric=self.config.metric)
# Get upper triangle (excluding diagonal)
upper_tri = distances[np.triu_indices_from(distances, k=1)]
if len(upper_tri) == 0:
# Only one point, return a default small eps
return 0.1
# Use percentile of distances as eps
eps = float(np.percentile(upper_tri, self.eps_percentile))
# Ensure eps is positive
return max(eps, 1e-6)
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results using DBSCAN algorithm.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Noise points are returned as singleton clusters.
"""
from sklearn.cluster import DBSCAN
import numpy as np
n_results = len(results)
if n_results == 0:
return []
# Handle edge case: single result
if n_results == 1:
return [[0]]
# Determine eps value
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
# Configure DBSCAN clusterer
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
clusterer = DBSCAN(
eps=eps,
min_samples=self.config.min_samples,
metric=self.config.metric,
)
# Fit and get cluster labels
# Labels: -1 = noise, 0+ = cluster index
labels = clusterer.fit_predict(embeddings)
# Group indices by cluster label
cluster_map: dict[int, list[int]] = {}
for idx, label in enumerate(labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(idx)
# Build result: non-noise clusters first, then noise as singletons
clusters: List[List[int]] = []
# Add proper clusters (label >= 0)
for label in sorted(cluster_map.keys()):
if label >= 0:
clusters.append(cluster_map[label])
# Add noise points as singleton clusters (label == -1)
if -1 in cluster_map:
for idx in cluster_map[-1]:
clusters.append([idx])
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
Selects the result with the highest score from each cluster.
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used in score-based selection).
Returns:
List of representative SearchResult objects, one per cluster,
ordered by score (highest first).
"""
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
# Find the result with the highest score in this cluster
best_idx = max(cluster_indices, key=lambda i: results[i].score)
representatives.append(results[best_idx])
# Sort by score descending
representatives.sort(key=lambda r: r.score, reverse=True)
return representatives

View File

@@ -0,0 +1,202 @@
"""Factory for creating clustering strategies.
Provides a unified interface for instantiating different clustering backends
with automatic fallback chain: hdbscan -> dbscan -> noop.
"""
from __future__ import annotations
from typing import Any, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
from .noop_strategy import NoOpStrategy
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
"""Check whether a specific clustering strategy can be used.
Args:
strategy: Strategy name to check. Options:
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
- "dbscan": DBSCAN clustering (requires sklearn)
- "frequency": Frequency-based clustering (always available)
- "noop": No-op strategy (always available)
Returns:
Tuple of (is_available, error_message).
error_message is None if available, otherwise contains install instructions.
"""
strategy = (strategy or "").strip().lower()
if strategy == "hdbscan":
try:
import hdbscan # noqa: F401
except ImportError:
return False, (
"hdbscan package not available. "
"Install with: pip install codexlens[clustering]"
)
return True, None
if strategy == "dbscan":
try:
from sklearn.cluster import DBSCAN # noqa: F401
except ImportError:
return False, (
"scikit-learn package not available. "
"Install with: pip install codexlens[clustering]"
)
return True, None
if strategy == "frequency":
# Frequency strategy is always available (no external deps)
return True, None
if strategy == "noop":
return True, None
return False, (
f"Invalid clustering strategy: {strategy}. "
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
)
def get_strategy(
strategy: str = "hdbscan",
config: Optional[ClusteringConfig] = None,
*,
fallback: bool = True,
**kwargs: Any,
) -> BaseClusteringStrategy:
"""Factory function to create clustering strategy with fallback chain.
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
Args:
strategy: Clustering strategy to use. Options:
- "hdbscan": HDBSCAN clustering (default, recommended)
- "dbscan": DBSCAN clustering (fallback)
- "frequency": Frequency-based clustering (groups by symbol occurrence)
- "noop": No-op strategy (returns all results ungrouped)
- "auto": Try hdbscan, then dbscan, then noop
config: Clustering configuration. Uses defaults if not provided.
For frequency strategy, pass FrequencyConfig for full control.
fallback: If True (default), automatically fall back to next strategy
in the chain when primary is unavailable. If False, raise ImportError
when requested strategy is unavailable.
**kwargs: Additional strategy-specific arguments.
For DBSCANStrategy: eps, eps_percentile
For FrequencyStrategy: group_by, min_frequency, etc.
Returns:
BaseClusteringStrategy: Configured clustering strategy instance.
Raises:
ValueError: If strategy is not recognized.
ImportError: If required dependencies are not installed and fallback=False.
Example:
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3)
>>> # Auto-select best available strategy
>>> strategy = get_strategy("auto", config)
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
>>> strategy = get_strategy("hdbscan", config)
>>> # Use frequency-based strategy
>>> from codexlens.search.clustering import FrequencyConfig
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
>>> strategy = get_strategy("frequency", freq_config)
"""
strategy = (strategy or "").strip().lower()
# Handle "auto" - try strategies in order
if strategy == "auto":
return _get_best_available_strategy(config, **kwargs)
if strategy == "hdbscan":
ok, err = check_clustering_strategy_available("hdbscan")
if ok:
from .hdbscan_strategy import HDBSCANStrategy
return HDBSCANStrategy(config)
if fallback:
# Try dbscan fallback
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
if ok_dbscan:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
# Final fallback to noop
return NoOpStrategy(config)
raise ImportError(err)
if strategy == "dbscan":
ok, err = check_clustering_strategy_available("dbscan")
if ok:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
if fallback:
# Fallback to noop
return NoOpStrategy(config)
raise ImportError(err)
if strategy == "frequency":
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
if config is None or not isinstance(config, FrequencyConfig):
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
else:
freq_config = config
return FrequencyStrategy(freq_config)
if strategy == "noop":
return NoOpStrategy(config)
raise ValueError(
f"Unknown clustering strategy: {strategy}. "
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
)
def _get_best_available_strategy(
config: Optional[ClusteringConfig] = None,
**kwargs: Any,
) -> BaseClusteringStrategy:
"""Get the best available clustering strategy.
Tries strategies in order: hdbscan -> dbscan -> noop
Args:
config: Clustering configuration.
**kwargs: Additional strategy-specific arguments.
Returns:
Best available clustering strategy instance.
"""
# Try HDBSCAN first
ok, _ = check_clustering_strategy_available("hdbscan")
if ok:
from .hdbscan_strategy import HDBSCANStrategy
return HDBSCANStrategy(config)
# Try DBSCAN second
ok, _ = check_clustering_strategy_available("dbscan")
if ok:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
# Fallback to NoOp
return NoOpStrategy(config)
# Alias for backward compatibility
ClusteringStrategyFactory = type(
"ClusteringStrategyFactory",
(),
{
"get_strategy": staticmethod(get_strategy),
"check_available": staticmethod(check_clustering_strategy_available),
},
)

View File

@@ -0,0 +1,263 @@
"""Frequency-based clustering strategy for search result deduplication.
This strategy groups search results by symbol/method name and prunes based on
occurrence frequency. High-frequency symbols (frequently referenced methods)
are considered more important and retained, while low-frequency results
(potentially noise) can be filtered out.
Use cases:
- Prioritize commonly called methods/functions
- Filter out one-off results that may be less relevant
- Deduplicate results pointing to the same symbol from different locations
"""
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
@dataclass
class FrequencyConfig(ClusteringConfig):
"""Configuration for frequency-based clustering strategy.
Attributes:
group_by: Field to group results by for frequency counting.
- 'symbol': Group by symbol_name (default, for method/function dedup)
- 'file': Group by file path
- 'symbol_kind': Group by symbol type (function, class, etc.)
min_frequency: Minimum occurrence count to keep a result.
Results appearing less than this are considered noise and pruned.
max_representatives_per_group: Maximum results to keep per symbol group.
frequency_weight: How much to boost score based on frequency.
Final score = original_score * (1 + frequency_weight * log(frequency))
keep_mode: How to handle low-frequency results.
- 'filter': Remove results below min_frequency
- 'demote': Keep but lower their score ranking
"""
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
max_representatives_per_group: int = 3
frequency_weight: float = 0.1 # Boost factor for frequency
keep_mode: Literal["filter", "demote"] = "demote"
def __post_init__(self) -> None:
"""Validate configuration parameters."""
# Skip parent validation since we don't use HDBSCAN params
if self.min_frequency < 1:
raise ValueError("min_frequency must be >= 1")
if self.max_representatives_per_group < 1:
raise ValueError("max_representatives_per_group must be >= 1")
if self.frequency_weight < 0:
raise ValueError("frequency_weight must be >= 0")
if self.group_by not in ("symbol", "file", "symbol_kind"):
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
if self.keep_mode not in ("filter", "demote"):
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
class FrequencyStrategy(BaseClusteringStrategy):
"""Frequency-based clustering strategy for search result deduplication.
This strategy groups search results by symbol name (or file/kind) and:
1. Counts how many times each symbol appears in results
2. Higher frequency = more important (frequently referenced method)
3. Filters or demotes low-frequency results
4. Selects top representatives from each frequency group
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
- Does NOT require embeddings (works with metadata only)
- Is very fast (O(n) complexity)
- Is deterministic (no random initialization)
- Works well for symbol-level deduplication
Example:
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
>>> strategy = FrequencyStrategy(config)
>>> # Results with symbol "authenticate" appearing 5 times
>>> # will be prioritized over "helper_func" appearing once
>>> representatives = strategy.fit_predict(embeddings, results)
"""
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
"""Initialize the frequency strategy.
Args:
config: Frequency configuration. Uses defaults if not provided.
"""
self.config: FrequencyConfig = config or FrequencyConfig()
def _get_group_key(self, result: "SearchResult") -> str:
"""Extract grouping key from a search result.
Args:
result: SearchResult to extract key from.
Returns:
String key for grouping (symbol name, file path, or kind).
"""
if self.config.group_by == "symbol":
# Use symbol_name if available, otherwise fall back to file:line
symbol = getattr(result, "symbol_name", None)
if symbol:
return str(symbol)
# Fallback: use file path + start_line as pseudo-symbol
start_line = getattr(result, "start_line", 0) or 0
return f"{result.path}:{start_line}"
elif self.config.group_by == "file":
return str(result.path)
elif self.config.group_by == "symbol_kind":
kind = getattr(result, "symbol_kind", None)
return str(kind) if kind else "unknown"
return str(result.path) # Default fallback
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Group search results by frequency of occurrence.
Note: This method ignores embeddings and groups by metadata only.
The embeddings parameter is kept for interface compatibility.
Args:
embeddings: Ignored (kept for interface compatibility).
results: List of SearchResult objects to cluster.
Returns:
List of clusters (groups), where each cluster contains indices
of results with the same grouping key. Clusters are ordered by
frequency (highest frequency first).
"""
if not results:
return []
# Group results by key
groups: Dict[str, List[int]] = defaultdict(list)
for idx, result in enumerate(results):
key = self._get_group_key(result)
groups[key].append(idx)
# Sort groups by frequency (descending) then by key (for stability)
sorted_groups = sorted(
groups.items(),
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
)
# Convert to list of clusters
clusters = [indices for _, indices in sorted_groups]
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results based on frequency and score.
For each frequency group:
1. If frequency < min_frequency: filter or demote based on keep_mode
2. Sort by score within group
3. Apply frequency boost to scores
4. Select top N representatives
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (used for tie-breaking if provided).
Returns:
List of representative SearchResult objects, ordered by
frequency-adjusted score (highest first).
"""
import math
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
demoted: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
frequency = len(cluster_indices)
# Get results in this cluster, sorted by score
cluster_results = [results[i] for i in cluster_indices]
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
# Check frequency threshold
if frequency < self.config.min_frequency:
if self.config.keep_mode == "filter":
# Skip low-frequency results entirely
continue
else: # demote mode
# Keep but add to demoted list (lower priority)
for result in cluster_results[: self.config.max_representatives_per_group]:
demoted.append(result)
continue
# Apply frequency boost and select top representatives
for result in cluster_results[: self.config.max_representatives_per_group]:
# Calculate frequency-boosted score
original_score = getattr(result, "score", 0.0)
# log(frequency + 1) to handle frequency=1 case smoothly
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
boosted_score = original_score * frequency_boost
# Create new result with boosted score and frequency metadata
# Note: SearchResult might be immutable, so we preserve original
# and track boosted score in metadata
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
result.metadata["frequency"] = frequency
result.metadata["frequency_boosted_score"] = boosted_score
representatives.append(result)
# Sort representatives by boosted score (or original score as fallback)
def get_sort_score(r: "SearchResult") -> float:
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
return getattr(r, "score", 0.0)
representatives.sort(key=get_sort_score, reverse=True)
# Add demoted results at the end
if demoted:
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
representatives.extend(demoted)
return representatives
def fit_predict(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List["SearchResult"]:
"""Convenience method to cluster and select representatives in one call.
Args:
embeddings: NumPy array (may be ignored for frequency-based clustering).
results: List of SearchResult objects.
Returns:
List of representative SearchResult objects.
"""
clusters = self.cluster(embeddings, results)
return self.select_representatives(clusters, results, embeddings)

View File

@@ -0,0 +1,153 @@
"""HDBSCAN-based clustering strategy for search results.
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
is the primary clustering strategy for grouping similar search results.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class HDBSCANStrategy(BaseClusteringStrategy):
"""HDBSCAN-based clustering strategy.
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
HDBSCAN is preferred over DBSCAN because it:
- Automatically determines the number of clusters
- Handles varying density clusters well
- Identifies noise points (outliers) effectively
Example:
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
>>> strategy = HDBSCANStrategy(config)
>>> clusters = strategy.cluster(embeddings, results)
>>> representatives = strategy.select_representatives(clusters, results)
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize HDBSCAN clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
Raises:
ImportError: If hdbscan package is not installed.
"""
super().__init__(config)
# Validate hdbscan is available
try:
import hdbscan # noqa: F401
except ImportError as exc:
raise ImportError(
"hdbscan package is required for HDBSCANStrategy. "
"Install with: pip install codexlens[clustering]"
) from exc
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results using HDBSCAN algorithm.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Noise points are returned as singleton clusters.
"""
import hdbscan
import numpy as np
n_results = len(results)
if n_results == 0:
return []
# Handle edge case: fewer results than min_cluster_size
if n_results < self.config.min_cluster_size:
# Return each result as its own singleton cluster
return [[i] for i in range(n_results)]
# Configure HDBSCAN clusterer
clusterer = hdbscan.HDBSCAN(
min_cluster_size=self.config.min_cluster_size,
min_samples=self.config.min_samples,
metric=self.config.metric,
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
allow_single_cluster=self.config.allow_single_cluster,
prediction_data=self.config.prediction_data,
)
# Fit and get cluster labels
# Labels: -1 = noise, 0+ = cluster index
labels = clusterer.fit_predict(embeddings)
# Group indices by cluster label
cluster_map: dict[int, list[int]] = {}
for idx, label in enumerate(labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(idx)
# Build result: non-noise clusters first, then noise as singletons
clusters: List[List[int]] = []
# Add proper clusters (label >= 0)
for label in sorted(cluster_map.keys()):
if label >= 0:
clusters.append(cluster_map[label])
# Add noise points as singleton clusters (label == -1)
if -1 in cluster_map:
for idx in cluster_map[-1]:
clusters.append([idx])
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
Selects the result with the highest score from each cluster.
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used in score-based selection).
Returns:
List of representative SearchResult objects, one per cluster,
ordered by score (highest first).
"""
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
# Find the result with the highest score in this cluster
best_idx = max(cluster_indices, key=lambda i: results[i].score)
representatives.append(results[best_idx])
# Sort by score descending
representatives.sort(key=lambda r: r.score, reverse=True)
return representatives

View File

@@ -0,0 +1,83 @@
"""No-op clustering strategy for search results.
NoOpStrategy returns all results ungrouped when clustering dependencies
are not available or clustering is disabled.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class NoOpStrategy(BaseClusteringStrategy):
"""No-op clustering strategy that returns all results ungrouped.
This strategy is used as a final fallback when no clustering dependencies
are available, or when clustering is explicitly disabled. Each result
is treated as its own singleton cluster.
Example:
>>> from codexlens.search.clustering import NoOpStrategy
>>> strategy = NoOpStrategy()
>>> clusters = strategy.cluster(embeddings, results)
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
>>> representatives = strategy.select_representatives(clusters, results)
>>> # Returns all results sorted by score
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize NoOp clustering strategy.
Args:
config: Clustering configuration. Ignored for NoOpStrategy
but accepted for interface compatibility.
"""
super().__init__(config)
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Return each result as its own singleton cluster.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
Not used but accepted for interface compatibility.
results: List of SearchResult objects.
Returns:
List of singleton clusters, one per result.
"""
return [[i] for i in range(len(results))]
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Return all results sorted by score.
Since each cluster is a singleton, this effectively returns all
results sorted by score descending.
Args:
clusters: List of singleton clusters.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used).
Returns:
All SearchResult objects sorted by score (highest first).
"""
if not results:
return []
# Return all results sorted by score
return sorted(results, key=lambda r: r.score, reverse=True)

View File

@@ -0,0 +1,171 @@
# codex-lens/src/codexlens/search/enrichment.py
"""Relationship enrichment for search results."""
import sqlite3
from pathlib import Path
from typing import List, Dict, Any, Optional
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.graph_expander import GraphExpander
from codexlens.storage.path_mapper import PathMapper
class RelationshipEnricher:
"""Enriches search results with code graph relationships."""
def __init__(self, index_path: Path):
"""Initialize with path to index database.
Args:
index_path: Path to _index.db SQLite database
"""
self.index_path = index_path
self.db_conn: Optional[sqlite3.Connection] = None
self._connect()
def _connect(self) -> None:
"""Establish read-only database connection."""
if self.index_path.exists():
self.db_conn = sqlite3.connect(
f"file:{self.index_path}?mode=ro",
uri=True,
check_same_thread=False
)
self.db_conn.row_factory = sqlite3.Row
def enrich(self, results: List[Dict[str, Any]], limit: int = 10) -> List[Dict[str, Any]]:
"""Add relationship data to search results.
Args:
results: List of search result dictionaries
limit: Maximum number of results to enrich
Returns:
Results with relationships field added
"""
if not self.db_conn:
return results
for result in results[:limit]:
file_path = result.get('file') or result.get('path')
symbol_name = result.get('symbol')
result['relationships'] = self._find_relationships(file_path, symbol_name)
return results
def _find_relationships(self, file_path: Optional[str], symbol_name: Optional[str]) -> List[Dict[str, Any]]:
"""Query relationships for a symbol.
Args:
file_path: Path to file containing the symbol
symbol_name: Name of the symbol
Returns:
List of relationship dictionaries with type, direction, target/source, file, line
"""
if not self.db_conn or not symbol_name:
return []
relationships = []
cursor = self.db_conn.cursor()
try:
# Find symbol ID(s) by name and optionally file
if file_path:
cursor.execute(
'SELECT id FROM symbols WHERE name = ? AND file_path = ?',
(symbol_name, file_path)
)
else:
cursor.execute('SELECT id FROM symbols WHERE name = ?', (symbol_name,))
symbol_ids = [row[0] for row in cursor.fetchall()]
if not symbol_ids:
return []
# Query outgoing relationships (symbol is source)
placeholders = ','.join('?' * len(symbol_ids))
cursor.execute(f'''
SELECT sr.relationship_type, sr.target_symbol_fqn, sr.file_path, sr.line
FROM symbol_relationships sr
WHERE sr.source_symbol_id IN ({placeholders})
''', symbol_ids)
for row in cursor.fetchall():
relationships.append({
'type': row[0],
'direction': 'outgoing',
'target': row[1],
'file': row[2],
'line': row[3],
})
# Query incoming relationships (symbol is target)
# Match against symbol name or qualified name patterns
cursor.execute('''
SELECT sr.relationship_type, s.name AS source_name, sr.file_path, sr.line
FROM symbol_relationships sr
JOIN symbols s ON sr.source_symbol_id = s.id
WHERE sr.target_symbol_fqn = ? OR sr.target_symbol_fqn LIKE ?
''', (symbol_name, f'%.{symbol_name}'))
for row in cursor.fetchall():
rel_type = row[0]
# Convert to incoming type
incoming_type = self._to_incoming_type(rel_type)
relationships.append({
'type': incoming_type,
'direction': 'incoming',
'source': row[1],
'file': row[2],
'line': row[3],
})
except sqlite3.Error:
return []
return relationships
def _to_incoming_type(self, outgoing_type: str) -> str:
"""Convert outgoing relationship type to incoming type.
Args:
outgoing_type: The outgoing relationship type (e.g., 'calls', 'imports')
Returns:
Corresponding incoming type (e.g., 'called_by', 'imported_by')
"""
type_map = {
'calls': 'called_by',
'imports': 'imported_by',
'extends': 'extended_by',
}
return type_map.get(outgoing_type, f'{outgoing_type}_by')
def close(self) -> None:
"""Close database connection."""
if self.db_conn:
self.db_conn.close()
self.db_conn = None
def __enter__(self) -> 'RelationshipEnricher':
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
class SearchEnrichmentPipeline:
"""Search post-processing pipeline (optional enrichments)."""
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
self._config = config
self._graph_expander = GraphExpander(mapper, config=config)
def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]:
"""Expand base results with related symbols when enabled in config."""
if self._config is None or not getattr(self._config, "enable_graph_expansion", False):
return []
depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2)
return self._graph_expander.expand(results, depth=depth)

View File

@@ -0,0 +1,264 @@
"""Graph expansion for search results using precomputed neighbors.
Expands top search results with related symbol definitions by traversing
precomputed N-hop neighbors stored in the per-directory index databases.
"""
from __future__ import annotations
import logging
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.storage.path_mapper import PathMapper
logger = logging.getLogger(__name__)
def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]:
return (result.path, result.symbol_name, result.start_line, result.end_line)
def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]:
if content is None:
return None
if start_line is None or end_line is None:
return None
if start_line < 1 or end_line < start_line:
return None
lines = content.splitlines()
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
if start_idx >= len(lines):
return None
return "\n".join(lines[start_idx:end_idx])
class GraphExpander:
"""Expands SearchResult lists with related symbols from the code graph."""
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
self._mapper = mapper
self._config = config
self._logger = logging.getLogger(__name__)
def expand(
self,
results: Sequence[SearchResult],
*,
depth: Optional[int] = None,
max_expand: int = 10,
max_related: int = 50,
) -> List[SearchResult]:
"""Expand top results with related symbols.
Args:
results: Base ranked results.
depth: Maximum relationship depth to include (defaults to Config or 2).
max_expand: Only expand the top-N base results to bound cost.
max_related: Maximum related results to return.
Returns:
A list of related SearchResult objects with relationship_depth metadata.
"""
if not results:
return []
configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2
max_depth = int(depth if depth is not None else configured_depth)
if max_depth <= 0:
return []
max_depth = min(max_depth, 2)
expand_count = max(0, int(max_expand))
related_limit = max(0, int(max_related))
if expand_count == 0 or related_limit == 0:
return []
seen = {_result_key(r) for r in results}
related_results: List[SearchResult] = []
conn_cache: Dict[Path, sqlite3.Connection] = {}
try:
for base in list(results)[:expand_count]:
if len(related_results) >= related_limit:
break
if not base.symbol_name or not base.path:
continue
index_path = self._mapper.source_to_index_db(Path(base.path).parent)
conn = conn_cache.get(index_path)
if conn is None:
conn = self._connect_readonly(index_path)
if conn is None:
continue
conn_cache[index_path] = conn
source_ids = self._resolve_source_symbol_ids(
conn,
file_path=base.path,
symbol_name=base.symbol_name,
symbol_kind=base.symbol_kind,
)
if not source_ids:
continue
for source_id in source_ids:
neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit)
for neighbor_id, rel_depth in neighbors:
if len(related_results) >= related_limit:
break
row = self._get_symbol_details(conn, neighbor_id)
if row is None:
continue
path = str(row["full_path"])
symbol_name = str(row["name"])
symbol_kind = str(row["kind"])
start_line = int(row["start_line"]) if row["start_line"] is not None else None
end_line = int(row["end_line"]) if row["end_line"] is not None else None
content_block = _slice_content_block(
str(row["content"]) if row["content"] is not None else "",
start_line,
end_line,
)
score = float(base.score) * (0.5 ** int(rel_depth))
candidate = SearchResult(
path=path,
score=max(0.0, score),
excerpt=None,
content=content_block,
start_line=start_line,
end_line=end_line,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
metadata={"relationship_depth": int(rel_depth)},
)
key = _result_key(candidate)
if key in seen:
continue
seen.add(key)
related_results.append(candidate)
finally:
for conn in conn_cache.values():
try:
conn.close()
except Exception:
pass
return related_results
def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]:
try:
if not index_path.exists() or index_path.stat().st_size == 0:
return None
except OSError:
return None
try:
conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
except Exception as exc:
self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc)
return None
def _resolve_source_symbol_ids(
self,
conn: sqlite3.Connection,
*,
file_path: str,
symbol_name: str,
symbol_kind: Optional[str],
) -> List[int]:
try:
if symbol_kind:
rows = conn.execute(
"""
SELECT s.id
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE f.full_path = ? AND s.name = ? AND s.kind = ?
""",
(file_path, symbol_name, symbol_kind),
).fetchall()
else:
rows = conn.execute(
"""
SELECT s.id
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE f.full_path = ? AND s.name = ?
""",
(file_path, symbol_name),
).fetchall()
except sqlite3.Error:
return []
ids: List[int] = []
for row in rows:
try:
ids.append(int(row["id"]))
except Exception:
continue
return ids
def _get_neighbors(
self,
conn: sqlite3.Connection,
source_symbol_id: int,
*,
max_depth: int,
limit: int,
) -> List[Tuple[int, int]]:
try:
rows = conn.execute(
"""
SELECT neighbor_symbol_id, relationship_depth
FROM graph_neighbors
WHERE source_symbol_id = ? AND relationship_depth <= ?
ORDER BY relationship_depth ASC, neighbor_symbol_id ASC
LIMIT ?
""",
(int(source_symbol_id), int(max_depth), int(limit)),
).fetchall()
except sqlite3.Error:
return []
neighbors: List[Tuple[int, int]] = []
for row in rows:
try:
neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"])))
except Exception:
continue
return neighbors
def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]:
try:
return conn.execute(
"""
SELECT
s.id,
s.name,
s.kind,
s.start_line,
s.end_line,
f.full_path,
f.content
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE s.id = ?
""",
(int(symbol_id),),
).fetchone()
except sqlite3.Error:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,242 @@
"""Query preprocessing for CodexLens search.
Provides query expansion for better identifier matching:
- CamelCase splitting: UserAuth → User OR Auth
- snake_case splitting: user_auth → user OR auth
- Preserves original query for exact matching
"""
from __future__ import annotations
import logging
import re
from typing import Set, List
log = logging.getLogger(__name__)
class QueryParser:
"""Parser for preprocessing search queries before FTS5 execution.
Expands identifier-style queries (CamelCase, snake_case) into OR queries
to improve recall when searching for code symbols.
Example transformations:
- 'UserAuth''UserAuth OR User OR Auth'
- 'user_auth''user_auth OR user OR auth'
- 'getUserData''getUserData OR get OR User OR Data'
"""
# Patterns for identifier splitting
CAMEL_CASE_PATTERN = re.compile(r'([a-z])([A-Z])')
SNAKE_CASE_PATTERN = re.compile(r'_+')
KEBAB_CASE_PATTERN = re.compile(r'-+')
# Minimum token length to include in expansion (avoid noise from single chars)
MIN_TOKEN_LENGTH = 2
# All-caps acronyms pattern (e.g., HTTP, SQL, API)
ALL_CAPS_PATTERN = re.compile(r'^[A-Z]{2,}$')
def __init__(self, enable: bool = True, min_token_length: int = 2):
"""Initialize query parser.
Args:
enable: Whether to enable query preprocessing
min_token_length: Minimum token length to include in expansion
"""
self.enable = enable
self.min_token_length = min_token_length
def preprocess_query(self, query: str) -> str:
"""Preprocess query with identifier expansion.
Args:
query: Original search query
Returns:
Expanded query with OR operator connecting original and split tokens
Example:
>>> parser = QueryParser()
>>> parser.preprocess_query('UserAuth')
'UserAuth OR User OR Auth'
>>> parser.preprocess_query('get_user_data')
'get_user_data OR get OR user OR data'
"""
if not self.enable:
return query
query = query.strip()
if not query:
return query
# Extract tokens from query (handle multiple words/terms)
# For simple queries, just process the whole thing
# For complex FTS5 queries with operators, preserve structure
if self._is_simple_query(query):
return self._expand_simple_query(query)
else:
# Complex query with FTS5 operators, don't expand
log.debug(f"Skipping expansion for complex FTS5 query: {query}")
return query
def _is_simple_query(self, query: str) -> bool:
"""Check if query is simple (no FTS5 operators).
Args:
query: Search query
Returns:
True if query is simple (safe to expand), False otherwise
"""
# Check for FTS5 operators that indicate complex query
fts5_operators = ['OR', 'AND', 'NOT', 'NEAR', '*', '^', '"']
return not any(op in query for op in fts5_operators)
def _expand_simple_query(self, query: str) -> str:
"""Expand a simple query with identifier splitting.
Args:
query: Simple search query
Returns:
Expanded query with OR operators
"""
tokens: Set[str] = set()
# Always include original query
tokens.add(query)
# Split on whitespace first
words = query.split()
for word in words:
# Extract tokens from this word
word_tokens = self._extract_tokens(word)
tokens.update(word_tokens)
# Filter out short tokens and duplicates
filtered_tokens = [
t for t in tokens
if len(t) >= self.min_token_length
]
# Remove duplicates while preserving original query first
unique_tokens: List[str] = []
seen: Set[str] = set()
# Always put original query first
if query not in seen and len(query) >= self.min_token_length:
unique_tokens.append(query)
seen.add(query)
# Add other tokens
for token in filtered_tokens:
if token not in seen:
unique_tokens.append(token)
seen.add(token)
# Join with OR operator (only if we have multiple tokens)
if len(unique_tokens) > 1:
expanded = ' OR '.join(unique_tokens)
log.debug(f"Expanded query: '{query}''{expanded}'")
return expanded
else:
return query
def _extract_tokens(self, word: str) -> Set[str]:
"""Extract tokens from a single word using various splitting strategies.
Args:
word: Single word/identifier to split
Returns:
Set of extracted tokens
"""
tokens: Set[str] = set()
# Add original word
tokens.add(word)
# Handle all-caps acronyms (don't split)
if self.ALL_CAPS_PATTERN.match(word):
return tokens
# CamelCase splitting
camel_tokens = self._split_camel_case(word)
tokens.update(camel_tokens)
# snake_case splitting
snake_tokens = self._split_snake_case(word)
tokens.update(snake_tokens)
# kebab-case splitting
kebab_tokens = self._split_kebab_case(word)
tokens.update(kebab_tokens)
return tokens
def _split_camel_case(self, word: str) -> List[str]:
"""Split CamelCase identifier into tokens.
Args:
word: CamelCase identifier (e.g., 'getUserData')
Returns:
List of tokens (e.g., ['get', 'User', 'Data'])
"""
# Insert space before uppercase letters preceded by lowercase
spaced = self.CAMEL_CASE_PATTERN.sub(r'\1 \2', word)
# Split on spaces and filter empty
return [t for t in spaced.split() if t]
def _split_snake_case(self, word: str) -> List[str]:
"""Split snake_case identifier into tokens.
Args:
word: snake_case identifier (e.g., 'get_user_data')
Returns:
List of tokens (e.g., ['get', 'user', 'data'])
"""
# Split on underscores
return [t for t in self.SNAKE_CASE_PATTERN.split(word) if t]
def _split_kebab_case(self, word: str) -> List[str]:
"""Split kebab-case identifier into tokens.
Args:
word: kebab-case identifier (e.g., 'get-user-data')
Returns:
List of tokens (e.g., ['get', 'user', 'data'])
"""
# Split on hyphens
return [t for t in self.KEBAB_CASE_PATTERN.split(word) if t]
# Global default parser instance
_default_parser = QueryParser(enable=True)
def preprocess_query(query: str, enable: bool = True) -> str:
"""Convenience function for query preprocessing.
Args:
query: Original search query
enable: Whether to enable preprocessing
Returns:
Preprocessed query with identifier expansion
"""
if not enable:
return query
return _default_parser.preprocess_query(query)
__all__ = [
"QueryParser",
"preprocess_query",
]

View File

@@ -0,0 +1,942 @@
"""Ranking algorithms for hybrid search result fusion.
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
"""
from __future__ import annotations
import re
import math
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from codexlens.entities import SearchResult, AdditionalLocation
# Default RRF weights for SPLADE-based hybrid search
DEFAULT_WEIGHTS = {
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
"vector": 0.5,
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
}
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
FTS_FALLBACK_WEIGHTS = {
"exact": 0.25,
"fuzzy": 0.1,
"vector": 0.5,
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
}
class QueryIntent(str, Enum):
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
KEYWORD = "keyword"
SEMANTIC = "semantic"
MIXED = "mixed"
def normalize_weights(weights: Dict[str, float | None]) -> Dict[str, float | None]:
"""Normalize weights to sum to 1.0 (best-effort)."""
total = sum(float(v) for v in weights.values() if v is not None)
# NaN total: do not attempt to normalize (division would propagate NaNs).
if math.isnan(total):
return dict(weights)
# Infinite total: do not attempt to normalize (division yields 0 or NaN).
if not math.isfinite(total):
return dict(weights)
# Zero/negative total: do not attempt to normalize (invalid denominator).
if total <= 0:
return dict(weights)
return {k: (float(v) / total if v is not None else None) for k, v in weights.items()}
def detect_query_intent(query: str) -> QueryIntent:
"""Detect whether a query is code-like, natural-language, or mixed.
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
"""
trimmed = (query or "").strip()
if not trimmed:
return QueryIntent.MIXED
lower = trimmed.lower()
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
has_code_signals = bool(
re.search(r"(::|->|\.)", trimmed)
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
or re.search(r"\b\w+_\w+\b", trimmed)
or re.search(
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
lower,
flags=re.IGNORECASE,
)
)
has_natural_signals = bool(
word_count > 5
or "?" in trimmed
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
or re.search(
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
trimmed,
flags=re.IGNORECASE,
)
)
if has_code_signals and has_natural_signals:
return QueryIntent.MIXED
if has_code_signals:
return QueryIntent.KEYWORD
if has_natural_signals:
return QueryIntent.SEMANTIC
return QueryIntent.MIXED
def adjust_weights_by_intent(
intent: QueryIntent,
base_weights: Dict[str, float],
) -> Dict[str, float]:
"""Adjust RRF weights based on query intent."""
# Check if using SPLADE or FTS mode
use_splade = "splade" in base_weights
if intent == QueryIntent.KEYWORD:
if use_splade:
target = {"splade": 0.6, "vector": 0.4}
else:
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
elif intent == QueryIntent.SEMANTIC:
if use_splade:
target = {"splade": 0.3, "vector": 0.7}
else:
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
else:
target = dict(base_weights)
# Filter to active backends
keys = list(base_weights.keys())
filtered = {k: float(target.get(k, 0.0)) for k in keys}
return normalize_weights(filtered)
def get_rrf_weights(
query: str,
base_weights: Dict[str, float],
) -> Dict[str, float]:
"""Compute adaptive RRF weights from query intent."""
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
# File extensions to category mapping for fast lookup
_EXT_TO_CATEGORY: Dict[str, str] = {
# Code extensions
".py": "code", ".js": "code", ".jsx": "code", ".ts": "code", ".tsx": "code",
".java": "code", ".go": "code", ".zig": "code", ".m": "code", ".mm": "code",
".c": "code", ".h": "code", ".cc": "code", ".cpp": "code", ".hpp": "code", ".cxx": "code",
".rs": "code",
# Doc extensions
".md": "doc", ".mdx": "doc", ".txt": "doc", ".rst": "doc",
}
def get_file_category(path: str) -> Optional[str]:
"""Get file category ('code' or 'doc') from path extension.
Args:
path: File path string
Returns:
'code', 'doc', or None if unknown
"""
ext = Path(path).suffix.lower()
return _EXT_TO_CATEGORY.get(ext)
def filter_results_by_category(
results: List[SearchResult],
intent: QueryIntent,
allow_mixed: bool = True,
) -> List[SearchResult]:
"""Filter results by category based on query intent.
Strategy:
- KEYWORD (code intent): Only return code files
- SEMANTIC (doc intent): Prefer docs, but allow code if allow_mixed=True
- MIXED: Return all results
Args:
results: List of SearchResult objects
intent: Query intent from detect_query_intent()
allow_mixed: If True, SEMANTIC intent includes code files with lower priority
Returns:
Filtered and re-ranked list of SearchResult objects
"""
if not results or intent == QueryIntent.MIXED:
return results
code_results = []
doc_results = []
unknown_results = []
for r in results:
category = get_file_category(r.path)
if category == "code":
code_results.append(r)
elif category == "doc":
doc_results.append(r)
else:
unknown_results.append(r)
if intent == QueryIntent.KEYWORD:
# Code intent: return only code files + unknown (might be code)
filtered = code_results + unknown_results
elif intent == QueryIntent.SEMANTIC:
if allow_mixed:
# Semantic intent with mixed: docs first, then code
filtered = doc_results + code_results + unknown_results
else:
# Semantic intent strict: only docs
filtered = doc_results + unknown_results
else:
filtered = results
return filtered
def simple_weighted_fusion(
results_map: Dict[str, List[SearchResult]],
weights: Dict[str, float] = None,
) -> List[SearchResult]:
"""Combine search results using simple weighted sum of normalized scores.
This is an alternative to RRF that preserves score magnitude information.
Scores are min-max normalized per source before weighted combination.
Formula: score(d) = Σ weight_source * normalized_score_source(d)
Args:
results_map: Dictionary mapping source name to list of SearchResult objects
Sources: 'exact', 'fuzzy', 'vector', 'splade'
weights: Dictionary mapping source name to weight (default: equal weights)
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
Returns:
List of SearchResult objects sorted by fused score (descending)
Examples:
>>> fts_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
>>> vector_results = [SearchResult(path="b.py", score=0.85, excerpt="...")]
>>> results_map = {'exact': fts_results, 'vector': vector_results}
>>> fused = simple_weighted_fusion(results_map)
"""
if not results_map:
return []
# Default equal weights if not provided
if weights is None:
num_sources = len(results_map)
weights = {source: 1.0 / num_sources for source in results_map}
# Normalize weights to sum to 1.0
weight_sum = sum(weights.values())
if not math.isclose(weight_sum, 1.0, abs_tol=0.01) and weight_sum > 0:
weights = {source: w / weight_sum for source, w in weights.items()}
# Compute min-max normalization parameters per source
source_stats: Dict[str, tuple] = {}
for source_name, results in results_map.items():
if not results:
continue
scores = [r.score for r in results]
min_s, max_s = min(scores), max(scores)
source_stats[source_name] = (min_s, max_s)
def normalize_score(score: float, source: str) -> float:
"""Normalize score to [0, 1] range using min-max scaling."""
if source not in source_stats:
return 0.0
min_s, max_s = source_stats[source]
if max_s == min_s:
return 1.0 if score >= min_s else 0.0
return (score - min_s) / (max_s - min_s)
# Build unified result set with weighted scores
path_to_result: Dict[str, SearchResult] = {}
path_to_fusion_score: Dict[str, float] = {}
path_to_source_scores: Dict[str, Dict[str, float]] = {}
for source_name, results in results_map.items():
weight = weights.get(source_name, 0.0)
if weight == 0:
continue
for result in results:
path = result.path
normalized = normalize_score(result.score, source_name)
contribution = weight * normalized
if path not in path_to_fusion_score:
path_to_fusion_score[path] = 0.0
path_to_result[path] = result
path_to_source_scores[path] = {}
path_to_fusion_score[path] += contribution
path_to_source_scores[path][source_name] = normalized
# Create final results with fusion scores
fused_results = []
for path, base_result in path_to_result.items():
fusion_score = path_to_fusion_score[path]
fused_result = SearchResult(
path=base_result.path,
score=fusion_score,
excerpt=base_result.excerpt,
content=base_result.content,
symbol=base_result.symbol,
chunk=base_result.chunk,
metadata={
**base_result.metadata,
"fusion_method": "simple_weighted",
"fusion_score": fusion_score,
"original_score": base_result.score,
"source_scores": path_to_source_scores[path],
},
start_line=base_result.start_line,
end_line=base_result.end_line,
symbol_name=base_result.symbol_name,
symbol_kind=base_result.symbol_kind,
)
fused_results.append(fused_result)
fused_results.sort(key=lambda r: r.score, reverse=True)
return fused_results
def reciprocal_rank_fusion(
results_map: Dict[str, List[SearchResult]],
weights: Dict[str, float] = None,
k: int = 60,
) -> List[SearchResult]:
"""Combine search results from multiple sources using Reciprocal Rank Fusion.
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
Supports three-way fusion with FTS, Vector, and SPLADE sources.
Args:
results_map: Dictionary mapping source name to list of SearchResult objects
Sources: 'exact', 'fuzzy', 'vector', 'splade'
weights: Dictionary mapping source name to weight (default: equal weights)
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
Or: {'splade': 0.4, 'vector': 0.6}
k: Constant to avoid division by zero and control rank influence (default 60)
Returns:
List of SearchResult objects sorted by fused score (descending)
Examples:
>>> exact_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
>>> fused = reciprocal_rank_fusion(results_map)
# Three-way fusion with SPLADE
>>> results_map = {
... 'exact': exact_results,
... 'vector': vector_results,
... 'splade': splade_results
... }
>>> fused = reciprocal_rank_fusion(results_map, k=60)
"""
if not results_map:
return []
# Default equal weights if not provided
if weights is None:
num_sources = len(results_map)
weights = {source: 1.0 / num_sources for source in results_map}
# Validate weights sum to 1.0
weight_sum = sum(weights.values())
if not math.isclose(weight_sum, 1.0, abs_tol=0.01):
# Normalize weights to sum to 1.0
weights = {source: w / weight_sum for source, w in weights.items()}
# Build unified result set with RRF scores
path_to_result: Dict[str, SearchResult] = {}
path_to_fusion_score: Dict[str, float] = {}
path_to_source_ranks: Dict[str, Dict[str, int]] = {}
for source_name, results in results_map.items():
weight = weights.get(source_name, 0.0)
if weight == 0:
continue
for rank, result in enumerate(results, start=1):
path = result.path
rrf_contribution = weight / (k + rank)
# Initialize or accumulate fusion score
if path not in path_to_fusion_score:
path_to_fusion_score[path] = 0.0
path_to_result[path] = result
path_to_source_ranks[path] = {}
path_to_fusion_score[path] += rrf_contribution
path_to_source_ranks[path][source_name] = rank
# Create final results with fusion scores
fused_results = []
for path, base_result in path_to_result.items():
fusion_score = path_to_fusion_score[path]
# Create new SearchResult with fusion_score in metadata
fused_result = SearchResult(
path=base_result.path,
score=fusion_score,
excerpt=base_result.excerpt,
content=base_result.content,
symbol=base_result.symbol,
chunk=base_result.chunk,
metadata={
**base_result.metadata,
"fusion_method": "rrf",
"fusion_score": fusion_score,
"original_score": base_result.score,
"rrf_k": k,
"source_ranks": path_to_source_ranks[path],
},
start_line=base_result.start_line,
end_line=base_result.end_line,
symbol_name=base_result.symbol_name,
symbol_kind=base_result.symbol_kind,
)
fused_results.append(fused_result)
# Sort by fusion score descending
fused_results.sort(key=lambda r: r.score, reverse=True)
return fused_results
def apply_symbol_boost(
results: List[SearchResult],
boost_factor: float = 1.5,
) -> List[SearchResult]:
"""Boost fused scores for results that include an explicit symbol match.
The boost is multiplicative on the current result.score (typically the RRF fusion score).
When boosted, the original score is preserved in metadata["original_fusion_score"] and
metadata["boosted"] is set to True.
"""
if not results:
return []
if boost_factor <= 1.0:
# Still return new objects to follow immutable transformation pattern.
return [
SearchResult(
path=r.path,
score=r.score,
excerpt=r.excerpt,
content=r.content,
symbol=r.symbol,
chunk=r.chunk,
metadata={**r.metadata},
start_line=r.start_line,
end_line=r.end_line,
symbol_name=r.symbol_name,
symbol_kind=r.symbol_kind,
additional_locations=list(r.additional_locations),
)
for r in results
]
boosted_results: List[SearchResult] = []
for result in results:
has_symbol = bool(result.symbol_name)
original_score = float(result.score)
boosted_score = original_score * boost_factor if has_symbol else original_score
metadata = {**result.metadata}
if has_symbol:
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
metadata["boosted"] = True
metadata["symbol_boost_factor"] = boost_factor
boosted_results.append(
SearchResult(
path=result.path,
score=boosted_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata=metadata,
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
boosted_results.sort(key=lambda r: r.score, reverse=True)
return boosted_results
def rerank_results(
query: str,
results: List[SearchResult],
embedder: Any,
top_k: int = 50,
) -> List[SearchResult]:
"""Re-rank results with embedding cosine similarity, combined with current score.
Combined score formula:
0.5 * rrf_score + 0.5 * cosine_similarity
If embedder is None or embedding fails, returns results as-is.
"""
if not results:
return []
if embedder is None or top_k <= 0:
return results
rerank_count = min(int(top_k), len(results))
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
# Defensive: handle mismatched lengths and zero vectors.
n = min(len(vec_a), len(vec_b))
if n == 0:
return 0.0
dot = 0.0
norm_a = 0.0
norm_b = 0.0
for i in range(n):
a = float(vec_a[i])
b = float(vec_b[i])
dot += a * b
norm_a += a * a
norm_b += b * b
if norm_a <= 0.0 or norm_b <= 0.0:
return 0.0
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
return max(0.0, min(1.0, sim))
def text_for_embedding(r: SearchResult) -> str:
if r.excerpt and r.excerpt.strip():
return r.excerpt
if r.content and r.content.strip():
return r.content
if r.chunk and r.chunk.content and r.chunk.content.strip():
return r.chunk.content
# Fallback: stable, non-empty text.
return r.symbol_name or r.path
try:
if hasattr(embedder, "embed_single"):
query_vec = embedder.embed_single(query)
else:
query_vec = embedder.embed(query)[0]
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
doc_vecs = embedder.embed(doc_texts)
except Exception:
return results
reranked_results: List[SearchResult] = []
for idx, result in enumerate(results):
if idx < rerank_count:
rrf_score = float(result.score)
sim = cosine_similarity(query_vec, doc_vecs[idx])
combined_score = 0.5 * rrf_score + 0.5 * sim
reranked_results.append(
SearchResult(
path=result.path,
score=combined_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={
**result.metadata,
"rrf_score": rrf_score,
"cosine_similarity": sim,
"reranked": True,
},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
else:
# Preserve remaining results without re-ranking, but keep immutability.
reranked_results.append(
SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
reranked_results.sort(key=lambda r: r.score, reverse=True)
return reranked_results
def cross_encoder_rerank(
query: str,
results: List[SearchResult],
reranker: Any,
top_k: int = 50,
batch_size: int = 32,
chunk_type_weights: Optional[Dict[str, float]] = None,
test_file_penalty: float = 0.0,
) -> List[SearchResult]:
"""Second-stage reranking using a cross-encoder model.
This function is dependency-agnostic: callers can pass any object that exposes
a compatible `score_pairs(pairs, batch_size=...)` method.
Args:
query: Search query string
results: List of search results to rerank
reranker: Cross-encoder model with score_pairs or predict method
top_k: Number of top results to rerank
batch_size: Batch size for reranking
chunk_type_weights: Optional weights for different chunk types.
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
test_file_penalty: Penalty applied to test files (0.0-1.0).
Example: 0.2 means test files get 20% score reduction
"""
if not results:
return []
if reranker is None or top_k <= 0:
return results
rerank_count = min(int(top_k), len(results))
def text_for_pair(r: SearchResult) -> str:
if r.excerpt and r.excerpt.strip():
return r.excerpt
if r.content and r.content.strip():
return r.content
if r.chunk and r.chunk.content and r.chunk.content.strip():
return r.chunk.content
return r.symbol_name or r.path
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
try:
if hasattr(reranker, "score_pairs"):
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
elif hasattr(reranker, "predict"):
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
else:
return results
except Exception:
return results
if not raw_scores or len(raw_scores) != rerank_count:
return results
scores = [float(s) for s in raw_scores]
min_s = min(scores)
max_s = max(scores)
def sigmoid(x: float) -> float:
# Clamp to keep exp() stable.
x = max(-50.0, min(50.0, x))
return 1.0 / (1.0 + math.exp(-x))
if 0.0 <= min_s and max_s <= 1.0:
probs = scores
else:
probs = [sigmoid(s) for s in scores]
reranked_results: List[SearchResult] = []
# Helper to detect test files
def is_test_file(path: str) -> bool:
if not path:
return False
basename = path.split("/")[-1].split("\\")[-1]
return (
basename.startswith("test_") or
basename.endswith("_test.py") or
basename.endswith(".test.ts") or
basename.endswith(".test.js") or
basename.endswith(".spec.ts") or
basename.endswith(".spec.js") or
"/tests/" in path or
"\\tests\\" in path or
"/test/" in path or
"\\test\\" in path
)
for idx, result in enumerate(results):
if idx < rerank_count:
prev_score = float(result.score)
ce_score = scores[idx]
ce_prob = probs[idx]
# Base combined score
combined_score = 0.5 * prev_score + 0.5 * ce_prob
# Apply chunk_type weight adjustment
if chunk_type_weights:
chunk_type = None
if result.chunk and hasattr(result.chunk, "metadata"):
chunk_type = result.chunk.metadata.get("chunk_type")
elif result.metadata:
chunk_type = result.metadata.get("chunk_type")
if chunk_type and chunk_type in chunk_type_weights:
weight = chunk_type_weights[chunk_type]
# Apply weight to CE contribution only
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
# Apply test file penalty
if test_file_penalty > 0 and is_test_file(result.path):
combined_score = combined_score * (1.0 - test_file_penalty)
reranked_results.append(
SearchResult(
path=result.path,
score=combined_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={
**result.metadata,
"pre_cross_encoder_score": prev_score,
"cross_encoder_score": ce_score,
"cross_encoder_prob": ce_prob,
"cross_encoder_reranked": True,
},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
else:
reranked_results.append(
SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
reranked_results.sort(key=lambda r: r.score, reverse=True)
return reranked_results
def normalize_bm25_score(score: float) -> float:
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
SQLite FTS5 returns negative BM25 scores (more negative = better match).
Uses sigmoid transformation for normalization.
Args:
score: Raw BM25 score from SQLite (typically negative)
Returns:
Normalized score in range [0, 1]
Examples:
>>> normalize_bm25_score(-10.5) # Good match
0.85
>>> normalize_bm25_score(-1.2) # Weak match
0.62
"""
# Take absolute value (BM25 is negative in SQLite)
abs_score = abs(score)
# Sigmoid transformation: 1 / (1 + e^(-x))
# Scale factor of 0.1 maps typical BM25 range (-20 to 0) to (0, 1)
normalized = 1.0 / (1.0 + math.exp(-abs_score * 0.1))
return normalized
def tag_search_source(results: List[SearchResult], source: str) -> List[SearchResult]:
"""Tag search results with their source for RRF tracking.
Args:
results: List of SearchResult objects
source: Source identifier ('exact', 'fuzzy', 'vector')
Returns:
List of SearchResult objects with 'search_source' in metadata
"""
tagged_results = []
for result in results:
tagged_result = SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata, "search_source": source},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
)
tagged_results.append(tagged_result)
return tagged_results
def group_similar_results(
results: List[SearchResult],
score_threshold_abs: float = 0.01,
content_field: str = "excerpt"
) -> List[SearchResult]:
"""Group search results by content and score similarity.
Groups results that have similar content and similar scores into a single
representative result, with other locations stored in additional_locations.
Algorithm:
1. Group results by content (using excerpt or content field)
2. Within each content group, create subgroups based on score similarity
3. Select highest-scoring result as representative for each subgroup
4. Store other results in subgroup as additional_locations
Args:
results: A list of SearchResult objects (typically sorted by score)
score_threshold_abs: Absolute score difference to consider results similar.
Results with |score_a - score_b| <= threshold are grouped.
Default 0.01 is suitable for RRF fusion scores.
content_field: The field to use for content grouping ('excerpt' or 'content')
Returns:
A new list of SearchResult objects where similar items are grouped.
The list is sorted by score descending.
Examples:
>>> results = [SearchResult(path="a.py", score=0.5, excerpt="def foo()"),
... SearchResult(path="b.py", score=0.5, excerpt="def foo()")]
>>> grouped = group_similar_results(results)
>>> len(grouped) # Two results merged into one
1
>>> len(grouped[0].additional_locations) # One additional location
1
"""
if not results:
return []
# Group results by content
content_map: Dict[str, List[SearchResult]] = {}
unidentifiable_results: List[SearchResult] = []
for r in results:
key = getattr(r, content_field, None)
if key and key.strip():
content_map.setdefault(key, []).append(r)
else:
# Results without content can't be grouped by content
unidentifiable_results.append(r)
final_results: List[SearchResult] = []
# Process each content group
for content_group in content_map.values():
# Sort by score descending within group
content_group.sort(key=lambda r: r.score, reverse=True)
while content_group:
# Take highest scoring as representative
representative = content_group.pop(0)
others_in_group = []
remaining_for_next_pass = []
# Find results with similar scores
for item in content_group:
if abs(representative.score - item.score) <= score_threshold_abs:
others_in_group.append(item)
else:
remaining_for_next_pass.append(item)
# Create grouped result with additional locations
if others_in_group:
# Build new result with additional_locations populated
grouped_result = SearchResult(
path=representative.path,
score=representative.score,
excerpt=representative.excerpt,
content=representative.content,
symbol=representative.symbol,
chunk=representative.chunk,
metadata={
**representative.metadata,
"grouped_count": len(others_in_group) + 1,
},
start_line=representative.start_line,
end_line=representative.end_line,
symbol_name=representative.symbol_name,
symbol_kind=representative.symbol_kind,
additional_locations=[
AdditionalLocation(
path=other.path,
score=other.score,
start_line=other.start_line,
end_line=other.end_line,
symbol_name=other.symbol_name,
) for other in others_in_group
],
)
final_results.append(grouped_result)
else:
final_results.append(representative)
content_group = remaining_for_next_pass
# Add ungroupable results
final_results.extend(unidentifiable_results)
# Sort final results by score descending
final_results.sort(key=lambda r: r.score, reverse=True)
return final_results

View File

@@ -0,0 +1,118 @@
"""Optional semantic search module for CodexLens.
Install with: pip install codexlens[semantic]
Uses fastembed (ONNX-based, lightweight ~200MB)
GPU Acceleration:
- Automatic GPU detection and usage when available
- Supports CUDA (NVIDIA), TensorRT, DirectML (Windows), ROCm (AMD), CoreML (Apple)
- Install GPU support: pip install onnxruntime-gpu (NVIDIA) or onnxruntime-directml (Windows)
"""
from __future__ import annotations
SEMANTIC_AVAILABLE = False
SEMANTIC_BACKEND: str | None = None
GPU_AVAILABLE = False
LITELLM_AVAILABLE = False
_import_error: str | None = None
def _detect_backend() -> tuple[bool, str | None, bool, str | None]:
"""Detect if fastembed and GPU are available."""
try:
import numpy as np
except ImportError as e:
return False, None, False, f"numpy not available: {e}"
try:
from fastembed import TextEmbedding
except ImportError:
return False, None, False, "fastembed not available. Install with: pip install codexlens[semantic]"
# Check GPU availability
gpu_available = False
try:
from .gpu_support import is_gpu_available
gpu_available = is_gpu_available()
except ImportError:
pass
return True, "fastembed", gpu_available, None
# Initialize on module load
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, GPU_AVAILABLE, _import_error = _detect_backend()
def check_semantic_available() -> tuple[bool, str | None]:
"""Check if semantic search dependencies are available."""
return SEMANTIC_AVAILABLE, _import_error
def check_gpu_available() -> tuple[bool, str]:
"""Check if GPU acceleration is available.
Returns:
Tuple of (is_available, status_message)
"""
if not SEMANTIC_AVAILABLE:
return False, "Semantic search not available"
try:
from .gpu_support import is_gpu_available, get_gpu_summary
if is_gpu_available():
return True, get_gpu_summary()
return False, "No GPU detected (using CPU)"
except ImportError:
return False, "GPU support module not available"
# Export embedder components
# BaseEmbedder is always available (abstract base class)
from .base import BaseEmbedder
# Factory function for creating embedders
from .factory import get_embedder as get_embedder_factory
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
try:
import ccw_litellm # noqa: F401
from .litellm_embedder import LiteLLMEmbedderWrapper
LITELLM_AVAILABLE = True
except ImportError:
LiteLLMEmbedderWrapper = None
LITELLM_AVAILABLE = False
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific embedding backend can be used.
Notes:
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
- "litellm" requires ccw-litellm to be installed in the same environment.
"""
backend = (backend or "").strip().lower()
if backend == "fastembed":
if SEMANTIC_AVAILABLE:
return True, None
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
if backend == "litellm":
if LITELLM_AVAILABLE:
return True, None
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
__all__ = [
"SEMANTIC_AVAILABLE",
"SEMANTIC_BACKEND",
"GPU_AVAILABLE",
"LITELLM_AVAILABLE",
"check_semantic_available",
"is_embedding_backend_available",
"check_gpu_available",
"BaseEmbedder",
"get_embedder_factory",
"LiteLLMEmbedderWrapper",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
"""Base class for embedders.
Defines the interface that all embedders must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterable
import numpy as np
class BaseEmbedder(ABC):
"""Base class for all embedders.
All embedder implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@property
@abstractmethod
def embedding_dim(self) -> int:
"""Return embedding dimensions.
Returns:
int: Dimension of the embedding vectors.
"""
...
@property
@abstractmethod
def model_name(self) -> str:
"""Return model name.
Returns:
str: Name or identifier of the underlying model.
"""
...
@property
def max_tokens(self) -> int:
"""Return maximum token limit for embeddings.
Returns:
int: Maximum number of tokens that can be embedded at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Embed texts to numpy array.
Args:
texts: Single text or iterable of texts to embed.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
"""
...

View File

@@ -0,0 +1,821 @@
"""Code chunking strategies for semantic search.
This module provides various chunking strategies for breaking down source code
into semantic chunks suitable for embedding and search.
Lightweight Mode:
The ChunkConfig supports a `skip_token_count` option for performance optimization.
When enabled, token counting uses a fast character-based estimation (char/4)
instead of expensive tiktoken encoding.
Use cases for lightweight mode:
- Large-scale indexing where speed is critical
- Scenarios where approximate token counts are acceptable
- Memory-constrained environments
- Initial prototyping and development
Example:
# Default mode (accurate tiktoken encoding)
config = ChunkConfig()
chunker = Chunker(config)
# Lightweight mode (fast char/4 estimation)
config = ChunkConfig(skip_token_count=True)
chunker = Chunker(config)
chunks = chunker.chunk_file(content, symbols, path, language)
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import SemanticChunk, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@dataclass
class ChunkConfig:
"""Configuration for chunking strategies."""
max_chunk_size: int = 1000 # Max characters per chunk
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
strip_comments: bool = True # Remove comments from chunk content for embedding
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
preserve_original: bool = True # Store original content in metadata when stripping
class CommentStripper:
"""Remove comments from source code while preserving structure."""
@staticmethod
def strip_python_comments(content: str) -> str:
"""Strip Python comments (# style) but preserve docstrings.
Args:
content: Python source code
Returns:
Code with comments removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
in_string = False
string_char = None
for line in lines:
new_line = []
i = 0
while i < len(line):
char = line[i]
# Handle string literals
if char in ('"', "'") and not in_string:
# Check for triple quotes
if line[i:i+3] in ('"""', "'''"):
in_string = True
string_char = line[i:i+3]
new_line.append(line[i:i+3])
i += 3
continue
else:
in_string = True
string_char = char
elif in_string:
if string_char and len(string_char) == 3:
if line[i:i+3] == string_char:
in_string = False
new_line.append(line[i:i+3])
i += 3
string_char = None
continue
elif char == string_char:
# Check for escape
if i > 0 and line[i-1] != '\\':
in_string = False
string_char = None
# Handle comments (only outside strings)
if char == '#' and not in_string:
# Rest of line is comment, skip it
new_line.append('\n' if line.endswith('\n') else '')
break
new_line.append(char)
i += 1
result_lines.append(''.join(new_line))
return ''.join(result_lines)
@staticmethod
def strip_c_style_comments(content: str) -> str:
"""Strip C-style comments (// and /* */) from code.
Args:
content: Source code with C-style comments
Returns:
Code with comments removed
"""
result = []
i = 0
in_string = False
string_char = None
in_multiline_comment = False
while i < len(content):
# Handle multi-line comment end
if in_multiline_comment:
if content[i:i+2] == '*/':
in_multiline_comment = False
i += 2
continue
i += 1
continue
char = content[i]
# Handle string literals
if char in ('"', "'", '`') and not in_string:
in_string = True
string_char = char
result.append(char)
i += 1
continue
elif in_string:
result.append(char)
if char == string_char and (i == 0 or content[i-1] != '\\'):
in_string = False
string_char = None
i += 1
continue
# Handle comments
if content[i:i+2] == '//':
# Single line comment - skip to end of line
while i < len(content) and content[i] != '\n':
i += 1
if i < len(content):
result.append('\n')
i += 1
continue
if content[i:i+2] == '/*':
in_multiline_comment = True
i += 2
continue
result.append(char)
i += 1
return ''.join(result)
@classmethod
def strip_comments(cls, content: str, language: str) -> str:
"""Strip comments based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with comments removed
"""
if language == "python":
return cls.strip_python_comments(content)
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
return cls.strip_c_style_comments(content)
return content
class DocstringStripper:
"""Remove docstrings from source code."""
@staticmethod
def strip_python_docstrings(content: str) -> str:
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
Args:
content: Python source code
Returns:
Code with docstrings removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
# Check for docstring start
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
# Single line docstring
if stripped.count(quote_type) >= 2:
# Skip this line (docstring)
i += 1
continue
# Multi-line docstring - skip until closing
i += 1
while i < len(lines):
if quote_type in lines[i]:
i += 1
break
i += 1
continue
result_lines.append(line)
i += 1
return ''.join(result_lines)
@staticmethod
def strip_jsdoc_comments(content: str) -> str:
"""Strip JSDoc comments (/** ... */) from code.
Args:
content: JavaScript/TypeScript source code
Returns:
Code with JSDoc comments removed
"""
result = []
i = 0
in_jsdoc = False
while i < len(content):
if in_jsdoc:
if content[i:i+2] == '*/':
in_jsdoc = False
i += 2
continue
i += 1
continue
# Check for JSDoc start (/** but not /*)
if content[i:i+3] == '/**':
in_jsdoc = True
i += 3
continue
result.append(content[i])
i += 1
return ''.join(result)
@classmethod
def strip_docstrings(cls, content: str, language: str) -> str:
"""Strip docstrings based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with docstrings removed
"""
if language == "python":
return cls.strip_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.strip_jsdoc_comments(content)
return content
class Chunker:
"""Chunk code files for semantic embedding."""
def __init__(self, config: ChunkConfig | None = None) -> None:
self.config = config or ChunkConfig()
self._tokenizer = get_default_tokenizer()
self._comment_stripper = CommentStripper()
self._docstring_stripper = DocstringStripper()
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
"""Process chunk content by stripping comments/docstrings if configured.
Args:
content: Original chunk content
language: Programming language
Returns:
Tuple of (processed_content, original_content_if_preserved)
"""
original = content if self.config.preserve_original else None
processed = content
if self.config.strip_comments:
processed = self._comment_stripper.strip_comments(processed, language)
if self.config.strip_docstrings:
processed = self._docstring_stripper.strip_docstrings(processed, language)
# If nothing changed, don't store original
if processed == content:
original = None
return processed, original
def _estimate_token_count(self, text: str) -> int:
"""Estimate token count based on config.
If skip_token_count is True, uses character-based estimation (char/4).
Otherwise, uses accurate tiktoken encoding.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if self.config.skip_token_count:
# Fast character-based estimation: ~4 chars per token
return max(1, len(text) // 4)
return self._tokenizer.count_tokens(text)
def chunk_by_symbol(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk code by extracted symbols (functions, classes).
Each symbol becomes one chunk with its full content.
Large symbols exceeding max_chunk_size are recursively split using sliding window.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
for symbol in symbols:
start_line, end_line = symbol.range
# Convert to 0-indexed
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
chunk_content = "".join(lines[start_idx:end_idx])
if len(chunk_content.strip()) < self.config.min_chunk_size:
continue
# Check if symbol content exceeds max_chunk_size
if len(chunk_content) > self.config.max_chunk_size:
# Create line mapping for correct line number tracking
line_mapping = list(range(start_line, end_line + 1))
# Use sliding window to split large symbol
sub_chunks = self.chunk_sliding_window(
chunk_content,
file_path=file_path,
language=language,
line_mapping=line_mapping
)
# Update sub_chunks with parent symbol metadata
for sub_chunk in sub_chunks:
sub_chunk.metadata["symbol_name"] = symbol.name
sub_chunk.metadata["symbol_kind"] = symbol.kind
sub_chunk.metadata["strategy"] = "symbol_split"
sub_chunk.metadata["chunk_type"] = "code"
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
chunks.extend(sub_chunks)
else:
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._estimate_token_count(processed_content)
metadata = {
"file": str(file_path),
"language": language,
"symbol_name": symbol.name,
"symbol_kind": symbol.kind,
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=processed_content,
embedding=None,
metadata=metadata
))
return chunks
def chunk_sliding_window(
self,
content: str,
file_path: str | Path,
language: str,
line_mapping: Optional[List[int]] = None,
) -> List[SemanticChunk]:
"""Chunk code using sliding window approach.
Used for files without clear symbol boundaries or very long functions.
Args:
content: Source code content
file_path: Path to source file
language: Programming language
line_mapping: Optional list mapping content line indices to original line numbers
(1-indexed). If provided, line_mapping[i] is the original line number
for the i-th line in content.
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
if not lines:
return chunks
# Calculate lines per chunk based on average line length
avg_line_len = len(content) / max(len(lines), 1)
lines_per_chunk = max(10, int(self.config.max_chunk_size / max(avg_line_len, 1)))
overlap_lines = max(2, int(self.config.overlap / max(avg_line_len, 1)))
# Ensure overlap is less than chunk size to prevent infinite loop
overlap_lines = min(overlap_lines, lines_per_chunk - 1)
start = 0
chunk_idx = 0
while start < len(lines):
end = min(start + lines_per_chunk, len(lines))
chunk_content = "".join(lines[start:end])
if len(chunk_content.strip()) >= self.config.min_chunk_size:
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
# Move window forward
step = lines_per_chunk - overlap_lines
if step <= 0:
step = 1
start += step
continue
token_count = self._estimate_token_count(processed_content)
# Calculate correct line numbers
if line_mapping:
# Use line mapping to get original line numbers
start_line = line_mapping[start]
end_line = line_mapping[end - 1]
else:
# Default behavior: treat content as starting at line 1
start_line = start + 1
end_line = end
metadata = {
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=processed_content,
embedding=None,
metadata=metadata
))
chunk_idx += 1
# Move window, accounting for overlap
step = lines_per_chunk - overlap_lines
if step <= 0:
step = 1 # Failsafe to prevent infinite loop
start += step
# Break if we've reached the end
if end >= len(lines):
break
return chunks
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk a file using the best strategy.
Uses symbol-based chunking if symbols available,
falls back to sliding window for files without symbols.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
if symbols:
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
return self.chunk_sliding_window(content, file_path, language)
class DocstringExtractor:
"""Extract docstrings from source code."""
@staticmethod
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
"""Extract Python docstrings with their line ranges.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
docstrings: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
start_line = i + 1
if stripped.count(quote_type) >= 2:
docstring_content = line
end_line = i + 1
docstrings.append((docstring_content, start_line, end_line))
i += 1
continue
docstring_lines = [line]
i += 1
while i < len(lines):
docstring_lines.append(lines[i])
if quote_type in lines[i]:
break
i += 1
end_line = i + 1
docstring_content = "".join(docstring_lines)
docstrings.append((docstring_content, start_line, end_line))
i += 1
return docstrings
@staticmethod
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
"""Extract JSDoc comments with their line ranges.
Returns: List of (comment_content, start_line, end_line) tuples
"""
comments: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('/**'):
start_line = i + 1
comment_lines = [line]
i += 1
while i < len(lines):
comment_lines.append(lines[i])
if '*/' in lines[i]:
break
i += 1
end_line = i + 1
comment_content = "".join(comment_lines)
comments.append((comment_content, start_line, end_line))
i += 1
return comments
@classmethod
def extract_docstrings(
cls,
content: str,
language: str
) -> List[Tuple[str, int, int]]:
"""Extract docstrings based on language.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
if language == "python":
return cls.extract_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.extract_jsdoc_comments(content)
return []
class HybridChunker:
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
Composition-based strategy that:
1. Extracts docstrings as dedicated chunks
2. For remaining code, uses base chunker (symbol or sliding window)
"""
def __init__(
self,
base_chunker: Chunker | None = None,
config: ChunkConfig | None = None
) -> None:
"""Initialize hybrid chunker.
Args:
base_chunker: Chunker to use for non-docstring content
config: Configuration for chunking
"""
self.config = config or ChunkConfig()
self.base_chunker = base_chunker or Chunker(self.config)
self.docstring_extractor = DocstringExtractor()
def _get_excluded_line_ranges(
self,
docstrings: List[Tuple[str, int, int]]
) -> set[int]:
"""Get set of line numbers that are part of docstrings."""
excluded_lines: set[int] = set()
for _, start_line, end_line in docstrings:
for line_num in range(start_line, end_line + 1):
excluded_lines.add(line_num)
return excluded_lines
def _filter_symbols_outside_docstrings(
self,
symbols: List[Symbol],
excluded_lines: set[int]
) -> List[Symbol]:
"""Filter symbols to exclude those completely within docstrings."""
filtered: List[Symbol] = []
for symbol in symbols:
start_line, end_line = symbol.range
symbol_lines = set(range(start_line, end_line + 1))
if not symbol_lines.issubset(excluded_lines):
filtered.append(symbol)
return filtered
def _find_parent_symbol(
self,
start_line: int,
end_line: int,
symbols: List[Symbol],
) -> Optional[Symbol]:
"""Find the smallest symbol range that fully contains a docstring span."""
candidates: List[Symbol] = []
for symbol in symbols:
sym_start, sym_end = symbol.range
if sym_start <= start_line and end_line <= sym_end:
candidates.append(symbol)
if not candidates:
return None
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk file using hybrid strategy.
Extracts docstrings first, then chunks remaining code.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
# Step 1: Extract docstrings as dedicated chunks
docstrings: List[Tuple[str, int, int]] = []
if language == "python":
# Fast path: avoid expensive docstring extraction if delimiters are absent.
if '"""' in content or "'''" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
elif language in {"javascript", "typescript"}:
if "/**" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
else:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
# Fast path: no docstrings -> delegate to base chunker directly.
if not docstrings:
if symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, symbols, file_path, language, symbol_token_counts
)
else:
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
return base_chunks
for docstring_content, start_line, end_line in docstrings:
if len(docstring_content.strip()) >= self.config.min_chunk_size:
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
# Use base chunker's token estimation method
token_count = self.base_chunker._estimate_token_count(docstring_content)
metadata = {
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
if parent_symbol is not None:
metadata["parent_symbol"] = parent_symbol.name
metadata["parent_symbol_kind"] = parent_symbol.kind
metadata["parent_symbol_range"] = parent_symbol.range
chunks.append(SemanticChunk(
content=docstring_content,
embedding=None,
metadata=metadata
))
# Step 2: Get line ranges occupied by docstrings
excluded_lines = self._get_excluded_line_ranges(docstrings)
# Step 3: Filter symbols to exclude docstring-only ranges
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
# Step 4: Chunk remaining content using base chunker
if filtered_symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, filtered_symbols, file_path, language, symbol_token_counts
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
else:
lines = content.splitlines(keepends=True)
remaining_lines: List[str] = []
for i, line in enumerate(lines, start=1):
if i not in excluded_lines:
remaining_lines.append(line)
if remaining_lines:
remaining_content = "".join(remaining_lines)
if len(remaining_content.strip()) >= self.config.min_chunk_size:
base_chunks = self.base_chunker.chunk_sliding_window(
remaining_content, file_path, language
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
return chunks

View File

@@ -0,0 +1,274 @@
"""Smart code extraction for complete code blocks."""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import SearchResult, Symbol
def extract_complete_code_block(
result: SearchResult,
source_file_path: Optional[str] = None,
context_lines: int = 0,
) -> str:
"""Extract complete code block from a search result.
Args:
result: SearchResult from semantic search.
source_file_path: Optional path to source file for re-reading.
context_lines: Additional lines of context to include above/below.
Returns:
Complete code block as string.
"""
# If we have full content stored, use it
if result.content:
if context_lines == 0:
return result.content
# Need to add context, read from file
# Try to read from source file
file_path = source_file_path or result.path
if not file_path or not Path(file_path).exists():
# Fall back to excerpt
return result.excerpt or ""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
# Get line range
start_line = result.start_line or 1
end_line = result.end_line or len(lines)
# Add context
start_idx = max(0, start_line - 1 - context_lines)
end_idx = min(len(lines), end_line + context_lines)
return "\n".join(lines[start_idx:end_idx])
except Exception:
return result.excerpt or result.content or ""
def extract_symbol_with_context(
file_path: str,
symbol: Symbol,
include_docstring: bool = True,
include_decorators: bool = True,
) -> str:
"""Extract a symbol (function/class) with its docstring and decorators.
Args:
file_path: Path to source file.
symbol: Symbol to extract.
include_docstring: Include docstring if present.
include_decorators: Include decorators/annotations above symbol.
Returns:
Complete symbol code with context.
"""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
start_line, end_line = symbol.range
start_idx = start_line - 1
end_idx = end_line
# Look for decorators above the symbol
if include_decorators and start_idx > 0:
decorator_start = start_idx
# Search backwards for decorators
i = start_idx - 1
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
line = lines[i].strip()
if line.startswith("@"):
decorator_start = i
i -= 1
elif line == "" or line.startswith("#"):
# Skip empty lines and comments, continue looking
i -= 1
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
# JavaScript/Java style comments
decorator_start = i
i -= 1
else:
# Found non-decorator, non-comment line, stop
break
start_idx = decorator_start
return "\n".join(lines[start_idx:end_idx])
except Exception:
return ""
def format_search_result_code(
result: SearchResult,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
highlight_match: bool = False,
) -> str:
"""Format search result code for display.
Args:
result: SearchResult to format.
max_lines: Maximum lines to show (None for all).
show_line_numbers: Include line numbers in output.
highlight_match: Add markers for matched region.
Returns:
Formatted code string.
"""
content = result.content or result.excerpt or ""
if not content:
return ""
lines = content.splitlines()
# Truncate if needed
truncated = False
if max_lines and len(lines) > max_lines:
lines = lines[:max_lines]
truncated = True
# Format with line numbers
if show_line_numbers:
start = result.start_line or 1
formatted_lines = []
for i, line in enumerate(lines):
line_num = start + i
formatted_lines.append(f"{line_num:4d} | {line}")
output = "\n".join(formatted_lines)
else:
output = "\n".join(lines)
if truncated:
output += "\n... (truncated)"
return output
def get_code_block_summary(result: SearchResult) -> str:
"""Get a concise summary of a code block.
Args:
result: SearchResult to summarize.
Returns:
Summary string like "function hello_world (lines 10-25)"
"""
parts = []
if result.symbol_kind:
parts.append(result.symbol_kind)
if result.symbol_name:
parts.append(f"`{result.symbol_name}`")
elif result.excerpt:
# Extract first meaningful identifier
first_line = result.excerpt.split("\n")[0][:50]
parts.append(f'"{first_line}..."')
if result.start_line and result.end_line:
if result.start_line == result.end_line:
parts.append(f"(line {result.start_line})")
else:
parts.append(f"(lines {result.start_line}-{result.end_line})")
if result.path:
file_name = Path(result.path).name
parts.append(f"in {file_name}")
return " ".join(parts) if parts else "unknown code block"
class CodeBlockResult:
"""Enhanced search result with complete code block."""
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
self.result = result
self.source_path = source_path or result.path
self._full_code: Optional[str] = None
@property
def score(self) -> float:
return self.result.score
@property
def path(self) -> str:
return self.result.path
@property
def file_name(self) -> str:
return Path(self.result.path).name
@property
def symbol_name(self) -> Optional[str]:
return self.result.symbol_name
@property
def symbol_kind(self) -> Optional[str]:
return self.result.symbol_kind
@property
def line_range(self) -> Tuple[int, int]:
return (
self.result.start_line or 1,
self.result.end_line or 1
)
@property
def full_code(self) -> str:
"""Get full code block content."""
if self._full_code is None:
self._full_code = extract_complete_code_block(self.result, self.source_path)
return self._full_code
@property
def excerpt(self) -> str:
"""Get short excerpt."""
return self.result.excerpt or ""
@property
def summary(self) -> str:
"""Get code block summary."""
return get_code_block_summary(self.result)
def format(
self,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
) -> str:
"""Format code for display."""
# Use full code if available
display_result = SearchResult(
path=self.result.path,
score=self.result.score,
content=self.full_code,
start_line=self.result.start_line,
end_line=self.result.end_line,
)
return format_search_result_code(
display_result,
max_lines=max_lines,
show_line_numbers=show_line_numbers
)
def __repr__(self) -> str:
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
def enhance_search_results(
results: List[SearchResult],
) -> List[CodeBlockResult]:
"""Enhance search results with complete code block access.
Args:
results: List of SearchResult from semantic search.
Returns:
List of CodeBlockResult with full code access.
"""
return [CodeBlockResult(r) for r in results]

View File

@@ -0,0 +1,288 @@
"""Embedder for semantic code search using fastembed.
Supports GPU acceleration via ONNX execution providers (CUDA, TensorRT, DirectML, ROCm, CoreML).
GPU acceleration is automatic when available, with transparent CPU fallback.
"""
from __future__ import annotations
import gc
import logging
import threading
from typing import Dict, Iterable, List, Optional
import numpy as np
from . import SEMANTIC_AVAILABLE
from .base import BaseEmbedder
from .gpu_support import get_optimal_providers, is_gpu_available, get_gpu_summary, get_selected_device_id
logger = logging.getLogger(__name__)
# Global embedder cache for singleton pattern
_embedder_cache: Dict[str, "Embedder"] = {}
_cache_lock = threading.RLock()
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
"""Get or create a cached Embedder instance (thread-safe singleton).
This function provides significant performance improvement by reusing
Embedder instances across multiple searches, avoiding repeated model
loading overhead (~0.8s per load).
Args:
profile: Model profile ("fast", "code", "multilingual", "balanced")
use_gpu: If True, use GPU acceleration when available (default: True)
Returns:
Cached Embedder instance for the given profile
"""
global _embedder_cache
# Cache key includes GPU preference to support mixed configurations
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
# All cache access is protected by _cache_lock to avoid races with
# clear_embedder_cache() during concurrent access.
with _cache_lock:
embedder = _embedder_cache.get(cache_key)
if embedder is not None:
return embedder
# Create new embedder and cache it
embedder = Embedder(profile=profile, use_gpu=use_gpu)
# Pre-load model to ensure it's ready
embedder._load_model()
_embedder_cache[cache_key] = embedder
# Log GPU status on first embedder creation
if use_gpu and is_gpu_available():
logger.info(f"Embedder initialized with GPU: {get_gpu_summary()}")
elif use_gpu:
logger.debug("GPU not available, using CPU for embeddings")
return embedder
def clear_embedder_cache() -> None:
"""Clear the embedder cache and release ONNX resources.
This method ensures proper cleanup of ONNX model resources to prevent
memory leaks when embedders are no longer needed.
"""
global _embedder_cache
with _cache_lock:
# Release ONNX resources before clearing cache
for embedder in _embedder_cache.values():
if embedder._model is not None:
del embedder._model
embedder._model = None
_embedder_cache.clear()
gc.collect()
class Embedder(BaseEmbedder):
"""Generate embeddings for code chunks using fastembed (ONNX-based).
Supported Model Profiles:
- fast: BAAI/bge-small-en-v1.5 (384 dim) - Fast, lightweight, English-optimized
- code: jinaai/jina-embeddings-v2-base-code (768 dim) - Code-optimized, best for programming languages
- multilingual: intfloat/multilingual-e5-large (1024 dim) - Multilingual + code support
- balanced: mixedbread-ai/mxbai-embed-large-v1 (1024 dim) - High accuracy, general purpose
"""
# Model profiles for different use cases
MODELS = {
"fast": "BAAI/bge-small-en-v1.5", # 384 dim - Fast, lightweight
"code": "jinaai/jina-embeddings-v2-base-code", # 768 dim - Code-optimized
"multilingual": "intfloat/multilingual-e5-large", # 1024 dim - Multilingual
"balanced": "mixedbread-ai/mxbai-embed-large-v1", # 1024 dim - High accuracy
}
# Dimension mapping for each model
MODEL_DIMS = {
"BAAI/bge-small-en-v1.5": 384,
"jinaai/jina-embeddings-v2-base-code": 768,
"intfloat/multilingual-e5-large": 1024,
"mixedbread-ai/mxbai-embed-large-v1": 1024,
}
# Default model (fast profile)
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
DEFAULT_PROFILE = "fast"
def __init__(
self,
model_name: str | None = None,
profile: str | None = None,
use_gpu: bool = True,
providers: List[str] | None = None,
) -> None:
"""Initialize embedder with model or profile.
Args:
model_name: Explicit model name (e.g., "jinaai/jina-embeddings-v2-base-code")
profile: Model profile shortcut ("fast", "code", "multilingual", "balanced")
If both provided, model_name takes precedence.
use_gpu: If True, use GPU acceleration when available (default: True)
providers: Explicit ONNX providers list (overrides use_gpu if provided)
"""
if not SEMANTIC_AVAILABLE:
raise ImportError(
"Semantic search dependencies not available. "
"Install with: pip install codexlens[semantic]"
)
# Resolve model name from profile or use explicit name
if model_name:
self._model_name = model_name
elif profile and profile in self.MODELS:
self._model_name = self.MODELS[profile]
else:
self._model_name = self.DEFAULT_MODEL
# Configure ONNX execution providers with device_id options for GPU selection
# Using with_device_options=True ensures DirectML/CUDA device_id is passed correctly
if providers is not None:
self._providers = providers
else:
self._providers = get_optimal_providers(use_gpu=use_gpu, with_device_options=True)
self._use_gpu = use_gpu
self._model = None
@property
def model_name(self) -> str:
"""Get model name."""
return self._model_name
@property
def embedding_dim(self) -> int:
"""Get embedding dimension for current model."""
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
@property
def max_tokens(self) -> int:
"""Get maximum token limit for current model.
Returns:
int: Maximum number of tokens based on model profile.
- fast: 512 (lightweight, optimized for speed)
- code: 8192 (code-optimized, larger context)
- multilingual: 512 (standard multilingual model)
- balanced: 512 (general purpose)
"""
# Determine profile from model name
profile = None
for prof, model in self.MODELS.items():
if model == self._model_name:
profile = prof
break
# Return token limit based on profile
if profile == "code":
return 8192
elif profile in ("fast", "multilingual", "balanced"):
return 512
else:
# Default for unknown models
return 512
@property
def providers(self) -> List[str]:
"""Get configured ONNX execution providers."""
return self._providers
@property
def is_gpu_enabled(self) -> bool:
"""Check if GPU acceleration is enabled for this embedder."""
gpu_providers = {"CUDAExecutionProvider", "TensorrtExecutionProvider",
"DmlExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"}
# Handle both string providers and tuple providers (name, options)
for p in self._providers:
provider_name = p[0] if isinstance(p, tuple) else p
if provider_name in gpu_providers:
return True
return False
def _load_model(self) -> None:
"""Lazy load the embedding model with configured providers."""
if self._model is not None:
return
from fastembed import TextEmbedding
# providers already include device_id options via get_optimal_providers(with_device_options=True)
# DO NOT pass device_ids separately - fastembed ignores it when providers is specified
# See: fastembed/text/onnx_embedding.py - device_ids is only used with cuda=True
try:
self._model = TextEmbedding(
model_name=self.model_name,
providers=self._providers,
)
logger.debug(f"Model loaded with providers: {self._providers}")
except TypeError:
# Fallback for older fastembed versions without providers parameter
logger.warning(
"fastembed version doesn't support 'providers' parameter. "
"Upgrade fastembed for GPU acceleration: pip install --upgrade fastembed"
)
self._model = TextEmbedding(model_name=self.model_name)
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
"""Generate embeddings for one or more texts.
Args:
texts: Single text or iterable of texts to embed.
Returns:
List of embedding vectors (each is a list of floats).
Note:
This method converts numpy arrays to Python lists for backward compatibility.
For memory-efficient processing, use embed_to_numpy() instead.
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
embeddings = list(self._model.embed(texts))
return [emb.tolist() for emb in embeddings]
def embed_to_numpy(self, texts: str | Iterable[str], batch_size: Optional[int] = None) -> np.ndarray:
"""Generate embeddings for one or more texts (returns numpy arrays).
This method is more memory-efficient than embed() as it avoids converting
numpy arrays to Python lists, which can significantly reduce memory usage
during batch processing.
Args:
texts: Single text or iterable of texts to embed.
batch_size: Optional batch size for fastembed processing.
Larger values improve GPU utilization but use more memory.
Returns:
numpy.ndarray of shape (n_texts, embedding_dim) containing embeddings.
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Pass batch_size to fastembed for optimal GPU utilization
# Default batch_size in fastembed is 256, but larger values can improve throughput
if batch_size is not None:
embeddings = list(self._model.embed(texts, batch_size=batch_size))
else:
embeddings = list(self._model.embed(texts))
return np.array(embeddings)
def embed_single(self, text: str) -> List[float]:
"""Generate embedding for a single text."""
return self.embed(text)[0]

View File

@@ -0,0 +1,158 @@
"""Factory for creating embedders.
Provides a unified interface for instantiating different embedder backends.
Includes caching to avoid repeated model loading overhead.
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Dict, List, Optional
from .base import BaseEmbedder
# Module-level cache for embedder instances
# Key: (backend, profile, model, use_gpu) -> embedder instance
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
_cache_lock = threading.Lock()
_logger = logging.getLogger(__name__)
def get_embedder(
backend: str = "fastembed",
profile: str = "code",
model: str = "default",
use_gpu: bool = True,
endpoints: Optional[List[Dict[str, Any]]] = None,
strategy: str = "latency_aware",
cooldown: float = 60.0,
**kwargs: Any,
) -> BaseEmbedder:
"""Factory function to create embedder based on backend.
Args:
backend: Embedder backend to use. Options:
- "fastembed": Use fastembed (ONNX-based) embedder (default)
- "litellm": Use ccw-litellm embedder
profile: Model profile for fastembed backend ("fast", "code", "multilingual", "balanced")
Used only when backend="fastembed". Default: "code"
model: Model identifier for litellm backend.
Used only when backend="litellm". Default: "default"
use_gpu: Whether to use GPU acceleration when available (default: True).
Used only when backend="fastembed".
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
Each endpoint is a dict with keys: model, api_key, api_base, weight.
Used only when backend="litellm" and multiple endpoints provided.
strategy: Selection strategy for multi-endpoint mode:
"round_robin", "latency_aware", "weighted_random".
Default: "latency_aware"
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
**kwargs: Additional backend-specific arguments
Returns:
BaseEmbedder: Configured embedder instance
Raises:
ValueError: If backend is not recognized
ImportError: If required backend dependencies are not installed
Examples:
Create fastembed embedder with code profile:
>>> embedder = get_embedder(backend="fastembed", profile="code")
Create fastembed embedder with fast profile and CPU only:
>>> embedder = get_embedder(backend="fastembed", profile="fast", use_gpu=False)
Create litellm embedder:
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
Create rotational embedder with multiple endpoints:
>>> endpoints = [
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
... ]
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
"""
# Build cache key from immutable configuration
if backend == "fastembed":
cache_key = ("fastembed", profile, None, use_gpu)
elif backend == "litellm":
# For litellm, use model as part of cache key
# Multi-endpoint mode is not cached as it's more complex
if endpoints and len(endpoints) > 1:
cache_key = None # Skip cache for multi-endpoint
else:
effective_model = endpoints[0]["model"] if endpoints else model
cache_key = ("litellm", None, effective_model, None)
else:
cache_key = None
# Check cache first (thread-safe)
if cache_key is not None:
with _cache_lock:
if cache_key in _embedder_cache:
_logger.debug("Returning cached embedder for %s", cache_key)
return _embedder_cache[cache_key]
# Create new embedder instance
embedder: Optional[BaseEmbedder] = None
if backend == "fastembed":
from .embedder import Embedder
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
elif backend == "litellm":
# Check if multi-endpoint mode is requested
if endpoints and len(endpoints) > 1:
from .rotational_embedder import create_rotational_embedder
# Multi-endpoint is not cached
return create_rotational_embedder(
endpoints_config=endpoints,
strategy=strategy,
default_cooldown=cooldown,
)
elif endpoints and len(endpoints) == 1:
# Single endpoint in list - use it directly
ep = endpoints[0]
ep_kwargs = {**kwargs}
if "api_key" in ep:
ep_kwargs["api_key"] = ep["api_key"]
if "api_base" in ep:
ep_kwargs["api_base"] = ep["api_base"]
from .litellm_embedder import LiteLLMEmbedderWrapper
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
else:
# No endpoints list - use model parameter
from .litellm_embedder import LiteLLMEmbedderWrapper
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
else:
raise ValueError(
f"Unknown backend: {backend}. "
f"Supported backends: 'fastembed', 'litellm'"
)
# Cache the embedder for future use (thread-safe)
if cache_key is not None and embedder is not None:
with _cache_lock:
# Double-check to avoid race condition
if cache_key not in _embedder_cache:
_embedder_cache[cache_key] = embedder
_logger.debug("Cached new embedder for %s", cache_key)
else:
# Another thread created it already, use that one
embedder = _embedder_cache[cache_key]
return embedder # type: ignore
def clear_embedder_cache() -> int:
"""Clear the embedder cache.
Returns:
Number of embedders cleared from cache
"""
with _cache_lock:
count = len(_embedder_cache)
_embedder_cache.clear()
_logger.debug("Cleared %d embedders from cache", count)
return count

View File

@@ -0,0 +1,431 @@
"""GPU acceleration support for semantic embeddings.
This module provides GPU detection, initialization, and fallback handling
for ONNX-based embedding generation.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional
logger = logging.getLogger(__name__)
@dataclass
class GPUDevice:
"""Individual GPU device info."""
device_id: int
name: str
is_discrete: bool # True for discrete GPU (NVIDIA, AMD), False for integrated (Intel UHD)
vendor: str # "nvidia", "amd", "intel", "unknown"
@dataclass
class GPUInfo:
"""GPU availability and configuration info."""
gpu_available: bool = False
cuda_available: bool = False
gpu_count: int = 0
gpu_name: Optional[str] = None
onnx_providers: List[str] = None
devices: List[GPUDevice] = None # List of detected GPU devices
preferred_device_id: Optional[int] = None # Preferred GPU for embedding
def __post_init__(self):
if self.onnx_providers is None:
self.onnx_providers = ["CPUExecutionProvider"]
if self.devices is None:
self.devices = []
_gpu_info_cache: Optional[GPUInfo] = None
def _enumerate_gpus() -> List[GPUDevice]:
"""Enumerate available GPU devices using WMI on Windows.
Returns:
List of GPUDevice with device info, ordered by device_id.
"""
devices = []
try:
import subprocess
import sys
if sys.platform == "win32":
# Use PowerShell to query GPU information via WMI
cmd = [
"powershell", "-NoProfile", "-Command",
"Get-WmiObject Win32_VideoController | Select-Object DeviceID, Name, AdapterCompatibility | ConvertTo-Json"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0 and result.stdout.strip():
import json
gpu_data = json.loads(result.stdout)
# Handle single GPU case (returns dict instead of list)
if isinstance(gpu_data, dict):
gpu_data = [gpu_data]
for idx, gpu in enumerate(gpu_data):
name = gpu.get("Name", "Unknown GPU")
compat = gpu.get("AdapterCompatibility", "").lower()
# Determine vendor
name_lower = name.lower()
if "nvidia" in name_lower or "nvidia" in compat:
vendor = "nvidia"
is_discrete = True
elif "amd" in name_lower or "radeon" in name_lower or "amd" in compat:
vendor = "amd"
is_discrete = True
elif "intel" in name_lower or "intel" in compat:
vendor = "intel"
# Intel UHD/Iris are integrated, Intel Arc is discrete
is_discrete = "arc" in name_lower
else:
vendor = "unknown"
is_discrete = False
devices.append(GPUDevice(
device_id=idx,
name=name,
is_discrete=is_discrete,
vendor=vendor
))
logger.debug(f"Detected GPU {idx}: {name} (vendor={vendor}, discrete={is_discrete})")
except Exception as e:
logger.debug(f"GPU enumeration failed: {e}")
return devices
def _get_preferred_device_id(devices: List[GPUDevice]) -> Optional[int]:
"""Determine the preferred GPU device_id for embedding.
Preference order:
1. NVIDIA discrete GPU (best DirectML/CUDA support)
2. AMD discrete GPU
3. Intel Arc (discrete)
4. Intel integrated (fallback)
Returns:
device_id of preferred GPU, or None to use default.
"""
if not devices:
return None
# Priority: NVIDIA > AMD > Intel Arc > Intel integrated
priority_order = [
("nvidia", True), # NVIDIA discrete
("amd", True), # AMD discrete
("intel", True), # Intel Arc (discrete)
("intel", False), # Intel integrated (fallback)
]
for target_vendor, target_discrete in priority_order:
for device in devices:
if device.vendor == target_vendor and device.is_discrete == target_discrete:
logger.info(f"Preferred GPU: {device.name} (device_id={device.device_id})")
return device.device_id
# If no match, use first device
if devices:
return devices[0].device_id
return None
def detect_gpu(force_refresh: bool = False) -> GPUInfo:
"""Detect available GPU resources for embedding acceleration.
Args:
force_refresh: If True, re-detect GPU even if cached.
Returns:
GPUInfo with detection results.
"""
global _gpu_info_cache
if _gpu_info_cache is not None and not force_refresh:
return _gpu_info_cache
info = GPUInfo()
# Enumerate GPU devices first
info.devices = _enumerate_gpus()
info.gpu_count = len(info.devices)
if info.devices:
# Set preferred device (discrete GPU preferred over integrated)
info.preferred_device_id = _get_preferred_device_id(info.devices)
# Set gpu_name to preferred device name
for dev in info.devices:
if dev.device_id == info.preferred_device_id:
info.gpu_name = dev.name
break
# Check PyTorch CUDA availability (most reliable detection)
try:
import torch
if torch.cuda.is_available():
info.cuda_available = True
info.gpu_available = True
info.gpu_count = torch.cuda.device_count()
if info.gpu_count > 0:
info.gpu_name = torch.cuda.get_device_name(0)
logger.debug(f"PyTorch CUDA detected: {info.gpu_count} GPU(s)")
except ImportError:
logger.debug("PyTorch not available for GPU detection")
# Check ONNX Runtime providers with validation
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
# Build provider list with priority order
providers = []
# Test each provider to ensure it actually works
def test_provider(provider_name: str) -> bool:
"""Test if a provider actually works by creating a dummy session."""
try:
# Create a minimal ONNX model to test provider
import numpy as np
# Simple test: just check if provider can be instantiated
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 4 # Suppress warnings
return True
except Exception:
return False
# CUDA provider (NVIDIA GPU) - check if CUDA runtime is available
if "CUDAExecutionProvider" in available_providers:
# Verify CUDA is actually usable by checking for cuBLAS
cuda_works = False
try:
import ctypes
# Try to load cuBLAS to verify CUDA installation
try:
ctypes.CDLL("cublas64_12.dll")
cuda_works = True
except OSError:
try:
ctypes.CDLL("cublas64_11.dll")
cuda_works = True
except OSError:
pass
except Exception:
pass
if cuda_works:
providers.append("CUDAExecutionProvider")
info.gpu_available = True
logger.debug("ONNX CUDAExecutionProvider available and working")
else:
logger.debug("ONNX CUDAExecutionProvider listed but CUDA runtime not found")
# TensorRT provider (optimized NVIDIA inference)
if "TensorrtExecutionProvider" in available_providers:
# TensorRT requires additional libraries, skip for now
logger.debug("ONNX TensorrtExecutionProvider available (requires TensorRT SDK)")
# DirectML provider (Windows GPU - AMD/Intel/NVIDIA)
if "DmlExecutionProvider" in available_providers:
providers.append("DmlExecutionProvider")
info.gpu_available = True
logger.debug("ONNX DmlExecutionProvider available (DirectML)")
# ROCm provider (AMD GPU on Linux)
if "ROCMExecutionProvider" in available_providers:
providers.append("ROCMExecutionProvider")
info.gpu_available = True
logger.debug("ONNX ROCMExecutionProvider available (AMD)")
# CoreML provider (Apple Silicon)
if "CoreMLExecutionProvider" in available_providers:
providers.append("CoreMLExecutionProvider")
info.gpu_available = True
logger.debug("ONNX CoreMLExecutionProvider available (Apple)")
# Always include CPU as fallback
providers.append("CPUExecutionProvider")
info.onnx_providers = providers
except ImportError:
logger.debug("ONNX Runtime not available")
info.onnx_providers = ["CPUExecutionProvider"]
_gpu_info_cache = info
return info
def get_optimal_providers(use_gpu: bool = True, with_device_options: bool = False) -> list:
"""Get optimal ONNX execution providers based on availability.
Args:
use_gpu: If True, include GPU providers when available.
If False, force CPU-only execution.
with_device_options: If True, return providers as tuples with device_id options
for proper GPU device selection (required for DirectML).
Returns:
List of provider names or tuples (provider_name, options_dict) in priority order.
"""
if not use_gpu:
return ["CPUExecutionProvider"]
gpu_info = detect_gpu()
# Check if GPU was requested but not available - log warning
if not gpu_info.gpu_available:
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
except ImportError:
available_providers = []
logger.warning(
"GPU acceleration was requested, but no supported GPU provider (CUDA, DirectML) "
f"was found. Available providers: {available_providers}. Falling back to CPU."
)
else:
# Log which GPU provider is being used
gpu_providers = [p for p in gpu_info.onnx_providers if p != "CPUExecutionProvider"]
if gpu_providers:
logger.info(f"Using {gpu_providers[0]} for ONNX GPU acceleration")
if not with_device_options:
return gpu_info.onnx_providers
# Build providers with device_id options for GPU providers
device_id = get_selected_device_id()
providers = []
for provider in gpu_info.onnx_providers:
if provider == "DmlExecutionProvider" and device_id is not None:
# DirectML requires device_id in provider_options tuple
providers.append(("DmlExecutionProvider", {"device_id": device_id}))
logger.debug(f"DmlExecutionProvider configured with device_id={device_id}")
elif provider == "CUDAExecutionProvider" and device_id is not None:
# CUDA also supports device_id in provider_options
providers.append(("CUDAExecutionProvider", {"device_id": device_id}))
logger.debug(f"CUDAExecutionProvider configured with device_id={device_id}")
elif provider == "ROCMExecutionProvider" and device_id is not None:
# ROCm supports device_id
providers.append(("ROCMExecutionProvider", {"device_id": device_id}))
logger.debug(f"ROCMExecutionProvider configured with device_id={device_id}")
else:
# CPU and other providers don't need device_id
providers.append(provider)
return providers
def is_gpu_available() -> bool:
"""Check if any GPU acceleration is available."""
return detect_gpu().gpu_available
def get_gpu_summary() -> str:
"""Get human-readable GPU status summary."""
info = detect_gpu()
if not info.gpu_available:
return "GPU: Not available (using CPU)"
parts = []
if info.gpu_name:
parts.append(f"GPU: {info.gpu_name}")
if info.gpu_count > 1:
parts.append(f"({info.gpu_count} devices)")
# Show active providers (excluding CPU fallback)
gpu_providers = [p for p in info.onnx_providers if p != "CPUExecutionProvider"]
if gpu_providers:
parts.append(f"Providers: {', '.join(gpu_providers)}")
return " | ".join(parts) if parts else "GPU: Available"
def clear_gpu_cache() -> None:
"""Clear cached GPU detection info."""
global _gpu_info_cache
_gpu_info_cache = None
# User-selected device ID (overrides auto-detection)
_selected_device_id: Optional[int] = None
def get_gpu_devices() -> List[dict]:
"""Get list of available GPU devices for frontend selection.
Returns:
List of dicts with device info for each GPU.
"""
info = detect_gpu()
devices = []
for dev in info.devices:
devices.append({
"device_id": dev.device_id,
"name": dev.name,
"vendor": dev.vendor,
"is_discrete": dev.is_discrete,
"is_preferred": dev.device_id == info.preferred_device_id,
"is_selected": dev.device_id == get_selected_device_id(),
})
return devices
def get_selected_device_id() -> Optional[int]:
"""Get the user-selected GPU device_id.
Returns:
User-selected device_id, or auto-detected preferred device_id if not set.
"""
global _selected_device_id
if _selected_device_id is not None:
return _selected_device_id
# Fall back to auto-detected preferred device
info = detect_gpu()
return info.preferred_device_id
def set_selected_device_id(device_id: Optional[int]) -> bool:
"""Set the GPU device_id to use for embeddings.
Args:
device_id: GPU device_id to use, or None to use auto-detection.
Returns:
True if device_id is valid, False otherwise.
"""
global _selected_device_id
if device_id is None:
_selected_device_id = None
logger.info("GPU selection reset to auto-detection")
return True
# Validate device_id exists
info = detect_gpu()
valid_ids = [dev.device_id for dev in info.devices]
if device_id in valid_ids:
_selected_device_id = device_id
device_name = next((dev.name for dev in info.devices if dev.device_id == device_id), "Unknown")
logger.info(f"GPU selection set to device {device_id}: {device_name}")
return True
else:
logger.warning(f"Invalid device_id {device_id}. Valid IDs: {valid_ids}")
return False

View File

@@ -0,0 +1,144 @@
"""LiteLLM embedder wrapper for CodexLens.
Provides integration with ccw-litellm's LiteLLMEmbedder for embedding generation.
"""
from __future__ import annotations
from typing import Iterable
import numpy as np
from .base import BaseEmbedder
class LiteLLMEmbedderWrapper(BaseEmbedder):
"""Wrapper for ccw-litellm LiteLLMEmbedder.
This wrapper adapts the ccw-litellm LiteLLMEmbedder to the CodexLens
BaseEmbedder interface, enabling seamless integration with CodexLens
semantic search functionality.
Args:
model: Model identifier for LiteLLM (default: "default")
**kwargs: Additional arguments passed to LiteLLMEmbedder
Raises:
ImportError: If ccw-litellm package is not installed
"""
def __init__(self, model: str = "default", **kwargs) -> None:
"""Initialize LiteLLM embedder wrapper.
Args:
model: Model identifier for LiteLLM (default: "default")
**kwargs: Additional arguments passed to LiteLLMEmbedder
Raises:
ImportError: If ccw-litellm package is not installed
"""
try:
from ccw_litellm import LiteLLMEmbedder
self._embedder = LiteLLMEmbedder(model=model, **kwargs)
except ImportError as e:
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from e
@property
def embedding_dim(self) -> int:
"""Return embedding dimensions from LiteLLMEmbedder.
Returns:
int: Dimension of the embedding vectors.
"""
return self._embedder.dimensions
@property
def model_name(self) -> str:
"""Return model name from LiteLLMEmbedder.
Returns:
str: Name or identifier of the underlying model.
"""
return self._embedder.model_name
@property
def max_tokens(self) -> int:
"""Return maximum token limit for the embedding model.
Returns:
int: Maximum number of tokens that can be embedded at once.
Reads from LiteLLM config's max_input_tokens property.
"""
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
if hasattr(self._embedder, 'max_input_tokens'):
return self._embedder.max_input_tokens
# Fallback: infer from model name
model_name_lower = self.model_name.lower()
# Large models (8B or "large" in name)
if '8b' in model_name_lower or 'large' in model_name_lower:
return 32768
# OpenAI text-embedding-3-* models
if 'text-embedding-3' in model_name_lower:
return 8191
# Default fallback
return 8192
def _sanitize_text(self, text: str) -> str:
"""Sanitize text to work around ModelScope API routing bug.
ModelScope incorrectly routes text starting with lowercase 'import'
to an Ollama endpoint, causing failures. This adds a leading space
to work around the issue without affecting embedding quality.
Args:
text: Text to sanitize.
Returns:
Sanitized text safe for embedding API.
"""
if text.startswith('import'):
return ' ' + text
return text
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
"""Embed texts to numpy array using LiteLLMEmbedder.
Args:
texts: Single text or iterable of texts to embed.
**kwargs: Additional arguments (ignored for LiteLLM backend).
Accepts batch_size for API compatibility with fastembed.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Sanitize texts to avoid ModelScope routing bug
texts = [self._sanitize_text(t) for t in texts]
# LiteLLM handles batching internally, ignore batch_size parameter
return self._embedder.embed(texts)
def embed_single(self, text: str) -> list[float]:
"""Generate embedding for a single text.
Args:
text: Text to embed.
Returns:
list[float]: Embedding vector as a list of floats.
"""
# Sanitize text before embedding
sanitized = self._sanitize_text(text)
embedding = self._embedder.embed([sanitized])
return embedding[0].tolist()

View File

@@ -0,0 +1,25 @@
"""Reranker backends for second-stage search ranking.
This subpackage provides a unified interface and factory for different reranking
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
"""
from __future__ import annotations
from .base import BaseReranker
from .factory import check_reranker_available, get_reranker
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
__all__ = [
"BaseReranker",
"check_reranker_available",
"get_reranker",
"CrossEncoderReranker",
"check_cross_encoder_available",
"FastEmbedReranker",
"check_fastembed_reranker_available",
"ONNXReranker",
"check_onnx_reranker_available",
]

View File

@@ -0,0 +1,403 @@
"""API-based reranker using a remote HTTP provider.
Supported providers:
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
- Cohere: https://api.cohere.ai/v1/rerank
- Jina: https://api.jina.ai/v1/rerank
"""
from __future__ import annotations
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Mapping, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback."""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Try loading from .env files
try:
from codexlens.env_config import get_env
return get_env(key, workspace_root=workspace_root)
except ImportError:
return None
def check_httpx_available() -> tuple[bool, str | None]:
try:
import httpx # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"httpx not available: {exc}. Install with: pip install httpx"
return True, None
class APIReranker(BaseReranker):
"""Reranker backed by a remote reranking HTTP API."""
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
"siliconflow": {
"api_base": "https://api.siliconflow.cn",
"endpoint": "/v1/rerank",
"default_model": "BAAI/bge-reranker-v2-m3",
},
"cohere": {
"api_base": "https://api.cohere.ai",
"endpoint": "/v1/rerank",
"default_model": "rerank-english-v3.0",
},
"jina": {
"api_base": "https://api.jina.ai",
"endpoint": "/v1/rerank",
"default_model": "jina-reranker-v2-base-multilingual",
},
}
def __init__(
self,
*,
provider: str = "siliconflow",
model_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
timeout: float = 30.0,
max_retries: int = 3,
backoff_base_s: float = 0.5,
backoff_max_s: float = 8.0,
env_api_key: str = _DEFAULT_ENV_API_KEY,
workspace_root: Path | str | None = None,
max_input_tokens: int | None = None,
) -> None:
ok, err = check_httpx_available()
if not ok: # pragma: no cover - exercised via factory availability tests
raise ImportError(err)
import httpx
self._workspace_root = Path(workspace_root) if workspace_root else None
self.provider = (provider or "").strip().lower()
if self.provider not in self._PROVIDER_DEFAULTS:
raise ValueError(
f"Unknown reranker provider: {provider}. "
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
)
defaults = self._PROVIDER_DEFAULTS[self.provider]
# Load api_base from env with .env fallback
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
self.endpoint = defaults["endpoint"]
# Load model from env with .env fallback
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
# Load API key from env with .env fallback
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
resolved_key = resolved_key.strip()
if not resolved_key:
raise ValueError(
f"Missing API key for reranker provider '{self.provider}'. "
f"Pass api_key=... or set ${env_api_key}."
)
self._api_key = resolved_key
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
if self.provider == "cohere":
headers.setdefault("Cohere-Version", "2022-12-06")
self._client = httpx.Client(
base_url=self.api_base,
headers=headers,
timeout=self.timeout_s,
)
# Store max_input_tokens with model-aware defaults
if max_input_tokens is not None:
self._max_input_tokens = max_input_tokens
else:
# Infer from model name
model_lower = self.model_name.lower()
if '8b' in model_lower or 'large' in model_lower:
self._max_input_tokens = 32768
else:
self._max_input_tokens = 8192
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking."""
return self._max_input_tokens
def close(self) -> None:
try:
self._client.close()
except Exception: # pragma: no cover - defensive
return
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
if retry_after_s is not None and retry_after_s > 0:
time.sleep(min(float(retry_after_s), self.backoff_max_s))
return
exp = self.backoff_base_s * (2**attempt)
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
time.sleep(min(self.backoff_max_s, exp + jitter))
@staticmethod
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
value = (headers.get("Retry-After") or "").strip()
if not value:
return None
try:
return float(value)
except ValueError:
return None
@staticmethod
def _should_retry_status(status_code: int) -> bool:
return status_code == 429 or 500 <= status_code <= 599
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
response = self._client.post(self.endpoint, json=dict(payload))
except Exception as exc: # httpx is optional at import-time
last_exc = exc
if attempt < self.max_retries:
self._sleep_backoff(attempt)
continue
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' after "
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
) from exc
status = int(getattr(response, "status_code", 0) or 0)
if status >= 400:
body_preview = ""
try:
body_preview = (response.text or "").strip()
except Exception:
body_preview = ""
if len(body_preview) > 300:
body_preview = body_preview[:300] + ""
if self._should_retry_status(status) and attempt < self.max_retries:
retry_after = self._parse_retry_after_seconds(response.headers)
logger.warning(
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
self.api_base,
self.endpoint,
status,
attempt + 1,
self.max_retries + 1,
)
self._sleep_backoff(attempt, retry_after_s=retry_after)
continue
if status in {401, 403}:
raise RuntimeError(
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
"Check your API key."
)
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
f"Response: {body_preview or '<empty>'}"
)
try:
data = response.json()
except Exception as exc:
raise RuntimeError(
f"Rerank response from provider '{self.provider}' is not valid JSON: "
f"{type(exc).__name__}: {exc}"
) from exc
if not isinstance(data, dict):
raise RuntimeError(
f"Rerank response from provider '{self.provider}' must be a JSON object; "
f"got {type(data).__name__}"
)
return data
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
)
@staticmethod
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
if not isinstance(results, list):
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
scores: list[float] = [0.0 for _ in range(expected)]
filled = 0
for item in results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score"))
if idx is None or score is None:
continue
try:
idx_int = int(idx)
score_f = float(score)
except (TypeError, ValueError):
continue
if 0 <= idx_int < expected:
scores[idx_int] = score_f
filled += 1
if filled != expected:
raise RuntimeError(
f"Rerank response contained {filled}/{expected} scored documents; "
"ensure top_n matches the number of documents."
)
return scores
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
payload: dict[str, Any] = {
"model": self.model_name,
"query": query,
"documents": list(documents),
"top_n": len(documents),
"return_documents": False,
}
return payload
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count using fast heuristic.
Uses len(text) // 4 as approximation (~4 chars per token for English).
Not perfectly accurate for all models/languages but sufficient for
batch sizing decisions where exact counts aren't critical.
"""
return len(text) // 4
def _create_token_aware_batches(
self,
query: str,
documents: Sequence[str],
) -> list[list[tuple[int, str]]]:
"""Split documents into batches that fit within token limits.
Uses 90% of max_input_tokens as safety margin.
Each batch includes the query tokens overhead.
"""
max_tokens = int(self._max_input_tokens * 0.9)
query_tokens = self._estimate_tokens(query)
batches: list[list[tuple[int, str]]] = []
current_batch: list[tuple[int, str]] = []
current_tokens = query_tokens # Start with query overhead
for idx, doc in enumerate(documents):
doc_tokens = self._estimate_tokens(doc)
# Warn if single document exceeds token limit (will be truncated by API)
if doc_tokens > max_tokens - query_tokens:
logger.warning(
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
f"(limit: {max_tokens - query_tokens} after query overhead). "
"Document will likely be truncated by the API."
)
# If batch would exceed limit, start new batch
if current_tokens + doc_tokens > max_tokens and current_batch:
batches.append(current_batch)
current_batch = []
current_tokens = query_tokens
current_batch.append((idx, doc))
current_tokens += doc_tokens
if current_batch:
batches.append(current_batch)
return batches
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
if not documents:
return []
# Create token-aware batches
batches = self._create_token_aware_batches(query, documents)
if len(batches) == 1:
# Single batch - original behavior
payload = self._build_payload(query=query, documents=documents)
data = self._request_json(payload)
results = data.get("results")
return self._extract_scores_from_results(results, expected=len(documents))
# Multiple batches - process each and merge results
logger.info(
f"Splitting {len(documents)} documents into {len(batches)} batches "
f"(max_input_tokens: {self._max_input_tokens})"
)
all_scores: list[float] = [0.0] * len(documents)
for batch in batches:
batch_docs = [doc for _, doc in batch]
payload = self._build_payload(query=query, documents=batch_docs)
data = self._request_json(payload)
results = data.get("results")
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
# Map scores back to original indices
for (orig_idx, _), score in zip(batch, batch_scores):
all_scores[orig_idx] = score
return all_scores
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
) -> list[float]:
if not pairs:
return []
grouped: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
grouped.setdefault(str(query), []).append((idx, str(doc)))
scores: list[float] = [0.0 for _ in range(len(pairs))]
for query, items in grouped.items():
documents = [doc for _, doc in items]
query_scores = self._rerank_one_query(query=query, documents=documents)
for (orig_idx, _), score in zip(items, query_scores):
scores[orig_idx] = float(score)
return scores

View File

@@ -0,0 +1,46 @@
"""Base class for rerankers.
Defines the interface that all rerankers must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence
class BaseReranker(ABC):
"""Base class for all rerankers.
All reranker implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking.
Returns:
int: Maximum number of tokens that can be processed at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair).
"""
...

Some files were not shown because too many files have changed in this diff Show More