mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-07 02:04:11 +08:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88ff109ac4 | ||
|
|
261196a804 | ||
|
|
ea6cb8440f | ||
|
|
bf896342f4 | ||
|
|
f2b0a5bbc9 | ||
|
|
cf5fecd66d | ||
|
|
86d469ccc9 | ||
|
|
357d3524f5 | ||
|
|
4334162ddf | ||
|
|
2dcd1637f0 | ||
|
|
38e1cdc737 | ||
|
|
097a7346b9 | ||
|
|
9df8063fbd | ||
|
|
d00f0bc7ca | ||
|
|
24efef7f17 | ||
|
|
44b8269a74 | ||
|
|
dd51837bbc |
948
.claude/commands/ccw-coordinator.md
Normal file
948
.claude/commands/ccw-coordinator.md
Normal file
@@ -0,0 +1,948 @@
|
||||
---
|
||||
name: ccw-coordinator
|
||||
description: Command orchestration tool - analyze requirements, recommend chain, execute sequentially with state persistence
|
||||
argument-hint: "[task description]"
|
||||
allowed-tools: Task(*), AskUserQuestion(*), Read(*), Write(*), Bash(*), Glob(*), Grep(*)
|
||||
---
|
||||
|
||||
# CCW Coordinator Command
|
||||
|
||||
Interactive orchestration tool: analyze task → discover commands → recommend chain → execute sequentially → track state.
|
||||
|
||||
**Execution Model**: Pseudocode guidance. Claude intelligently executes each phase based on context.
|
||||
|
||||
## Core Concept: Minimum Execution Units (最小执行单元)
|
||||
|
||||
### What is a Minimum Execution Unit?
|
||||
|
||||
**Definition**: A set of commands that must execute together as an atomic group to achieve a meaningful workflow milestone. Splitting these commands breaks the logical flow and creates incomplete states.
|
||||
|
||||
**Why This Matters**:
|
||||
- **Prevents Incomplete States**: Avoid stopping after task generation without execution
|
||||
- **User Experience**: User gets complete results, not intermediate artifacts requiring manual follow-up
|
||||
- **Workflow Integrity**: Maintains logical coherence of multi-step operations
|
||||
|
||||
### Minimum Execution Units
|
||||
|
||||
**Planning + Execution Units** (规划+执行单元):
|
||||
|
||||
| Unit Name | Commands | Purpose | Output |
|
||||
|-----------|----------|---------|--------|
|
||||
| **Quick Implementation** | lite-plan → lite-execute | Lightweight plan and immediate execution | Working code |
|
||||
| **Multi-CLI Planning** | multi-cli-plan → lite-execute | Multi-perspective analysis and execution | Working code |
|
||||
| **Bug Fix** | lite-fix → lite-execute | Quick bug diagnosis and fix execution | Fixed code |
|
||||
| **Full Planning + Execution** | plan → execute | Detailed planning and execution | Working code |
|
||||
| **Verified Planning + Execution** | plan → plan-verify → execute | Planning with verification and execution | Working code |
|
||||
| **Replanning + Execution** | replan → execute | Update plan and execute changes | Working code |
|
||||
| **TDD Planning + Execution** | tdd-plan → execute | Test-driven development planning and execution | Working code |
|
||||
| **Test Generation + Execution** | test-gen → execute | Generate test suite and execute | Generated tests |
|
||||
|
||||
**Testing Units** (测试单元):
|
||||
|
||||
| Unit Name | Commands | Purpose | Output |
|
||||
|-----------|----------|---------|--------|
|
||||
| **Test Validation** | test-fix-gen → test-cycle-execute | Generate test tasks and execute test-fix cycle | Tests passed |
|
||||
|
||||
**Review Units** (审查单元):
|
||||
|
||||
| Unit Name | Commands | Purpose | Output |
|
||||
|-----------|----------|---------|--------|
|
||||
| **Code Review (Session)** | review-session-cycle → review-fix | Complete review cycle and apply fixes | Fixed code |
|
||||
| **Code Review (Module)** | review-module-cycle → review-fix | Module review cycle and apply fixes | Fixed code |
|
||||
|
||||
### Command-to-Unit Mapping (命令与最小单元的映射)
|
||||
|
||||
| Command | Can Precede | Atomic Units |
|
||||
|---------|-----------|--------------|
|
||||
| lite-plan | lite-execute | Quick Implementation |
|
||||
| multi-cli-plan | lite-execute | Multi-CLI Planning |
|
||||
| lite-fix | lite-execute | Bug Fix |
|
||||
| plan | plan-verify, execute | Full Planning + Execution, Verified Planning + Execution |
|
||||
| plan-verify | execute | Verified Planning + Execution |
|
||||
| replan | execute | Replanning + Execution |
|
||||
| test-gen | execute | Test Generation + Execution |
|
||||
| tdd-plan | execute | TDD Planning + Execution |
|
||||
| review-session-cycle | review-fix | Code Review (Session) |
|
||||
| review-module-cycle | review-fix | Code Review (Module) |
|
||||
| test-fix-gen | test-cycle-execute | Test Validation |
|
||||
|
||||
### Atomic Group Rules
|
||||
|
||||
1. **Never Split Units**: Coordinator must recommend complete units, not partial chains
|
||||
2. **Multi-Unit Participation**: Some commands can participate in multiple units (e.g., plan → execute or plan → plan-verify → execute)
|
||||
3. **User Override**: User can explicitly request partial execution (advanced mode)
|
||||
4. **Visualization**: Pipeline view shows unit boundaries with `【 】` markers
|
||||
5. **Validation**: Before execution, verify all unit commands are included
|
||||
|
||||
**Example Pipeline with Units**:
|
||||
```
|
||||
需求 → 【lite-plan → lite-execute】→ 代码 → 【test-fix-gen → test-cycle-execute】→ 测试通过
|
||||
└──── Quick Implementation ────┘ └────── Test Validation ──────┘
|
||||
```
|
||||
|
||||
## 3-Phase Workflow
|
||||
|
||||
### Phase 1: Analyze Requirements
|
||||
|
||||
Parse task to extract: goal, scope, constraints, complexity, and task type.
|
||||
|
||||
```javascript
|
||||
function analyzeRequirements(taskDescription) {
|
||||
return {
|
||||
goal: extractMainGoal(taskDescription), // e.g., "Implement user registration"
|
||||
scope: extractScope(taskDescription), // e.g., ["auth", "user_management"]
|
||||
constraints: extractConstraints(taskDescription), // e.g., ["no breaking changes"]
|
||||
complexity: determineComplexity(taskDescription), // 'simple' | 'medium' | 'complex'
|
||||
task_type: detectTaskType(taskDescription) // See task type patterns below
|
||||
};
|
||||
}
|
||||
|
||||
// Task Type Detection Patterns
|
||||
function detectTaskType(text) {
|
||||
// Priority order (first match wins)
|
||||
if (/fix|bug|error|crash|fail|debug|diagnose/.test(text)) return 'bugfix';
|
||||
if (/tdd|test-driven|先写测试|test first/.test(text)) return 'tdd';
|
||||
if (/测试失败|test fail|fix test|failing test/.test(text)) return 'test-fix';
|
||||
if (/generate test|写测试|add test|补充测试/.test(text)) return 'test-gen';
|
||||
if (/review|审查|code review/.test(text)) return 'review';
|
||||
if (/不确定|explore|研究|what if|brainstorm|权衡/.test(text)) return 'brainstorm';
|
||||
if (/多视角|比较方案|cross-verify|multi-cli/.test(text)) return 'multi-cli';
|
||||
return 'feature'; // Default
|
||||
}
|
||||
|
||||
// Complexity Assessment
|
||||
function determineComplexity(text) {
|
||||
let score = 0;
|
||||
if (/refactor|重构|migrate|迁移|architect|架构|system|系统/.test(text)) score += 2;
|
||||
if (/multiple|多个|across|跨|all|所有|entire|整个/.test(text)) score += 2;
|
||||
if (/integrate|集成|api|database|数据库/.test(text)) score += 1;
|
||||
if (/security|安全|performance|性能|scale|扩展/.test(text)) score += 1;
|
||||
return score >= 4 ? 'complex' : score >= 2 ? 'medium' : 'simple';
|
||||
}
|
||||
```
|
||||
|
||||
**Display to user**:
|
||||
```
|
||||
Analysis Complete:
|
||||
Goal: [extracted goal]
|
||||
Scope: [identified areas]
|
||||
Constraints: [identified constraints]
|
||||
Complexity: [level]
|
||||
Task Type: [detected type]
|
||||
```
|
||||
|
||||
### Phase 2: Discover Commands & Recommend Chain
|
||||
|
||||
Dynamic command chain assembly using port-based matching.
|
||||
|
||||
#### Command Port Definition
|
||||
|
||||
Each command has input/output ports (tags) for pipeline composition:
|
||||
|
||||
```javascript
|
||||
// Port labels represent data types flowing through the pipeline
|
||||
const commandPorts = {
|
||||
'lite-plan': {
|
||||
name: 'lite-plan',
|
||||
input: ['requirement'], // 输入端口:需求
|
||||
output: ['plan'], // 输出端口:计划
|
||||
tags: ['planning'],
|
||||
atomic_group: 'quick-implementation' // 最小单元:与 lite-execute 绑定
|
||||
},
|
||||
'lite-execute': {
|
||||
name: 'lite-execute',
|
||||
input: ['plan', 'multi-cli-plan', 'lite-fix'], // 输入端口:可接受多种规划输出
|
||||
output: ['code'], // 输出端口:代码
|
||||
tags: ['execution'],
|
||||
atomic_groups: [ // 可参与多个最小单元
|
||||
'quick-implementation', // lite-plan → lite-execute
|
||||
'multi-cli-planning', // multi-cli-plan → lite-execute
|
||||
'bug-fix' // lite-fix → lite-execute
|
||||
]
|
||||
},
|
||||
'plan': {
|
||||
name: 'plan',
|
||||
input: ['requirement'],
|
||||
output: ['detailed-plan'],
|
||||
tags: ['planning'],
|
||||
atomic_groups: [ // 可参与多个最小单元
|
||||
'full-planning-execution', // plan → execute
|
||||
'verified-planning-execution' // plan → plan-verify → execute
|
||||
]
|
||||
},
|
||||
'plan-verify': {
|
||||
name: 'plan-verify',
|
||||
input: ['detailed-plan'],
|
||||
output: ['verified-plan'],
|
||||
tags: ['planning'],
|
||||
atomic_group: 'verified-planning-execution' // 最小单元:plan → plan-verify → execute
|
||||
},
|
||||
'replan': {
|
||||
name: 'replan',
|
||||
input: ['session', 'feedback'], // 输入端口:会话或反馈
|
||||
output: ['replan'], // 输出端口:更新后的计划(供 execute 执行)
|
||||
tags: ['planning'],
|
||||
atomic_group: 'replanning-execution' // 最小单元:与 execute 绑定
|
||||
},
|
||||
'execute': {
|
||||
name: 'execute',
|
||||
input: ['detailed-plan', 'verified-plan', 'replan', 'test-tasks', 'tdd-tasks'], // 可接受多种规划输出
|
||||
output: ['code'],
|
||||
tags: ['execution'],
|
||||
atomic_groups: [ // 可参与多个最小单元
|
||||
'full-planning-execution', // plan → execute
|
||||
'verified-planning-execution', // plan → plan-verify → execute
|
||||
'replanning-execution', // replan → execute
|
||||
'test-generation-execution', // test-gen → execute
|
||||
'tdd-planning-execution' // tdd-plan → execute
|
||||
]
|
||||
},
|
||||
'test-cycle-execute': {
|
||||
name: 'test-cycle-execute',
|
||||
input: ['test-tasks'], // 输入端口:测试任务(需先test-fix-gen生成)
|
||||
output: ['test-passed'], // 输出端口:测试通过
|
||||
tags: ['testing'],
|
||||
atomic_group: 'test-validation', // 最小单元:与 test-fix-gen 绑定
|
||||
note: '需要先执行test-fix-gen生成测试任务,再由此命令执行测试周期'
|
||||
},
|
||||
'tdd-plan': {
|
||||
name: 'tdd-plan',
|
||||
input: ['requirement'],
|
||||
output: ['tdd-tasks'], // TDD 任务(供 execute 执行)
|
||||
tags: ['planning', 'tdd'],
|
||||
atomic_group: 'tdd-planning-execution' // 最小单元:与 execute 绑定
|
||||
},
|
||||
'tdd-verify': {
|
||||
name: 'tdd-verify',
|
||||
input: ['code'],
|
||||
output: ['tdd-verified'],
|
||||
tags: ['testing']
|
||||
},
|
||||
'lite-fix': {
|
||||
name: 'lite-fix',
|
||||
input: ['bug-report'], // 输入端口:bug 报告
|
||||
output: ['lite-fix'], // 输出端口:修复计划(供 lite-execute 执行)
|
||||
tags: ['bugfix'],
|
||||
atomic_group: 'bug-fix' // 最小单元:与 lite-execute 绑定
|
||||
},
|
||||
'debug': {
|
||||
name: 'debug',
|
||||
input: ['bug-report'],
|
||||
output: ['debug-log'],
|
||||
tags: ['bugfix']
|
||||
},
|
||||
'test-gen': {
|
||||
name: 'test-gen',
|
||||
input: ['code', 'session'], // 可接受代码或会话
|
||||
output: ['test-tasks'], // 输出测试任务(IMPL-001,IMPL-002),供 execute 执行
|
||||
tags: ['testing'],
|
||||
atomic_group: 'test-generation-execution' // 最小单元:与 execute 绑定
|
||||
},
|
||||
'test-fix-gen': {
|
||||
name: 'test-fix-gen',
|
||||
input: ['failing-tests', 'session'],
|
||||
output: ['test-tasks'], // 输出测试任务,针对特定问题生成测试并在测试中修正
|
||||
tags: ['testing'],
|
||||
atomic_group: 'test-validation', // 最小单元:与 test-cycle-execute 绑定
|
||||
note: '生成测试任务供test-cycle-execute执行'
|
||||
},
|
||||
'review': {
|
||||
name: 'review',
|
||||
input: ['code', 'session'],
|
||||
output: ['review-findings'],
|
||||
tags: ['review']
|
||||
},
|
||||
'review-fix': {
|
||||
name: 'review-fix',
|
||||
input: ['review-findings', 'review-verified'], // Accept output from review-session-cycle or review-module-cycle
|
||||
output: ['fixed-code'],
|
||||
tags: ['review'],
|
||||
atomic_group: 'code-review' // 最小单元:与 review-session-cycle/review-module-cycle 绑定
|
||||
},
|
||||
'brainstorm:auto-parallel': {
|
||||
name: 'brainstorm:auto-parallel',
|
||||
input: ['exploration-topic'], // 输入端口:探索主题
|
||||
output: ['brainstorm-analysis'],
|
||||
tags: ['brainstorm']
|
||||
},
|
||||
'multi-cli-plan': {
|
||||
name: 'multi-cli-plan',
|
||||
input: ['requirement'],
|
||||
output: ['multi-cli-plan'], // 对比分析计划(供 lite-execute 执行)
|
||||
tags: ['planning', 'multi-cli'],
|
||||
atomic_group: 'multi-cli-planning' // 最小单元:与 lite-execute 绑定
|
||||
},
|
||||
'review-session-cycle': {
|
||||
name: 'review-session-cycle',
|
||||
input: ['code', 'session'], // 可接受代码或会话
|
||||
output: ['review-verified'], // 输出端口:审查通过
|
||||
tags: ['review'],
|
||||
atomic_group: 'code-review' // 最小单元:与 review-fix 绑定
|
||||
},
|
||||
'review-module-cycle': {
|
||||
name: 'review-module-cycle',
|
||||
input: ['module-pattern'], // 输入端口:模块模式
|
||||
output: ['review-verified'], // 输出端口:审查通过
|
||||
tags: ['review'],
|
||||
atomic_group: 'code-review' // 最小单元:与 review-fix 绑定
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
#### Recommendation Algorithm
|
||||
|
||||
```javascript
|
||||
async function recommendCommandChain(analysis) {
|
||||
// Step 1: 根据任务类型确定起始端口和目标端口
|
||||
const { inputPort, outputPort } = determinePortFlow(analysis.task_type, analysis.constraints);
|
||||
|
||||
// Step 2: Claude 根据命令端口定义和任务特征,智能选择命令序列
|
||||
// 优先级:简单任务 → lite-* 命令,复杂任务 → 完整命令,特殊约束 → 调整流程
|
||||
const chain = selectChainByPorts(inputPort, outputPort, analysis);
|
||||
|
||||
return chain;
|
||||
}
|
||||
|
||||
// 任务类型对应的端口流
|
||||
function determinePortFlow(taskType, constraints) {
|
||||
const flows = {
|
||||
'bugfix': { inputPort: 'bug-report', outputPort: constraints?.includes('skip-tests') ? 'fixed-code' : 'test-passed' },
|
||||
'tdd': { inputPort: 'requirement', outputPort: 'tdd-verified' },
|
||||
'test-fix': { inputPort: 'failing-tests', outputPort: 'test-passed' },
|
||||
'test-gen': { inputPort: 'code', outputPort: 'test-passed' },
|
||||
'review': { inputPort: 'code', outputPort: 'review-verified' },
|
||||
'brainstorm': { inputPort: 'exploration-topic', outputPort: 'test-passed' },
|
||||
'multi-cli': { inputPort: 'requirement', outputPort: 'test-passed' },
|
||||
'feature': { inputPort: 'requirement', outputPort: constraints?.includes('skip-tests') ? 'code' : 'test-passed' }
|
||||
};
|
||||
return flows[taskType] || flows['feature'];
|
||||
}
|
||||
|
||||
// Claude 根据端口流选择命令链
|
||||
function selectChainByPorts(inputPort, outputPort, analysis) {
|
||||
// 参考下面的命令端口定义表和执行示例,Claude 智能选择合适的命令序列
|
||||
// 返回值示例: [lite-plan, lite-execute, test-cycle-execute]
|
||||
}
|
||||
```
|
||||
|
||||
#### Display to User
|
||||
|
||||
```
|
||||
Recommended Command Chain:
|
||||
|
||||
Pipeline (管道视图):
|
||||
需求 → lite-plan → 计划 → lite-execute → 代码 → test-cycle-execute → 测试通过
|
||||
|
||||
Commands (命令列表):
|
||||
1. /workflow:lite-plan
|
||||
2. /workflow:lite-execute
|
||||
3. /workflow:test-cycle-execute
|
||||
|
||||
Proceed? [Confirm / Show Details / Adjust / Cancel]
|
||||
```
|
||||
|
||||
### Phase 2b: Get User Confirmation
|
||||
|
||||
```javascript
|
||||
async function getUserConfirmation(chain) {
|
||||
const response = await AskUserQuestion({
|
||||
questions: [{
|
||||
question: 'Proceed with this command chain?',
|
||||
header: 'Confirm',
|
||||
options: [
|
||||
{ label: 'Confirm and execute', description: 'Proceed with commands' },
|
||||
{ label: 'Show details', description: 'View each command' },
|
||||
{ label: 'Adjust chain', description: 'Remove or reorder' },
|
||||
{ label: 'Cancel', description: 'Abort' }
|
||||
]
|
||||
}]
|
||||
});
|
||||
|
||||
if (response.confirm === 'Cancel') throw new Error('Cancelled');
|
||||
if (response.confirm === 'Show details') {
|
||||
displayCommandDetails(chain);
|
||||
return getUserConfirmation(chain);
|
||||
}
|
||||
if (response.confirm === 'Adjust chain') {
|
||||
return await adjustChain(chain);
|
||||
}
|
||||
return chain;
|
||||
}
|
||||
```
|
||||
|
||||
### Phase 3: Execute Sequential Command Chain
|
||||
|
||||
```javascript
|
||||
async function executeCommandChain(chain, analysis) {
|
||||
const sessionId = `ccw-coord-${Date.now()}`;
|
||||
const stateDir = `.workflow/.ccw-coordinator/${sessionId}`;
|
||||
Bash(`mkdir -p "${stateDir}"`);
|
||||
|
||||
const state = {
|
||||
session_id: sessionId,
|
||||
status: 'running',
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
analysis: analysis,
|
||||
command_chain: chain.map((cmd, idx) => ({ ...cmd, index: idx, status: 'pending' })),
|
||||
execution_results: [],
|
||||
prompts_used: []
|
||||
};
|
||||
|
||||
// Save initial state immediately after confirmation
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
for (let i = 0; i < chain.length; i++) {
|
||||
const cmd = chain[i];
|
||||
console.log(`[${i+1}/${chain.length}] ${cmd.command}`);
|
||||
|
||||
// Update command_chain status to running
|
||||
state.command_chain[i].status = 'running';
|
||||
state.updated_at = new Date().toISOString();
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
// Assemble prompt: Command first, then context
|
||||
let promptContent = formatCommand(cmd, state.execution_results, analysis);
|
||||
|
||||
// Build full prompt: Command → Task → Previous Results
|
||||
let prompt = `${promptContent}\n\nTask: ${analysis.goal}`;
|
||||
if (state.execution_results.length > 0) {
|
||||
prompt += '\n\nPrevious results:\n';
|
||||
state.execution_results.forEach(r => {
|
||||
if (r.session_id) {
|
||||
prompt += `- ${r.command}: ${r.session_id} (${r.artifacts?.join(', ') || 'completed'})\n`;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Record prompt used
|
||||
state.prompts_used.push({
|
||||
index: i,
|
||||
command: cmd.command,
|
||||
prompt: prompt
|
||||
});
|
||||
|
||||
// Execute CLI command in background and stop
|
||||
// Format: ccw cli -p "PROMPT" --tool <tool> --mode <mode>
|
||||
// Note: -y is a command parameter INSIDE the prompt, not a ccw cli parameter
|
||||
// Example prompt: "/workflow:plan -y \"task description here\""
|
||||
try {
|
||||
const taskId = Bash(
|
||||
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write`,
|
||||
{ run_in_background: true }
|
||||
).task_id;
|
||||
|
||||
// Save checkpoint
|
||||
state.execution_results.push({
|
||||
index: i,
|
||||
command: cmd.command,
|
||||
status: 'in-progress',
|
||||
task_id: taskId,
|
||||
session_id: null,
|
||||
artifacts: [],
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
state.command_chain[i].status = 'running';
|
||||
state.updated_at = new Date().toISOString();
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
console.log(`[${i+1}/${chain.length}] ${cmd.command}\n`);
|
||||
break; // Stop, wait for hook callback
|
||||
|
||||
} catch (error) {
|
||||
state.command_chain[i].status = 'failed';
|
||||
state.updated_at = new Date().toISOString();
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
const action = await AskUserQuestion({
|
||||
questions: [{
|
||||
question: `${cmd.command} failed to start: ${error.message}. What to do?`,
|
||||
header: 'Error',
|
||||
options: [
|
||||
{ label: 'Retry', description: 'Try again' },
|
||||
{ label: 'Skip', description: 'Continue next command' },
|
||||
{ label: 'Abort', description: 'Stop execution' }
|
||||
]
|
||||
}]
|
||||
});
|
||||
|
||||
if (action.error === 'Retry') {
|
||||
state.command_chain[i].status = 'pending';
|
||||
state.execution_results.pop();
|
||||
i--;
|
||||
} else if (action.error === 'Skip') {
|
||||
state.execution_results[state.execution_results.length - 1].status = 'skipped';
|
||||
} else if (action.error === 'Abort') {
|
||||
state.status = 'failed';
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
}
|
||||
|
||||
// Hook callbacks handle completion
|
||||
if (state.status !== 'failed') state.status = 'waiting';
|
||||
state.updated_at = new Date().toISOString();
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
console.log(`\n📋 Orchestrator paused: ${state.session_id}\n`);
|
||||
return state;
|
||||
}
|
||||
|
||||
// Smart parameter assembly
|
||||
// Returns prompt content to be used with: ccw cli -p "RETURNED_VALUE" --tool claude --mode write
|
||||
function formatCommand(cmd, previousResults, analysis) {
|
||||
// Format: /workflow:<command> -y <parameters>
|
||||
let prompt = `/workflow:${cmd.name} -y`;
|
||||
const name = cmd.name;
|
||||
|
||||
// Planning commands - take task description
|
||||
if (['lite-plan', 'plan', 'tdd-plan', 'multi-cli-plan'].includes(name)) {
|
||||
prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Lite execution - use --in-memory if plan exists
|
||||
} else if (name === 'lite-execute') {
|
||||
const hasPlan = previousResults.some(r => r.command.includes('plan'));
|
||||
prompt += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
|
||||
|
||||
// Standard execution - resume from planning session
|
||||
} else if (name === 'execute') {
|
||||
const plan = previousResults.find(r => r.command.includes('plan'));
|
||||
if (plan?.session_id) prompt += ` --resume-session="${plan.session_id}"`;
|
||||
|
||||
// Bug fix commands - take bug description
|
||||
} else if (['lite-fix', 'debug'].includes(name)) {
|
||||
prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Brainstorm - take topic description
|
||||
} else if (name === 'brainstorm:auto-parallel' || name === 'auto-parallel') {
|
||||
prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Test generation from session - needs source session
|
||||
} else if (name === 'test-gen') {
|
||||
const impl = previousResults.find(r =>
|
||||
r.command.includes('execute') || r.command.includes('lite-execute')
|
||||
);
|
||||
if (impl?.session_id) prompt += ` "${impl.session_id}"`;
|
||||
else prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Test fix generation - session or description
|
||||
} else if (name === 'test-fix-gen') {
|
||||
const latest = previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) prompt += ` "${latest.session_id}"`;
|
||||
else prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Review commands - take session or use latest
|
||||
} else if (name === 'review') {
|
||||
const latest = previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||
|
||||
// Review fix - takes session from review
|
||||
} else if (name === 'review-fix') {
|
||||
const review = previousResults.find(r => r.command.includes('review'));
|
||||
const latest = review || previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||
|
||||
// TDD verify - takes execution session
|
||||
} else if (name === 'tdd-verify') {
|
||||
const exec = previousResults.find(r => r.command.includes('execute'));
|
||||
if (exec?.session_id) prompt += ` --session="${exec.session_id}"`;
|
||||
|
||||
// Session-based commands (test-cycle, review-session, plan-verify)
|
||||
} else if (name.includes('test') || name.includes('review') || name.includes('verify')) {
|
||||
const latest = previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||
}
|
||||
|
||||
return prompt;
|
||||
}
|
||||
|
||||
// Hook callback: Called when background CLI completes
|
||||
async function handleCliCompletion(sessionId, taskId, output) {
|
||||
const stateDir = `.workflow/.ccw-coordinator/${sessionId}`;
|
||||
const state = JSON.parse(Read(`${stateDir}/state.json`));
|
||||
|
||||
const pendingIdx = state.execution_results.findIndex(r => r.task_id === taskId);
|
||||
if (pendingIdx === -1) {
|
||||
console.error(`Unknown task_id: ${taskId}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = parseOutput(output);
|
||||
const cmdIdx = state.execution_results[pendingIdx].index;
|
||||
|
||||
// Update result
|
||||
state.execution_results[pendingIdx] = {
|
||||
...state.execution_results[pendingIdx],
|
||||
status: parsed.sessionId ? 'completed' : 'failed',
|
||||
session_id: parsed.sessionId,
|
||||
artifacts: parsed.artifacts,
|
||||
completed_at: new Date().toISOString()
|
||||
};
|
||||
state.command_chain[cmdIdx].status = parsed.sessionId ? 'completed' : 'failed';
|
||||
state.updated_at = new Date().toISOString();
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
// Trigger next command or complete
|
||||
const nextIdx = cmdIdx + 1;
|
||||
if (nextIdx < state.command_chain.length) {
|
||||
await resumeChainExecution(sessionId, nextIdx);
|
||||
} else {
|
||||
state.status = 'completed';
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
console.log(`✅ Completed: ${sessionId}\n`);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse command output
|
||||
function parseOutput(output) {
|
||||
const sessionMatch = output.match(/WFS-[\w-]+/);
|
||||
const artifacts = [];
|
||||
output.matchAll(/\.workflow\/[^\s]+/g).forEach(m => artifacts.push(m[0]));
|
||||
return { sessionId: sessionMatch?.[0] || null, artifacts };
|
||||
}
|
||||
```
|
||||
|
||||
## State File Structure
|
||||
|
||||
**Location**: `.workflow/.ccw-coordinator/{session_id}/state.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"session_id": "ccw-coord-20250124-143025",
|
||||
"status": "running|waiting|completed|failed",
|
||||
"created_at": "2025-01-24T14:30:25Z",
|
||||
"updated_at": "2025-01-24T14:35:45Z",
|
||||
"analysis": {
|
||||
"goal": "Implement user registration",
|
||||
"scope": ["authentication", "user_management"],
|
||||
"constraints": ["no breaking changes"],
|
||||
"complexity": "medium"
|
||||
},
|
||||
"command_chain": [
|
||||
{
|
||||
"index": 0,
|
||||
"command": "/workflow:plan",
|
||||
"name": "plan",
|
||||
"description": "Detailed planning",
|
||||
"argumentHint": "[--explore] \"task\"",
|
||||
"status": "completed"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"command": "/workflow:execute",
|
||||
"name": "execute",
|
||||
"description": "Execute with state resume",
|
||||
"argumentHint": "[--resume-session=\"WFS-xxx\"]",
|
||||
"status": "completed"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"command": "/workflow:test-cycle-execute",
|
||||
"name": "test-cycle-execute",
|
||||
"status": "pending"
|
||||
}
|
||||
],
|
||||
"execution_results": [
|
||||
{
|
||||
"index": 0,
|
||||
"command": "/workflow:plan",
|
||||
"status": "completed",
|
||||
"task_id": "task-001",
|
||||
"session_id": "WFS-plan-20250124",
|
||||
"artifacts": ["IMPL_PLAN.md", "exploration-architecture.json"],
|
||||
"timestamp": "2025-01-24T14:30:25Z",
|
||||
"completed_at": "2025-01-24T14:30:45Z"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"command": "/workflow:execute",
|
||||
"status": "in-progress",
|
||||
"task_id": "task-002",
|
||||
"session_id": null,
|
||||
"artifacts": [],
|
||||
"timestamp": "2025-01-24T14:32:00Z",
|
||||
"completed_at": null
|
||||
}
|
||||
],
|
||||
"prompts_used": [
|
||||
{
|
||||
"index": 0,
|
||||
"command": "/workflow:plan",
|
||||
"prompt": "/workflow:plan -y \"Implement user registration...\"\n\nTask: Implement user registration..."
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"command": "/workflow:execute",
|
||||
"prompt": "/workflow:execute -y --resume-session=\"WFS-plan-20250124\"\n\nTask: Implement user registration\n\nPrevious results:\n- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Status Flow
|
||||
|
||||
```
|
||||
running → waiting → [hook callback] → waiting → [hook callback] → completed
|
||||
↓ ↑
|
||||
failed ←────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Status Values**:
|
||||
- `running`: Orchestrator actively executing (launching CLI commands)
|
||||
- `waiting`: Paused, waiting for hook callbacks to trigger continuation
|
||||
- `completed`: All commands finished successfully
|
||||
- `failed`: User aborted or unrecoverable error
|
||||
|
||||
### Field Descriptions
|
||||
|
||||
**execution_results[] fields**:
|
||||
- `index`: Command position in chain (0-indexed)
|
||||
- `command`: Full command string (e.g., `/workflow:plan`)
|
||||
- `status`: `in-progress` | `completed` | `skipped` | `failed`
|
||||
- `task_id`: Background task identifier (from Bash tool)
|
||||
- `session_id`: Workflow session ID (e.g., `WFS-*`) or null if failed
|
||||
- `artifacts`: Generated files/directories
|
||||
- `timestamp`: Command start time (ISO 8601)
|
||||
- `completed_at`: Command completion time or null if pending
|
||||
|
||||
**command_chain[] status values**:
|
||||
- `pending`: Not started yet
|
||||
- `running`: Currently executing
|
||||
- `completed`: Successfully finished
|
||||
- `failed`: Failed to execute
|
||||
|
||||
## CommandRegistry Integration
|
||||
|
||||
Sole CCW tool for command discovery:
|
||||
|
||||
```javascript
|
||||
import { CommandRegistry } from 'ccw/tools/command-registry';
|
||||
|
||||
const registry = new CommandRegistry();
|
||||
|
||||
// Get all commands
|
||||
const allCommands = registry.getAllCommandsSummary();
|
||||
// Map<"/workflow:lite-plan" => {name, description}>
|
||||
|
||||
// Get categorized
|
||||
const byCategory = registry.getAllCommandsByCategory();
|
||||
// {planning, execution, testing, review, other}
|
||||
|
||||
// Get single command metadata
|
||||
const cmd = registry.getCommand('lite-plan');
|
||||
// {name, command, description, argumentHint, allowedTools, filePath}
|
||||
```
|
||||
|
||||
## Universal Prompt Template
|
||||
|
||||
### Standard Format
|
||||
|
||||
```bash
|
||||
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
|
||||
```
|
||||
|
||||
### Prompt Content Template
|
||||
|
||||
```
|
||||
/workflow:<command> -y <command_parameters>
|
||||
|
||||
Task: <task_description>
|
||||
|
||||
<optional_previous_results>
|
||||
```
|
||||
|
||||
### Template Variables
|
||||
|
||||
| Variable | Description | Examples |
|
||||
|----------|-------------|----------|
|
||||
| `<command>` | Workflow command name | `plan`, `lite-execute`, `test-cycle-execute` |
|
||||
| `-y` | Auto-confirm flag (inside prompt) | Always include for automation |
|
||||
| `<command_parameters>` | Command-specific parameters | Task description, session ID, flags |
|
||||
| `<task_description>` | Brief task description | "Implement user authentication", "Fix memory leak" |
|
||||
| `<optional_previous_results>` | Context from previous commands | "Previous results:\n- /workflow:plan: WFS-xxx" |
|
||||
|
||||
### Command Parameter Patterns
|
||||
|
||||
| Command Type | Parameter Pattern | Example |
|
||||
|--------------|------------------|---------|
|
||||
| **Planning** | `"task description"` | `/workflow:plan -y "Implement OAuth2"` |
|
||||
| **Execution (with plan)** | `--resume-session="WFS-xxx"` | `/workflow:execute -y --resume-session="WFS-plan-001"` |
|
||||
| **Execution (standalone)** | `--in-memory` or `"task"` | `/workflow:lite-execute -y --in-memory` |
|
||||
| **Session-based** | `--session="WFS-xxx"` | `/workflow:test-fix-gen -y --session="WFS-impl-001"` |
|
||||
| **Fix/Debug** | `"problem description"` | `/workflow:lite-fix -y "Fix timeout bug"` |
|
||||
|
||||
### Complete Examples
|
||||
|
||||
**Planning Command**:
|
||||
```bash
|
||||
ccw cli -p '/workflow:plan -y "Implement user registration with email validation"
|
||||
|
||||
Task: Implement user registration' --tool claude --mode write
|
||||
```
|
||||
|
||||
**Execution with Context**:
|
||||
```bash
|
||||
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
|
||||
|
||||
Task: Implement user registration
|
||||
|
||||
Previous results:
|
||||
- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)' --tool claude --mode write
|
||||
```
|
||||
|
||||
**Standalone Lite Execution**:
|
||||
```bash
|
||||
ccw cli -p '/workflow:lite-fix -y "Fix login timeout in auth module"
|
||||
|
||||
Task: Fix login timeout' --tool claude --mode write
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
|
||||
```javascript
|
||||
// Main entry point
|
||||
async function ccwCoordinator(taskDescription) {
|
||||
// Phase 1
|
||||
const analysis = await analyzeRequirements(taskDescription);
|
||||
|
||||
// Phase 2
|
||||
const chain = await recommendCommandChain(analysis);
|
||||
const confirmedChain = await getUserConfirmation(chain);
|
||||
|
||||
// Phase 3
|
||||
const state = await executeCommandChain(confirmedChain, analysis);
|
||||
|
||||
console.log(`✅ Complete! Session: ${state.session_id}`);
|
||||
console.log(`State: .workflow/.ccw-coordinator/${state.session_id}/state.json`);
|
||||
}
|
||||
```
|
||||
|
||||
## Key Design Principles
|
||||
|
||||
1. **No Fixed Logic** - Claude intelligently decides based on analysis
|
||||
2. **Dynamic Discovery** - CommandRegistry retrieves available commands
|
||||
3. **Smart Parameters** - Command args assembled based on previous results
|
||||
4. **Full State Tracking** - All execution recorded to state.json
|
||||
5. **User Control** - Confirmation + error handling with user choice
|
||||
6. **Context Passing** - Each prompt includes previous results
|
||||
7. **Resumable** - Can load state.json to continue
|
||||
8. **Serial Blocking** - Commands execute one-by-one with hook-based continuation
|
||||
|
||||
## CLI Execution Model
|
||||
|
||||
### CLI Invocation Format
|
||||
|
||||
**IMPORTANT**: The `ccw cli` command executes prompts through external tools. The format is:
|
||||
|
||||
```bash
|
||||
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `-p "PROMPT_CONTENT"`: The prompt content to execute (required)
|
||||
- `--tool <tool>`: CLI tool to use (e.g., `claude`, `gemini`, `qwen`)
|
||||
- `--mode <mode>`: Execution mode (`analysis` or `write`)
|
||||
|
||||
**Note**: `-y` is a **command parameter inside the prompt**, NOT a `ccw cli` parameter.
|
||||
|
||||
### Prompt Assembly
|
||||
|
||||
The prompt content MUST start with the workflow command, followed by task context:
|
||||
|
||||
```
|
||||
/workflow:<command> -y <parameters>
|
||||
|
||||
Task: <description>
|
||||
|
||||
<optional_context>
|
||||
```
|
||||
|
||||
**Examples**:
|
||||
```bash
|
||||
# Planning command
|
||||
ccw cli -p '/workflow:plan -y "Implement user registration feature"
|
||||
|
||||
Task: Implement user registration' --tool claude --mode write
|
||||
|
||||
# Execution command (with session reference)
|
||||
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
|
||||
|
||||
Task: Implement user registration
|
||||
|
||||
Previous results:
|
||||
- /workflow:plan: WFS-plan-20250124' --tool claude --mode write
|
||||
|
||||
# Lite execution (in-memory from previous plan)
|
||||
ccw cli -p '/workflow:lite-execute -y --in-memory
|
||||
|
||||
Task: Implement user registration' --tool claude --mode write
|
||||
```
|
||||
|
||||
### Serial Blocking
|
||||
|
||||
**CRITICAL**: Commands execute one-by-one. After launching CLI in background:
|
||||
1. Orchestrator stops immediately (`break`)
|
||||
2. Wait for hook callback - **DO NOT use TaskOutput polling**
|
||||
3. Hook callback triggers next command
|
||||
|
||||
**Prompt Structure**: Command must be first in prompt content
|
||||
|
||||
```javascript
|
||||
// Example: Execute command and stop
|
||||
const prompt = '/workflow:plan -y "Implement user authentication"\n\nTask: Implement user auth system';
|
||||
const taskId = Bash(`ccw cli -p "${prompt}" --tool claude --mode write`, { run_in_background: true }).task_id;
|
||||
state.execution_results.push({ status: 'in-progress', task_id: taskId, ... });
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
break; // ⚠️ STOP HERE - DO NOT use TaskOutput polling
|
||||
|
||||
// Hook callback will call handleCliCompletion(sessionId, taskId, output) when done
|
||||
// → Updates state → Triggers next command via resumeChainExecution()
|
||||
```
|
||||
|
||||
|
||||
## Available Commands
|
||||
|
||||
All from `~/.claude/commands/workflow/`:
|
||||
|
||||
**Planning**: lite-plan, plan, multi-cli-plan, plan-verify, tdd-plan
|
||||
**Execution**: lite-execute, execute, develop-with-file
|
||||
**Testing**: test-cycle-execute, test-gen, test-fix-gen, tdd-verify
|
||||
**Review**: review, review-session-cycle, review-module-cycle, review-fix
|
||||
**Bug Fixes**: lite-fix, debug, debug-with-file
|
||||
**Brainstorming**: brainstorm:auto-parallel, brainstorm:artifacts, brainstorm:synthesis
|
||||
**Design**: ui-design:*, animation-extract, layout-extract, style-extract, codify-style
|
||||
**Session Management**: session:start, session:resume, session:complete, session:solidify, session:list
|
||||
**Tools**: context-gather, test-context-gather, task-generate, conflict-resolution, action-plan-verify
|
||||
**Utility**: clean, init, replan
|
||||
|
||||
### Testing Commands Distinction
|
||||
|
||||
| Command | Purpose | Output | Follow-up |
|
||||
|---------|---------|--------|-----------|
|
||||
| **test-gen** | 广泛测试示例生成并进行测试 | test-tasks (IMPL-001, IMPL-002) | `/workflow:execute` |
|
||||
| **test-fix-gen** | 针对特定问题生成测试并在测试中修正 | test-tasks | `/workflow:test-cycle-execute` |
|
||||
| **test-cycle-execute** | 执行测试周期(迭代测试和修复) | test-passed | N/A (终点) |
|
||||
|
||||
**流程说明**:
|
||||
- **test-gen → execute**: 生成全面的测试套件,execute 执行生成和测试
|
||||
- **test-fix-gen → test-cycle-execute**: 针对特定问题生成修复任务,test-cycle-execute 迭代测试和修复直到通过
|
||||
|
||||
### Task Type Routing (Pipeline Summary)
|
||||
|
||||
**Note**: `【 】` marks Minimum Execution Units (最小执行单元) - these commands must execute together.
|
||||
|
||||
| Task Type | Pipeline | Minimum Units |
|
||||
|-----------|----------|---|
|
||||
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Quick Implementation + Test Validation |
|
||||
| **feature** (complex) | 需求 →【plan → plan-verify】→ validate → execute → 代码 → review → fix | Full Planning + Code Review + Testing |
|
||||
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Bug Fix + Test Validation |
|
||||
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify | TDD Planning + Execution |
|
||||
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Test Validation |
|
||||
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 | Test Generation + Execution |
|
||||
| **review** | 代码 →【review-* → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Code Review + Testing |
|
||||
| **brainstorm** | 探索主题 → brainstorm → 分析 →【plan → plan-verify】→ execute → test | Exploration + Planning + Execution |
|
||||
| **multi-cli** | 需求 → multi-cli-plan → 对比分析 → lite-execute → test | Multi-Perspective + Testing |
|
||||
|
||||
Use `CommandRegistry.getAllCommandsSummary()` to discover all commands dynamically.
|
||||
486
.claude/commands/ccw.md
Normal file
486
.claude/commands/ccw.md
Normal file
@@ -0,0 +1,486 @@
|
||||
---
|
||||
name: ccw
|
||||
description: Main workflow orchestrator - analyze intent, select workflow, execute command chain in main process
|
||||
argument-hint: "\"task description\""
|
||||
allowed-tools: SlashCommand(*), TodoWrite(*), AskUserQuestion(*), Read(*), Grep(*), Glob(*)
|
||||
---
|
||||
|
||||
# CCW Command - Main Workflow Orchestrator
|
||||
|
||||
Main process orchestrator: intent analysis → workflow selection → command chain execution.
|
||||
|
||||
## Core Concept: Minimum Execution Units (最小执行单元)
|
||||
|
||||
**Definition**: A set of commands that must execute together as an atomic group to achieve a meaningful workflow milestone.
|
||||
|
||||
**Why This Matters**:
|
||||
- **Prevents Incomplete States**: Avoid stopping after task generation without execution
|
||||
- **User Experience**: User gets complete results, not intermediate artifacts requiring manual follow-up
|
||||
- **Workflow Integrity**: Maintains logical coherence of multi-step operations
|
||||
|
||||
**Key Units in CCW**:
|
||||
|
||||
| Unit Type | Pattern | Example |
|
||||
|-----------|---------|---------|
|
||||
| **Planning + Execution** | plan-cmd → execute-cmd | lite-plan → lite-execute |
|
||||
| **Testing** | test-gen-cmd → test-exec-cmd | test-fix-gen → test-cycle-execute |
|
||||
| **Review** | review-cmd → fix-cmd | review-session-cycle → review-fix |
|
||||
|
||||
**Atomic Rules**:
|
||||
1. CCW automatically groups commands into minimum units - never splits them
|
||||
2. Pipeline visualization shows units with `【 】` markers
|
||||
3. Error handling preserves unit boundaries (retry/skip affects whole unit)
|
||||
|
||||
## Execution Model
|
||||
|
||||
**Synchronous (Main Process)**: Commands execute via SlashCommand in main process, blocking until complete.
|
||||
|
||||
```
|
||||
User Input → Analyze Intent → Select Workflow → [Confirm] → Execute Chain
|
||||
↓
|
||||
SlashCommand (blocking)
|
||||
↓
|
||||
Update TodoWrite
|
||||
↓
|
||||
Next Command...
|
||||
```
|
||||
|
||||
**vs ccw-coordinator**: External CLI execution with background tasks and hook callbacks.
|
||||
|
||||
## 5-Phase Workflow
|
||||
|
||||
### Phase 1: Analyze Intent
|
||||
|
||||
```javascript
|
||||
function analyzeIntent(input) {
|
||||
return {
|
||||
goal: extractGoal(input),
|
||||
scope: extractScope(input),
|
||||
constraints: extractConstraints(input),
|
||||
task_type: detectTaskType(input), // bugfix|feature|tdd|review|exploration|...
|
||||
complexity: assessComplexity(input), // low|medium|high
|
||||
clarity_score: calculateClarity(input) // 0-3 (>=2 = clear)
|
||||
};
|
||||
}
|
||||
|
||||
// Task type detection (priority order)
|
||||
function detectTaskType(text) {
|
||||
const patterns = {
|
||||
'bugfix-hotfix': /urgent|production|critical/ && /fix|bug/,
|
||||
'bugfix': /fix|bug|error|crash|fail|debug/,
|
||||
'issue-batch': /issues?|batch/ && /fix|resolve/,
|
||||
'exploration': /uncertain|explore|research|what if/,
|
||||
'multi-perspective': /multi-perspective|compare|cross-verify/,
|
||||
'quick-task': /quick|simple|small/ && /feature|function/,
|
||||
'ui-design': /ui|design|component|style/,
|
||||
'tdd': /tdd|test-driven|test first/,
|
||||
'test-fix': /test fail|fix test|failing test/,
|
||||
'review': /review|code review/,
|
||||
'documentation': /docs|documentation|readme/
|
||||
};
|
||||
for (const [type, pattern] of Object.entries(patterns)) {
|
||||
if (pattern.test(text)) return type;
|
||||
}
|
||||
return 'feature';
|
||||
}
|
||||
```
|
||||
|
||||
**Output**: `Type: [task_type] | Goal: [goal] | Complexity: [complexity] | Clarity: [clarity_score]/3`
|
||||
|
||||
---
|
||||
|
||||
### Phase 1.5: Requirement Clarification (if clarity_score < 2)
|
||||
|
||||
```javascript
|
||||
async function clarifyRequirements(analysis) {
|
||||
if (analysis.clarity_score >= 2) return analysis;
|
||||
|
||||
const questions = generateClarificationQuestions(analysis); // Goal, Scope, Constraints
|
||||
const answers = await AskUserQuestion({ questions });
|
||||
return updateAnalysis(analysis, answers);
|
||||
}
|
||||
```
|
||||
|
||||
**Questions**: Goal (Create/Fix/Optimize/Analyze), Scope (Single file/Module/Cross-module/System), Constraints (Backward compat/Skip tests/Urgent hotfix)
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Select Workflow & Build Command Chain
|
||||
|
||||
```javascript
|
||||
function selectWorkflow(analysis) {
|
||||
const levelMap = {
|
||||
'bugfix-hotfix': { level: 2, flow: 'bugfix.hotfix' },
|
||||
'bugfix': { level: 2, flow: 'bugfix.standard' },
|
||||
'issue-batch': { level: 'Issue', flow: 'issue' },
|
||||
'exploration': { level: 4, flow: 'full' },
|
||||
'quick-task': { level: 1, flow: 'lite-lite-lite' },
|
||||
'ui-design': { level: analysis.complexity === 'high' ? 4 : 3, flow: 'ui' },
|
||||
'tdd': { level: 3, flow: 'tdd' },
|
||||
'test-fix': { level: 3, flow: 'test-fix-gen' },
|
||||
'review': { level: 3, flow: 'review-fix' },
|
||||
'documentation': { level: 2, flow: 'docs' },
|
||||
'feature': { level: analysis.complexity === 'high' ? 3 : 2, flow: analysis.complexity === 'high' ? 'coupled' : 'rapid' }
|
||||
};
|
||||
|
||||
const selected = levelMap[analysis.task_type] || levelMap['feature'];
|
||||
return buildCommandChain(selected, analysis);
|
||||
}
|
||||
|
||||
// Build command chain (port-based matching with Minimum Execution Units)
|
||||
function buildCommandChain(workflow, analysis) {
|
||||
const chains = {
|
||||
// Level 1 - Rapid
|
||||
'lite-lite-lite': [
|
||||
{ cmd: '/workflow:lite-lite-lite', args: `"${analysis.goal}"` }
|
||||
],
|
||||
|
||||
// Level 2 - Lightweight
|
||||
'rapid': [
|
||||
// Unit: Quick Implementation【lite-plan → lite-execute】
|
||||
{ cmd: '/workflow:lite-plan', args: `"${analysis.goal}"`, unit: 'quick-impl' },
|
||||
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'quick-impl' },
|
||||
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
|
||||
...(analysis.constraints?.includes('skip-tests') ? [] : [
|
||||
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
|
||||
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
|
||||
])
|
||||
],
|
||||
|
||||
'bugfix.standard': [
|
||||
// Unit: Bug Fix【lite-fix → lite-execute】
|
||||
{ cmd: '/workflow:lite-fix', args: `"${analysis.goal}"`, unit: 'bug-fix' },
|
||||
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'bug-fix' },
|
||||
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
|
||||
...(analysis.constraints?.includes('skip-tests') ? [] : [
|
||||
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
|
||||
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
|
||||
])
|
||||
],
|
||||
|
||||
'bugfix.hotfix': [
|
||||
{ cmd: '/workflow:lite-fix', args: `--hotfix "${analysis.goal}"` }
|
||||
],
|
||||
|
||||
'multi-cli-plan': [
|
||||
// Unit: Multi-CLI Planning【multi-cli-plan → lite-execute】
|
||||
{ cmd: '/workflow:multi-cli-plan', args: `"${analysis.goal}"`, unit: 'multi-cli' },
|
||||
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'multi-cli' },
|
||||
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
|
||||
...(analysis.constraints?.includes('skip-tests') ? [] : [
|
||||
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
|
||||
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
|
||||
])
|
||||
],
|
||||
|
||||
'docs': [
|
||||
// Unit: Quick Implementation【lite-plan → lite-execute】
|
||||
{ cmd: '/workflow:lite-plan', args: `"${analysis.goal}"`, unit: 'quick-impl' },
|
||||
{ cmd: '/workflow:lite-execute', args: '--in-memory', unit: 'quick-impl' }
|
||||
],
|
||||
|
||||
// Level 3 - Standard
|
||||
'coupled': [
|
||||
// Unit: Verified Planning【plan → plan-verify】
|
||||
{ cmd: '/workflow:plan', args: `"${analysis.goal}"`, unit: 'verified-planning' },
|
||||
{ cmd: '/workflow:plan-verify', args: '', unit: 'verified-planning' },
|
||||
// Execution
|
||||
{ cmd: '/workflow:execute', args: '' },
|
||||
// Unit: Code Review【review-session-cycle → review-fix】
|
||||
{ cmd: '/workflow:review-session-cycle', args: '', unit: 'code-review' },
|
||||
{ cmd: '/workflow:review-fix', args: '', unit: 'code-review' },
|
||||
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
|
||||
...(analysis.constraints?.includes('skip-tests') ? [] : [
|
||||
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
|
||||
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
|
||||
])
|
||||
],
|
||||
|
||||
'tdd': [
|
||||
// Unit: TDD Planning + Execution【tdd-plan → execute】
|
||||
{ cmd: '/workflow:tdd-plan', args: `"${analysis.goal}"`, unit: 'tdd-planning' },
|
||||
{ cmd: '/workflow:execute', args: '', unit: 'tdd-planning' },
|
||||
// TDD Verification
|
||||
{ cmd: '/workflow:tdd-verify', args: '' }
|
||||
],
|
||||
|
||||
'test-fix-gen': [
|
||||
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
|
||||
{ cmd: '/workflow:test-fix-gen', args: `"${analysis.goal}"`, unit: 'test-validation' },
|
||||
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
|
||||
],
|
||||
|
||||
'review-fix': [
|
||||
// Unit: Code Review【review-session-cycle → review-fix】
|
||||
{ cmd: '/workflow:review-session-cycle', args: '', unit: 'code-review' },
|
||||
{ cmd: '/workflow:review-fix', args: '', unit: 'code-review' },
|
||||
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
|
||||
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
|
||||
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
|
||||
],
|
||||
|
||||
'ui': [
|
||||
{ cmd: '/workflow:ui-design:explore-auto', args: `"${analysis.goal}"` },
|
||||
// Unit: Planning + Execution【plan → execute】
|
||||
{ cmd: '/workflow:plan', args: '', unit: 'plan-execute' },
|
||||
{ cmd: '/workflow:execute', args: '', unit: 'plan-execute' }
|
||||
],
|
||||
|
||||
// Level 4 - Brainstorm
|
||||
'full': [
|
||||
{ cmd: '/workflow:brainstorm:auto-parallel', args: `"${analysis.goal}"` },
|
||||
// Unit: Verified Planning【plan → plan-verify】
|
||||
{ cmd: '/workflow:plan', args: '', unit: 'verified-planning' },
|
||||
{ cmd: '/workflow:plan-verify', args: '', unit: 'verified-planning' },
|
||||
// Execution
|
||||
{ cmd: '/workflow:execute', args: '' },
|
||||
// Unit: Test Validation【test-fix-gen → test-cycle-execute】
|
||||
{ cmd: '/workflow:test-fix-gen', args: '', unit: 'test-validation' },
|
||||
{ cmd: '/workflow:test-cycle-execute', args: '', unit: 'test-validation' }
|
||||
],
|
||||
|
||||
// Issue Workflow
|
||||
'issue': [
|
||||
{ cmd: '/issue:discover', args: '' },
|
||||
{ cmd: '/issue:plan', args: '--all-pending' },
|
||||
{ cmd: '/issue:queue', args: '' },
|
||||
{ cmd: '/issue:execute', args: '' }
|
||||
]
|
||||
};
|
||||
|
||||
return chains[workflow.flow] || chains['rapid'];
|
||||
}
|
||||
```
|
||||
|
||||
**Output**: `Level [X] - [flow] | Pipeline: [...] | Commands: [1. /cmd1 2. /cmd2 ...]`
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: User Confirmation
|
||||
|
||||
```javascript
|
||||
async function getUserConfirmation(chain) {
|
||||
const response = await AskUserQuestion({
|
||||
questions: [{
|
||||
question: "Execute this command chain?",
|
||||
header: "Confirm",
|
||||
options: [
|
||||
{ label: "Confirm", description: "Start" },
|
||||
{ label: "Adjust", description: "Modify" },
|
||||
{ label: "Cancel", description: "Abort" }
|
||||
]
|
||||
}]
|
||||
});
|
||||
|
||||
if (response.error === "Cancel") throw new Error("Cancelled");
|
||||
if (response.error === "Adjust") return await adjustChain(chain);
|
||||
return chain;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: Setup TODO Tracking
|
||||
|
||||
```javascript
|
||||
function setupTodoTracking(chain, workflow) {
|
||||
const todos = chain.map((step, i) => ({
|
||||
content: `CCW:${workflow}: [${i + 1}/${chain.length}] ${step.cmd}`,
|
||||
status: i === 0 ? 'in_progress' : 'pending',
|
||||
activeForm: `Executing ${step.cmd}`
|
||||
}));
|
||||
TodoWrite({ todos });
|
||||
}
|
||||
```
|
||||
|
||||
**Output**: `-> CCW:rapid: [1/3] /workflow:lite-plan | CCW:rapid: [2/3] /workflow:lite-execute | ...`
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: Execute Command Chain
|
||||
|
||||
```javascript
|
||||
async function executeCommandChain(chain, workflow) {
|
||||
let previousResult = null;
|
||||
|
||||
for (let i = 0; i < chain.length; i++) {
|
||||
try {
|
||||
const fullCommand = assembleCommand(chain[i], previousResult);
|
||||
const result = await SlashCommand({ command: fullCommand });
|
||||
|
||||
previousResult = { ...result, success: true };
|
||||
updateTodoStatus(i, chain.length, workflow, 'completed');
|
||||
|
||||
} catch (error) {
|
||||
const action = await handleError(chain[i], error, i);
|
||||
if (action === 'retry') {
|
||||
i--; // Retry
|
||||
} else if (action === 'abort') {
|
||||
return { success: false, error: error.message };
|
||||
}
|
||||
// 'skip' - continue
|
||||
}
|
||||
}
|
||||
|
||||
return { success: true, completed: chain.length };
|
||||
}
|
||||
|
||||
// Assemble full command with session/plan parameters
|
||||
function assembleCommand(step, previousResult) {
|
||||
let command = step.cmd;
|
||||
if (step.args) {
|
||||
command += ` ${step.args}`;
|
||||
} else if (previousResult?.session_id) {
|
||||
command += ` --session="${previousResult.session_id}"`;
|
||||
}
|
||||
return command;
|
||||
}
|
||||
|
||||
// Update TODO: mark current as complete, next as in-progress
|
||||
function updateTodoStatus(index, total, workflow, status) {
|
||||
const todos = getAllCurrentTodos();
|
||||
const updated = todos.map(todo => {
|
||||
if (todo.content.startsWith(`CCW:${workflow}:`)) {
|
||||
const stepNum = extractStepIndex(todo.content);
|
||||
if (stepNum === index + 1) return { ...todo, status };
|
||||
if (stepNum === index + 2 && status === 'completed') return { ...todo, status: 'in_progress' };
|
||||
}
|
||||
return todo;
|
||||
});
|
||||
TodoWrite({ todos: updated });
|
||||
}
|
||||
|
||||
// Error handling: Retry/Skip/Abort
|
||||
async function handleError(step, error, index) {
|
||||
const response = await AskUserQuestion({
|
||||
questions: [{
|
||||
question: `${step.cmd} failed: ${error.message}`,
|
||||
header: "Error",
|
||||
options: [
|
||||
{ label: "Retry", description: "Re-execute" },
|
||||
{ label: "Skip", description: "Continue next" },
|
||||
{ label: "Abort", description: "Stop" }
|
||||
]
|
||||
}]
|
||||
});
|
||||
return { Retry: 'retry', Skip: 'skip', Abort: 'abort' }[response.Error] || 'abort';
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Execution Flow Summary
|
||||
|
||||
```
|
||||
User Input
|
||||
|
|
||||
Phase 1: Analyze Intent
|
||||
|-- Extract: goal, scope, constraints, task_type, complexity, clarity
|
||||
+-- If clarity < 2 -> Phase 1.5: Clarify Requirements
|
||||
|
|
||||
Phase 2: Select Workflow & Build Chain
|
||||
|-- Map task_type -> Level (1/2/3/4/Issue)
|
||||
|-- Select flow based on complexity
|
||||
+-- Build command chain (port-based)
|
||||
|
|
||||
Phase 3: User Confirmation (optional)
|
||||
|-- Show pipeline visualization
|
||||
+-- Allow adjustment
|
||||
|
|
||||
Phase 4: Setup TODO Tracking
|
||||
+-- Create todos with CCW prefix
|
||||
|
|
||||
Phase 5: Execute Command Chain
|
||||
|-- For each command:
|
||||
| |-- Assemble full command
|
||||
| |-- Execute via SlashCommand
|
||||
| |-- Update TODO status
|
||||
| +-- Handle errors (retry/skip/abort)
|
||||
+-- Return workflow result
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pipeline Examples (with Minimum Execution Units)
|
||||
|
||||
**Note**: `【 】` marks Minimum Execution Units - commands execute together as atomic groups.
|
||||
|
||||
| Input | Type | Level | Pipeline (with Units) |
|
||||
|-------|------|-------|-----------------------|
|
||||
| "Add API endpoint" | feature (low) | 2 |【lite-plan → lite-execute】→【test-fix-gen → test-cycle-execute】|
|
||||
| "Fix login timeout" | bugfix | 2 |【lite-fix → lite-execute】→【test-fix-gen → test-cycle-execute】|
|
||||
| "OAuth2 system" | feature (high) | 3 |【plan → plan-verify】→ execute →【review-session-cycle → review-fix】→【test-fix-gen → test-cycle-execute】|
|
||||
| "Implement with TDD" | tdd | 3 |【tdd-plan → execute】→ tdd-verify |
|
||||
| "Uncertain: real-time arch" | exploration | 4 | brainstorm:auto-parallel →【plan → plan-verify】→ execute →【test-fix-gen → test-cycle-execute】|
|
||||
|
||||
---
|
||||
|
||||
## Key Design Principles
|
||||
|
||||
1. **Main Process Execution** - Use SlashCommand in main process, no external CLI
|
||||
2. **Intent-Driven** - Auto-select workflow based on task intent
|
||||
3. **Port-Based Chaining** - Build command chain using port matching
|
||||
4. **Minimum Execution Units** - Commands grouped into atomic units, never split (e.g., lite-plan → lite-execute)
|
||||
5. **Progressive Clarification** - Low clarity triggers clarification phase
|
||||
6. **TODO Tracking** - Use CCW prefix to isolate workflow todos
|
||||
7. **Unit-Aware Error Handling** - Retry/skip/abort affects whole unit, not individual commands
|
||||
8. **User Control** - Optional user confirmation at each phase
|
||||
|
||||
---
|
||||
|
||||
## State Management
|
||||
|
||||
**TodoWrite-Based Tracking**: All execution state tracked via TodoWrite with `CCW:` prefix.
|
||||
|
||||
```javascript
|
||||
// Initial state
|
||||
todos = [
|
||||
{ content: "CCW:rapid: [1/3] /workflow:lite-plan", status: "in_progress" },
|
||||
{ content: "CCW:rapid: [2/3] /workflow:lite-execute", status: "pending" },
|
||||
{ content: "CCW:rapid: [3/3] /workflow:test-cycle-execute", status: "pending" }
|
||||
];
|
||||
|
||||
// After command 1 completes
|
||||
todos = [
|
||||
{ content: "CCW:rapid: [1/3] /workflow:lite-plan", status: "completed" },
|
||||
{ content: "CCW:rapid: [2/3] /workflow:lite-execute", status: "in_progress" },
|
||||
{ content: "CCW:rapid: [3/3] /workflow:test-cycle-execute", status: "pending" }
|
||||
];
|
||||
```
|
||||
|
||||
**vs ccw-coordinator**: Extensive state.json with task_id, status transitions, hook callbacks.
|
||||
|
||||
---
|
||||
|
||||
## Type Comparison: ccw vs ccw-coordinator
|
||||
|
||||
| Aspect | ccw | ccw-coordinator |
|
||||
|--------|-----|-----------------|
|
||||
| **Type** | Main process (SlashCommand) | External CLI (ccw cli + hook callbacks) |
|
||||
| **Execution** | Synchronous blocking | Async background with hook completion |
|
||||
| **Workflow** | Auto intent-based selection | Manual chain building |
|
||||
| **Intent Analysis** | 5-phase clarity check | 3-phase requirement analysis |
|
||||
| **State** | TodoWrite only (in-memory) | state.json + checkpoint/resume |
|
||||
| **Error Handling** | Retry/skip/abort (interactive) | Retry/skip/abort (via AskUser) |
|
||||
| **Use Case** | Auto workflow for any task | Manual orchestration, large chains |
|
||||
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Auto-select workflow
|
||||
ccw "Add user authentication"
|
||||
|
||||
# Complex requirement (triggers clarification)
|
||||
ccw "Optimize system performance"
|
||||
|
||||
# Bug fix
|
||||
ccw "Fix memory leak in WebSocket handler"
|
||||
|
||||
# TDD development
|
||||
ccw "Implement user registration with TDD"
|
||||
|
||||
# Exploratory task
|
||||
ccw "Uncertain about architecture for real-time notifications"
|
||||
```
|
||||
@@ -1,45 +0,0 @@
|
||||
# CCW Coordinator
|
||||
|
||||
交互式命令编排工具
|
||||
|
||||
## 使用
|
||||
|
||||
```
|
||||
/ccw-coordinator
|
||||
或
|
||||
/coordinator
|
||||
```
|
||||
|
||||
## 流程
|
||||
|
||||
1. 用户描述任务
|
||||
2. Claude推荐命令链
|
||||
3. 用户确认或调整
|
||||
4. 执行命令链
|
||||
5. 生成报告
|
||||
|
||||
## 示例
|
||||
|
||||
**Bug修复**
|
||||
```
|
||||
任务: 修复登录bug
|
||||
推荐: lite-fix → test-cycle-execute
|
||||
```
|
||||
|
||||
**新功能**
|
||||
```
|
||||
任务: 实现注册功能
|
||||
推荐: plan → execute → test-cycle-execute
|
||||
```
|
||||
|
||||
## 文件说明
|
||||
|
||||
| 文件 | 用途 |
|
||||
|------|------|
|
||||
| SKILL.md | Skill入口 |
|
||||
| phases/orchestrator.md | 编排逻辑 |
|
||||
| phases/state-schema.md | 状态定义 |
|
||||
| phases/actions/*.md | 动作实现 |
|
||||
| specs/specs.md | 命令库、验证规则、注册表 |
|
||||
| tools/chain-validate.cjs | 验证工具 |
|
||||
| tools/command-registry.cjs | 命令注册表工具 |
|
||||
@@ -1,320 +0,0 @@
|
||||
---
|
||||
name: ccw-coordinator
|
||||
description: Interactive command orchestration tool for building and executing Claude CLI command chains. Triggers on "coordinator", "ccw-coordinator", "命令编排", "command chain", "orchestrate commands", "编排CLI命令".
|
||||
allowed-tools: Task, AskUserQuestion, Read, Write, Bash, Glob, Grep
|
||||
---
|
||||
|
||||
# CCW Coordinator
|
||||
|
||||
交互式命令编排工具:允许用户依次选择命令,形成命令串,然后依次调用claude cli执行整个命令串。
|
||||
|
||||
支持灵活的工作流组合,提供交互式界面用于命令选择、编排和执行管理。
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Orchestrator (状态驱动决策) │
|
||||
│ 根据用户选择编排命令和执行流程 │
|
||||
└───────────────┬─────────────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────┼───────────┬───────────────┐
|
||||
↓ ↓ ↓ ↓
|
||||
┌─────────┐ ┌──────────────┐ ┌────────────┐ ┌──────────┐
|
||||
│ Init │ │ Command │ │ Command │ │ Execute │
|
||||
│ │ │ Selection │ │ Build │ │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ 初始化 │ │ 选择命令 │ │ 编排调整 │ │ 执行链 │
|
||||
└─────────┘ └──────────────┘ └────────────┘ └──────────┘
|
||||
│ │ │ │
|
||||
└───────────────┼──────────────┴────────────┘
|
||||
│
|
||||
↓
|
||||
┌──────────────┐
|
||||
│ Complete │
|
||||
│ 生成报告 │
|
||||
└──────────────┘
|
||||
```
|
||||
|
||||
## Key Design Principles
|
||||
|
||||
1. **智能推荐**: Claude 根据用户任务描述,自动推荐最优命令链
|
||||
2. **交互式编排**: 用户通过交互式界面选择和编排命令,实时反馈
|
||||
3. **无状态动作**: 每个动作独立执行,通过共享状态进行通信
|
||||
4. **灵活的命令库**: 支持ccw workflow命令和标准claude cli命令
|
||||
5. **执行透明性**: 展示执行进度、结果和可能的错误
|
||||
6. **会话持久化**: 保存编排会话,支持中途暂停和恢复
|
||||
7. **智能提示词生成**: 根据任务上下文和前序产物自动生成 ccw cli 提示词
|
||||
8. **自动确认**: 所有命令自动添加 `-y` 参数,跳过交互式确认,实现无人值守执行
|
||||
|
||||
## Intelligent Prompt Generation
|
||||
|
||||
执行命令时,系统根据以下信息智能生成 `ccw cli -p` 提示词:
|
||||
|
||||
### 提示词构成
|
||||
|
||||
```javascript
|
||||
// 集成命令注册表 (~/.claude/tools/command-registry.js)
|
||||
const registry = new CommandRegistry();
|
||||
registry.buildRegistry();
|
||||
|
||||
function generatePrompt(cmd, state) {
|
||||
const cmdMeta = registry.getCommand(cmd.command);
|
||||
|
||||
let prompt = `任务: ${state.task_description}\n`;
|
||||
|
||||
if (state.execution_results.length > 0) {
|
||||
const previousOutputs = state.execution_results
|
||||
.filter(r => r.status === 'success')
|
||||
.map(r => {
|
||||
if (r.summary?.session) {
|
||||
return `- ${r.command}: ${r.summary.session} (${r.summary.files?.join(', ')})`;
|
||||
}
|
||||
return `- ${r.command}: 已完成`;
|
||||
})
|
||||
.join('\n');
|
||||
|
||||
prompt += `\n前序完成:\n${previousOutputs}\n`;
|
||||
}
|
||||
|
||||
// 从 YAML 头提取命令元数据
|
||||
if (cmdMeta) {
|
||||
prompt += `\n命令: ${cmd.command}`;
|
||||
if (cmdMeta.argumentHint) {
|
||||
prompt += ` ${cmdMeta.argumentHint}`;
|
||||
}
|
||||
}
|
||||
|
||||
return prompt;
|
||||
}
|
||||
```
|
||||
|
||||
### 产物追踪
|
||||
|
||||
每个命令执行后自动提取关键产物:
|
||||
|
||||
```javascript
|
||||
{
|
||||
command: "/workflow:lite-plan",
|
||||
status: "success",
|
||||
output: "...",
|
||||
summary: {
|
||||
session: "WFS-plan-20250123", // 会话ID
|
||||
files: [".workflow/IMPL_PLAN.md"], // 产物文件
|
||||
timestamp: "2025-01-23T10:30:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 命令调用示例
|
||||
|
||||
```bash
|
||||
# 自动生成的智能提示词
|
||||
ccw cli -p "任务: 实现用户认证功能
|
||||
|
||||
前序完成:
|
||||
- /workflow:lite-plan: WFS-plan-20250123 (.workflow/IMPL_PLAN.md)
|
||||
|
||||
命令: /workflow:lite-execute [--resume-session=\"session-id\"]" /workflow:lite-execute
|
||||
```
|
||||
|
||||
### 命令注册表集成
|
||||
|
||||
- **位置**: `tools/command-registry.js` (skill 内置)
|
||||
- **工作模式**: 按需提取(只提取用户任务链中的命令)
|
||||
- **功能**: 自动查找全局 `.claude/commands/workflow` 目录,解析命令 YAML 头元数据
|
||||
- **作用**: 确保提示词包含准确的命令参数和上下文
|
||||
|
||||
详见 `tools/README.md`
|
||||
|
||||
---
|
||||
|
||||
## Execution Flow
|
||||
|
||||
### Orchestrator Execution Loop
|
||||
|
||||
```javascript
|
||||
1. 初始化会话
|
||||
↓
|
||||
2. 循环执行直到完成
|
||||
├─ 读取当前状态
|
||||
├─ 选择下一个动作(根据状态和用户意图)
|
||||
├─ 执行动作,更新状态
|
||||
└─ 检查终止条件
|
||||
↓
|
||||
3. 生成最终报告
|
||||
```
|
||||
|
||||
### Action Sequence (Typical)
|
||||
|
||||
```
|
||||
action-init
|
||||
↓ (status: pending → running)
|
||||
action-command-selection (可重复)
|
||||
↓ (添加命令到链)
|
||||
action-command-build (可选)
|
||||
↓ (调整命令顺序)
|
||||
action-command-execute
|
||||
↓ (依次执行所有命令)
|
||||
action-complete
|
||||
↓ (status: running → completed)
|
||||
```
|
||||
|
||||
## State Management
|
||||
|
||||
### Initial State
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "pending",
|
||||
"task_description": "",
|
||||
"command_chain": [],
|
||||
"confirmed": false,
|
||||
"error_count": 0,
|
||||
"execution_results": [],
|
||||
"current_command_index": 0,
|
||||
"started_at": null
|
||||
}
|
||||
```
|
||||
|
||||
### State Transitions
|
||||
|
||||
```
|
||||
pending → running (init) → running → completed (execute)
|
||||
↓
|
||||
aborted (error or user exit)
|
||||
```
|
||||
|
||||
## Directory Setup
|
||||
|
||||
```javascript
|
||||
const timestamp = new Date().toISOString().slice(0,19).replace(/[-:T]/g, '');
|
||||
const workDir = `.workflow/.ccw-coordinator/${timestamp}`;
|
||||
|
||||
Bash(`mkdir -p "${workDir}"`);
|
||||
Bash(`mkdir -p "${workDir}/commands"`);
|
||||
Bash(`mkdir -p "${workDir}/logs"`);
|
||||
```
|
||||
|
||||
## Output Structure
|
||||
|
||||
```
|
||||
.workflow/.ccw-coordinator/{timestamp}/
|
||||
├── state.json # 当前会话状态
|
||||
├── command-chain.json # 编排的完整命令链
|
||||
├── execution-log.md # 执行日志
|
||||
├── final-summary.md # 最终报告
|
||||
├── commands/ # 各命令执行详情
|
||||
│ ├── 01-command.log
|
||||
│ ├── 02-command.log
|
||||
│ └── ...
|
||||
└── logs/ # 错误和警告日志
|
||||
├── errors.log
|
||||
└── warnings.log
|
||||
```
|
||||
|
||||
## Reference Documents
|
||||
|
||||
| Document | Purpose |
|
||||
|----------|---------|
|
||||
| [phases/orchestrator.md](phases/orchestrator.md) | 编排器实现 |
|
||||
| [phases/state-schema.md](phases/state-schema.md) | 状态结构定义 |
|
||||
| [phases/actions/action-init.md](phases/actions/action-init.md) | 初始化动作 |
|
||||
| [phases/actions/action-command-selection.md](phases/actions/action-command-selection.md) | 命令选择动作 |
|
||||
| [phases/actions/action-command-build.md](phases/actions/action-command-build.md) | 命令编排动作 |
|
||||
| [phases/actions/action-command-execute.md](phases/actions/action-command-execute.md) | 命令执行动作 |
|
||||
| [phases/actions/action-complete.md](phases/actions/action-complete.md) | 完成动作 |
|
||||
| [phases/actions/action-abort.md](phases/actions/action-abort.md) | 中止动作 |
|
||||
| [specs/specs.md](specs/specs.md) | 命令库、验证规则、注册表 |
|
||||
| [tools/chain-validate.cjs](tools/chain-validate.cjs) | 验证工具 |
|
||||
| [tools/command-registry.cjs](tools/command-registry.cjs) | 命令注册表工具 |
|
||||
|
||||
---
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### 快速命令链
|
||||
|
||||
用户想要执行:规划 → 执行 → 测试
|
||||
|
||||
```
|
||||
1. 触发 /ccw-coordinator
|
||||
2. 描述任务:"实现用户注册功能"
|
||||
3. Claude推荐: plan → execute → test-cycle-execute
|
||||
4. 用户确认
|
||||
5. 执行命令链
|
||||
```
|
||||
|
||||
### 复杂工作流
|
||||
|
||||
用户想要执行:规划 → 执行 → 审查 → 修复
|
||||
|
||||
```
|
||||
1. 触发 /ccw-coordinator
|
||||
2. 描述任务:"重构认证模块"
|
||||
3. Claude推荐: plan → execute → review-session-cycle → review-fix
|
||||
4. 用户可调整命令顺序
|
||||
5. 确认执行
|
||||
6. 实时查看执行进度
|
||||
```
|
||||
|
||||
### 紧急修复
|
||||
|
||||
用户想要快速修复bug
|
||||
|
||||
```
|
||||
1. 触发 /ccw-coordinator
|
||||
2. 描述任务:"修复生产环境登录bug"
|
||||
3. Claude推荐: lite-fix --hotfix → test-cycle-execute
|
||||
4. 用户确认
|
||||
5. 快速执行修复
|
||||
```
|
||||
|
||||
## Constraints and Rules
|
||||
|
||||
### 1. 命令推荐约束
|
||||
|
||||
- **智能推荐优先**: 必须先基于用户任务描述进行智能推荐,而非直接展示命令库
|
||||
- **不使用静态映射**: 禁止使用查表或硬编码的推荐逻辑(如 `if task=bug则推荐lite-fix`)
|
||||
- **推荐必须说明理由**: Claude 推荐命令链时必须解释为什么这样推荐
|
||||
- **用户保留选择权**: 推荐后,用户可选择:使用推荐/调整/手动选择
|
||||
|
||||
### 2. 验证约束
|
||||
|
||||
- **执行前必须验证**: 使用 `chain-validate.js` 验证命令链合法性
|
||||
- **不合法必须提示**: 如果验证失败,必须明确告知用户错误原因和修复方法
|
||||
- **允许用户覆盖**: 验证失败时,询问用户是否仍要执行(警告模式)
|
||||
|
||||
### 3. 执行约束
|
||||
|
||||
- **顺序执行**: 命令必须严格按照 command_chain 中的 order 顺序执行
|
||||
- **错误处理**: 单个命令失败时,询问用户:重试/跳过/中止
|
||||
- **错误上限**: 连续 3 次错误自动中止会话
|
||||
- **实时反馈**: 每个命令执行时显示进度(如 `[2/5] 执行: lite-execute`)
|
||||
|
||||
### 4. 状态管理约束
|
||||
|
||||
- **状态持久化**: 每次状态更新必须立即写入磁盘
|
||||
- **单一数据源**: 状态只保存在 `state.json`,禁止多个状态文件
|
||||
- **原子操作**: 状态更新必须使用 read-modify-write 模式,避免并发冲突
|
||||
|
||||
### 5. 用户体验约束
|
||||
|
||||
- **最小交互**: 默认使用智能推荐 + 一次确认,避免多次询问
|
||||
- **清晰输出**: 每个步骤输出必须包含:当前状态、可用选项、建议操作
|
||||
- **可恢复性**: 会话中断后,用户可从上次状态恢复
|
||||
|
||||
### 6. 禁止行为
|
||||
|
||||
- ❌ **禁止跳过推荐步骤**: 不能直接进入手动选择,必须先尝试推荐
|
||||
- ❌ **禁止静态推荐**: 不能使用 recommended-chains.json 查表
|
||||
- ❌ **禁止无验证执行**: 不能跳过链条验证直接执行
|
||||
- ❌ **禁止静默失败**: 错误必须明确报告,不能静默跳过
|
||||
|
||||
## Notes
|
||||
|
||||
- 编排器使用状态机模式,确保执行流程的可靠性
|
||||
- 所有命令链和执行结果都被持久化保存,支持后续查询和调试
|
||||
- 支持用户中途修改命令链(在执行前)
|
||||
- 执行错误会自动记录,支持重试机制
|
||||
- Claude 智能推荐基于任务分析,非查表静态推荐
|
||||
@@ -1,9 +0,0 @@
|
||||
# action-abort
|
||||
|
||||
中止会话,保存状态
|
||||
|
||||
```javascript
|
||||
updateState({ status: 'aborted' });
|
||||
|
||||
console.log(`会话已中止: ${workDir}`);
|
||||
```
|
||||
@@ -1,40 +0,0 @@
|
||||
# action-command-build
|
||||
|
||||
调整命令链顺序或删除命令
|
||||
|
||||
## 流程
|
||||
|
||||
1. 显示当前命令链
|
||||
2. 让用户调整(重新排序、删除)
|
||||
3. 确认执行
|
||||
|
||||
## 伪代码
|
||||
|
||||
```javascript
|
||||
// 显示链
|
||||
console.log('命令链:');
|
||||
state.command_chain.forEach((cmd, i) => {
|
||||
console.log(`${i+1}. ${cmd.command}`);
|
||||
});
|
||||
|
||||
// 询问用户
|
||||
const action = await AskUserQuestion({
|
||||
options: [
|
||||
'继续执行',
|
||||
'删除命令',
|
||||
'重新排序',
|
||||
'返回选择'
|
||||
]
|
||||
});
|
||||
|
||||
// 处理用户操作
|
||||
if (action === '继续执行') {
|
||||
updateState({confirmed: true, status: 'executing'});
|
||||
}
|
||||
// ... 其他操作
|
||||
```
|
||||
|
||||
## 状态更新
|
||||
|
||||
- command_chain (可能修改)
|
||||
- confirmed = true 时状态转为 executing
|
||||
@@ -1,124 +0,0 @@
|
||||
# action-command-execute
|
||||
|
||||
依次执行命令链,智能生成 ccw cli 提示词
|
||||
|
||||
## 命令注册表集成
|
||||
|
||||
```javascript
|
||||
// 从 ./tools/command-registry.cjs 按需提取命令元数据
|
||||
const CommandRegistry = require('./tools/command-registry.cjs');
|
||||
const registry = new CommandRegistry();
|
||||
|
||||
// 只提取当前任务链中的命令
|
||||
const commandNames = command_chain.map(cmd => cmd.command);
|
||||
const commandMeta = registry.getCommands(commandNames);
|
||||
```
|
||||
|
||||
## 提示词生成策略
|
||||
|
||||
```javascript
|
||||
function generatePrompt(cmd, state, commandMeta) {
|
||||
const { task_description, execution_results } = state;
|
||||
|
||||
// 获取命令元数据(从已提取的 commandMeta)
|
||||
const cmdInfo = commandMeta[cmd.command];
|
||||
|
||||
// 提取前序产物信息
|
||||
const previousOutputs = execution_results
|
||||
.filter(r => r.status === 'success')
|
||||
.map(r => {
|
||||
const summary = r.summary;
|
||||
if (summary?.session) {
|
||||
return `- ${r.command}: ${summary.session} (${summary.files?.join(', ') || '完成'})`;
|
||||
}
|
||||
return `- ${r.command}: 已完成`;
|
||||
})
|
||||
.join('\n');
|
||||
|
||||
// 根据命令类型构建提示词
|
||||
let prompt = `任务: ${task_description}\n`;
|
||||
|
||||
if (previousOutputs) {
|
||||
prompt += `\n前序完成:\n${previousOutputs}\n`;
|
||||
}
|
||||
|
||||
// 添加命令元数据上下文
|
||||
if (cmdInfo) {
|
||||
prompt += `\n命令: ${cmd.command}`;
|
||||
if (cmdInfo.argumentHint) {
|
||||
prompt += ` ${cmdInfo.argumentHint}`;
|
||||
}
|
||||
}
|
||||
|
||||
return prompt;
|
||||
}
|
||||
```
|
||||
|
||||
## 执行逻辑
|
||||
|
||||
```javascript
|
||||
for (let i = current_command_index; i < command_chain.length; i++) {
|
||||
const cmd = command_chain[i];
|
||||
|
||||
console.log(`[${i+1}/${command_chain.length}] 执行: ${cmd.command}`);
|
||||
|
||||
// 生成智能提示词
|
||||
const prompt = generatePrompt(cmd, state, commandMeta);
|
||||
|
||||
try {
|
||||
// 使用 ccw cli 执行(添加 -y 参数跳过确认)
|
||||
const result = Bash(`ccw cli -p "${prompt.replace(/"/g, '\\"')}" ${cmd.command} -y`, {
|
||||
run_in_background: true
|
||||
});
|
||||
|
||||
execution_results.push({
|
||||
command: cmd.command,
|
||||
status: result.exit_code === 0 ? 'success' : 'failed',
|
||||
exit_code: result.exit_code,
|
||||
output: result.stdout,
|
||||
summary: extractSummary(result.stdout) // 提取关键产物
|
||||
});
|
||||
|
||||
command_chain[i].status = 'completed';
|
||||
current_command_index = i + 1;
|
||||
|
||||
} catch (error) {
|
||||
error_count++;
|
||||
command_chain[i].status = 'failed';
|
||||
|
||||
if (error_count >= 3) break;
|
||||
|
||||
const action = await AskUserQuestion({
|
||||
options: ['重试', '跳过', '中止']
|
||||
});
|
||||
|
||||
if (action === '重试') i--;
|
||||
if (action === '中止') break;
|
||||
}
|
||||
|
||||
updateState({ command_chain, execution_results, current_command_index, error_count });
|
||||
}
|
||||
```
|
||||
|
||||
## 产物提取
|
||||
|
||||
```javascript
|
||||
function extractSummary(output) {
|
||||
// 从输出提取关键产物信息
|
||||
// 例如: 会话ID, 文件路径, 任务完成状态等
|
||||
const sessionMatch = output.match(/WFS-\w+-\d+/);
|
||||
const fileMatch = output.match(/\.workflow\/[^\s]+/g);
|
||||
|
||||
return {
|
||||
session: sessionMatch?.[0],
|
||||
files: fileMatch || [],
|
||||
timestamp: new Date().toISOString()
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
## 状态更新
|
||||
|
||||
- execution_results (包含 summary 产物信息)
|
||||
- command_chain[].status
|
||||
- current_command_index
|
||||
@@ -1,48 +0,0 @@
|
||||
# action-command-selection
|
||||
|
||||
## 流程
|
||||
|
||||
1. 问用户任务
|
||||
2. Claude推荐命令链
|
||||
3. 用户确认/手动选择
|
||||
4. 添加到command_chain
|
||||
|
||||
## 伪代码
|
||||
|
||||
```javascript
|
||||
// 1. 获取用户任务描述
|
||||
const taskInput = await AskUserQuestion({
|
||||
question: '请描述您的任务',
|
||||
options: [
|
||||
{ label: '手动选择命令', value: 'manual' }
|
||||
]
|
||||
});
|
||||
|
||||
// 保存任务描述到状态
|
||||
updateState({ task_description: taskInput.text || taskInput.value });
|
||||
|
||||
// 2. 若用户描述任务,Claude推荐
|
||||
if (taskInput.text) {
|
||||
console.log('推荐: ', recommendChain(taskInput.text));
|
||||
const confirm = await AskUserQuestion({
|
||||
question: '是否使用推荐链?',
|
||||
options: ['使用推荐', '调整', '手动选择']
|
||||
});
|
||||
if (confirm === '使用推荐') {
|
||||
addCommandsToChain(recommendedChain);
|
||||
updateState({ confirmed: true });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 手动选择
|
||||
const commands = loadCommandLibrary();
|
||||
const selected = await AskUserQuestion(commands);
|
||||
addToChain(selected);
|
||||
```
|
||||
|
||||
## 状态更新
|
||||
|
||||
- task_description = 用户任务描述
|
||||
- command_chain.push(newCommand)
|
||||
- 如果用户确认: confirmed = true
|
||||
@@ -1,25 +0,0 @@
|
||||
# action-complete
|
||||
|
||||
生成执行报告
|
||||
|
||||
```javascript
|
||||
const success = execution_results.filter(r => r.status === 'success').length;
|
||||
const failed = execution_results.filter(r => r.status === 'failed').length;
|
||||
const duration = Date.now() - new Date(started_at).getTime();
|
||||
|
||||
const report = `
|
||||
# 执行报告
|
||||
|
||||
- 会话: ${session_id}
|
||||
- 耗时: ${Math.round(duration/1000)}s
|
||||
- 成功: ${success}
|
||||
- 失败: ${failed}
|
||||
|
||||
## 命令详情
|
||||
|
||||
${command_chain.map((c, i) => `${i+1}. ${c.command} - ${c.status}`).join('\n')}
|
||||
`;
|
||||
|
||||
Write(`${workDir}/final-report.md`, report);
|
||||
updateState({ status: 'completed' });
|
||||
```
|
||||
@@ -1,26 +0,0 @@
|
||||
# action-init
|
||||
|
||||
初始化编排会话
|
||||
|
||||
```javascript
|
||||
const timestamp = Date.now();
|
||||
const workDir = `.workflow/.ccw-coordinator/${timestamp}`;
|
||||
|
||||
Bash(`mkdir -p "${workDir}"`);
|
||||
|
||||
const state = {
|
||||
session_id: `coord-${timestamp}`,
|
||||
status: 'running',
|
||||
started_at: new Date().toISOString(),
|
||||
task_description: '',
|
||||
command_chain: [],
|
||||
current_command_index: 0,
|
||||
execution_results: [],
|
||||
confirmed: false,
|
||||
error_count: 0
|
||||
};
|
||||
|
||||
Write(`${workDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
console.log(`会话已初始化: ${workDir}`);
|
||||
```
|
||||
@@ -1,59 +0,0 @@
|
||||
# Orchestrator
|
||||
|
||||
状态驱动编排:读状态 → 选动作 → 执行 → 更新状态
|
||||
|
||||
## 决策逻辑
|
||||
|
||||
```javascript
|
||||
function selectNextAction(state) {
|
||||
if (['completed', 'aborted'].includes(state.status)) return null;
|
||||
if (state.error_count >= 3) return 'action-abort';
|
||||
|
||||
switch (state.status) {
|
||||
case 'pending':
|
||||
return 'action-init';
|
||||
case 'running':
|
||||
return state.confirmed && state.command_chain.length > 0
|
||||
? 'action-command-execute'
|
||||
: 'action-command-selection';
|
||||
case 'executing':
|
||||
const pending = state.command_chain.filter(c => c.status === 'pending');
|
||||
return pending.length === 0 ? 'action-complete' : 'action-command-execute';
|
||||
default:
|
||||
return 'action-abort';
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 执行循环
|
||||
|
||||
```javascript
|
||||
const timestamp = Date.now();
|
||||
const workDir = `.workflow/.ccw-coordinator/${timestamp}`;
|
||||
Bash(`mkdir -p "${workDir}"`);
|
||||
|
||||
const state = {
|
||||
session_id: `coord-${timestamp}`,
|
||||
status: 'pending',
|
||||
started_at: new Date().toISOString(),
|
||||
task_description: '', // 从 action-command-selection 获取
|
||||
command_chain: [],
|
||||
current_command_index: 0,
|
||||
execution_results: [],
|
||||
confirmed: false,
|
||||
error_count: 0
|
||||
};
|
||||
Write(`${workDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
let iterations = 0;
|
||||
while (iterations < 50) {
|
||||
const state = JSON.parse(Read(`${workDir}/state.json`));
|
||||
const nextAction = selectNextAction(state);
|
||||
if (!nextAction) break;
|
||||
|
||||
console.log(`[${nextAction}]`);
|
||||
// 执行 phases/actions/{nextAction}.md
|
||||
|
||||
iterations++;
|
||||
}
|
||||
```
|
||||
@@ -1,66 +0,0 @@
|
||||
# State Schema
|
||||
|
||||
```typescript
|
||||
interface State {
|
||||
session_id: string;
|
||||
status: 'pending' | 'running' | 'executing' | 'completed' | 'aborted';
|
||||
started_at: string;
|
||||
task_description: string; // 用户任务描述
|
||||
command_chain: Command[];
|
||||
current_command_index: number;
|
||||
execution_results: ExecutionResult[];
|
||||
confirmed: boolean;
|
||||
error_count: number;
|
||||
}
|
||||
|
||||
interface Command {
|
||||
id: string;
|
||||
order: number;
|
||||
command: string;
|
||||
status: 'pending' | 'running' | 'completed' | 'failed';
|
||||
result?: ExecutionResult;
|
||||
}
|
||||
|
||||
interface ExecutionResult {
|
||||
command: string;
|
||||
status: 'success' | 'failed';
|
||||
exit_code: number;
|
||||
output?: string;
|
||||
summary?: { // 提取的关键产物
|
||||
session?: string;
|
||||
files?: string[];
|
||||
timestamp: string;
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
## 状态转移
|
||||
|
||||
```
|
||||
pending → running → executing → completed
|
||||
↓ ↓
|
||||
(abort) (error → abort)
|
||||
```
|
||||
|
||||
## 初始化
|
||||
|
||||
```javascript
|
||||
{
|
||||
session_id: generateId(),
|
||||
status: 'pending',
|
||||
started_at: new Date().toISOString(),
|
||||
task_description: '', // 从用户输入获取
|
||||
command_chain: [],
|
||||
current_command_index: 0,
|
||||
execution_results: [],
|
||||
confirmed: false,
|
||||
error_count: 0
|
||||
}
|
||||
```
|
||||
|
||||
## 更新
|
||||
|
||||
- 添加命令: `command_chain.push(cmd)`
|
||||
- 确认执行: `confirmed = true, status = 'executing'`
|
||||
- 记录执行: `execution_results.push(...), current_command_index++`
|
||||
- 错误计数: `error_count++`
|
||||
@@ -1,66 +0,0 @@
|
||||
{
|
||||
"skill_name": "ccw-coordinator",
|
||||
"display_name": "CCW Coordinator",
|
||||
"description": "Interactive command orchestration - select, build, and execute workflow command chains",
|
||||
"execution_mode": "autonomous",
|
||||
"version": "1.0.0",
|
||||
"triggers": [
|
||||
"coordinator",
|
||||
"ccw-coordinator",
|
||||
"命令编排",
|
||||
"command chain"
|
||||
],
|
||||
"allowed_tools": [
|
||||
"Task",
|
||||
"AskUserQuestion",
|
||||
"Read",
|
||||
"Write",
|
||||
"Bash"
|
||||
],
|
||||
"actions": [
|
||||
{
|
||||
"id": "action-init",
|
||||
"name": "Init",
|
||||
"description": "Initialize orchestration session"
|
||||
},
|
||||
{
|
||||
"id": "action-command-selection",
|
||||
"name": "Select Commands",
|
||||
"description": "Interactive command selection from library"
|
||||
},
|
||||
{
|
||||
"id": "action-command-build",
|
||||
"name": "Build Chain",
|
||||
"description": "Adjust and confirm command chain"
|
||||
},
|
||||
{
|
||||
"id": "action-command-execute",
|
||||
"name": "Execute",
|
||||
"description": "Execute command chain sequentially"
|
||||
},
|
||||
{
|
||||
"id": "action-complete",
|
||||
"name": "Complete",
|
||||
"description": "Generate final report"
|
||||
},
|
||||
{
|
||||
"id": "action-abort",
|
||||
"name": "Abort",
|
||||
"description": "Abort session and save state"
|
||||
}
|
||||
],
|
||||
"termination_conditions": [
|
||||
"user_exit",
|
||||
"task_completed",
|
||||
"error"
|
||||
],
|
||||
"output": {
|
||||
"location": ".workflow/.ccw-coordinator/{timestamp}",
|
||||
"artifacts": [
|
||||
"state.json",
|
||||
"command-chain.json",
|
||||
"execution-log.md",
|
||||
"final-report.md"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
# Command Library
|
||||
|
||||
CCW Coordinator 支持的命令库。基于 CCW workflow 命令系统。
|
||||
|
||||
## Command Categories
|
||||
|
||||
### Planning Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:lite-plan` | 轻量级规划 | L2 |
|
||||
| `/workflow:plan` | 标准规划 | L3 |
|
||||
| `/workflow:multi-cli-plan` | 多CLI协作规划 | L2 |
|
||||
| `/workflow:brainstorm:auto-parallel` | 头脑风暴规划 | L4 |
|
||||
| `/workflow:tdd-plan` | TDD规划 | L3 |
|
||||
|
||||
### Execution Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:lite-execute` | 轻量级执行 | L2 |
|
||||
| `/workflow:execute` | 标准执行 | L3 |
|
||||
| `/workflow:test-cycle-execute` | 测试循环执行 | L3 |
|
||||
|
||||
### BugFix Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:lite-fix` | 轻量级修复 | L2 |
|
||||
| `/workflow:lite-fix --hotfix` | 紧急修复 | L2 |
|
||||
|
||||
### Testing Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:test-gen` | 测试生成 | L3 |
|
||||
| `/workflow:test-fix-gen` | 测试修复生成 | L3 |
|
||||
| `/workflow:tdd-verify` | TDD验证 | L3 |
|
||||
|
||||
### Review Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:review-session-cycle` | 会话审查 | L3 |
|
||||
| `/workflow:review-module-cycle` | 模块审查 | L3 |
|
||||
| `/workflow:review-fix` | 审查修复 | L3 |
|
||||
| `/workflow:plan-verify` | 计划验证 | L3 |
|
||||
|
||||
### Documentation Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/memory:docs` | 生成文档 | L2 |
|
||||
| `/memory:update-related` | 更新相关文档 | L2 |
|
||||
| `/memory:update-full` | 全面更新文档 | L2 |
|
||||
|
||||
### Issue Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/issue:discover` | 发现Issue | Supplementary |
|
||||
| `/issue:discover-by-prompt` | 基于提示发现Issue | Supplementary |
|
||||
| `/issue:plan --all-pending` | 规划所有待处理Issue | Supplementary |
|
||||
| `/issue:queue` | 排队Issue | Supplementary |
|
||||
| `/issue:execute` | 执行Issue | Supplementary |
|
||||
|
||||
## Command Chains (Recommended)
|
||||
|
||||
### 标准开发流程
|
||||
|
||||
```
|
||||
1. /workflow:lite-plan
|
||||
2. /workflow:lite-execute
|
||||
3. /workflow:test-cycle-execute
|
||||
```
|
||||
|
||||
### 完整规划流程
|
||||
|
||||
```
|
||||
1. /workflow:plan
|
||||
2. /workflow:plan-verify
|
||||
3. /workflow:execute
|
||||
4. /workflow:review-session-cycle
|
||||
```
|
||||
|
||||
### TDD 流程
|
||||
|
||||
```
|
||||
1. /workflow:tdd-plan
|
||||
2. /workflow:execute
|
||||
3. /workflow:tdd-verify
|
||||
```
|
||||
|
||||
### Issue 批处理流程
|
||||
|
||||
```
|
||||
1. /issue:plan --all-pending
|
||||
2. /issue:queue
|
||||
3. /issue:execute
|
||||
```
|
||||
|
||||
## JSON Format
|
||||
|
||||
```json
|
||||
{
|
||||
"workflow_commands": [
|
||||
{
|
||||
"category": "Planning",
|
||||
"commands": [
|
||||
{ "name": "/workflow:lite-plan", "description": "轻量级规划" },
|
||||
{ "name": "/workflow:plan", "description": "标准规划" },
|
||||
{ "name": "/workflow:multi-cli-plan", "description": "多CLI协作规划" },
|
||||
{ "name": "/workflow:brainstorm:auto-parallel", "description": "头脑风暴" },
|
||||
{ "name": "/workflow:tdd-plan", "description": "TDD规划" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"category": "Execution",
|
||||
"commands": [
|
||||
{ "name": "/workflow:lite-execute", "description": "轻量级执行" },
|
||||
{ "name": "/workflow:execute", "description": "标准执行" },
|
||||
{ "name": "/workflow:test-cycle-execute", "description": "测试循环执行" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"category": "BugFix",
|
||||
"commands": [
|
||||
{ "name": "/workflow:lite-fix", "description": "轻量级修复" },
|
||||
{ "name": "/workflow:lite-fix --hotfix", "description": "紧急修复" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"category": "Testing",
|
||||
"commands": [
|
||||
{ "name": "/workflow:test-gen", "description": "测试生成" },
|
||||
{ "name": "/workflow:test-fix-gen", "description": "测试修复" },
|
||||
{ "name": "/workflow:tdd-verify", "description": "TDD验证" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"category": "Review",
|
||||
"commands": [
|
||||
{ "name": "/workflow:review-session-cycle", "description": "会话审查" },
|
||||
{ "name": "/workflow:review-module-cycle", "description": "模块审查" },
|
||||
{ "name": "/workflow:review-fix", "description": "审查修复" },
|
||||
{ "name": "/workflow:plan-verify", "description": "计划验证" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"category": "Documentation",
|
||||
"commands": [
|
||||
{ "name": "/memory:docs", "description": "生成文档" },
|
||||
{ "name": "/memory:update-related", "description": "更新相关文档" },
|
||||
{ "name": "/memory:update-full", "description": "全面更新文档" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"category": "Issues",
|
||||
"commands": [
|
||||
{ "name": "/issue:discover", "description": "发现Issue" },
|
||||
{ "name": "/issue:discover-by-prompt", "description": "基于提示发现Issue" },
|
||||
{ "name": "/issue:plan --all-pending", "description": "规划所有待处理Issue" },
|
||||
{ "name": "/issue:queue", "description": "排队Issue" },
|
||||
{ "name": "/issue:execute", "description": "执行Issue" }
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -1,362 +0,0 @@
|
||||
# CCW Coordinator Specifications
|
||||
|
||||
命令库、验证规则和注册表一体化规范。
|
||||
|
||||
---
|
||||
|
||||
## 命令库
|
||||
|
||||
### Planning Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:lite-plan` | 轻量级规划 | L2 |
|
||||
| `/workflow:plan` | 标准规划 | L3 |
|
||||
| `/workflow:multi-cli-plan` | 多CLI协作规划 | L2 |
|
||||
| `/workflow:brainstorm:auto-parallel` | 头脑风暴规划 | L4 |
|
||||
| `/workflow:tdd-plan` | TDD规划 | L3 |
|
||||
|
||||
### Execution Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:lite-execute` | 轻量级执行 | L2 |
|
||||
| `/workflow:execute` | 标准执行 | L3 |
|
||||
| `/workflow:test-cycle-execute` | 测试循环执行 | L3 |
|
||||
|
||||
### BugFix Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:lite-fix` | 轻量级修复 | L2 |
|
||||
| `/workflow:lite-fix --hotfix` | 紧急修复 | L2 |
|
||||
|
||||
### Testing Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:test-gen` | 测试生成 | L3 |
|
||||
| `/workflow:test-fix-gen` | 测试修复生成 | L3 |
|
||||
| `/workflow:tdd-verify` | TDD验证 | L3 |
|
||||
|
||||
### Review Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/workflow:review-session-cycle` | 会话审查 | L3 |
|
||||
| `/workflow:review-module-cycle` | 模块审查 | L3 |
|
||||
| `/workflow:review-fix` | 审查修复 | L3 |
|
||||
| `/workflow:plan-verify` | 计划验证 | L3 |
|
||||
|
||||
### Documentation Commands
|
||||
|
||||
| Command | Description | Level |
|
||||
|---------|-------------|-------|
|
||||
| `/memory:docs` | 生成文档 | L2 |
|
||||
| `/memory:update-related` | 更新相关文档 | L2 |
|
||||
| `/memory:update-full` | 全面更新文档 | L2 |
|
||||
|
||||
### Issue Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/issue:discover` | 发现Issue |
|
||||
| `/issue:discover-by-prompt` | 基于提示发现Issue |
|
||||
| `/issue:plan --all-pending` | 规划所有待处理Issue |
|
||||
| `/issue:queue` | 排队Issue |
|
||||
| `/issue:execute` | 执行Issue |
|
||||
|
||||
---
|
||||
|
||||
## 命令链推荐
|
||||
|
||||
### 标准开发流程
|
||||
|
||||
```
|
||||
1. /workflow:lite-plan
|
||||
2. /workflow:lite-execute
|
||||
3. /workflow:test-cycle-execute
|
||||
```
|
||||
|
||||
### 完整规划流程
|
||||
|
||||
```
|
||||
1. /workflow:plan
|
||||
2. /workflow:plan-verify
|
||||
3. /workflow:execute
|
||||
4. /workflow:review-session-cycle
|
||||
```
|
||||
|
||||
### TDD 流程
|
||||
|
||||
```
|
||||
1. /workflow:tdd-plan
|
||||
2. /workflow:execute
|
||||
3. /workflow:tdd-verify
|
||||
```
|
||||
|
||||
### Issue 批处理流程
|
||||
|
||||
```
|
||||
1. /issue:plan --all-pending
|
||||
2. /issue:queue
|
||||
3. /issue:execute
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 验证规则
|
||||
|
||||
### Rule 1: Single Planning Command
|
||||
|
||||
每条链最多包含一个规划命令。
|
||||
|
||||
| 有效 | 无效 |
|
||||
|------|------|
|
||||
| `plan → execute` | `plan → lite-plan → execute` |
|
||||
|
||||
### Rule 2: Compatible Pairs
|
||||
|
||||
规划和执行命令必须兼容。
|
||||
|
||||
| Planning | Execution | 兼容 |
|
||||
|----------|-----------|------|
|
||||
| lite-plan | lite-execute | ✓ |
|
||||
| lite-plan | execute | ✗ |
|
||||
| multi-cli-plan | lite-execute | ✓ |
|
||||
| multi-cli-plan | execute | ✓ |
|
||||
| plan | execute | ✓ |
|
||||
| plan | lite-execute | ✗ |
|
||||
| tdd-plan | execute | ✓ |
|
||||
| tdd-plan | lite-execute | ✗ |
|
||||
|
||||
### Rule 3: Testing After Execution
|
||||
|
||||
测试命令必须在执行命令之后。
|
||||
|
||||
| 有效 | 无效 |
|
||||
|------|------|
|
||||
| `execute → test-cycle-execute` | `test-cycle-execute → execute` |
|
||||
|
||||
### Rule 4: Review After Execution
|
||||
|
||||
审查命令必须在执行命令之后。
|
||||
|
||||
| 有效 | 无效 |
|
||||
|------|------|
|
||||
| `execute → review-session-cycle` | `review-session-cycle → execute` |
|
||||
|
||||
### Rule 5: BugFix Standalone
|
||||
|
||||
`lite-fix` 必须单独执行,不能与其他命令组合。
|
||||
|
||||
| 有效 | 无效 |
|
||||
|------|------|
|
||||
| `lite-fix` | `plan → lite-fix → execute` |
|
||||
| `lite-fix --hotfix` | `lite-fix → test-cycle-execute` |
|
||||
|
||||
### Rule 6: Dependency Satisfaction
|
||||
|
||||
每个命令的依赖必须在前面执行。
|
||||
|
||||
```javascript
|
||||
test-fix-gen → test-cycle-execute ✓
|
||||
test-cycle-execute ✗
|
||||
```
|
||||
|
||||
### Rule 7: No Redundancy
|
||||
|
||||
链条中不能有重复的命令。
|
||||
|
||||
| 有效 | 无效 |
|
||||
|------|------|
|
||||
| `plan → execute → test` | `plan → plan → execute` |
|
||||
|
||||
### Rule 8: Command Exists
|
||||
|
||||
所有命令必须在此规范中定义。
|
||||
|
||||
---
|
||||
|
||||
## 反模式(避免)
|
||||
|
||||
### ❌ Pattern 1: Multiple Planning
|
||||
|
||||
```
|
||||
plan → lite-plan → execute
|
||||
```
|
||||
**问题**: 重复分析,浪费时间
|
||||
**修复**: 选一个规划命令
|
||||
|
||||
### ❌ Pattern 2: Test Without Context
|
||||
|
||||
```
|
||||
test-cycle-execute (独立执行)
|
||||
```
|
||||
**问题**: 没有执行上下文,无法工作
|
||||
**修复**: 先执行 `execute` 或 `test-fix-gen`
|
||||
|
||||
### ❌ Pattern 3: BugFix with Planning
|
||||
|
||||
```
|
||||
plan → execute → lite-fix
|
||||
```
|
||||
**问题**: lite-fix 是独立命令,不应与规划混合
|
||||
**修复**: 用 `lite-fix` 单独修复,或用 `plan → execute` 做大改
|
||||
|
||||
### ❌ Pattern 4: Review Without Changes
|
||||
|
||||
```
|
||||
review-session-cycle (独立执行)
|
||||
```
|
||||
**问题**: 没有 git 改动可审查
|
||||
**修复**: 先执行 `execute` 生成改动
|
||||
|
||||
### ❌ Pattern 5: TDD Misuse
|
||||
|
||||
```
|
||||
tdd-plan → lite-execute
|
||||
```
|
||||
**问题**: lite-execute 无法处理 TDD 任务结构
|
||||
**修复**: 用 `tdd-plan → execute → tdd-verify`
|
||||
|
||||
---
|
||||
|
||||
## 命令注册表
|
||||
|
||||
### 命令元数据结构
|
||||
|
||||
```json
|
||||
{
|
||||
"command_name": {
|
||||
"category": "Planning|Execution|Testing|Review|BugFix|Maintenance",
|
||||
"level": "L0|L1|L2|L3",
|
||||
"description": "命令描述",
|
||||
"inputs": ["input1", "input2"],
|
||||
"outputs": ["output1", "output2"],
|
||||
"dependencies": ["依赖命令"],
|
||||
"parameters": [
|
||||
{"name": "--flag", "type": "string|boolean|number", "default": "value"}
|
||||
],
|
||||
"chain_position": "start|middle|middle_or_end|end|standalone",
|
||||
"next_recommended": ["推荐的下一个命令"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 命令分组
|
||||
|
||||
| Group | Commands |
|
||||
|-------|----------|
|
||||
| planning | lite-plan, multi-cli-plan, plan, tdd-plan |
|
||||
| execution | lite-execute, execute, develop-with-file |
|
||||
| testing | test-gen, test-fix-gen, test-cycle-execute, tdd-verify |
|
||||
| review | review-session-cycle, review-module-cycle, review-fix |
|
||||
| bugfix | lite-fix, debug, debug-with-file |
|
||||
| maintenance | clean, replan |
|
||||
| verification | plan-verify, tdd-verify |
|
||||
|
||||
### 兼容性矩阵
|
||||
|
||||
| 组合 | 状态 |
|
||||
|------|------|
|
||||
| lite-plan + lite-execute | ✓ compatible |
|
||||
| lite-plan + execute | ✗ incompatible - use plan |
|
||||
| multi-cli-plan + lite-execute | ✓ compatible |
|
||||
| plan + execute | ✓ compatible |
|
||||
| plan + lite-execute | ✗ incompatible - use lite-plan |
|
||||
| tdd-plan + execute | ✓ compatible |
|
||||
| execute + test-cycle-execute | ✓ compatible |
|
||||
| lite-execute + test-cycle-execute | ✓ compatible |
|
||||
| test-fix-gen + test-cycle-execute | ✓ required |
|
||||
| review-session-cycle + review-fix | ✓ compatible |
|
||||
| lite-fix + test-cycle-execute | ✗ incompatible - lite-fix standalone |
|
||||
|
||||
---
|
||||
|
||||
## 验证工具
|
||||
|
||||
### chain-validate.cjs
|
||||
|
||||
位置: `tools/chain-validate.cjs`
|
||||
|
||||
验证命令链合法性:
|
||||
|
||||
```bash
|
||||
node tools/chain-validate.cjs plan execute test-cycle-execute
|
||||
```
|
||||
|
||||
输出:
|
||||
```
|
||||
{
|
||||
"valid": true,
|
||||
"errors": [],
|
||||
"warnings": []
|
||||
}
|
||||
```
|
||||
|
||||
## 命令注册表
|
||||
|
||||
### 工具位置
|
||||
|
||||
位置: `tools/command-registry.cjs` (skill 内置)
|
||||
|
||||
### 工作模式
|
||||
|
||||
**按需提取**: 只提取用户确定的任务链中的命令,不是全量扫描。
|
||||
|
||||
```javascript
|
||||
// 用户任务链: [lite-plan, lite-execute]
|
||||
const commandNames = command_chain.map(cmd => cmd.command);
|
||||
const commandMeta = registry.getCommands(commandNames);
|
||||
// 只提取这 2 个命令的元数据
|
||||
```
|
||||
|
||||
### 功能
|
||||
|
||||
- 自动查找全局 `.claude/commands/workflow` 目录(相对路径 > 用户 home)
|
||||
- 按需提取指定命令的 YAML 头元数据
|
||||
- 缓存机制避免重复读取
|
||||
- 提供批量查询接口
|
||||
|
||||
### 集成方式
|
||||
|
||||
在 action-command-execute 中自动集成:
|
||||
|
||||
```javascript
|
||||
const CommandRegistry = require('./tools/command-registry.cjs');
|
||||
const registry = new CommandRegistry();
|
||||
|
||||
// 只提取任务链中的命令
|
||||
const commandNames = command_chain.map(cmd => cmd.command);
|
||||
const commandMeta = registry.getCommands(commandNames);
|
||||
|
||||
// 使用元数据生成提示词
|
||||
const cmdInfo = commandMeta[cmd.command];
|
||||
// {
|
||||
// name: 'lite-plan',
|
||||
// description: '轻量级规划...',
|
||||
// argumentHint: '[-e|--explore] "task description"',
|
||||
// allowedTools: [...],
|
||||
// filePath: '...'
|
||||
// }
|
||||
```
|
||||
|
||||
### 提示词生成
|
||||
|
||||
智能提示词自动包含:
|
||||
|
||||
1. **任务上下文**: 用户任务描述
|
||||
2. **前序产物**: 已完成命令的产物信息
|
||||
3. **命令元数据**: 命令的参数提示和描述
|
||||
|
||||
```
|
||||
任务: 实现用户注册功能
|
||||
|
||||
前序完成:
|
||||
- /workflow:lite-plan: WFS-plan-001 (IMPL_PLAN.md)
|
||||
|
||||
命令: /workflow:lite-execute [--resume-session="session-id"]
|
||||
```
|
||||
|
||||
详见 `tools/README.md`。
|
||||
@@ -1,95 +0,0 @@
|
||||
# CCW Coordinator Tools
|
||||
|
||||
## command-registry.cjs
|
||||
|
||||
命令注册表工具:获取和提取命令元数据。
|
||||
|
||||
### 功能
|
||||
|
||||
- **按需提取**: 只提取指定命令的完整信息(name, description, argumentHint, allowedTools 等)
|
||||
- **全量获取**: 获取所有命令的名称和描述(快速查询)
|
||||
- **自动查找**: 从全局 `.claude/commands/workflow` 目录读取(项目相对路径 > 用户 home)
|
||||
- **缓存机制**: 避免重复读取文件
|
||||
|
||||
### 编程接口
|
||||
|
||||
```javascript
|
||||
const CommandRegistry = require('./tools/command-registry.cjs');
|
||||
const registry = new CommandRegistry();
|
||||
|
||||
// 1. 获取所有命令的名称和描述(快速)
|
||||
const allCommands = registry.getAllCommandsSummary();
|
||||
// {
|
||||
// "/workflow:lite-plan": {
|
||||
// name: 'lite-plan',
|
||||
// description: '轻量级规划...'
|
||||
// },
|
||||
// "/workflow:lite-execute": { ... }
|
||||
// }
|
||||
|
||||
// 2. 按需提取指定命令的完整信息
|
||||
const commands = registry.getCommands([
|
||||
'/workflow:lite-plan',
|
||||
'/workflow:lite-execute'
|
||||
]);
|
||||
// {
|
||||
// "/workflow:lite-plan": {
|
||||
// name: 'lite-plan',
|
||||
// description: '...',
|
||||
// argumentHint: '[-e|--explore] "task description"',
|
||||
// allowedTools: [...],
|
||||
// filePath: '...'
|
||||
// },
|
||||
// ...
|
||||
// }
|
||||
```
|
||||
|
||||
### 命令行接口
|
||||
|
||||
```bash
|
||||
# 获取所有命令的名称和描述
|
||||
node .claude/skills/ccw-coordinator/tools/command-registry.cjs
|
||||
node .claude/skills/ccw-coordinator/tools/command-registry.cjs --all
|
||||
|
||||
# 输出: 23 个命令的简明列表 (name + description)
|
||||
```
|
||||
|
||||
```bash
|
||||
# 按需提取指定命令的完整信息
|
||||
node .claude/skills/ccw-coordinator/tools/command-registry.cjs lite-plan lite-execute
|
||||
|
||||
# 输出: 完整信息 (name, description, argumentHint, allowedTools, filePath)
|
||||
```
|
||||
|
||||
### 集成用途
|
||||
|
||||
在 `action-command-execute` 中使用:
|
||||
|
||||
```javascript
|
||||
// 1. 初始化时只提取任务链中的命令(完整信息)
|
||||
const commandNames = command_chain.map(cmd => cmd.command);
|
||||
const commandMeta = registry.getCommands(commandNames);
|
||||
|
||||
// 2. 生成提示词时使用
|
||||
function generatePrompt(cmd, state, commandMeta) {
|
||||
const cmdInfo = commandMeta[cmd.command];
|
||||
let prompt = `任务: ${state.task_description}\n`;
|
||||
|
||||
if (cmdInfo?.argumentHint) {
|
||||
prompt += `命令: ${cmd.command} ${cmdInfo.argumentHint}`;
|
||||
}
|
||||
|
||||
return prompt;
|
||||
}
|
||||
```
|
||||
|
||||
确保 `ccw cli -p "..."` 提示词包含准确的命令参数提示。
|
||||
|
||||
### 目录查找逻辑
|
||||
|
||||
自动查找顺序:
|
||||
1. `.claude/commands/workflow` (相对于当前工作目录)
|
||||
2. `~/.claude/commands/workflow` (用户 home 目录)
|
||||
|
||||
|
||||
|
||||
@@ -1,320 +0,0 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
/**
|
||||
* Chain Validation Tool
|
||||
*
|
||||
* Validates workflow command chains against defined rules.
|
||||
*
|
||||
* Usage:
|
||||
* node chain-validate.js plan execute test-cycle-execute
|
||||
* node chain-validate.js --json "plan,execute,test-cycle-execute"
|
||||
* node chain-validate.js --file custom-chain.json
|
||||
*/
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
// Optional registry loading - gracefully degrade if not found
|
||||
let registry = null;
|
||||
try {
|
||||
const registryPath = path.join(__dirname, '..', 'specs', 'chain-registry.json');
|
||||
if (fs.existsSync(registryPath)) {
|
||||
registry = JSON.parse(fs.readFileSync(registryPath, 'utf8'));
|
||||
}
|
||||
} catch (error) {
|
||||
// Registry not available - dependency validation will be skipped
|
||||
}
|
||||
|
||||
class ChainValidator {
|
||||
constructor(registry) {
|
||||
this.registry = registry;
|
||||
this.errors = [];
|
||||
this.warnings = [];
|
||||
}
|
||||
|
||||
validate(chain) {
|
||||
this.errors = [];
|
||||
this.warnings = [];
|
||||
|
||||
this.validateSinglePlanning(chain);
|
||||
this.validateCompatiblePairs(chain);
|
||||
this.validateTestingPosition(chain);
|
||||
this.validateReviewPosition(chain);
|
||||
this.validateBugfixStandalone(chain);
|
||||
this.validateDependencies(chain);
|
||||
this.validateNoRedundancy(chain);
|
||||
this.validateCommandExistence(chain);
|
||||
|
||||
return {
|
||||
valid: this.errors.length === 0,
|
||||
errors: this.errors,
|
||||
warnings: this.warnings
|
||||
};
|
||||
}
|
||||
|
||||
validateSinglePlanning(chain) {
|
||||
const planningCommands = chain.filter(cmd =>
|
||||
['plan', 'lite-plan', 'multi-cli-plan', 'tdd-plan'].includes(cmd)
|
||||
);
|
||||
|
||||
if (planningCommands.length > 1) {
|
||||
this.errors.push({
|
||||
rule: 'Single Planning Command',
|
||||
message: `Too many planning commands: ${planningCommands.join(', ')}`,
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
validateCompatiblePairs(chain) {
|
||||
const compatibility = {
|
||||
'lite-plan': ['lite-execute'],
|
||||
'multi-cli-plan': ['lite-execute', 'execute'],
|
||||
'plan': ['execute'],
|
||||
'tdd-plan': ['execute']
|
||||
};
|
||||
|
||||
const planningCmd = chain.find(cmd =>
|
||||
['plan', 'lite-plan', 'multi-cli-plan', 'tdd-plan'].includes(cmd)
|
||||
);
|
||||
|
||||
const executionCmd = chain.find(cmd =>
|
||||
['execute', 'lite-execute'].includes(cmd)
|
||||
);
|
||||
|
||||
if (planningCmd && executionCmd) {
|
||||
const compatible = compatibility[planningCmd] || [];
|
||||
if (!compatible.includes(executionCmd)) {
|
||||
this.errors.push({
|
||||
rule: 'Compatible Pairs',
|
||||
message: `${planningCmd} incompatible with ${executionCmd}`,
|
||||
fix: `Use ${planningCmd} with ${compatible.join(' or ')}`,
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validateTestingPosition(chain) {
|
||||
const executionIdx = chain.findIndex(cmd =>
|
||||
['execute', 'lite-execute', 'develop-with-file'].includes(cmd)
|
||||
);
|
||||
|
||||
const testingIdx = chain.findIndex(cmd =>
|
||||
['test-cycle-execute', 'tdd-verify', 'test-gen', 'test-fix-gen'].includes(cmd)
|
||||
);
|
||||
|
||||
if (testingIdx !== -1 && executionIdx !== -1 && executionIdx > testingIdx) {
|
||||
this.errors.push({
|
||||
rule: 'Testing After Execution',
|
||||
message: 'Testing commands must come after execution',
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
|
||||
if (testingIdx !== -1 && executionIdx === -1) {
|
||||
const hasTestGen = chain.some(cmd => ['test-gen', 'test-fix-gen'].includes(cmd));
|
||||
if (!hasTestGen) {
|
||||
this.warnings.push({
|
||||
rule: 'Testing After Execution',
|
||||
message: 'test-cycle-execute without execution context - needs test-gen or execute first',
|
||||
severity: 'warning'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validateReviewPosition(chain) {
|
||||
const executionIdx = chain.findIndex(cmd =>
|
||||
['execute', 'lite-execute'].includes(cmd)
|
||||
);
|
||||
|
||||
const reviewIdx = chain.findIndex(cmd =>
|
||||
cmd.includes('review')
|
||||
);
|
||||
|
||||
if (reviewIdx !== -1 && executionIdx !== -1 && executionIdx > reviewIdx) {
|
||||
this.errors.push({
|
||||
rule: 'Review After Changes',
|
||||
message: 'Review commands must come after execution',
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
|
||||
if (reviewIdx !== -1 && executionIdx === -1) {
|
||||
const isModuleReview = chain[reviewIdx] === 'review-module-cycle';
|
||||
if (!isModuleReview) {
|
||||
this.warnings.push({
|
||||
rule: 'Review After Changes',
|
||||
message: 'Review without execution - needs git changes to review',
|
||||
severity: 'warning'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validateBugfixStandalone(chain) {
|
||||
if (chain.includes('lite-fix')) {
|
||||
const others = chain.filter(cmd => cmd !== 'lite-fix');
|
||||
if (others.length > 0) {
|
||||
this.errors.push({
|
||||
rule: 'BugFix Standalone',
|
||||
message: 'lite-fix must be standalone, cannot combine with other commands',
|
||||
fix: 'Use lite-fix alone OR use plan + execute for larger changes',
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validateDependencies(chain) {
|
||||
// Skip if registry not available
|
||||
if (!this.registry || !this.registry.commands) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (let i = 0; i < chain.length; i++) {
|
||||
const cmd = chain[i];
|
||||
const cmdMeta = this.registry.commands[cmd];
|
||||
|
||||
if (!cmdMeta) continue;
|
||||
|
||||
const deps = cmdMeta.dependencies || [];
|
||||
const depsOptional = cmdMeta.dependencies_optional || false;
|
||||
|
||||
if (deps.length > 0 && !depsOptional) {
|
||||
const hasDependency = deps.some(dep => {
|
||||
const depIdx = chain.indexOf(dep);
|
||||
return depIdx !== -1 && depIdx < i;
|
||||
});
|
||||
|
||||
if (!hasDependency) {
|
||||
this.errors.push({
|
||||
rule: 'Dependency Satisfaction',
|
||||
message: `${cmd} requires ${deps.join(' or ')} before it`,
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validateNoRedundancy(chain) {
|
||||
const seen = new Set();
|
||||
const duplicates = [];
|
||||
|
||||
for (const cmd of chain) {
|
||||
if (seen.has(cmd)) {
|
||||
duplicates.push(cmd);
|
||||
}
|
||||
seen.add(cmd);
|
||||
}
|
||||
|
||||
if (duplicates.length > 0) {
|
||||
this.errors.push({
|
||||
rule: 'No Redundant Commands',
|
||||
message: `Duplicate commands: ${duplicates.join(', ')}`,
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
validateCommandExistence(chain) {
|
||||
// Skip if registry not available
|
||||
if (!this.registry || !this.registry.commands) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const cmd of chain) {
|
||||
if (!this.registry.commands[cmd]) {
|
||||
this.errors.push({
|
||||
rule: 'Command Existence',
|
||||
message: `Unknown command: ${cmd}`,
|
||||
severity: 'error'
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function main() {
|
||||
const args = process.argv.slice(2);
|
||||
|
||||
if (args.length === 0) {
|
||||
console.log('Usage:');
|
||||
console.log(' chain-validate.js <command1> <command2> ...');
|
||||
console.log(' chain-validate.js --json "cmd1,cmd2,cmd3"');
|
||||
console.log(' chain-validate.js --file chain.json');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
let chain;
|
||||
|
||||
if (args[0] === '--json') {
|
||||
chain = args[1].split(',').map(s => s.trim());
|
||||
} else if (args[0] === '--file') {
|
||||
const filePath = args[1];
|
||||
|
||||
// SEC-001: 路径遍历验证 - 只允许访问工作目录下的文件
|
||||
const resolvedPath = path.resolve(filePath);
|
||||
const workDir = path.resolve('.');
|
||||
if (!resolvedPath.startsWith(workDir)) {
|
||||
console.error('Error: File path must be within current working directory');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// CORR-001: JSON 解析错误处理
|
||||
let fileContent;
|
||||
try {
|
||||
fileContent = JSON.parse(fs.readFileSync(resolvedPath, 'utf8'));
|
||||
} catch (error) {
|
||||
console.error(`Error: Failed to parse JSON file ${filePath}: ${error.message}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// CORR-002: 嵌套属性 null 检查
|
||||
chain = fileContent.chain || fileContent.steps?.map(s => s.command) || [];
|
||||
if (chain.length === 0) {
|
||||
console.error('Error: No valid chain found in file (expected "chain" array or "steps" with "command" fields)');
|
||||
process.exit(1);
|
||||
}
|
||||
} else {
|
||||
chain = args;
|
||||
}
|
||||
|
||||
const validator = new ChainValidator(registry);
|
||||
const result = validator.validate(chain);
|
||||
|
||||
console.log('\n=== Chain Validation Report ===\n');
|
||||
console.log('Chain:', chain.join(' → '));
|
||||
console.log('');
|
||||
|
||||
if (result.valid) {
|
||||
console.log('✓ Chain is valid!\n');
|
||||
} else {
|
||||
console.log('✗ Chain has errors:\n');
|
||||
result.errors.forEach(err => {
|
||||
console.log(` [${err.rule}] ${err.message}`);
|
||||
if (err.fix) {
|
||||
console.log(` Fix: ${err.fix}`);
|
||||
}
|
||||
});
|
||||
console.log('');
|
||||
}
|
||||
|
||||
if (result.warnings.length > 0) {
|
||||
console.log('⚠ Warnings:\n');
|
||||
result.warnings.forEach(warn => {
|
||||
console.log(` [${warn.rule}] ${warn.message}`);
|
||||
});
|
||||
console.log('');
|
||||
}
|
||||
|
||||
process.exit(result.valid ? 0 : 1);
|
||||
}
|
||||
|
||||
if (require.main === module) {
|
||||
main();
|
||||
}
|
||||
|
||||
module.exports = { ChainValidator };
|
||||
@@ -1,255 +0,0 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
/**
|
||||
* Command Registry Tool
|
||||
*
|
||||
* 功能:
|
||||
* 1. 根据命令名称查找并提取 YAML 头
|
||||
* 2. 从全局 .claude/commands/workflow 目录读取
|
||||
* 3. 支持按需提取(不是全量扫描)
|
||||
*/
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const os = require('os');
|
||||
|
||||
class CommandRegistry {
|
||||
constructor(commandDir = null) {
|
||||
// 优先使用传入的目录
|
||||
if (commandDir) {
|
||||
this.commandDir = commandDir;
|
||||
} else {
|
||||
// 自动查找 .claude/commands/workflow
|
||||
this.commandDir = this.findCommandDir();
|
||||
}
|
||||
this.cache = {};
|
||||
}
|
||||
|
||||
/**
|
||||
* 自动查找 .claude/commands/workflow 目录
|
||||
* 支持: 项目相对路径、用户 home 目录
|
||||
*/
|
||||
findCommandDir() {
|
||||
// 1. 尝试相对于当前工作目录
|
||||
const relativePath = path.join('.claude', 'commands', 'workflow');
|
||||
if (fs.existsSync(relativePath)) {
|
||||
return path.resolve(relativePath);
|
||||
}
|
||||
|
||||
// 2. 尝试用户 home 目录
|
||||
const homeDir = os.homedir();
|
||||
const homeCommandDir = path.join(homeDir, '.claude', 'commands', 'workflow');
|
||||
if (fs.existsSync(homeCommandDir)) {
|
||||
return homeCommandDir;
|
||||
}
|
||||
|
||||
// 未找到时返回 null,后续操作会失败并提示
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析 YAML 头 (简化版本)
|
||||
*
|
||||
* 限制:
|
||||
* - 只支持简单的 key: value 对 (单行值)
|
||||
* - 不支持多行值、嵌套对象、复杂列表
|
||||
* - allowed-tools 字段支持逗号分隔的字符串,自动转为数组
|
||||
*
|
||||
* 示例:
|
||||
* ---
|
||||
* name: lite-plan
|
||||
* description: "Lightweight planning workflow"
|
||||
* allowed-tools: Read, Write, Bash
|
||||
* ---
|
||||
*/
|
||||
parseYamlHeader(content) {
|
||||
// 处理 Windows 行结尾 (\r\n)
|
||||
const match = content.match(/^---[\r\n]+([\s\S]*?)[\r\n]+---/);
|
||||
if (!match) return null;
|
||||
|
||||
const yamlContent = match[1];
|
||||
const result = {};
|
||||
|
||||
try {
|
||||
const lines = yamlContent.split(/[\r\n]+/);
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || trimmed.startsWith('#')) continue; // 跳过空行和注释
|
||||
|
||||
const colonIndex = trimmed.indexOf(':');
|
||||
if (colonIndex === -1) continue;
|
||||
|
||||
const key = trimmed.substring(0, colonIndex).trim();
|
||||
const value = trimmed.substring(colonIndex + 1).trim();
|
||||
|
||||
if (!key) continue; // 跳过无效行
|
||||
|
||||
// 去除引号 (单引号或双引号)
|
||||
let cleanValue = value.replace(/^["']|["']$/g, '');
|
||||
|
||||
// allowed-tools 字段特殊处理:转为数组
|
||||
// 支持格式: "Read, Write, Bash" 或 "Read,Write,Bash"
|
||||
if (key === 'allowed-tools') {
|
||||
cleanValue = Array.isArray(cleanValue)
|
||||
? cleanValue
|
||||
: cleanValue.split(',').map(t => t.trim()).filter(t => t);
|
||||
}
|
||||
|
||||
result[key] = cleanValue;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('YAML parsing error:', error.message);
|
||||
return null;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取单个命令的元数据
|
||||
* @param {string} commandName 命令名称 (e.g., "lite-plan" 或 "/workflow:lite-plan")
|
||||
* @returns {object|null} 命令信息或 null
|
||||
*/
|
||||
getCommand(commandName) {
|
||||
if (!this.commandDir) {
|
||||
console.error('ERROR: .claude/commands/workflow 目录未找到');
|
||||
return null;
|
||||
}
|
||||
|
||||
// 标准化命令名称
|
||||
const normalized = commandName.startsWith('/workflow:')
|
||||
? commandName.substring('/workflow:'.length)
|
||||
: commandName;
|
||||
|
||||
// 检查缓存
|
||||
if (this.cache[normalized]) {
|
||||
return this.cache[normalized];
|
||||
}
|
||||
|
||||
// 读取命令文件
|
||||
const filePath = path.join(this.commandDir, `${normalized}.md`);
|
||||
if (!fs.existsSync(filePath)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const content = fs.readFileSync(filePath, 'utf-8');
|
||||
const header = this.parseYamlHeader(content);
|
||||
|
||||
if (header && header.name) {
|
||||
const result = {
|
||||
name: header.name,
|
||||
command: `/workflow:${header.name}`,
|
||||
description: header.description || '',
|
||||
argumentHint: header['argument-hint'] || '',
|
||||
allowedTools: Array.isArray(header['allowed-tools'])
|
||||
? header['allowed-tools']
|
||||
: (header['allowed-tools'] ? [header['allowed-tools']] : []),
|
||||
filePath: filePath
|
||||
};
|
||||
|
||||
// 缓存结果
|
||||
this.cache[normalized] = result;
|
||||
return result;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`读取命令失败 ${filePath}:`, error.message);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量获取多个命令的元数据
|
||||
* @param {array} commandNames 命令名称数组
|
||||
* @returns {object} 命令信息映射
|
||||
*/
|
||||
getCommands(commandNames) {
|
||||
const result = {};
|
||||
|
||||
for (const name of commandNames) {
|
||||
const cmd = this.getCommand(name);
|
||||
if (cmd) {
|
||||
result[cmd.command] = cmd;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有命令的名称和描述
|
||||
* @returns {object} 命令名称和描述的映射
|
||||
*/
|
||||
getAllCommandsSummary() {
|
||||
const result = {};
|
||||
const commandDir = this.commandDir;
|
||||
|
||||
if (!commandDir) {
|
||||
return result;
|
||||
}
|
||||
|
||||
try {
|
||||
const files = fs.readdirSync(commandDir);
|
||||
|
||||
for (const file of files) {
|
||||
if (!file.endsWith('.md')) continue;
|
||||
|
||||
const filePath = path.join(commandDir, file);
|
||||
const stat = fs.statSync(filePath);
|
||||
|
||||
if (stat.isDirectory()) continue;
|
||||
|
||||
try {
|
||||
const content = fs.readFileSync(filePath, 'utf-8');
|
||||
const header = this.parseYamlHeader(content);
|
||||
|
||||
if (header && header.name) {
|
||||
const commandName = `/workflow:${header.name}`;
|
||||
result[commandName] = {
|
||||
name: header.name,
|
||||
description: header.description || ''
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
// 跳过读取失败的文件
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// 目录读取失败
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成注册表 JSON
|
||||
*/
|
||||
toJSON(commands = null) {
|
||||
const data = commands || this.cache;
|
||||
return JSON.stringify(data, null, 2);
|
||||
}
|
||||
}
|
||||
|
||||
// CLI 模式
|
||||
if (require.main === module) {
|
||||
const args = process.argv.slice(2);
|
||||
|
||||
if (args.length === 0 || args[0] === '--all') {
|
||||
// 获取所有命令的名称和描述
|
||||
const registry = new CommandRegistry();
|
||||
const commands = registry.getAllCommandsSummary();
|
||||
console.log(JSON.stringify(commands, null, 2));
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const registry = new CommandRegistry();
|
||||
const commands = registry.getCommands(args);
|
||||
|
||||
console.log(JSON.stringify(commands, null, 2));
|
||||
}
|
||||
|
||||
module.exports = CommandRegistry;
|
||||
|
||||
@@ -1,522 +0,0 @@
|
||||
---
|
||||
name: ccw
|
||||
description: Stateless workflow orchestrator. Auto-selects optimal workflow based on task intent. Triggers "ccw", "workflow".
|
||||
allowed-tools: Task(*), SlashCommand(*), AskUserQuestion(*), Read(*), Bash(*), Grep(*), TodoWrite(*)
|
||||
---
|
||||
|
||||
# CCW - Claude Code Workflow Orchestrator
|
||||
|
||||
无状态工作流协调器,根据任务意图自动选择最优工作流。
|
||||
|
||||
## Workflow System Overview
|
||||
|
||||
CCW 提供两个工作流系统:**Main Workflow** 和 **Issue Workflow**,协同覆盖完整的软件开发生命周期。
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Main Workflow │
|
||||
│ │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ Level 1 │ → │ Level 2 │ → │ Level 3 │ → │ Level 4 │ │
|
||||
│ │ Rapid │ │ Lightweight │ │ Standard │ │ Brainstorm │ │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ │ lite-lite- │ │ lite-plan │ │ plan │ │ brainstorm │ │
|
||||
│ │ lite │ │ lite-fix │ │ tdd-plan │ │ :auto- │ │
|
||||
│ │ │ │ multi-cli- │ │ test-fix- │ │ parallel │ │
|
||||
│ │ │ │ plan │ │ gen │ │ ↓ │ │
|
||||
│ │ │ │ │ │ │ │ plan │ │
|
||||
│ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │
|
||||
│ │
|
||||
│ Complexity: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━▶ │
|
||||
│ Low High │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ After development
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Issue Workflow │
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ Accumulate │ → │ Plan │ → │ Execute │ │
|
||||
│ │ Discover & │ │ Batch │ │ Parallel │ │
|
||||
│ │ Collect │ │ Planning │ │ Execution │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ │
|
||||
│ Supplementary role: Maintain main branch stability, worktree isolation │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ CCW Orchestrator (CLI-Enhanced + Requirement Analysis) │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ Phase 1 │ Input Analysis (rule-based, fast path) │
|
||||
│ Phase 1.5 │ CLI Classification (semantic, smart path) │
|
||||
│ Phase 1.75 │ Requirement Clarification (clarity < 2) │
|
||||
│ Phase 2 │ Level Selection (intent → level → workflow) │
|
||||
│ Phase 2.5 │ CLI Action Planning (high complexity) │
|
||||
│ Phase 3 │ User Confirmation (optional) │
|
||||
│ Phase 4 │ TODO Tracking Setup │
|
||||
│ Phase 5 │ Execution Loop │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Level Quick Reference
|
||||
|
||||
| Level | Name | Workflows | Artifacts | Execution |
|
||||
|-------|------|-----------|-----------|-----------|
|
||||
| **1** | Rapid | `lite-lite-lite` | None | Direct execute |
|
||||
| **2** | Lightweight | `lite-plan`, `lite-fix`, `multi-cli-plan` | Memory/Lightweight files | → `lite-execute` |
|
||||
| **3** | Standard | `plan`, `tdd-plan`, `test-fix-gen` | Session persistence | → `execute` / `test-cycle-execute` |
|
||||
| **4** | Brainstorm | `brainstorm:auto-parallel` → `plan` | Multi-role analysis + Session | → `execute` |
|
||||
| **-** | Issue | `discover` → `plan` → `queue` → `execute` | Issue records | Worktree isolation (optional) |
|
||||
|
||||
## Workflow Selection Decision Tree
|
||||
|
||||
```
|
||||
Start
|
||||
│
|
||||
├─ Is it post-development maintenance?
|
||||
│ ├─ Yes → Issue Workflow
|
||||
│ └─ No ↓
|
||||
│
|
||||
├─ Are requirements clear?
|
||||
│ ├─ Uncertain → Level 4 (brainstorm:auto-parallel)
|
||||
│ └─ Clear ↓
|
||||
│
|
||||
├─ Need persistent Session?
|
||||
│ ├─ Yes → Level 3 (plan / tdd-plan / test-fix-gen)
|
||||
│ └─ No ↓
|
||||
│
|
||||
├─ Need multi-perspective / solution comparison?
|
||||
│ ├─ Yes → Level 2 (multi-cli-plan)
|
||||
│ └─ No ↓
|
||||
│
|
||||
├─ Is it a bug fix?
|
||||
│ ├─ Yes → Level 2 (lite-fix)
|
||||
│ └─ No ↓
|
||||
│
|
||||
├─ Need planning?
|
||||
│ ├─ Yes → Level 2 (lite-plan)
|
||||
│ └─ No → Level 1 (lite-lite-lite)
|
||||
```
|
||||
|
||||
## Intent Classification
|
||||
|
||||
### Priority Order (with Level Mapping)
|
||||
|
||||
| Priority | Intent | Patterns | Level | Flow |
|
||||
|----------|--------|----------|-------|------|
|
||||
| 1 | bugfix/hotfix | `urgent,production,critical` + bug | L2 | `bugfix.hotfix` |
|
||||
| 1 | bugfix | `fix,bug,error,crash,fail` | L2 | `bugfix.standard` |
|
||||
| 2 | issue batch | `issues,batch` + `fix,resolve` | Issue | `issue` |
|
||||
| 3 | exploration | `不确定,explore,研究,what if` | L4 | `full` |
|
||||
| 3 | multi-perspective | `多视角,权衡,比较方案,cross-verify` | L2 | `multi-cli-plan` |
|
||||
| 4 | quick-task | `快速,简单,small,quick` + feature | L1 | `lite-lite-lite` |
|
||||
| 5 | ui design | `ui,design,component,style` | L3/L4 | `ui` |
|
||||
| 6 | tdd | `tdd,test-driven,先写测试` | L3 | `tdd` |
|
||||
| 7 | test-fix | `测试失败,test fail,fix test` | L3 | `test-fix-gen` |
|
||||
| 8 | review | `review,审查,code review` | L3 | `review-fix` |
|
||||
| 9 | documentation | `文档,docs,readme` | L2 | `docs` |
|
||||
| 99 | feature | complexity-based | L2/L3 | `rapid`/`coupled` |
|
||||
|
||||
### Quick Selection Guide
|
||||
|
||||
| Scenario | Recommended Workflow | Level |
|
||||
|----------|---------------------|-------|
|
||||
| Quick fixes, config adjustments | `lite-lite-lite` | 1 |
|
||||
| Clear single-module features | `lite-plan → lite-execute` | 2 |
|
||||
| Bug diagnosis and fix | `lite-fix` | 2 |
|
||||
| Production emergencies | `lite-fix --hotfix` | 2 |
|
||||
| Technology selection, solution comparison | `multi-cli-plan → lite-execute` | 2 |
|
||||
| Multi-module changes, refactoring | `plan → verify → execute` | 3 |
|
||||
| Test-driven development | `tdd-plan → execute → tdd-verify` | 3 |
|
||||
| Test failure fixes | `test-fix-gen → test-cycle-execute` | 3 |
|
||||
| New features, architecture design | `brainstorm:auto-parallel → plan → execute` | 4 |
|
||||
| Post-development issue fixes | Issue Workflow | - |
|
||||
|
||||
### Complexity Assessment
|
||||
|
||||
```javascript
|
||||
function assessComplexity(text) {
|
||||
let score = 0
|
||||
if (/refactor|重构|migrate|迁移|architect|架构|system|系统/.test(text)) score += 2
|
||||
if (/multiple|多个|across|跨|all|所有|entire|整个/.test(text)) score += 2
|
||||
if (/integrate|集成|api|database|数据库/.test(text)) score += 1
|
||||
if (/security|安全|performance|性能|scale|扩展/.test(text)) score += 1
|
||||
return score >= 4 ? 'high' : score >= 2 ? 'medium' : 'low'
|
||||
}
|
||||
```
|
||||
|
||||
| Complexity | Flow |
|
||||
|------------|------|
|
||||
| high | `coupled` (plan → verify → execute) |
|
||||
| medium/low | `rapid` (lite-plan → lite-execute) |
|
||||
|
||||
### Dimension Extraction (WHAT/WHERE/WHY/HOW)
|
||||
|
||||
从用户输入提取四个维度,用于需求澄清和工作流选择:
|
||||
|
||||
| 维度 | 提取内容 | 示例模式 |
|
||||
|------|----------|----------|
|
||||
| **WHAT** | action + target | `创建/修复/重构/优化/分析` + 目标对象 |
|
||||
| **WHERE** | scope + paths | `file/module/system` + 文件路径 |
|
||||
| **WHY** | goal + motivation | `为了.../因为.../目的是...` |
|
||||
| **HOW** | constraints + preferences | `必须.../不要.../应该...` |
|
||||
|
||||
**Clarity Score** (0-3):
|
||||
- +0.5: 有明确 action
|
||||
- +0.5: 有具体 target
|
||||
- +0.5: 有文件路径
|
||||
- +0.5: scope 不是 unknown
|
||||
- +0.5: 有明确 goal
|
||||
- +0.5: 有约束条件
|
||||
- -0.5: 包含不确定词 (`不知道/maybe/怎么`)
|
||||
|
||||
### Requirement Clarification
|
||||
|
||||
当 `clarity_score < 2` 时触发需求澄清:
|
||||
|
||||
```javascript
|
||||
if (dimensions.clarity_score < 2) {
|
||||
const questions = generateClarificationQuestions(dimensions)
|
||||
// 生成问题:目标是什么? 范围是什么? 有什么约束?
|
||||
AskUserQuestion({ questions })
|
||||
}
|
||||
```
|
||||
|
||||
**澄清问题类型**:
|
||||
- 目标不明确 → "你想要对什么进行操作?"
|
||||
- 范围不明确 → "操作的范围是什么?"
|
||||
- 目的不明确 → "这个操作的主要目标是什么?"
|
||||
- 复杂操作 → "有什么特殊要求或限制?"
|
||||
|
||||
## TODO Tracking Protocol
|
||||
|
||||
### CRITICAL: Append-Only Rule
|
||||
|
||||
CCW 创建的 Todo **必须附加到现有列表**,不能覆盖用户的其他 Todo。
|
||||
|
||||
### Implementation
|
||||
|
||||
```javascript
|
||||
// 1. 使用 CCW 前缀隔离工作流 todo
|
||||
const prefix = `CCW:${flowName}`
|
||||
|
||||
// 2. 创建新 todo 时使用前缀格式
|
||||
TodoWrite({
|
||||
todos: [
|
||||
...existingNonCCWTodos, // 保留用户的 todo
|
||||
{ content: `${prefix}: [1/N] /command:step1`, status: "in_progress", activeForm: "..." },
|
||||
{ content: `${prefix}: [2/N] /command:step2`, status: "pending", activeForm: "..." }
|
||||
]
|
||||
})
|
||||
|
||||
// 3. 更新状态时只修改匹配前缀的 todo
|
||||
```
|
||||
|
||||
### Todo Format
|
||||
|
||||
```
|
||||
CCW:{flow}: [{N}/{Total}] /command:name
|
||||
```
|
||||
|
||||
### Visual Example
|
||||
|
||||
```
|
||||
✓ CCW:rapid: [1/2] /workflow:lite-plan
|
||||
→ CCW:rapid: [2/2] /workflow:lite-execute
|
||||
用户自己的 todo(保留不动)
|
||||
```
|
||||
|
||||
### Status Management
|
||||
|
||||
- 开始工作流:创建所有步骤 todo,第一步 `in_progress`
|
||||
- 完成步骤:当前步骤 `completed`,下一步 `in_progress`
|
||||
- 工作流结束:所有 CCW todo 标记 `completed`
|
||||
|
||||
## Execution Flow
|
||||
|
||||
```javascript
|
||||
// 1. Check explicit command
|
||||
if (input.startsWith('/workflow:') || input.startsWith('/issue:')) {
|
||||
SlashCommand(input)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Classify intent
|
||||
const intent = classifyIntent(input) // See command.json intent_rules
|
||||
|
||||
// 3. Select flow
|
||||
const flow = selectFlow(intent) // See command.json flows
|
||||
|
||||
// 4. Create todos with CCW prefix
|
||||
createWorkflowTodos(flow)
|
||||
|
||||
// 5. Dispatch first command
|
||||
SlashCommand(flow.steps[0].command, args: input)
|
||||
```
|
||||
|
||||
## CLI Tool Integration
|
||||
|
||||
CCW 在特定条件下自动注入 CLI 调用:
|
||||
|
||||
| Condition | CLI Inject |
|
||||
|-----------|------------|
|
||||
| 大量代码上下文 (≥50k chars) | `gemini --mode analysis` |
|
||||
| 高复杂度任务 | `gemini --mode analysis` |
|
||||
| Bug 诊断 | `gemini --mode analysis` |
|
||||
| 多任务执行 (≥3 tasks) | `codex --mode write` |
|
||||
|
||||
### CLI Enhancement Phases
|
||||
|
||||
**Phase 1.5: CLI-Assisted Classification**
|
||||
|
||||
当规则匹配不明确时,使用 CLI 辅助分类:
|
||||
|
||||
| 触发条件 | 说明 |
|
||||
|----------|------|
|
||||
| matchCount < 2 | 多个意图模式匹配 |
|
||||
| complexity = high | 高复杂度任务 |
|
||||
| input > 100 chars | 长输入需要语义理解 |
|
||||
|
||||
**Phase 2.5: CLI-Assisted Action Planning**
|
||||
|
||||
高复杂度任务的工作流优化:
|
||||
|
||||
| 触发条件 | 说明 |
|
||||
|----------|------|
|
||||
| complexity = high | 高复杂度任务 |
|
||||
| steps >= 3 | 多步骤工作流 |
|
||||
| input > 200 chars | 复杂需求描述 |
|
||||
|
||||
CLI 可返回建议:`use_default` | `modify` (调整步骤) | `upgrade` (升级工作流)
|
||||
|
||||
## Continuation Commands
|
||||
|
||||
工作流执行中的用户控制命令:
|
||||
|
||||
| 命令 | 作用 |
|
||||
|------|------|
|
||||
| `continue` | 继续执行下一步 |
|
||||
| `skip` | 跳过当前步骤 |
|
||||
| `abort` | 终止工作流 |
|
||||
| `/workflow:*` | 切换到指定命令 |
|
||||
| 自然语言 | 重新分析意图 |
|
||||
|
||||
## Workflow Flow Details
|
||||
|
||||
### Issue Workflow (Main Workflow 补充机制)
|
||||
|
||||
Issue Workflow 是 Main Workflow 的**补充机制**,专注于开发后的持续维护。
|
||||
|
||||
#### 设计理念
|
||||
|
||||
| 方面 | Main Workflow | Issue Workflow |
|
||||
|------|---------------|----------------|
|
||||
| **用途** | 主要开发周期 | 开发后维护 |
|
||||
| **时机** | 功能开发阶段 | 主工作流完成后 |
|
||||
| **范围** | 完整功能实现 | 针对性修复/增强 |
|
||||
| **并行性** | 依赖分析 → Agent 并行 | Worktree 隔离 (可选) |
|
||||
| **分支模型** | 当前分支工作 | 可使用隔离的 worktree |
|
||||
|
||||
#### 为什么 Main Workflow 不自动使用 Worktree?
|
||||
|
||||
**依赖分析已解决并行性问题**:
|
||||
1. 规划阶段 (`/workflow:plan`) 执行依赖分析
|
||||
2. 自动识别任务依赖和关键路径
|
||||
3. 划分为**并行组**(独立任务)和**串行链**(依赖任务)
|
||||
4. Agent 并行执行独立任务,无需文件系统隔离
|
||||
|
||||
#### 两阶段生命周期
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ Phase 1: Accumulation (积累阶段) │
|
||||
│ │
|
||||
│ Triggers: 任务完成后的 review、代码审查发现、测试失败 │
|
||||
│ │
|
||||
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
|
||||
│ │ discover │ │ discover- │ │ new │ │
|
||||
│ │ Auto-find │ │ by-prompt │ │ Manual │ │
|
||||
│ └────────────┘ └────────────┘ └────────────┘ │
|
||||
│ │
|
||||
│ 持续积累 issues 到待处理队列 │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
│ 积累足够后
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ Phase 2: Batch Resolution (批量解决阶段) │
|
||||
│ │
|
||||
│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │
|
||||
│ │ plan │ ──→ │ queue │ ──→ │ execute │ │
|
||||
│ │ --all- │ │ Optimize │ │ Parallel │ │
|
||||
│ │ pending │ │ order │ │ execution │ │
|
||||
│ └────────────┘ └────────────┘ └────────────┘ │
|
||||
│ │
|
||||
│ 支持 worktree 隔离,保持主分支稳定 │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### 与 Main Workflow 的协作
|
||||
|
||||
```
|
||||
开发迭代循环
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ │
|
||||
│ ┌─────────┐ ┌─────────┐ │
|
||||
│ │ Feature │ ──→ Main Workflow ──→ Done ──→│ Review │ │
|
||||
│ │ Request │ (Level 1-4) └────┬────┘ │
|
||||
│ └─────────┘ │ │
|
||||
│ ▲ │ 发现 Issues │
|
||||
│ │ ▼ │
|
||||
│ │ ┌─────────┐ │
|
||||
│ 继续 │ │ Issue │ │
|
||||
│ 新功能│ │ Workflow│ │
|
||||
│ │ └────┬────┘ │
|
||||
│ │ ┌──────────────────────────────┘ │
|
||||
│ │ │ 修复完成 │
|
||||
│ │ ▼ │
|
||||
│ ┌────┴────┐◀────── │
|
||||
│ │ Main │ Merge │
|
||||
│ │ Branch │ back │
|
||||
│ └─────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
#### 命令列表
|
||||
|
||||
**积累阶段:**
|
||||
```bash
|
||||
/issue:discover # 多视角自动发现
|
||||
/issue:discover-by-prompt # 基于提示发现
|
||||
/issue:new # 手动创建
|
||||
```
|
||||
|
||||
**批量解决阶段:**
|
||||
```bash
|
||||
/issue:plan --all-pending # 批量规划所有待处理
|
||||
/issue:queue # 生成优化执行队列
|
||||
/issue:execute # 并行执行
|
||||
```
|
||||
|
||||
### lite-lite-lite vs multi-cli-plan
|
||||
|
||||
| 维度 | lite-lite-lite | multi-cli-plan |
|
||||
|------|---------------|----------------|
|
||||
| **产物** | 无文件 | IMPL_PLAN.md + plan.json + synthesis.json |
|
||||
| **状态** | 无状态 | 持久化 session |
|
||||
| **CLI选择** | 自动分析任务类型选择 | 配置驱动 |
|
||||
| **迭代** | 通过 AskUser | 多轮收敛 |
|
||||
| **执行** | 直接执行 | 通过 lite-execute |
|
||||
| **适用** | 快速修复、简单功能 | 复杂多步骤实现 |
|
||||
|
||||
**选择指南**:
|
||||
- 任务清晰、改动范围小 → `lite-lite-lite`
|
||||
- 需要多视角分析、复杂架构 → `multi-cli-plan`
|
||||
|
||||
### multi-cli-plan vs lite-plan
|
||||
|
||||
| 维度 | multi-cli-plan | lite-plan |
|
||||
|------|---------------|-----------|
|
||||
| **上下文** | ACE 语义搜索 | 手动文件模式 |
|
||||
| **分析** | 多 CLI 交叉验证 | 单次规划 |
|
||||
| **迭代** | 多轮直到收敛 | 单轮 |
|
||||
| **置信度** | 高 (共识驱动) | 中 (单一视角) |
|
||||
| **适用** | 需要多视角的复杂任务 | 直接明确的实现 |
|
||||
|
||||
**选择指南**:
|
||||
- 需求明确、路径清晰 → `lite-plan`
|
||||
- 需要权衡、多方案比较 → `multi-cli-plan`
|
||||
|
||||
## Artifact Flow Protocol
|
||||
|
||||
工作流产出的自动流转机制,支持不同格式产出间的意图提取和完成度判断。
|
||||
|
||||
### 产出格式
|
||||
|
||||
| 命令 | 产出位置 | 格式 | 关键字段 |
|
||||
|------|----------|------|----------|
|
||||
| `/workflow:lite-plan` | memory://plan | structured_plan | tasks, files, dependencies |
|
||||
| `/workflow:plan` | .workflow/{session}/IMPL_PLAN.md | markdown_plan | phases, tasks, risks |
|
||||
| `/workflow:execute` | execution_log.json | execution_report | completed_tasks, errors |
|
||||
| `/workflow:test-cycle-execute` | test_results.json | test_report | pass_rate, failures, coverage |
|
||||
| `/workflow:review-session-cycle` | review_report.md | review_report | findings, severity_counts |
|
||||
|
||||
### 意图提取 (Intent Extraction)
|
||||
|
||||
流转到下一步时,自动提取关键信息:
|
||||
|
||||
```
|
||||
plan → execute:
|
||||
提取: tasks (未完成), priority_order, files_to_modify, context_summary
|
||||
|
||||
execute → test:
|
||||
提取: modified_files, test_scope (推断), pending_verification
|
||||
|
||||
test → fix:
|
||||
条件: pass_rate < 0.95
|
||||
提取: failures, error_messages, affected_files, suggested_fixes
|
||||
|
||||
review → fix:
|
||||
条件: critical > 0 OR high > 3
|
||||
提取: findings (critical/high), fix_priority, affected_files
|
||||
```
|
||||
|
||||
### 完成度判断
|
||||
|
||||
**Test 完成度路由**:
|
||||
```
|
||||
pass_rate >= 0.95 AND coverage >= 0.80 → complete
|
||||
pass_rate >= 0.95 AND coverage < 0.80 → add_more_tests
|
||||
pass_rate >= 0.80 → fix_failures_then_continue
|
||||
pass_rate < 0.80 → major_fix_required
|
||||
```
|
||||
|
||||
**Review 完成度路由**:
|
||||
```
|
||||
critical == 0 AND high <= 3 → complete_or_optional_fix
|
||||
critical > 0 → mandatory_fix
|
||||
high > 3 → recommended_fix
|
||||
```
|
||||
|
||||
### 流转决策模式
|
||||
|
||||
**plan_execute_test**:
|
||||
```
|
||||
plan → execute → test
|
||||
↓ (if test fail)
|
||||
extract_failures → fix → test (max 3 iterations)
|
||||
↓ (if still fail)
|
||||
manual_intervention
|
||||
```
|
||||
|
||||
**iterative_improvement**:
|
||||
```
|
||||
execute → test → fix → test → ...
|
||||
loop until: pass_rate >= 0.95 OR iterations >= 3
|
||||
```
|
||||
|
||||
### 使用示例
|
||||
|
||||
```javascript
|
||||
// 执行完成后,根据产出决定下一步
|
||||
const result = await execute(plan)
|
||||
|
||||
// 提取意图流转到测试
|
||||
const testContext = extractIntent('execute_to_test', result)
|
||||
// testContext = { modified_files, test_scope, pending_verification }
|
||||
|
||||
// 测试完成后,根据完成度决定路由
|
||||
const testResult = await test(testContext)
|
||||
const nextStep = evaluateCompletion('test', testResult)
|
||||
// nextStep = 'fix_failures_then_continue' if pass_rate = 0.85
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
- [command.json](command.json) - 命令元数据、Flow 定义、意图规则、Artifact Flow
|
||||
@@ -1,641 +0,0 @@
|
||||
{
|
||||
"_metadata": {
|
||||
"version": "2.0.0",
|
||||
"description": "Unified CCW command index with capabilities, flows, and intent rules"
|
||||
},
|
||||
|
||||
"capabilities": {
|
||||
"explore": {
|
||||
"description": "Codebase exploration and context gathering",
|
||||
"commands": ["/workflow:init", "/workflow:tools:gather", "/memory:load"],
|
||||
"agents": ["cli-explore-agent", "context-search-agent"]
|
||||
},
|
||||
"brainstorm": {
|
||||
"description": "Multi-perspective analysis and ideation",
|
||||
"commands": ["/workflow:brainstorm:auto-parallel", "/workflow:brainstorm:artifacts", "/workflow:brainstorm:synthesis"],
|
||||
"roles": ["product-manager", "system-architect", "ux-expert", "data-architect", "api-designer"]
|
||||
},
|
||||
"plan": {
|
||||
"description": "Task planning and decomposition",
|
||||
"commands": ["/workflow:lite-plan", "/workflow:plan", "/workflow:tdd-plan", "/task:create", "/task:breakdown"],
|
||||
"agents": ["cli-lite-planning-agent", "action-planning-agent"]
|
||||
},
|
||||
"verify": {
|
||||
"description": "Plan and quality verification",
|
||||
"commands": ["/workflow:plan-verify", "/workflow:tdd-verify"]
|
||||
},
|
||||
"execute": {
|
||||
"description": "Task execution and implementation",
|
||||
"commands": ["/workflow:lite-execute", "/workflow:execute", "/task:execute"],
|
||||
"agents": ["code-developer", "cli-execution-agent", "universal-executor"]
|
||||
},
|
||||
"bugfix": {
|
||||
"description": "Bug diagnosis and fixing",
|
||||
"commands": ["/workflow:lite-fix"],
|
||||
"agents": ["code-developer"]
|
||||
},
|
||||
"test": {
|
||||
"description": "Test generation and execution",
|
||||
"commands": ["/workflow:test-gen", "/workflow:test-fix-gen", "/workflow:test-cycle-execute"],
|
||||
"agents": ["test-fix-agent"]
|
||||
},
|
||||
"review": {
|
||||
"description": "Code review and quality analysis",
|
||||
"commands": ["/workflow:review-session-cycle", "/workflow:review-module-cycle", "/workflow:review", "/workflow:review-fix"]
|
||||
},
|
||||
"issue": {
|
||||
"description": "Issue lifecycle management - discover, accumulate, batch resolve",
|
||||
"commands": ["/issue:new", "/issue:discover", "/issue:discover-by-prompt", "/issue:plan", "/issue:queue", "/issue:execute", "/issue:manage"],
|
||||
"agents": ["issue-plan-agent", "issue-queue-agent", "cli-explore-agent"],
|
||||
"lifecycle": {
|
||||
"accumulation": {
|
||||
"description": "任务完成后进行需求扩展、bug分析、测试发现",
|
||||
"triggers": ["post-task review", "code review findings", "test failures"],
|
||||
"commands": ["/issue:discover", "/issue:discover-by-prompt", "/issue:new"]
|
||||
},
|
||||
"batch_resolution": {
|
||||
"description": "积累的issue集中规划和并行执行",
|
||||
"flow": ["plan", "queue", "execute"],
|
||||
"commands": ["/issue:plan --all-pending", "/issue:queue", "/issue:execute"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"ui-design": {
|
||||
"description": "UI design and prototyping",
|
||||
"commands": ["/workflow:ui-design:explore-auto", "/workflow:ui-design:imitate-auto", "/workflow:ui-design:design-sync"],
|
||||
"agents": ["ui-design-agent"]
|
||||
},
|
||||
"memory": {
|
||||
"description": "Documentation and knowledge management",
|
||||
"commands": ["/memory:docs", "/memory:update-related", "/memory:update-full", "/memory:skill-memory"],
|
||||
"agents": ["doc-generator", "memory-bridge"]
|
||||
}
|
||||
},
|
||||
|
||||
"flows": {
|
||||
"_level_guide": {
|
||||
"L1": "Rapid - No artifacts, direct execution",
|
||||
"L2": "Lightweight - Memory/lightweight files, → lite-execute",
|
||||
"L3": "Standard - Session persistence, → execute/test-cycle-execute",
|
||||
"L4": "Brainstorm - Multi-role analysis + Session, → execute"
|
||||
},
|
||||
"lite-lite-lite": {
|
||||
"name": "Ultra-Rapid Execution",
|
||||
"level": "L1",
|
||||
"description": "零文件 + 自动CLI选择 + 语义描述 + 直接执行",
|
||||
"complexity": ["low"],
|
||||
"artifacts": "none",
|
||||
"steps": [
|
||||
{ "phase": "clarify", "description": "需求澄清 (AskUser if needed)" },
|
||||
{ "phase": "auto-select", "description": "任务分析 → 自动选择CLI组合" },
|
||||
{ "phase": "multi-cli", "description": "并行多CLI分析" },
|
||||
{ "phase": "decision", "description": "展示结果 → AskUser决策" },
|
||||
{ "phase": "execute", "description": "直接执行 (无中间文件)" }
|
||||
],
|
||||
"cli_hints": {
|
||||
"analysis": { "tool": "auto", "mode": "analysis", "parallel": true },
|
||||
"execution": { "tool": "auto", "mode": "write" }
|
||||
},
|
||||
"estimated_time": "10-30 min"
|
||||
},
|
||||
"rapid": {
|
||||
"name": "Rapid Iteration",
|
||||
"level": "L2",
|
||||
"description": "内存规划 + 直接执行",
|
||||
"complexity": ["low", "medium"],
|
||||
"artifacts": "memory://plan",
|
||||
"steps": [
|
||||
{ "command": "/workflow:lite-plan", "optional": false, "auto_continue": true },
|
||||
{ "command": "/workflow:lite-execute", "optional": false }
|
||||
],
|
||||
"cli_hints": {
|
||||
"explore_phase": { "tool": "gemini", "mode": "analysis", "trigger": "needs_exploration" },
|
||||
"execution": { "tool": "codex", "mode": "write", "trigger": "complexity >= medium" }
|
||||
},
|
||||
"estimated_time": "15-45 min"
|
||||
},
|
||||
"multi-cli-plan": {
|
||||
"name": "Multi-CLI Collaborative Planning",
|
||||
"level": "L2",
|
||||
"description": "ACE上下文 + 多CLI协作分析 + 迭代收敛 + 计划生成",
|
||||
"complexity": ["medium", "high"],
|
||||
"artifacts": ".workflow/.multi-cli-plan/{session}/",
|
||||
"steps": [
|
||||
{ "command": "/workflow:multi-cli-plan", "optional": false, "phases": [
|
||||
"context_gathering: ACE语义搜索",
|
||||
"multi_cli_discussion: cli-discuss-agent多轮分析",
|
||||
"present_options: 展示解决方案",
|
||||
"user_decision: 用户选择",
|
||||
"plan_generation: cli-lite-planning-agent生成计划"
|
||||
]},
|
||||
{ "command": "/workflow:lite-execute", "optional": false }
|
||||
],
|
||||
"vs_lite_plan": {
|
||||
"context": "ACE semantic search vs Manual file patterns",
|
||||
"analysis": "Multi-CLI cross-verification vs Single-pass planning",
|
||||
"iteration": "Multiple rounds until convergence vs Single round",
|
||||
"confidence": "High (consensus-based) vs Medium (single perspective)",
|
||||
"best_for": "Complex tasks needing multiple perspectives vs Straightforward implementations"
|
||||
},
|
||||
"agents": ["cli-discuss-agent", "cli-lite-planning-agent"],
|
||||
"cli_hints": {
|
||||
"discussion": { "tools": ["gemini", "codex", "claude"], "mode": "analysis", "parallel": true },
|
||||
"planning": { "tool": "gemini", "mode": "analysis" }
|
||||
},
|
||||
"estimated_time": "30-90 min"
|
||||
},
|
||||
"coupled": {
|
||||
"name": "Standard Planning",
|
||||
"level": "L3",
|
||||
"description": "完整规划 + 验证 + 执行",
|
||||
"complexity": ["medium", "high"],
|
||||
"artifacts": ".workflow/active/{session}/",
|
||||
"steps": [
|
||||
{ "command": "/workflow:plan", "optional": false },
|
||||
{ "command": "/workflow:plan-verify", "optional": false, "auto_continue": true },
|
||||
{ "command": "/workflow:execute", "optional": false },
|
||||
{ "command": "/workflow:review", "optional": true }
|
||||
],
|
||||
"cli_hints": {
|
||||
"pre_analysis": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
|
||||
"execution": { "tool": "codex", "mode": "write", "trigger": "always" }
|
||||
},
|
||||
"estimated_time": "2-4 hours"
|
||||
},
|
||||
"full": {
|
||||
"name": "Full Exploration (Brainstorm)",
|
||||
"level": "L4",
|
||||
"description": "头脑风暴 + 规划 + 执行",
|
||||
"complexity": ["high"],
|
||||
"artifacts": ".workflow/active/{session}/.brainstorming/",
|
||||
"steps": [
|
||||
{ "command": "/workflow:brainstorm:auto-parallel", "optional": false, "confirm_before": true },
|
||||
{ "command": "/workflow:plan", "optional": false },
|
||||
{ "command": "/workflow:plan-verify", "optional": true, "auto_continue": true },
|
||||
{ "command": "/workflow:execute", "optional": false }
|
||||
],
|
||||
"cli_hints": {
|
||||
"role_analysis": { "tool": "gemini", "mode": "analysis", "trigger": "always", "parallel": true },
|
||||
"execution": { "tool": "codex", "mode": "write", "trigger": "task_count >= 3" }
|
||||
},
|
||||
"estimated_time": "1-3 hours"
|
||||
},
|
||||
"bugfix": {
|
||||
"name": "Bug Fix",
|
||||
"level": "L2",
|
||||
"description": "智能诊断 + 修复 (5 phases)",
|
||||
"complexity": ["low", "medium"],
|
||||
"artifacts": ".workflow/.lite-fix/{bug-slug}-{date}/",
|
||||
"variants": {
|
||||
"standard": [{ "command": "/workflow:lite-fix", "optional": false }],
|
||||
"hotfix": [{ "command": "/workflow:lite-fix --hotfix", "optional": false }]
|
||||
},
|
||||
"phases": [
|
||||
"Phase 1: Bug Analysis & Diagnosis (severity pre-assessment)",
|
||||
"Phase 2: Clarification (optional, AskUserQuestion)",
|
||||
"Phase 3: Fix Planning (Low/Medium → Claude, High/Critical → cli-lite-planning-agent)",
|
||||
"Phase 4: Confirmation & Selection",
|
||||
"Phase 5: Execute (→ lite-execute --mode bugfix)"
|
||||
],
|
||||
"cli_hints": {
|
||||
"diagnosis": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
|
||||
"fix": { "tool": "codex", "mode": "write", "trigger": "severity >= medium" }
|
||||
},
|
||||
"estimated_time": "10-30 min"
|
||||
},
|
||||
"issue": {
|
||||
"name": "Issue Lifecycle",
|
||||
"level": "Supplementary",
|
||||
"description": "发现积累 → 批量规划 → 队列优化 → 并行执行 (Main Workflow 补充机制)",
|
||||
"complexity": ["medium", "high"],
|
||||
"artifacts": ".workflow/.issues/",
|
||||
"purpose": "Post-development continuous maintenance, maintain main branch stability",
|
||||
"phases": {
|
||||
"accumulation": {
|
||||
"description": "项目迭代中持续发现和积累issue",
|
||||
"commands": ["/issue:discover", "/issue:discover-by-prompt", "/issue:new"],
|
||||
"trigger": "post-task, code-review, test-failure"
|
||||
},
|
||||
"resolution": {
|
||||
"description": "集中规划和执行积累的issue",
|
||||
"steps": [
|
||||
{ "command": "/issue:plan --all-pending", "optional": false },
|
||||
{ "command": "/issue:queue", "optional": false },
|
||||
{ "command": "/issue:execute", "optional": false }
|
||||
]
|
||||
}
|
||||
},
|
||||
"worktree_support": {
|
||||
"description": "可选的 worktree 隔离,保持主分支稳定",
|
||||
"use_case": "主开发完成后的 issue 修复"
|
||||
},
|
||||
"cli_hints": {
|
||||
"discovery": { "tool": "gemini", "mode": "analysis", "trigger": "perspective_analysis", "parallel": true },
|
||||
"solution_generation": { "tool": "gemini", "mode": "analysis", "trigger": "always", "parallel": true },
|
||||
"batch_execution": { "tool": "codex", "mode": "write", "trigger": "always" }
|
||||
},
|
||||
"estimated_time": "1-4 hours"
|
||||
},
|
||||
"tdd": {
|
||||
"name": "Test-Driven Development",
|
||||
"level": "L3",
|
||||
"description": "TDD规划 + 执行 + 验证 (6 phases)",
|
||||
"complexity": ["medium", "high"],
|
||||
"artifacts": ".workflow/active/{session}/",
|
||||
"steps": [
|
||||
{ "command": "/workflow:tdd-plan", "optional": false },
|
||||
{ "command": "/workflow:plan-verify", "optional": true, "auto_continue": true },
|
||||
{ "command": "/workflow:execute", "optional": false },
|
||||
{ "command": "/workflow:tdd-verify", "optional": false }
|
||||
],
|
||||
"tdd_structure": {
|
||||
"description": "Each IMPL task contains complete internal Red-Green-Refactor cycle",
|
||||
"meta": "tdd_workflow: true",
|
||||
"flow_control": "implementation_approach contains 3 steps (red/green/refactor)"
|
||||
},
|
||||
"cli_hints": {
|
||||
"test_strategy": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
|
||||
"red_green_refactor": { "tool": "codex", "mode": "write", "trigger": "always" }
|
||||
},
|
||||
"estimated_time": "1-3 hours"
|
||||
},
|
||||
"test-fix": {
|
||||
"name": "Test Fix Generation",
|
||||
"level": "L3",
|
||||
"description": "测试修复生成 + 执行循环 (5 phases)",
|
||||
"complexity": ["medium", "high"],
|
||||
"artifacts": ".workflow/active/WFS-test-{session}/",
|
||||
"dual_mode": {
|
||||
"session_mode": { "input": "WFS-xxx", "context_source": "Source session summaries" },
|
||||
"prompt_mode": { "input": "Text/file path", "context_source": "Direct codebase analysis" }
|
||||
},
|
||||
"steps": [
|
||||
{ "command": "/workflow:test-fix-gen", "optional": false },
|
||||
{ "command": "/workflow:test-cycle-execute", "optional": false }
|
||||
],
|
||||
"task_structure": [
|
||||
"IMPL-001.json (test understanding & generation)",
|
||||
"IMPL-001.5-review.json (quality gate)",
|
||||
"IMPL-002.json (test execution & fix cycle)"
|
||||
],
|
||||
"cli_hints": {
|
||||
"analysis": { "tool": "gemini", "mode": "analysis", "trigger": "always" },
|
||||
"fix_cycle": { "tool": "codex", "mode": "write", "trigger": "pass_rate < 0.95" }
|
||||
},
|
||||
"estimated_time": "1-2 hours"
|
||||
},
|
||||
"ui": {
|
||||
"name": "UI-First Development",
|
||||
"level": "L3/L4",
|
||||
"description": "UI设计 + 规划 + 执行",
|
||||
"complexity": ["medium", "high"],
|
||||
"artifacts": ".workflow/active/{session}/",
|
||||
"variants": {
|
||||
"explore": [
|
||||
{ "command": "/workflow:ui-design:explore-auto", "optional": false },
|
||||
{ "command": "/workflow:ui-design:design-sync", "optional": false, "auto_continue": true },
|
||||
{ "command": "/workflow:plan", "optional": false },
|
||||
{ "command": "/workflow:execute", "optional": false }
|
||||
],
|
||||
"imitate": [
|
||||
{ "command": "/workflow:ui-design:imitate-auto", "optional": false },
|
||||
{ "command": "/workflow:ui-design:design-sync", "optional": false, "auto_continue": true },
|
||||
{ "command": "/workflow:plan", "optional": false },
|
||||
{ "command": "/workflow:execute", "optional": false }
|
||||
]
|
||||
},
|
||||
"estimated_time": "2-4 hours"
|
||||
},
|
||||
"review-fix": {
|
||||
"name": "Review and Fix",
|
||||
"level": "L3",
|
||||
"description": "多维审查 + 自动修复",
|
||||
"complexity": ["medium"],
|
||||
"artifacts": ".workflow/active/{session}/review_report.md",
|
||||
"steps": [
|
||||
{ "command": "/workflow:review-session-cycle", "optional": false },
|
||||
{ "command": "/workflow:review-fix", "optional": true }
|
||||
],
|
||||
"cli_hints": {
|
||||
"multi_dimension_review": { "tool": "gemini", "mode": "analysis", "trigger": "always", "parallel": true },
|
||||
"auto_fix": { "tool": "codex", "mode": "write", "trigger": "findings_count >= 3" }
|
||||
},
|
||||
"estimated_time": "30-90 min"
|
||||
},
|
||||
"docs": {
|
||||
"name": "Documentation",
|
||||
"level": "L2",
|
||||
"description": "批量文档生成",
|
||||
"complexity": ["low", "medium"],
|
||||
"variants": {
|
||||
"incremental": [{ "command": "/memory:update-related", "optional": false }],
|
||||
"full": [
|
||||
{ "command": "/memory:docs", "optional": false },
|
||||
{ "command": "/workflow:execute", "optional": false }
|
||||
]
|
||||
},
|
||||
"estimated_time": "15-60 min"
|
||||
}
|
||||
},
|
||||
|
||||
"intent_rules": {
|
||||
"_level_mapping": {
|
||||
"description": "Intent → Level → Flow mapping guide",
|
||||
"L1": ["lite-lite-lite"],
|
||||
"L2": ["rapid", "bugfix", "multi-cli-plan", "docs"],
|
||||
"L3": ["coupled", "tdd", "test-fix", "review-fix", "ui"],
|
||||
"L4": ["full"],
|
||||
"Supplementary": ["issue"]
|
||||
},
|
||||
"bugfix": {
|
||||
"priority": 1,
|
||||
"level": "L2",
|
||||
"variants": {
|
||||
"hotfix": {
|
||||
"patterns": ["hotfix", "urgent", "production", "critical", "emergency", "紧急", "生产环境", "线上"],
|
||||
"flow": "bugfix.hotfix"
|
||||
},
|
||||
"standard": {
|
||||
"patterns": ["fix", "bug", "error", "issue", "crash", "broken", "fail", "wrong", "修复", "错误", "崩溃"],
|
||||
"flow": "bugfix.standard"
|
||||
}
|
||||
}
|
||||
},
|
||||
"issue_batch": {
|
||||
"priority": 2,
|
||||
"level": "Supplementary",
|
||||
"patterns": {
|
||||
"batch": ["issues", "batch", "queue", "多个", "批量"],
|
||||
"action": ["fix", "resolve", "处理", "解决"]
|
||||
},
|
||||
"require_both": true,
|
||||
"flow": "issue"
|
||||
},
|
||||
"exploration": {
|
||||
"priority": 3,
|
||||
"level": "L4",
|
||||
"patterns": ["不确定", "不知道", "explore", "研究", "分析一下", "怎么做", "what if", "探索"],
|
||||
"flow": "full"
|
||||
},
|
||||
"multi_perspective": {
|
||||
"priority": 3,
|
||||
"level": "L2",
|
||||
"patterns": ["多视角", "权衡", "比较方案", "cross-verify", "多CLI", "协作分析"],
|
||||
"flow": "multi-cli-plan"
|
||||
},
|
||||
"quick_task": {
|
||||
"priority": 4,
|
||||
"level": "L1",
|
||||
"patterns": ["快速", "简单", "small", "quick", "simple", "trivial", "小改动"],
|
||||
"flow": "lite-lite-lite"
|
||||
},
|
||||
"ui_design": {
|
||||
"priority": 5,
|
||||
"level": "L3/L4",
|
||||
"patterns": ["ui", "界面", "design", "设计", "component", "组件", "style", "样式", "layout", "布局"],
|
||||
"variants": {
|
||||
"imitate": { "triggers": ["参考", "模仿", "像", "类似"], "flow": "ui.imitate" },
|
||||
"explore": { "triggers": [], "flow": "ui.explore" }
|
||||
}
|
||||
},
|
||||
"tdd": {
|
||||
"priority": 6,
|
||||
"level": "L3",
|
||||
"patterns": ["tdd", "test-driven", "测试驱动", "先写测试", "test first"],
|
||||
"flow": "tdd"
|
||||
},
|
||||
"test_fix": {
|
||||
"priority": 7,
|
||||
"level": "L3",
|
||||
"patterns": ["测试失败", "test fail", "fix test", "test error", "pass rate", "coverage gap"],
|
||||
"flow": "test-fix"
|
||||
},
|
||||
"review": {
|
||||
"priority": 8,
|
||||
"level": "L3",
|
||||
"patterns": ["review", "审查", "检查代码", "code review", "质量检查"],
|
||||
"flow": "review-fix"
|
||||
},
|
||||
"documentation": {
|
||||
"priority": 9,
|
||||
"level": "L2",
|
||||
"patterns": ["文档", "documentation", "docs", "readme"],
|
||||
"variants": {
|
||||
"incremental": { "triggers": ["更新", "增量"], "flow": "docs.incremental" },
|
||||
"full": { "triggers": ["全部", "完整"], "flow": "docs.full" }
|
||||
}
|
||||
},
|
||||
"feature": {
|
||||
"priority": 99,
|
||||
"complexity_map": {
|
||||
"high": { "level": "L3", "flow": "coupled" },
|
||||
"medium": { "level": "L2", "flow": "rapid" },
|
||||
"low": { "level": "L1", "flow": "lite-lite-lite" }
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
"complexity_indicators": {
|
||||
"high": {
|
||||
"threshold": 4,
|
||||
"patterns": {
|
||||
"architecture": { "keywords": ["refactor", "重构", "migrate", "迁移", "architect", "架构", "system", "系统"], "weight": 2 },
|
||||
"multi_module": { "keywords": ["multiple", "多个", "across", "跨", "all", "所有", "entire", "整个"], "weight": 2 },
|
||||
"integration": { "keywords": ["integrate", "集成", "api", "database", "数据库"], "weight": 1 },
|
||||
"quality": { "keywords": ["security", "安全", "performance", "性能", "scale", "扩展"], "weight": 1 }
|
||||
}
|
||||
},
|
||||
"medium": { "threshold": 2 },
|
||||
"low": { "threshold": 0 }
|
||||
},
|
||||
|
||||
"cli_tools": {
|
||||
"gemini": {
|
||||
"strengths": ["超长上下文", "深度分析", "架构理解", "执行流追踪"],
|
||||
"triggers": ["分析", "理解", "设计", "架构", "诊断"],
|
||||
"mode": "analysis"
|
||||
},
|
||||
"qwen": {
|
||||
"strengths": ["代码模式识别", "多维度分析"],
|
||||
"triggers": ["评估", "对比", "验证"],
|
||||
"mode": "analysis"
|
||||
},
|
||||
"codex": {
|
||||
"strengths": ["精确代码生成", "自主执行"],
|
||||
"triggers": ["实现", "重构", "修复", "生成"],
|
||||
"mode": "write"
|
||||
}
|
||||
},
|
||||
|
||||
"cli_injection_rules": {
|
||||
"context_gathering": { "trigger": "file_read >= 50k OR module_count >= 5", "inject": "gemini --mode analysis" },
|
||||
"pre_planning_analysis": { "trigger": "complexity === high", "inject": "gemini --mode analysis" },
|
||||
"debug_diagnosis": { "trigger": "intent === bugfix AND root_cause_unclear", "inject": "gemini --mode analysis" },
|
||||
"code_review": { "trigger": "step === review", "inject": "gemini --mode analysis" },
|
||||
"implementation": { "trigger": "step === execute AND task_count >= 3", "inject": "codex --mode write" }
|
||||
},
|
||||
|
||||
"artifact_flow": {
|
||||
"_description": "定义工作流产出的格式、意图提取和流转规则",
|
||||
|
||||
"outputs": {
|
||||
"/workflow:lite-plan": {
|
||||
"artifact": "memory://plan",
|
||||
"format": "structured_plan",
|
||||
"fields": ["tasks", "files", "dependencies", "approach"]
|
||||
},
|
||||
"/workflow:plan": {
|
||||
"artifact": ".workflow/{session}/IMPL_PLAN.md",
|
||||
"format": "markdown_plan",
|
||||
"fields": ["phases", "tasks", "dependencies", "risks", "test_strategy"]
|
||||
},
|
||||
"/workflow:multi-cli-plan": {
|
||||
"artifact": ".workflow/.multi-cli-plan/{session}/",
|
||||
"format": "multi_file",
|
||||
"files": ["IMPL_PLAN.md", "plan.json", "synthesis.json"],
|
||||
"fields": ["consensus", "divergences", "recommended_approach", "tasks"]
|
||||
},
|
||||
"/workflow:lite-execute": {
|
||||
"artifact": "git_changes",
|
||||
"format": "code_diff",
|
||||
"fields": ["modified_files", "added_files", "deleted_files", "build_status"]
|
||||
},
|
||||
"/workflow:execute": {
|
||||
"artifact": ".workflow/{session}/execution_log.json",
|
||||
"format": "execution_report",
|
||||
"fields": ["completed_tasks", "pending_tasks", "errors", "warnings"]
|
||||
},
|
||||
"/workflow:test-cycle-execute": {
|
||||
"artifact": ".workflow/{session}/test_results.json",
|
||||
"format": "test_report",
|
||||
"fields": ["pass_rate", "failures", "coverage", "duration"]
|
||||
},
|
||||
"/workflow:review-session-cycle": {
|
||||
"artifact": ".workflow/{session}/review_report.md",
|
||||
"format": "review_report",
|
||||
"fields": ["findings", "severity_counts", "recommendations"]
|
||||
},
|
||||
"/workflow:lite-fix": {
|
||||
"artifact": "git_changes",
|
||||
"format": "fix_report",
|
||||
"fields": ["root_cause", "fix_applied", "files_modified", "verification_status"]
|
||||
}
|
||||
},
|
||||
|
||||
"intent_extraction": {
|
||||
"plan_to_execute": {
|
||||
"from": ["lite-plan", "plan", "multi-cli-plan"],
|
||||
"to": ["lite-execute", "execute"],
|
||||
"extract": {
|
||||
"tasks": "$.tasks[] | filter(status != 'completed')",
|
||||
"priority_order": "$.tasks | sort_by(priority)",
|
||||
"files_to_modify": "$.tasks[].files | flatten | unique",
|
||||
"dependencies": "$.dependencies",
|
||||
"context_summary": "$.approach OR $.recommended_approach"
|
||||
}
|
||||
},
|
||||
"execute_to_test": {
|
||||
"from": ["lite-execute", "execute"],
|
||||
"to": ["test-cycle-execute", "test-fix-gen"],
|
||||
"extract": {
|
||||
"modified_files": "$.modified_files",
|
||||
"test_scope": "infer_from($.modified_files)",
|
||||
"build_status": "$.build_status",
|
||||
"pending_verification": "$.completed_tasks | needs_test"
|
||||
}
|
||||
},
|
||||
"test_to_fix": {
|
||||
"from": ["test-cycle-execute"],
|
||||
"to": ["lite-fix", "review-fix"],
|
||||
"condition": "$.pass_rate < 0.95",
|
||||
"extract": {
|
||||
"failures": "$.failures",
|
||||
"error_messages": "$.failures[].message",
|
||||
"affected_files": "$.failures[].file",
|
||||
"suggested_fixes": "$.failures[].suggested_fix"
|
||||
}
|
||||
},
|
||||
"review_to_fix": {
|
||||
"from": ["review-session-cycle", "review-module-cycle"],
|
||||
"to": ["review-fix"],
|
||||
"condition": "$.severity_counts.critical > 0 OR $.severity_counts.high > 3",
|
||||
"extract": {
|
||||
"findings": "$.findings | filter(severity in ['critical', 'high'])",
|
||||
"fix_priority": "$.findings | group_by(category) | sort_by(severity)",
|
||||
"affected_files": "$.findings[].file | unique"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
"completion_criteria": {
|
||||
"plan": {
|
||||
"required": ["has_tasks", "has_files"],
|
||||
"optional": ["has_tests", "no_blocking_risks"],
|
||||
"threshold": 0.8,
|
||||
"routing": {
|
||||
"complete": "proceed_to_execute",
|
||||
"incomplete": "clarify_requirements"
|
||||
}
|
||||
},
|
||||
"execute": {
|
||||
"required": ["all_tasks_attempted", "no_critical_errors"],
|
||||
"optional": ["build_passes", "lint_passes"],
|
||||
"threshold": 1.0,
|
||||
"routing": {
|
||||
"complete": "proceed_to_test_or_review",
|
||||
"partial": "continue_execution",
|
||||
"failed": "diagnose_and_retry"
|
||||
}
|
||||
},
|
||||
"test": {
|
||||
"metrics": {
|
||||
"pass_rate": { "target": 0.95, "minimum": 0.80 },
|
||||
"coverage": { "target": 0.80, "minimum": 0.60 }
|
||||
},
|
||||
"routing": {
|
||||
"pass_rate >= 0.95 AND coverage >= 0.80": "complete",
|
||||
"pass_rate >= 0.95 AND coverage < 0.80": "add_more_tests",
|
||||
"pass_rate >= 0.80": "fix_failures_then_continue",
|
||||
"pass_rate < 0.80": "major_fix_required"
|
||||
}
|
||||
},
|
||||
"review": {
|
||||
"metrics": {
|
||||
"critical_findings": { "target": 0, "maximum": 0 },
|
||||
"high_findings": { "target": 0, "maximum": 3 }
|
||||
},
|
||||
"routing": {
|
||||
"critical == 0 AND high <= 3": "complete_or_optional_fix",
|
||||
"critical > 0": "mandatory_fix",
|
||||
"high > 3": "recommended_fix"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
"flow_decisions": {
|
||||
"_description": "根据产出完成度决定下一步",
|
||||
"patterns": {
|
||||
"plan_execute_test": {
|
||||
"sequence": ["plan", "execute", "test"],
|
||||
"on_test_fail": {
|
||||
"action": "extract_failures_and_fix",
|
||||
"max_iterations": 3,
|
||||
"fallback": "manual_intervention"
|
||||
}
|
||||
},
|
||||
"plan_execute_review": {
|
||||
"sequence": ["plan", "execute", "review"],
|
||||
"on_review_issues": {
|
||||
"action": "prioritize_and_fix",
|
||||
"auto_fix_threshold": "severity < high"
|
||||
}
|
||||
},
|
||||
"iterative_improvement": {
|
||||
"sequence": ["execute", "test", "fix"],
|
||||
"loop_until": "pass_rate >= 0.95 OR iterations >= 3",
|
||||
"on_loop_exit": "report_status"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
43
README.md
43
README.md
@@ -263,6 +263,49 @@ Open Dashboard via `ccw view`, manage indexes and execute searches in **CodexLen
|
||||
|
||||
## 💻 CCW CLI Commands
|
||||
|
||||
### 🌟 Recommended Commands (Main Features)
|
||||
|
||||
<div align="center">
|
||||
<table>
|
||||
<tr><th>Command</th><th>Description</th><th>When to Use</th></tr>
|
||||
<tr>
|
||||
<td><b>/ccw</b></td>
|
||||
<td>Auto workflow orchestrator - analyzes intent, selects workflow level, executes command chain in main process</td>
|
||||
<td>✅ General tasks, auto workflow selection, quick development</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b>/ccw-coordinator</b></td>
|
||||
<td>Manual orchestrator - recommends command chains, executes via external CLI with state persistence</td>
|
||||
<td>🔧 Complex multi-step workflows, custom chains, resumable sessions</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
**Quick Examples**:
|
||||
|
||||
```bash
|
||||
# /ccw - Auto workflow selection (Main Process)
|
||||
/ccw "Add user authentication" # Auto-selects workflow based on intent
|
||||
/ccw "Fix memory leak in WebSocket" # Detects bugfix workflow
|
||||
/ccw "Implement with TDD" # Routes to TDD workflow
|
||||
|
||||
# /ccw-coordinator - Manual chain orchestration (External CLI)
|
||||
/ccw-coordinator "Implement OAuth2 system" # Analyzes → Recommends chain → User confirms → Executes
|
||||
```
|
||||
|
||||
**Key Differences**:
|
||||
|
||||
| Aspect | /ccw | /ccw-coordinator |
|
||||
|--------|------|------------------|
|
||||
| **Execution** | Main process (SlashCommand) | External CLI (background tasks) |
|
||||
| **Selection** | Auto intent-based | Manual chain confirmation |
|
||||
| **State** | TodoWrite tracking | Persistent state.json |
|
||||
| **Use Case** | General tasks, quick dev | Complex chains, resumable |
|
||||
|
||||
---
|
||||
|
||||
### Other CLI Commands
|
||||
|
||||
```bash
|
||||
ccw install # Install workflow files
|
||||
ccw view # Open dashboard
|
||||
|
||||
43
README_CN.md
43
README_CN.md
@@ -263,6 +263,49 @@ codexlens index /path/to/project
|
||||
|
||||
## 💻 CCW CLI 命令
|
||||
|
||||
### 🌟 推荐命令(核心功能)
|
||||
|
||||
<div align="center">
|
||||
<table>
|
||||
<tr><th>命令</th><th>说明</th><th>适用场景</th></tr>
|
||||
<tr>
|
||||
<td><b>/ccw</b></td>
|
||||
<td>自动工作流编排器 - 分析意图、自动选择工作流级别、在主进程中执行命令链</td>
|
||||
<td>✅ 通用任务、自动选择工作流、快速开发</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b>/ccw-coordinator</b></td>
|
||||
<td>手动编排器 - 推荐命令链、通过外部 CLI 执行、持久化状态</td>
|
||||
<td>🔧 复杂多步骤工作流、自定义链、可恢复会话</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
**快速示例**:
|
||||
|
||||
```bash
|
||||
# /ccw - 自动工作流选择(主进程)
|
||||
/ccw "添加用户认证" # 自动根据意图选择工作流
|
||||
/ccw "修复 WebSocket 中的内存泄漏" # 识别为 bugfix 工作流
|
||||
/ccw "使用 TDD 方式实现" # 路由到 TDD 工作流
|
||||
|
||||
# /ccw-coordinator - 手动链编排(外部 CLI)
|
||||
/ccw-coordinator "实现 OAuth2 系统" # 分析 → 推荐链 → 用户确认 → 执行
|
||||
```
|
||||
|
||||
**主要区别**:
|
||||
|
||||
| 方面 | /ccw | /ccw-coordinator |
|
||||
|------|------|------------------|
|
||||
| **执行方式** | 主进程(SlashCommand) | 外部 CLI(后台任务) |
|
||||
| **选择方式** | 自动基于意图识别 | 手动链确认 |
|
||||
| **状态管理** | TodoWrite 跟踪 | 持久化 state.json |
|
||||
| **适用场景** | 通用任务、快速开发 | 复杂链条、可恢复 |
|
||||
|
||||
---
|
||||
|
||||
### 其他 CLI 命令
|
||||
|
||||
```bash
|
||||
ccw install # 安装工作流文件
|
||||
ccw view # 打开 Dashboard
|
||||
|
||||
669
ccw/src/tools/command-registry.test.ts
Normal file
669
ccw/src/tools/command-registry.test.ts
Normal file
@@ -0,0 +1,669 @@
|
||||
/**
|
||||
* CommandRegistry Tests
|
||||
*
|
||||
* Test coverage:
|
||||
* - YAML header parsing
|
||||
* - Command metadata extraction
|
||||
* - Directory detection (relative and home)
|
||||
* - Caching mechanism
|
||||
* - Batch operations
|
||||
* - Categorization
|
||||
* - Error handling
|
||||
*/
|
||||
|
||||
import { CommandRegistry, createCommandRegistry, getAllCommandsSync, getCommandSync } from './command-registry';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
|
||||
// Mock fs module
|
||||
jest.mock('fs');
|
||||
jest.mock('os');
|
||||
|
||||
describe('CommandRegistry', () => {
|
||||
const mockReadFileSync = fs.readFileSync as jest.MockedFunction<typeof fs.readFileSync>;
|
||||
const mockExistsSync = fs.existsSync as jest.MockedFunction<typeof fs.existsSync>;
|
||||
const mockReaddirSync = fs.readdirSync as jest.MockedFunction<typeof fs.readdirSync>;
|
||||
const mockStatSync = fs.statSync as jest.MockedFunction<typeof fs.statSync>;
|
||||
const mockHomedir = os.homedir as jest.MockedFunction<typeof os.homedir>;
|
||||
|
||||
// Sample YAML headers
|
||||
const sampleLitePlanYaml = `---
|
||||
name: lite-plan
|
||||
description: Quick planning for simple features
|
||||
argument-hint: "\"feature description\""
|
||||
allowed-tools: Task(*), Read(*), Write(*), Bash(*)
|
||||
---
|
||||
|
||||
# Content here`;
|
||||
|
||||
const sampleExecuteYaml = `---
|
||||
name: execute
|
||||
description: Execute implementation from plan
|
||||
argument-hint: "--resume-session=\"WFS-xxx\""
|
||||
allowed-tools: Task(*), Bash(*)
|
||||
---
|
||||
|
||||
# Content here`;
|
||||
|
||||
const sampleTestYaml = `---
|
||||
name: test-cycle-execute
|
||||
description: Run tests and fix failures
|
||||
argument-hint: "--session=\"WFS-xxx\""
|
||||
allowed-tools: Task(*), Bash(*)
|
||||
---
|
||||
|
||||
# Content here`;
|
||||
|
||||
const sampleReviewYaml = `---
|
||||
name: review
|
||||
description: Code review workflow
|
||||
argument-hint: "--session=\"WFS-xxx\""
|
||||
allowed-tools: Task(*), Read(*)
|
||||
---
|
||||
|
||||
# Content here`;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('constructor & directory detection', () => {
|
||||
it('should use provided command directory', () => {
|
||||
const customDir = '/custom/path';
|
||||
const registry = new CommandRegistry(customDir);
|
||||
|
||||
expect((registry as any).commandDir).toBe(customDir);
|
||||
});
|
||||
|
||||
it('should auto-detect relative .claude/commands/workflow directory', () => {
|
||||
mockExistsSync.mockImplementation((path: string) => {
|
||||
return path === '.claude/commands/workflow';
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry();
|
||||
|
||||
expect((registry as any).commandDir).toBe('.claude/commands/workflow');
|
||||
expect(mockExistsSync).toHaveBeenCalledWith('.claude/commands/workflow');
|
||||
});
|
||||
|
||||
it('should auto-detect home directory ~/.claude/commands/workflow', () => {
|
||||
mockExistsSync.mockImplementation((checkPath: string) => {
|
||||
return checkPath === path.join('/home/user', '.claude', 'commands', 'workflow');
|
||||
});
|
||||
mockHomedir.mockReturnValue('/home/user');
|
||||
|
||||
const registry = new CommandRegistry();
|
||||
|
||||
expect((registry as any).commandDir).toBe(
|
||||
path.join('/home/user', '.claude', 'commands', 'workflow')
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null if no command directory found', () => {
|
||||
mockExistsSync.mockReturnValue(false);
|
||||
mockHomedir.mockReturnValue('/home/user');
|
||||
|
||||
const registry = new CommandRegistry();
|
||||
|
||||
expect((registry as any).commandDir).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseYamlHeader', () => {
|
||||
it('should parse simple YAML header with Unix line endings', () => {
|
||||
const yaml = `---
|
||||
name: test-command
|
||||
description: Test description
|
||||
argument-hint: "\"test\""
|
||||
allowed-tools: Task(*), Read(*)
|
||||
---
|
||||
|
||||
Content here`;
|
||||
|
||||
const registry = new CommandRegistry('/fake/path');
|
||||
const result = (registry as any).parseYamlHeader(yaml);
|
||||
|
||||
expect(result).toEqual({
|
||||
name: 'test-command',
|
||||
description: 'Test description',
|
||||
'argument-hint': '"test"',
|
||||
'allowed-tools': 'Task(*), Read(*)'
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse YAML header with Windows line endings (\\r\\n)', () => {
|
||||
const yaml = `---\r\nname: test-command\r\ndescription: Test\r\n---`;
|
||||
|
||||
const registry = new CommandRegistry('/fake/path');
|
||||
const result = (registry as any).parseYamlHeader(yaml);
|
||||
|
||||
expect(result).toEqual({
|
||||
name: 'test-command',
|
||||
description: 'Test'
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle quoted values', () => {
|
||||
const yaml = `---
|
||||
name: "cmd"
|
||||
description: 'double quoted'
|
||||
---`;
|
||||
|
||||
const registry = new CommandRegistry('/fake/path');
|
||||
const result = (registry as any).parseYamlHeader(yaml);
|
||||
|
||||
expect(result).toEqual({
|
||||
name: 'cmd',
|
||||
description: 'double quoted'
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse allowed-tools and trim spaces', () => {
|
||||
const yaml = `---
|
||||
name: test
|
||||
allowed-tools: Task(*), Read(*) , Write(*), Bash(*)
|
||||
---`;
|
||||
|
||||
const registry = new CommandRegistry('/fake/path');
|
||||
const result = (registry as any).parseYamlHeader(yaml);
|
||||
|
||||
expect(result['allowed-tools']).toBe('Task(*), Read(*), Write(*), Bash(*)');
|
||||
});
|
||||
|
||||
it('should skip comments and empty lines', () => {
|
||||
const yaml = `---
|
||||
# This is a comment
|
||||
name: test-command
|
||||
|
||||
# Another comment
|
||||
description: Test
|
||||
|
||||
---`;
|
||||
|
||||
const registry = new CommandRegistry('/fake/path');
|
||||
const result = (registry as any).parseYamlHeader(yaml);
|
||||
|
||||
expect(result).toEqual({
|
||||
name: 'test-command',
|
||||
description: 'Test'
|
||||
});
|
||||
});
|
||||
|
||||
it('should return null for missing YAML markers', () => {
|
||||
const yaml = `name: test-command
|
||||
description: Test`;
|
||||
|
||||
const registry = new CommandRegistry('/fake/path');
|
||||
const result = (registry as any).parseYamlHeader(yaml);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null for malformed YAML', () => {
|
||||
const yaml = `---
|
||||
invalid yaml content without colons
|
||||
---`;
|
||||
|
||||
const registry = new CommandRegistry('/fake/path');
|
||||
const result = (registry as any).parseYamlHeader(yaml);
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCommand', () => {
|
||||
it('should get command metadata by name', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockImplementation((checkPath: string) => {
|
||||
return checkPath === path.join(cmdDir, 'lite-plan.md');
|
||||
});
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('lite-plan');
|
||||
|
||||
expect(result).toEqual({
|
||||
name: 'lite-plan',
|
||||
command: '/workflow:lite-plan',
|
||||
description: 'Quick planning for simple features',
|
||||
argumentHint: '"feature description"',
|
||||
allowedTools: ['Task(*)', 'Read(*)', 'Write(*)', 'Bash(*)'],
|
||||
filePath: path.join(cmdDir, 'lite-plan.md')
|
||||
});
|
||||
});
|
||||
|
||||
it('should normalize /workflow: prefix', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('/workflow:lite-plan');
|
||||
|
||||
expect(result?.name).toBe('lite-plan');
|
||||
});
|
||||
|
||||
it('should use cache for repeated requests', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
|
||||
registry.getCommand('lite-plan');
|
||||
registry.getCommand('lite-plan');
|
||||
|
||||
// readFileSync should only be called once due to cache
|
||||
expect(mockReadFileSync).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should return null if command file not found', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(false);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('nonexistent');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null if no command directory', () => {
|
||||
mockExistsSync.mockReturnValue(false);
|
||||
mockHomedir.mockReturnValue('/home/user');
|
||||
|
||||
const registry = new CommandRegistry();
|
||||
const result = registry.getCommand('lite-plan');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null if YAML header is invalid', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockReturnValue('No YAML header here');
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('lite-plan');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should parse allowedTools correctly', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockReturnValue(sampleExecuteYaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('execute');
|
||||
|
||||
expect(result?.allowedTools).toEqual(['Task(*)', 'Bash(*)']);
|
||||
});
|
||||
|
||||
it('should handle empty allowedTools', () => {
|
||||
const yaml = `---
|
||||
name: minimal-cmd
|
||||
description: Minimal command
|
||||
---`;
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockReturnValue(yaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('minimal-cmd');
|
||||
|
||||
expect(result?.allowedTools).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCommands', () => {
|
||||
it('should get multiple commands', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
|
||||
if (filePath.includes('execute')) return sampleExecuteYaml;
|
||||
return '';
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommands(['lite-plan', 'execute', 'nonexistent']);
|
||||
|
||||
expect(result.size).toBe(2);
|
||||
expect(result.has('/workflow:lite-plan')).toBe(true);
|
||||
expect(result.has('/workflow:execute')).toBe(true);
|
||||
});
|
||||
|
||||
it('should skip nonexistent commands', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(false);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommands(['nonexistent1', 'nonexistent2']);
|
||||
|
||||
expect(result.size).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllCommandsSummary', () => {
|
||||
it('should get all commands summary', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md', 'test.md'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
|
||||
if (filePath.includes('execute')) return sampleExecuteYaml;
|
||||
if (filePath.includes('test')) return sampleTestYaml;
|
||||
return '';
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsSummary();
|
||||
|
||||
expect(result.size).toBe(3);
|
||||
expect(result.get('/workflow:lite-plan')).toEqual({
|
||||
name: 'lite-plan',
|
||||
description: 'Quick planning for simple features'
|
||||
});
|
||||
});
|
||||
|
||||
it('should skip directories', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['file.md', 'directory'] as any);
|
||||
mockStatSync.mockImplementation((filePath: string) => ({
|
||||
isDirectory: () => filePath.includes('directory')
|
||||
} as any));
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsSummary();
|
||||
|
||||
// Only file.md should be processed
|
||||
expect(mockReadFileSync).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should skip files with invalid YAML headers', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['valid.md', 'invalid.md'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('valid')) return sampleLitePlanYaml;
|
||||
return 'No YAML header';
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsSummary();
|
||||
|
||||
expect(result.size).toBe(1);
|
||||
});
|
||||
|
||||
it('should return empty map if no command directory', () => {
|
||||
mockExistsSync.mockReturnValue(false);
|
||||
mockHomedir.mockReturnValue('/home/user');
|
||||
|
||||
const registry = new CommandRegistry();
|
||||
const result = registry.getAllCommandsSummary();
|
||||
|
||||
expect(result.size).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle directory read errors gracefully', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockImplementation(() => {
|
||||
throw new Error('Permission denied');
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsSummary();
|
||||
|
||||
expect(result.size).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllCommandsByCategory', () => {
|
||||
it('should categorize commands by name patterns', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md', 'test-cycle-execute.md', 'review.md'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
|
||||
if (filePath.includes('execute')) return sampleExecuteYaml;
|
||||
if (filePath.includes('test')) return sampleTestYaml;
|
||||
if (filePath.includes('review')) return sampleReviewYaml;
|
||||
return '';
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsByCategory();
|
||||
|
||||
expect(result.planning.length).toBe(1);
|
||||
expect(result.execution.length).toBe(1);
|
||||
expect(result.testing.length).toBe(1);
|
||||
expect(result.review.length).toBe(1);
|
||||
expect(result.other.length).toBe(0);
|
||||
|
||||
expect(result.planning[0].name).toBe('lite-plan');
|
||||
expect(result.execution[0].name).toBe('execute');
|
||||
});
|
||||
|
||||
it('should handle commands matching multiple patterns', () => {
|
||||
const yamlMultiMatch = `---
|
||||
name: test-plan
|
||||
description: TDD planning
|
||||
allowed-tools: Task(*)
|
||||
---`;
|
||||
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['test-plan.md'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockReturnValue(yamlMultiMatch);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsByCategory();
|
||||
|
||||
// Should match 'plan' pattern (planning)
|
||||
expect(result.planning.length).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('toJSON', () => {
|
||||
it('should serialize cached commands to JSON', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
registry.getCommand('lite-plan');
|
||||
|
||||
const json = registry.toJSON();
|
||||
|
||||
expect(json['/workflow:lite-plan']).toEqual({
|
||||
name: 'lite-plan',
|
||||
command: '/workflow:lite-plan',
|
||||
description: 'Quick planning for simple features',
|
||||
argumentHint: '"feature description"',
|
||||
allowedTools: ['Task(*)', 'Read(*)', 'Write(*)', 'Bash(*)'],
|
||||
filePath: path.join(cmdDir, 'lite-plan.md')
|
||||
});
|
||||
});
|
||||
|
||||
it('should only include cached commands', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
|
||||
return sampleExecuteYaml;
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
registry.getCommand('lite-plan');
|
||||
// Don't call getCommand for 'execute'
|
||||
|
||||
const json = registry.toJSON();
|
||||
|
||||
expect(Object.keys(json).length).toBe(1);
|
||||
expect(json['/workflow:lite-plan']).toBeDefined();
|
||||
expect(json['/workflow:execute']).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('exported functions', () => {
|
||||
it('createCommandRegistry should create new instance', () => {
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
|
||||
const registry = createCommandRegistry('/custom/path');
|
||||
|
||||
expect((registry as any).commandDir).toBe('/custom/path');
|
||||
});
|
||||
|
||||
it('getAllCommandsSync should return all commands', () => {
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['lite-plan.md'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
mockHomedir.mockReturnValue('/home/user');
|
||||
|
||||
const result = getAllCommandsSync();
|
||||
|
||||
expect(result.size).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
it('getCommandSync should return specific command', () => {
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
mockHomedir.mockReturnValue('/home/user');
|
||||
|
||||
const result = getCommandSync('lite-plan');
|
||||
|
||||
expect(result?.name).toBe('lite-plan');
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle file read errors', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReadFileSync.mockImplementation(() => {
|
||||
throw new Error('File read error');
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('lite-plan');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle YAML parsing errors', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
// Return something that will cause parsing to fail
|
||||
mockReadFileSync.mockReturnValue('---\ninvalid: : : yaml\n---');
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getCommand('lite-plan');
|
||||
|
||||
// Should return null since name is not in result
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle empty command directory', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue([] as any);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsSummary();
|
||||
|
||||
expect(result.size).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle non-md files in command directory', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['lite-plan.md', 'readme.txt', '.gitignore'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockReturnValue(sampleLitePlanYaml);
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
const result = registry.getAllCommandsSummary();
|
||||
|
||||
expect(result.size).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('integration tests', () => {
|
||||
it('should work with full workflow', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md', 'test-cycle-execute.md'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
|
||||
if (filePath.includes('execute')) return sampleExecuteYaml;
|
||||
if (filePath.includes('test')) return sampleTestYaml;
|
||||
return '';
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
|
||||
// Get all summary
|
||||
const summary = registry.getAllCommandsSummary();
|
||||
expect(summary.size).toBe(3);
|
||||
|
||||
// Get by category
|
||||
const byCategory = registry.getAllCommandsByCategory();
|
||||
expect(byCategory.planning.length).toBe(1);
|
||||
expect(byCategory.execution.length).toBe(1);
|
||||
expect(byCategory.testing.length).toBe(1);
|
||||
|
||||
// Get specific command
|
||||
const cmd = registry.getCommand('lite-plan');
|
||||
expect(cmd?.name).toBe('lite-plan');
|
||||
|
||||
// Get multiple commands
|
||||
const multiple = registry.getCommands(['lite-plan', 'execute']);
|
||||
expect(multiple.size).toBe(2);
|
||||
|
||||
// Convert to JSON
|
||||
const json = registry.toJSON();
|
||||
expect(Object.keys(json).length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should maintain cache across operations', () => {
|
||||
const cmdDir = '/workflows';
|
||||
mockExistsSync.mockReturnValue(true);
|
||||
mockReaddirSync.mockReturnValue(['lite-plan.md', 'execute.md'] as any);
|
||||
mockStatSync.mockReturnValue({ isDirectory: () => false } as any);
|
||||
mockReadFileSync.mockImplementation((filePath: string) => {
|
||||
if (filePath.includes('lite-plan')) return sampleLitePlanYaml;
|
||||
return sampleExecuteYaml;
|
||||
});
|
||||
|
||||
const registry = new CommandRegistry(cmdDir);
|
||||
|
||||
// First call
|
||||
registry.getCommand('lite-plan');
|
||||
const initialCallCount = mockReadFileSync.mock.calls.length;
|
||||
|
||||
// getAllCommandsSummary will read all files
|
||||
registry.getAllCommandsSummary();
|
||||
const afterSummaryCallCount = mockReadFileSync.mock.calls.length;
|
||||
|
||||
// Second getCommand should use cache
|
||||
registry.getCommand('lite-plan');
|
||||
const finalCallCount = mockReadFileSync.mock.calls.length;
|
||||
|
||||
// lite-plan.md should only be read twice:
|
||||
// 1. Initial getCommand
|
||||
// 2. getAllCommandsSummary (must read all files)
|
||||
// Not again in second getCommand due to cache
|
||||
expect(finalCallCount).toBe(afterSummaryCallCount);
|
||||
});
|
||||
});
|
||||
});
|
||||
308
ccw/src/tools/command-registry.ts
Normal file
308
ccw/src/tools/command-registry.ts
Normal file
@@ -0,0 +1,308 @@
|
||||
/**
|
||||
* Command Registry Tool
|
||||
*
|
||||
* Features:
|
||||
* 1. Scan and parse YAML headers from command files
|
||||
* 2. Read from global ~/.claude/commands/workflow directory
|
||||
* 3. Support on-demand extraction (not full scan)
|
||||
* 4. Cache parsed metadata for performance
|
||||
*/
|
||||
|
||||
import { existsSync, readdirSync, readFileSync, statSync } from 'fs';
|
||||
import { join } from 'path';
|
||||
import { homedir } from 'os';
|
||||
|
||||
export interface CommandMetadata {
|
||||
name: string;
|
||||
command: string;
|
||||
description: string;
|
||||
argumentHint: string;
|
||||
allowedTools: string[];
|
||||
filePath: string;
|
||||
}
|
||||
|
||||
export interface CommandSummary {
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export class CommandRegistry {
|
||||
private commandDir: string | null;
|
||||
private cache: Map<string, CommandMetadata>;
|
||||
|
||||
constructor(commandDir?: string) {
|
||||
this.cache = new Map();
|
||||
|
||||
if (commandDir) {
|
||||
this.commandDir = commandDir;
|
||||
} else {
|
||||
this.commandDir = this.findCommandDir();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Auto-detect ~/.claude/commands/workflow directory
|
||||
*/
|
||||
private findCommandDir(): string | null {
|
||||
// Try relative to current working directory
|
||||
const relativePath = join('.claude', 'commands', 'workflow');
|
||||
if (existsSync(relativePath)) {
|
||||
return relativePath;
|
||||
}
|
||||
|
||||
// Try user home directory
|
||||
const homeDir = homedir();
|
||||
const homeCommandDir = join(homeDir, '.claude', 'commands', 'workflow');
|
||||
if (existsSync(homeCommandDir)) {
|
||||
return homeCommandDir;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse YAML header (simplified version)
|
||||
*
|
||||
* Limitations:
|
||||
* - Only supports simple key: value pairs (single-line values)
|
||||
* - No support for multi-line values, nested objects, complex lists
|
||||
* - allowed-tools field converts comma-separated strings to arrays
|
||||
*/
|
||||
private parseYamlHeader(content: string): Record<string, any> | null {
|
||||
// Handle Windows line endings (\r\n)
|
||||
const match = content.match(/^---[\r\n]+([\s\S]*?)[\r\n]+---/);
|
||||
if (!match) return null;
|
||||
|
||||
const yamlContent = match[1];
|
||||
const result: Record<string, any> = {};
|
||||
|
||||
try {
|
||||
const lines = yamlContent.split(/[\r\n]+/);
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || trimmed.startsWith('#')) continue; // Skip empty lines and comments
|
||||
|
||||
const colonIndex = trimmed.indexOf(':');
|
||||
if (colonIndex === -1) continue;
|
||||
|
||||
const key = trimmed.substring(0, colonIndex).trim();
|
||||
let value = trimmed.substring(colonIndex + 1).trim();
|
||||
|
||||
if (!key) continue; // Skip invalid lines
|
||||
|
||||
// Remove quotes (single or double)
|
||||
let cleanValue = value.replace(/^["']|["']$/g, '');
|
||||
|
||||
// Special handling for allowed-tools field: convert to array
|
||||
// Supports format: "Read, Write, Bash" or "Read,Write,Bash"
|
||||
if (key === 'allowed-tools') {
|
||||
cleanValue = cleanValue
|
||||
.split(',')
|
||||
.map(t => t.trim())
|
||||
.filter(t => t)
|
||||
.join(','); // Keep as comma-separated for now, will convert in getCommand
|
||||
}
|
||||
|
||||
result[key] = cleanValue;
|
||||
}
|
||||
} catch (error) {
|
||||
const err = error as Error;
|
||||
console.error('YAML parsing error:', err.message);
|
||||
return null;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get single command metadata
|
||||
* @param commandName Command name (e.g., "lite-plan" or "/workflow:lite-plan")
|
||||
* @returns Command metadata or null
|
||||
*/
|
||||
public getCommand(commandName: string): CommandMetadata | null {
|
||||
if (!this.commandDir) {
|
||||
console.error('ERROR: ~/.claude/commands/workflow directory not found');
|
||||
return null;
|
||||
}
|
||||
|
||||
// Normalize command name
|
||||
const normalized = commandName.startsWith('/workflow:')
|
||||
? commandName.substring('/workflow:'.length)
|
||||
: commandName;
|
||||
|
||||
// Check cache
|
||||
const cached = this.cache.get(normalized);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
// Read command file
|
||||
const filePath = join(this.commandDir, `${normalized}.md`);
|
||||
if (!existsSync(filePath)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(filePath, 'utf-8');
|
||||
const header = this.parseYamlHeader(content);
|
||||
|
||||
if (header && header.name) {
|
||||
const toolsStr = header['allowed-tools'] || '';
|
||||
const allowedTools = toolsStr
|
||||
.split(',')
|
||||
.map((t: string) => t.trim())
|
||||
.filter((t: string) => t);
|
||||
|
||||
const result: CommandMetadata = {
|
||||
name: header.name,
|
||||
command: `/workflow:${header.name}`,
|
||||
description: header.description || '',
|
||||
argumentHint: header['argument-hint'] || '',
|
||||
allowedTools: allowedTools,
|
||||
filePath: filePath
|
||||
};
|
||||
|
||||
// Cache result
|
||||
this.cache.set(normalized, result);
|
||||
return result;
|
||||
}
|
||||
} catch (error) {
|
||||
const err = error as Error;
|
||||
console.error(`Failed to read command ${filePath}:`, err.message);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get multiple commands metadata
|
||||
* @param commandNames Array of command names
|
||||
* @returns Map of command metadata
|
||||
*/
|
||||
public getCommands(commandNames: string[]): Map<string, CommandMetadata> {
|
||||
const result = new Map<string, CommandMetadata>();
|
||||
|
||||
for (const name of commandNames) {
|
||||
const cmd = this.getCommand(name);
|
||||
if (cmd) {
|
||||
result.set(cmd.command, cmd);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all commands' names and descriptions
|
||||
* @returns Map of command names to summaries
|
||||
*/
|
||||
public getAllCommandsSummary(): Map<string, CommandSummary> {
|
||||
const result = new Map<string, CommandSummary>();
|
||||
|
||||
if (!this.commandDir) {
|
||||
return result;
|
||||
}
|
||||
|
||||
try {
|
||||
const files = readdirSync(this.commandDir);
|
||||
|
||||
for (const file of files) {
|
||||
if (!file.endsWith('.md')) continue;
|
||||
|
||||
const filePath = join(this.commandDir, file);
|
||||
const stat = statSync(filePath);
|
||||
|
||||
if (stat.isDirectory()) continue;
|
||||
|
||||
try {
|
||||
const content = readFileSync(filePath, 'utf-8');
|
||||
const header = this.parseYamlHeader(content);
|
||||
|
||||
if (header && header.name) {
|
||||
const commandName = `/workflow:${header.name}`;
|
||||
result.set(commandName, {
|
||||
name: header.name,
|
||||
description: header.description || ''
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
// Skip files that fail to read
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Return empty map if directory read fails
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all commands organized by category/tags
|
||||
*/
|
||||
public getAllCommandsByCategory(): Record<string, CommandMetadata[]> {
|
||||
const summary = this.getAllCommandsSummary();
|
||||
const result: Record<string, CommandMetadata[]> = {
|
||||
planning: [],
|
||||
execution: [],
|
||||
testing: [],
|
||||
review: [],
|
||||
other: []
|
||||
};
|
||||
|
||||
for (const [cmdName] of summary) {
|
||||
const cmd = this.getCommand(cmdName);
|
||||
if (cmd) {
|
||||
// Categorize based on command name patterns
|
||||
if (cmd.name.includes('plan')) {
|
||||
result.planning.push(cmd);
|
||||
} else if (cmd.name.includes('execute')) {
|
||||
result.execution.push(cmd);
|
||||
} else if (cmd.name.includes('test')) {
|
||||
result.testing.push(cmd);
|
||||
} else if (cmd.name.includes('review')) {
|
||||
result.review.push(cmd);
|
||||
} else {
|
||||
result.other.push(cmd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert to JSON for serialization
|
||||
*/
|
||||
public toJSON(): Record<string, any> {
|
||||
const result: Record<string, CommandMetadata> = {};
|
||||
for (const [key, value] of this.cache) {
|
||||
result[`/workflow:${key}`] = value;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Export function for direct usage
|
||||
*/
|
||||
export function createCommandRegistry(commandDir?: string): CommandRegistry {
|
||||
return new CommandRegistry(commandDir);
|
||||
}
|
||||
|
||||
/**
|
||||
* Export function to get all commands
|
||||
*/
|
||||
export function getAllCommandsSync(): Map<string, CommandSummary> {
|
||||
const registry = new CommandRegistry();
|
||||
return registry.getAllCommandsSummary();
|
||||
}
|
||||
|
||||
/**
|
||||
* Export function to get specific command
|
||||
*/
|
||||
export function getCommandSync(name: string): CommandMetadata | null {
|
||||
const registry = new CommandRegistry();
|
||||
return registry.getCommand(name);
|
||||
}
|
||||
@@ -378,3 +378,7 @@ export { registerTool };
|
||||
|
||||
// Export ToolSchema type
|
||||
export type { ToolSchema };
|
||||
|
||||
// Export CommandRegistry for direct import
|
||||
export { CommandRegistry, createCommandRegistry, getAllCommandsSync, getCommandSync } from './command-registry.js';
|
||||
export type { CommandMetadata, CommandSummary } from './command-registry.js';
|
||||
|
||||
@@ -19,5 +19,5 @@
|
||||
"noEmit": false
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["src/templates/**/*", "node_modules", "dist"]
|
||||
"exclude": ["src/templates/**/*", "src/**/*.test.ts", "node_modules", "dist"]
|
||||
}
|
||||
|
||||
21
codex-lens/LICENSE
Normal file
21
codex-lens/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 CodexLens Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
59
codex-lens/README.md
Normal file
59
codex-lens/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# CodexLens
|
||||
|
||||
CodexLens is a multi-modal code analysis platform designed to provide comprehensive code understanding and analysis capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-language Support**: Analyze code in Python, JavaScript, TypeScript and more using Tree-sitter parsers
|
||||
- **Semantic Search**: Find relevant code snippets using semantic understanding with fastembed and HNSWLIB
|
||||
- **Code Parsing**: Advanced code structure parsing with tree-sitter
|
||||
- **Flexible Architecture**: Modular design for easy extension and customization
|
||||
|
||||
## Installation
|
||||
|
||||
### Basic Installation
|
||||
|
||||
```bash
|
||||
pip install codex-lens
|
||||
```
|
||||
|
||||
### With Semantic Search
|
||||
|
||||
```bash
|
||||
pip install codex-lens[semantic]
|
||||
```
|
||||
|
||||
### With GPU Acceleration (NVIDIA CUDA)
|
||||
|
||||
```bash
|
||||
pip install codex-lens[semantic-gpu]
|
||||
```
|
||||
|
||||
### With DirectML (Windows - NVIDIA/AMD/Intel)
|
||||
|
||||
```bash
|
||||
pip install codex-lens[semantic-directml]
|
||||
```
|
||||
|
||||
### With All Optional Features
|
||||
|
||||
```bash
|
||||
pip install codex-lens[full]
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python >= 3.10
|
||||
- See `pyproject.toml` for detailed dependency list
|
||||
|
||||
## Development
|
||||
|
||||
This project uses setuptools for building and packaging.
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
## Authors
|
||||
|
||||
CodexLens Contributors
|
||||
28
codex-lens/build/lib/codexlens/__init__.py
Normal file
28
codex-lens/build/lib/codexlens/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""CodexLens package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import config, entities, errors
|
||||
from .config import Config
|
||||
from .entities import IndexedFile, SearchResult, SemanticChunk, Symbol
|
||||
from .errors import CodexLensError, ConfigError, ParseError, SearchError, StorageError
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"config",
|
||||
"entities",
|
||||
"errors",
|
||||
"Config",
|
||||
"IndexedFile",
|
||||
"SearchResult",
|
||||
"SemanticChunk",
|
||||
"Symbol",
|
||||
"CodexLensError",
|
||||
"ConfigError",
|
||||
"ParseError",
|
||||
"StorageError",
|
||||
"SearchError",
|
||||
]
|
||||
|
||||
14
codex-lens/build/lib/codexlens/__main__.py
Normal file
14
codex-lens/build/lib/codexlens/__main__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Module entrypoint for `python -m codexlens`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codexlens.cli import app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
88
codex-lens/build/lib/codexlens/api/__init__.py
Normal file
88
codex-lens/build/lib/codexlens/api/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Codexlens Public API Layer.
|
||||
|
||||
This module exports all public API functions and dataclasses for the
|
||||
codexlens LSP-like functionality.
|
||||
|
||||
Dataclasses (from models.py):
|
||||
- CallInfo: Call relationship information
|
||||
- MethodContext: Method context with call relationships
|
||||
- FileContextResult: File context result with method summaries
|
||||
- DefinitionResult: Definition lookup result
|
||||
- ReferenceResult: Reference lookup result
|
||||
- GroupedReferences: References grouped by definition
|
||||
- SymbolInfo: Symbol information for workspace search
|
||||
- HoverInfo: Hover information for a symbol
|
||||
- SemanticResult: Semantic search result
|
||||
|
||||
Utility functions (from utils.py):
|
||||
- resolve_project: Resolve and validate project root path
|
||||
- normalize_relationship_type: Normalize relationship type to canonical form
|
||||
- rank_by_proximity: Rank results by file path proximity
|
||||
|
||||
Example:
|
||||
>>> from codexlens.api import (
|
||||
... DefinitionResult,
|
||||
... resolve_project,
|
||||
... normalize_relationship_type
|
||||
... )
|
||||
>>> project = resolve_project("/path/to/project")
|
||||
>>> rel_type = normalize_relationship_type("calls")
|
||||
>>> print(rel_type)
|
||||
'call'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Dataclasses
|
||||
from .models import (
|
||||
CallInfo,
|
||||
MethodContext,
|
||||
FileContextResult,
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
SymbolInfo,
|
||||
HoverInfo,
|
||||
SemanticResult,
|
||||
)
|
||||
|
||||
# Utility functions
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
rank_by_proximity,
|
||||
rank_by_score,
|
||||
)
|
||||
|
||||
# API functions
|
||||
from .definition import find_definition
|
||||
from .symbols import workspace_symbols
|
||||
from .hover import get_hover
|
||||
from .file_context import file_context
|
||||
from .references import find_references
|
||||
from .semantic import semantic_search
|
||||
|
||||
__all__ = [
|
||||
# Dataclasses
|
||||
"CallInfo",
|
||||
"MethodContext",
|
||||
"FileContextResult",
|
||||
"DefinitionResult",
|
||||
"ReferenceResult",
|
||||
"GroupedReferences",
|
||||
"SymbolInfo",
|
||||
"HoverInfo",
|
||||
"SemanticResult",
|
||||
# Utility functions
|
||||
"resolve_project",
|
||||
"normalize_relationship_type",
|
||||
"rank_by_proximity",
|
||||
"rank_by_score",
|
||||
# API functions
|
||||
"find_definition",
|
||||
"workspace_symbols",
|
||||
"get_hover",
|
||||
"file_context",
|
||||
"find_references",
|
||||
"semantic_search",
|
||||
]
|
||||
126
codex-lens/build/lib/codexlens/api/definition.py
Normal file
126
codex-lens/build/lib/codexlens/api/definition.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""find_definition API implementation.
|
||||
|
||||
This module provides the find_definition() function for looking up
|
||||
symbol definitions with a 3-stage fallback strategy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import DefinitionResult
|
||||
from .utils import resolve_project, rank_by_proximity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_definition(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
file_context: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[DefinitionResult]:
|
||||
"""Find definition locations for a symbol.
|
||||
|
||||
Uses a 3-stage fallback strategy:
|
||||
1. Exact match with kind filter
|
||||
2. Exact match without kind filter
|
||||
3. Prefix match
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to find
|
||||
symbol_kind: Optional symbol kind filter (class, function, etc.)
|
||||
file_context: Optional file path for proximity ranking
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of DefinitionResult sorted by proximity if file_context provided
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Stage 1: Exact match with kind filter
|
||||
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 2: Exact match without kind (if kind was specified)
|
||||
if symbol_kind:
|
||||
results = _search_with_kind(global_index, symbol_name, None, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 3: Prefix match
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
if results:
|
||||
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
logger.debug(f"No definitions found for {symbol_name}")
|
||||
return []
|
||||
|
||||
|
||||
def _search_with_kind(
|
||||
global_index: GlobalSymbolIndex,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str],
|
||||
limit: int
|
||||
) -> List[Symbol]:
|
||||
"""Search for symbols with optional kind filter."""
|
||||
return global_index.search(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind,
|
||||
limit=limit,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
|
||||
def _rank_and_convert(
|
||||
symbols: List[Symbol],
|
||||
file_context: Optional[str]
|
||||
) -> List[DefinitionResult]:
|
||||
"""Convert symbols to DefinitionResult and rank by proximity."""
|
||||
results = [
|
||||
DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Could extract from file if needed
|
||||
container=None, # Could extract from parent symbol
|
||||
score=1.0
|
||||
)
|
||||
for sym in symbols
|
||||
]
|
||||
return rank_by_proximity(results, file_context)
|
||||
271
codex-lens/build/lib/codexlens/api/file_context.py
Normal file
271
codex-lens/build/lib/codexlens/api/file_context.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""file_context API implementation.
|
||||
|
||||
This module provides the file_context() function for retrieving
|
||||
method call graphs from a source file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.dir_index import DirIndexStore
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import (
|
||||
FileContextResult,
|
||||
MethodContext,
|
||||
CallInfo,
|
||||
)
|
||||
from .utils import resolve_project, normalize_relationship_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def file_context(
|
||||
project_root: str,
|
||||
file_path: str,
|
||||
include_calls: bool = True,
|
||||
include_callers: bool = True,
|
||||
max_depth: int = 1,
|
||||
format: str = "brief"
|
||||
) -> FileContextResult:
|
||||
"""Get method call context for a code file.
|
||||
|
||||
Retrieves all methods/functions in the file along with their
|
||||
outgoing calls and incoming callers.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
file_path: Path to the code file to analyze
|
||||
include_calls: Whether to include outgoing calls
|
||||
include_callers: Whether to include incoming callers
|
||||
max_depth: Call chain depth (V1 only supports 1)
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
FileContextResult with method contexts and summary
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
FileNotFoundError: If file does not exist
|
||||
ValueError: If max_depth > 1 (V1 limitation)
|
||||
"""
|
||||
# V1 limitation: only depth=1 supported
|
||||
if max_depth > 1:
|
||||
raise ValueError(
|
||||
f"max_depth > 1 not supported in V1. "
|
||||
f"Requested: {max_depth}, supported: 1"
|
||||
)
|
||||
|
||||
project_path = resolve_project(project_root)
|
||||
file_path_resolved = Path(file_path).resolve()
|
||||
|
||||
# Validate file exists
|
||||
if not file_path_resolved.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path_resolved}")
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Get all symbols in the file
|
||||
symbols = global_index.get_file_symbols(str(file_path_resolved))
|
||||
|
||||
# Filter to functions, methods, and classes
|
||||
method_symbols = [
|
||||
s for s in symbols
|
||||
if s.kind in ("function", "method", "class")
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
|
||||
|
||||
# Try to find dir_index for relationship queries
|
||||
dir_index = _find_dir_index(project_info, file_path_resolved)
|
||||
|
||||
# Build method contexts
|
||||
methods: List[MethodContext] = []
|
||||
outgoing_resolved = True
|
||||
incoming_resolved = True
|
||||
targets_resolved = True
|
||||
|
||||
for symbol in method_symbols:
|
||||
calls: List[CallInfo] = []
|
||||
callers: List[CallInfo] = []
|
||||
|
||||
if include_calls and dir_index:
|
||||
try:
|
||||
outgoing = dir_index.get_outgoing_calls(
|
||||
str(file_path_resolved),
|
||||
symbol.name
|
||||
)
|
||||
for target_name, rel_type, line, target_file in outgoing:
|
||||
calls.append(CallInfo(
|
||||
symbol_name=target_name,
|
||||
file_path=target_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
if target_file is None:
|
||||
targets_resolved = False
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get outgoing calls: {e}")
|
||||
outgoing_resolved = False
|
||||
|
||||
if include_callers and dir_index:
|
||||
try:
|
||||
incoming = dir_index.get_incoming_calls(symbol.name)
|
||||
for source_name, rel_type, line, source_file in incoming:
|
||||
callers.append(CallInfo(
|
||||
symbol_name=source_name,
|
||||
file_path=source_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get incoming calls: {e}")
|
||||
incoming_resolved = False
|
||||
|
||||
methods.append(MethodContext(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
signature=None, # Could extract from source
|
||||
calls=calls,
|
||||
callers=callers
|
||||
))
|
||||
|
||||
# Detect language from file extension
|
||||
language = _detect_language(file_path_resolved)
|
||||
|
||||
# Generate summary
|
||||
summary = _generate_summary(file_path_resolved, methods, format)
|
||||
|
||||
return FileContextResult(
|
||||
file_path=str(file_path_resolved),
|
||||
language=language,
|
||||
methods=methods,
|
||||
summary=summary,
|
||||
discovery_status={
|
||||
"outgoing_resolved": outgoing_resolved,
|
||||
"incoming_resolved": incoming_resolved,
|
||||
"targets_resolved": targets_resolved
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
|
||||
"""Find the dir_index that contains the file.
|
||||
|
||||
Args:
|
||||
project_info: Project information from registry
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
DirIndexStore if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Look for _index.db in file's directory or parent directories
|
||||
current = file_path.parent
|
||||
while current != current.parent:
|
||||
index_db = current / "_index.db"
|
||||
if index_db.exists():
|
||||
return DirIndexStore(str(index_db))
|
||||
|
||||
# Also check in project's index_root
|
||||
relative = current.relative_to(project_info.source_root)
|
||||
index_in_cache = project_info.index_root / relative / "_index.db"
|
||||
if index_in_cache.exists():
|
||||
return DirIndexStore(str(index_in_cache))
|
||||
|
||||
current = current.parent
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to find dir_index: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _detect_language(file_path: Path) -> str:
|
||||
"""Detect programming language from file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Language name
|
||||
"""
|
||||
ext_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".jsx": "javascript",
|
||||
".tsx": "typescript",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".java": "java",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
}
|
||||
return ext_map.get(file_path.suffix.lower(), "unknown")
|
||||
|
||||
|
||||
def _generate_summary(
|
||||
file_path: Path,
|
||||
methods: List[MethodContext],
|
||||
format: str
|
||||
) -> str:
|
||||
"""Generate human-readable summary of file context.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
methods: List of method contexts
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
Markdown-formatted summary
|
||||
"""
|
||||
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
|
||||
|
||||
for method in methods:
|
||||
start, end = method.line_range
|
||||
lines.append(f"### {method.name} (line {start}-{end})")
|
||||
|
||||
if method.calls:
|
||||
calls_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.calls
|
||||
)
|
||||
lines.append(f"- Calls: {calls_str}")
|
||||
|
||||
if method.callers:
|
||||
callers_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.callers
|
||||
)
|
||||
lines.append(f"- Called by: {callers_str}")
|
||||
|
||||
if not method.calls and not method.callers:
|
||||
lines.append("- (no call relationships)")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
148
codex-lens/build/lib/codexlens/api/hover.py
Normal file
148
codex-lens/build/lib/codexlens/api/hover.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""get_hover API implementation.
|
||||
|
||||
This module provides the get_hover() function for retrieving
|
||||
detailed hover information for symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import HoverInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hover(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
file_path: Optional[str] = None
|
||||
) -> Optional[HoverInfo]:
|
||||
"""Get detailed hover information for a symbol.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to look up
|
||||
file_path: Optional file path to disambiguate when symbol
|
||||
appears in multiple files
|
||||
|
||||
Returns:
|
||||
HoverInfo if symbol found, None otherwise
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search for the symbol
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=50,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
if not results:
|
||||
logger.debug(f"No hover info found for {symbol_name}")
|
||||
return None
|
||||
|
||||
# If file_path provided, filter to that file
|
||||
if file_path:
|
||||
file_path_resolved = str(Path(file_path).resolve())
|
||||
matching = [s for s in results if s.file == file_path_resolved]
|
||||
if matching:
|
||||
results = matching
|
||||
|
||||
# Take the first result
|
||||
symbol = results[0]
|
||||
|
||||
# Build hover info
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=_extract_signature(symbol),
|
||||
documentation=_extract_documentation(symbol),
|
||||
file_path=symbol.file or "",
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
type_info=_extract_type_info(symbol)
|
||||
)
|
||||
|
||||
|
||||
def _extract_signature(symbol: Symbol) -> str:
|
||||
"""Extract signature from symbol.
|
||||
|
||||
For now, generates a basic signature based on kind and name.
|
||||
In a full implementation, this would parse the actual source code.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract signature from
|
||||
|
||||
Returns:
|
||||
Signature string
|
||||
"""
|
||||
if symbol.kind == "function":
|
||||
return f"def {symbol.name}(...)"
|
||||
elif symbol.kind == "method":
|
||||
return f"def {symbol.name}(self, ...)"
|
||||
elif symbol.kind == "class":
|
||||
return f"class {symbol.name}"
|
||||
elif symbol.kind == "variable":
|
||||
return symbol.name
|
||||
elif symbol.kind == "constant":
|
||||
return f"{symbol.name} = ..."
|
||||
else:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
|
||||
def _extract_documentation(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract documentation from symbol.
|
||||
|
||||
In a full implementation, this would parse docstrings from source.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract documentation from
|
||||
|
||||
Returns:
|
||||
Documentation string if available, None otherwise
|
||||
"""
|
||||
# Would need to read source file and parse docstring
|
||||
# For V1, return None
|
||||
return None
|
||||
|
||||
|
||||
def _extract_type_info(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract type information from symbol.
|
||||
|
||||
In a full implementation, this would parse type annotations.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract type info from
|
||||
|
||||
Returns:
|
||||
Type info string if available, None otherwise
|
||||
"""
|
||||
# Would need to parse type annotations from source
|
||||
# For V1, return None
|
||||
return None
|
||||
281
codex-lens/build/lib/codexlens/api/models.py
Normal file
281
codex-lens/build/lib/codexlens/api/models.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""API dataclass definitions for codexlens LSP API.
|
||||
|
||||
This module defines all result dataclasses used by the public API layer,
|
||||
following the patterns established in mcp/schema.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.2: file_context dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class CallInfo:
|
||||
"""Call relationship information.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the called/calling symbol
|
||||
file_path: Target file path (may be None if unresolved)
|
||||
line: Line number of the call
|
||||
relationship: Type of relationship (call | import | inheritance)
|
||||
"""
|
||||
symbol_name: str
|
||||
file_path: Optional[str]
|
||||
line: int
|
||||
relationship: str # call | import | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodContext:
|
||||
"""Method context with call relationships.
|
||||
|
||||
Attributes:
|
||||
name: Method/function name
|
||||
kind: Symbol kind (function | method | class)
|
||||
line_range: Start and end line numbers
|
||||
signature: Function signature (if available)
|
||||
calls: List of outgoing calls
|
||||
callers: List of incoming calls
|
||||
"""
|
||||
name: str
|
||||
kind: str # function | method | class
|
||||
line_range: Tuple[int, int]
|
||||
signature: Optional[str]
|
||||
calls: List[CallInfo] = field(default_factory=list)
|
||||
callers: List[CallInfo] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"line_range": list(self.line_range),
|
||||
"calls": [c.to_dict() for c in self.calls],
|
||||
"callers": [c.to_dict() for c in self.callers],
|
||||
}
|
||||
if self.signature is not None:
|
||||
result["signature"] = self.signature
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileContextResult:
|
||||
"""File context result with method summaries.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the analyzed file
|
||||
language: Programming language
|
||||
methods: List of method contexts
|
||||
summary: Human-readable summary
|
||||
discovery_status: Status flags for call resolution
|
||||
"""
|
||||
file_path: str
|
||||
language: str
|
||||
methods: List[MethodContext]
|
||||
summary: str
|
||||
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
|
||||
"outgoing_resolved": False,
|
||||
"incoming_resolved": True,
|
||||
"targets_resolved": False
|
||||
})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"language": self.language,
|
||||
"methods": [m.to_dict() for m in self.methods],
|
||||
"summary": self.summary,
|
||||
"discovery_status": self.discovery_status,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.3: find_definition dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class DefinitionResult:
|
||||
"""Definition lookup result.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind (class, function, method, etc.)
|
||||
file_path: File where symbol is defined
|
||||
line: Start line number
|
||||
end_line: End line number
|
||||
signature: Symbol signature (if available)
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
end_line: int
|
||||
signature: Optional[str] = None
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.4: find_references dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class ReferenceResult:
|
||||
"""Reference lookup result.
|
||||
|
||||
Attributes:
|
||||
file_path: File containing the reference
|
||||
line: Line number
|
||||
column: Column number
|
||||
context_line: The line of code containing the reference
|
||||
relationship: Type of reference (call | import | type_annotation | inheritance)
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context_line: str
|
||||
relationship: str # call | import | type_annotation | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupedReferences:
|
||||
"""References grouped by definition.
|
||||
|
||||
Used when a symbol has multiple definitions (e.g., overloads).
|
||||
|
||||
Attributes:
|
||||
definition: The definition this group refers to
|
||||
references: List of references to this definition
|
||||
"""
|
||||
definition: DefinitionResult
|
||||
references: List[ReferenceResult] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"definition": self.definition.to_dict(),
|
||||
"references": [r.to_dict() for r in self.references],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.5: workspace_symbols dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Symbol information for workspace search.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.6: get_hover dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
signature: Symbol signature
|
||||
documentation: Documentation string (if available)
|
||||
file_path: File where symbol is defined
|
||||
line_range: Start and end line numbers
|
||||
type_info: Type information (if available)
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: Tuple[int, int]
|
||||
type_info: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"signature": self.signature,
|
||||
"file_path": self.file_path,
|
||||
"line_range": list(self.line_range),
|
||||
}
|
||||
if self.documentation is not None:
|
||||
result["documentation"] = self.documentation
|
||||
if self.type_info is not None:
|
||||
result["type_info"] = self.type_info
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.7: semantic_search dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SemanticResult:
|
||||
"""Semantic search result.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the matched symbol
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
vector_score: Vector similarity score (None if not available)
|
||||
structural_score: Structural match score (None if not available)
|
||||
fusion_score: Combined fusion score
|
||||
snippet: Code snippet
|
||||
match_reason: Explanation of why this matched (optional)
|
||||
"""
|
||||
symbol_name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
vector_score: Optional[float]
|
||||
structural_score: Optional[float]
|
||||
fusion_score: float
|
||||
snippet: str
|
||||
match_reason: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
345
codex-lens/build/lib/codexlens/api/references.py
Normal file
345
codex-lens/build/lib/codexlens/api/references.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Find references API for codexlens.
|
||||
|
||||
This module implements the find_references() function that wraps
|
||||
ChainSearchEngine.search_references() with grouped result structure
|
||||
for multi-definition symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from .models import (
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
)
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _read_line_from_file(file_path: str, line: int) -> str:
|
||||
"""Read a specific line from a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
line: Line number (1-based)
|
||||
|
||||
Returns:
|
||||
The line content, stripped of trailing whitespace.
|
||||
Returns empty string if file cannot be read or line doesn't exist.
|
||||
"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return ""
|
||||
|
||||
with path.open("r", encoding="utf-8", errors="replace") as f:
|
||||
for i, content in enumerate(f, 1):
|
||||
if i == line:
|
||||
return content.rstrip()
|
||||
return ""
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _transform_to_reference_result(
|
||||
raw_ref: "RawReferenceResult",
|
||||
) -> ReferenceResult:
|
||||
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
|
||||
|
||||
Args:
|
||||
raw_ref: Raw reference result from ChainSearchEngine
|
||||
|
||||
Returns:
|
||||
API ReferenceResult with context_line and normalized relationship
|
||||
"""
|
||||
# Read the actual line from the file
|
||||
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
|
||||
|
||||
# Normalize relationship type
|
||||
relationship = normalize_relationship_type(raw_ref.relationship_type)
|
||||
|
||||
return ReferenceResult(
|
||||
file_path=raw_ref.file_path,
|
||||
line=raw_ref.line,
|
||||
column=raw_ref.column,
|
||||
context_line=context_line,
|
||||
relationship=relationship,
|
||||
)
|
||||
|
||||
|
||||
def find_references(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
include_definition: bool = True,
|
||||
group_by_definition: bool = True,
|
||||
limit: int = 100,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Find all reference locations for a symbol.
|
||||
|
||||
Multi-definition case returns grouped results to resolve ambiguity.
|
||||
|
||||
This function wraps ChainSearchEngine.search_references() and groups
|
||||
the results by definition location. Each GroupedReferences contains
|
||||
a definition and all references that point to it.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory path
|
||||
symbol_name: Name of the symbol to find references for
|
||||
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
|
||||
include_definition: Whether to include the definition location
|
||||
in the result (default True)
|
||||
group_by_definition: Whether to group references by definition.
|
||||
If False, returns a single group with all references.
|
||||
(default True)
|
||||
limit: Maximum number of references to return (default 100)
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences. Each group contains:
|
||||
- definition: The DefinitionResult for this symbol definition
|
||||
- references: List of ReferenceResult pointing to this definition
|
||||
|
||||
Raises:
|
||||
ValueError: If project_root does not exist or is not a directory
|
||||
|
||||
Examples:
|
||||
>>> refs = find_references("/path/to/project", "authenticate")
|
||||
>>> for group in refs:
|
||||
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
|
||||
... for ref in group.references:
|
||||
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
|
||||
|
||||
Note:
|
||||
Reference relationship types are normalized:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
"""
|
||||
# Validate and resolve project root
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
# Initialize infrastructure
|
||||
config = Config()
|
||||
registry = RegistryStore()
|
||||
mapper = PathMapper(config.index_dir)
|
||||
|
||||
# Create chain search engine
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
|
||||
try:
|
||||
# Step 1: Find definitions for the symbol
|
||||
definitions: List[DefinitionResult] = []
|
||||
|
||||
if include_definition or group_by_definition:
|
||||
# Search for symbol definitions
|
||||
symbols = engine.search_symbols(
|
||||
name=symbol_name,
|
||||
source_path=project_path,
|
||||
kind=symbol_kind,
|
||||
)
|
||||
|
||||
# Convert Symbol to DefinitionResult
|
||||
for sym in symbols:
|
||||
# Only include exact name matches for definitions
|
||||
if sym.name != symbol_name:
|
||||
continue
|
||||
|
||||
# Optionally filter by kind
|
||||
if symbol_kind and sym.kind != symbol_kind:
|
||||
continue
|
||||
|
||||
definitions.append(DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Not available from Symbol
|
||||
container=None, # Not available from Symbol
|
||||
score=1.0,
|
||||
))
|
||||
|
||||
# Step 2: Get all references using ChainSearchEngine
|
||||
raw_references = engine.search_references(
|
||||
symbol_name=symbol_name,
|
||||
source_path=project_path,
|
||||
depth=-1,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Step 3: Transform raw references to API ReferenceResult
|
||||
api_references: List[ReferenceResult] = []
|
||||
for raw_ref in raw_references:
|
||||
api_ref = _transform_to_reference_result(raw_ref)
|
||||
api_references.append(api_ref)
|
||||
|
||||
# Step 4: Group references by definition
|
||||
if group_by_definition and definitions:
|
||||
return _group_references_by_definition(
|
||||
definitions=definitions,
|
||||
references=api_references,
|
||||
include_definition=include_definition,
|
||||
)
|
||||
else:
|
||||
# Return single group with placeholder definition or first definition
|
||||
if definitions:
|
||||
definition = definitions[0]
|
||||
else:
|
||||
# Create placeholder definition when no definition found
|
||||
definition = DefinitionResult(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path="",
|
||||
line=0,
|
||||
end_line=0,
|
||||
signature=None,
|
||||
container=None,
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
return [GroupedReferences(
|
||||
definition=definition,
|
||||
references=api_references,
|
||||
)]
|
||||
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
|
||||
def _group_references_by_definition(
|
||||
definitions: List[DefinitionResult],
|
||||
references: List[ReferenceResult],
|
||||
include_definition: bool = True,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Group references by their likely definition.
|
||||
|
||||
Uses file proximity heuristic to assign references to definitions.
|
||||
References in the same file or directory as a definition are
|
||||
assigned to that definition.
|
||||
|
||||
Args:
|
||||
definitions: List of definition locations
|
||||
references: List of reference locations
|
||||
include_definition: Whether to include definition in results
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences with references assigned to definitions
|
||||
"""
|
||||
import os
|
||||
|
||||
if not definitions:
|
||||
return []
|
||||
|
||||
if len(definitions) == 1:
|
||||
# Single definition - all references belong to it
|
||||
return [GroupedReferences(
|
||||
definition=definitions[0],
|
||||
references=references,
|
||||
)]
|
||||
|
||||
# Multiple definitions - group by proximity
|
||||
groups: Dict[int, List[ReferenceResult]] = {
|
||||
i: [] for i in range(len(definitions))
|
||||
}
|
||||
|
||||
for ref in references:
|
||||
# Find the closest definition by file proximity
|
||||
best_def_idx = 0
|
||||
best_score = -1
|
||||
|
||||
for i, defn in enumerate(definitions):
|
||||
score = _proximity_score(ref.file_path, defn.file_path)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_def_idx = i
|
||||
|
||||
groups[best_def_idx].append(ref)
|
||||
|
||||
# Build result groups
|
||||
result: List[GroupedReferences] = []
|
||||
for i, defn in enumerate(definitions):
|
||||
# Skip definitions with no references if not including definition itself
|
||||
if not include_definition and not groups[i]:
|
||||
continue
|
||||
|
||||
result.append(GroupedReferences(
|
||||
definition=defn,
|
||||
references=groups[i],
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _proximity_score(ref_path: str, def_path: str) -> int:
|
||||
"""Calculate proximity score between two file paths.
|
||||
|
||||
Args:
|
||||
ref_path: Reference file path
|
||||
def_path: Definition file path
|
||||
|
||||
Returns:
|
||||
Proximity score (higher = closer):
|
||||
- Same file: 1000
|
||||
- Same directory: 100
|
||||
- Otherwise: common path prefix length
|
||||
"""
|
||||
import os
|
||||
|
||||
if not ref_path or not def_path:
|
||||
return 0
|
||||
|
||||
# Normalize paths
|
||||
ref_path = os.path.normpath(ref_path)
|
||||
def_path = os.path.normpath(def_path)
|
||||
|
||||
# Same file
|
||||
if ref_path == def_path:
|
||||
return 1000
|
||||
|
||||
ref_dir = os.path.dirname(ref_path)
|
||||
def_dir = os.path.dirname(def_path)
|
||||
|
||||
# Same directory
|
||||
if ref_dir == def_dir:
|
||||
return 100
|
||||
|
||||
# Common path prefix
|
||||
try:
|
||||
common = os.path.commonpath([ref_path, def_path])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
|
||||
# Type alias for the raw reference from ChainSearchEngine
|
||||
class RawReferenceResult:
|
||||
"""Type stub for ChainSearchEngine.ReferenceResult.
|
||||
|
||||
This is only used for type hints and is replaced at runtime
|
||||
by the actual import.
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
471
codex-lens/build/lib/codexlens/api/semantic.py
Normal file
471
codex-lens/build/lib/codexlens/api/semantic.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Semantic search API with RRF fusion.
|
||||
|
||||
This module provides the semantic_search() function for combining
|
||||
vector, structural, and keyword search with configurable fusion strategies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from .models import SemanticResult
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def semantic_search(
|
||||
project_root: str,
|
||||
query: str,
|
||||
mode: str = "fusion",
|
||||
vector_weight: float = 0.5,
|
||||
structural_weight: float = 0.3,
|
||||
keyword_weight: float = 0.2,
|
||||
fusion_strategy: str = "rrf",
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
include_match_reason: bool = False,
|
||||
) -> List[SemanticResult]:
|
||||
"""Semantic search - combining vector and structural search.
|
||||
|
||||
This function provides a high-level API for semantic code search,
|
||||
combining vector similarity, structural (symbol + relationships),
|
||||
and keyword-based search methods with configurable fusion.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory
|
||||
query: Natural language query
|
||||
mode: Search mode
|
||||
- vector: Vector search only
|
||||
- structural: Structural search only (symbol + relationships)
|
||||
- fusion: Fusion search (default)
|
||||
vector_weight: Vector search weight [0, 1] (default 0.5)
|
||||
structural_weight: Structural search weight [0, 1] (default 0.3)
|
||||
keyword_weight: Keyword search weight [0, 1] (default 0.2)
|
||||
fusion_strategy: Fusion strategy (maps to chain_search.py)
|
||||
- rrf: Reciprocal Rank Fusion (recommended, default)
|
||||
- staged: Staged cascade -> staged_cascade_search
|
||||
- binary: Binary rerank cascade -> binary_cascade_search
|
||||
- hybrid: Hybrid cascade -> hybrid_cascade_search
|
||||
kind_filter: Symbol type filter (e.g., ["function", "class"])
|
||||
limit: Max return count (default 20)
|
||||
include_match_reason: Generate match reason (heuristic, not LLM)
|
||||
|
||||
Returns:
|
||||
Results sorted by fusion_score
|
||||
|
||||
Degradation:
|
||||
- No vector index: vector_score=None, uses FTS + structural search
|
||||
- No relationship data: structural_score=None, vector search only
|
||||
|
||||
Examples:
|
||||
>>> results = semantic_search(
|
||||
... "/path/to/project",
|
||||
... "authentication handler",
|
||||
... mode="fusion",
|
||||
... fusion_strategy="rrf"
|
||||
... )
|
||||
>>> for r in results:
|
||||
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
|
||||
"""
|
||||
# Validate and resolve project path
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Normalize weights to sum to 1.0
|
||||
total_weight = vector_weight + structural_weight + keyword_weight
|
||||
if total_weight > 0:
|
||||
vector_weight = vector_weight / total_weight
|
||||
structural_weight = structural_weight / total_weight
|
||||
keyword_weight = keyword_weight / total_weight
|
||||
else:
|
||||
# Default to equal weights if all zero
|
||||
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
|
||||
|
||||
# Initialize search infrastructure
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
except ImportError as exc:
|
||||
logger.error("Failed to import search dependencies: %s", exc)
|
||||
return []
|
||||
|
||||
# Load config
|
||||
config = Config.load()
|
||||
|
||||
# Get or create registry and mapper
|
||||
try:
|
||||
registry = RegistryStore.default()
|
||||
mapper = PathMapper(registry)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize search infrastructure: %s", exc)
|
||||
return []
|
||||
|
||||
# Build search options based on mode
|
||||
search_options = _build_search_options(
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Execute search based on fusion_strategy
|
||||
try:
|
||||
with ChainSearchEngine(registry, mapper, config=config) as engine:
|
||||
chain_result = _execute_search(
|
||||
engine=engine,
|
||||
query=query,
|
||||
source_path=project_path,
|
||||
fusion_strategy=fusion_strategy,
|
||||
options=search_options,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Search execution failed: %s", exc)
|
||||
return []
|
||||
|
||||
# Transform results to SemanticResult
|
||||
semantic_results = _transform_results(
|
||||
results=chain_result.results,
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
kind_filter=kind_filter,
|
||||
include_match_reason=include_match_reason,
|
||||
query=query,
|
||||
)
|
||||
|
||||
return semantic_results[:limit]
|
||||
|
||||
|
||||
def _build_search_options(
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
limit: int,
|
||||
) -> "SearchOptions":
|
||||
"""Build SearchOptions based on mode and weights.
|
||||
|
||||
Args:
|
||||
mode: Search mode (vector, structural, fusion)
|
||||
vector_weight: Vector search weight
|
||||
structural_weight: Structural search weight
|
||||
keyword_weight: Keyword search weight
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
Configured SearchOptions
|
||||
"""
|
||||
from codexlens.search.chain_search import SearchOptions
|
||||
|
||||
# Default options
|
||||
options = SearchOptions(
|
||||
total_limit=limit * 2, # Fetch extra for filtering
|
||||
limit_per_dir=limit,
|
||||
include_symbols=True, # Always include symbols for structural
|
||||
)
|
||||
|
||||
if mode == "vector":
|
||||
# Pure vector mode
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = True
|
||||
options.pure_vector = True
|
||||
options.enable_fuzzy = False
|
||||
elif mode == "structural":
|
||||
# Structural only - use FTS + symbols
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = False
|
||||
options.enable_fuzzy = True
|
||||
options.include_symbols = True
|
||||
else:
|
||||
# Fusion mode (default)
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = vector_weight > 0
|
||||
options.enable_fuzzy = keyword_weight > 0
|
||||
options.include_symbols = structural_weight > 0
|
||||
|
||||
# Set custom weights for RRF
|
||||
if options.enable_vector and keyword_weight > 0:
|
||||
options.hybrid_weights = {
|
||||
"vector": vector_weight,
|
||||
"exact": keyword_weight * 0.7,
|
||||
"fuzzy": keyword_weight * 0.3,
|
||||
}
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _execute_search(
|
||||
engine: "ChainSearchEngine",
|
||||
query: str,
|
||||
source_path: Path,
|
||||
fusion_strategy: str,
|
||||
options: "SearchOptions",
|
||||
limit: int,
|
||||
) -> "ChainSearchResult":
|
||||
"""Execute search using appropriate strategy.
|
||||
|
||||
Maps fusion_strategy to ChainSearchEngine methods:
|
||||
- rrf: Standard hybrid search with RRF fusion
|
||||
- staged: staged_cascade_search
|
||||
- binary: binary_cascade_search
|
||||
- hybrid: hybrid_cascade_search
|
||||
|
||||
Args:
|
||||
engine: ChainSearchEngine instance
|
||||
query: Search query
|
||||
source_path: Project root path
|
||||
fusion_strategy: Strategy name
|
||||
options: Search options
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
ChainSearchResult from the search
|
||||
"""
|
||||
from codexlens.search.chain_search import ChainSearchResult
|
||||
|
||||
if fusion_strategy == "staged":
|
||||
# Use staged cascade search (4-stage pipeline)
|
||||
return engine.staged_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "binary":
|
||||
# Use binary cascade search (binary coarse + dense fine)
|
||||
return engine.binary_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "hybrid":
|
||||
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
|
||||
return engine.hybrid_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
# Default: rrf - Standard search with RRF fusion
|
||||
return engine.search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
def _transform_results(
|
||||
results: List,
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
kind_filter: Optional[List[str]],
|
||||
include_match_reason: bool,
|
||||
query: str,
|
||||
) -> List[SemanticResult]:
|
||||
"""Transform ChainSearchEngine results to SemanticResult.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
mode: Search mode
|
||||
vector_weight: Vector weight used
|
||||
structural_weight: Structural weight used
|
||||
keyword_weight: Keyword weight used
|
||||
kind_filter: Optional symbol kind filter
|
||||
include_match_reason: Whether to generate match reasons
|
||||
query: Original query (for match reason generation)
|
||||
|
||||
Returns:
|
||||
List of SemanticResult objects
|
||||
"""
|
||||
semantic_results = []
|
||||
|
||||
for result in results:
|
||||
# Extract symbol info
|
||||
symbol_name = getattr(result, "symbol_name", None)
|
||||
symbol_kind = getattr(result, "symbol_kind", None)
|
||||
start_line = getattr(result, "start_line", None)
|
||||
|
||||
# Use symbol object if available
|
||||
if hasattr(result, "symbol") and result.symbol:
|
||||
symbol_name = symbol_name or result.symbol.name
|
||||
symbol_kind = symbol_kind or result.symbol.kind
|
||||
if hasattr(result.symbol, "range") and result.symbol.range:
|
||||
start_line = start_line or result.symbol.range[0]
|
||||
|
||||
# Filter by kind if specified
|
||||
if kind_filter and symbol_kind:
|
||||
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
|
||||
continue
|
||||
|
||||
# Determine scores based on mode and metadata
|
||||
metadata = getattr(result, "metadata", {}) or {}
|
||||
fusion_score = result.score
|
||||
|
||||
# Try to extract source scores from metadata
|
||||
source_scores = metadata.get("source_scores", {})
|
||||
vector_score: Optional[float] = None
|
||||
structural_score: Optional[float] = None
|
||||
|
||||
if mode == "vector":
|
||||
# In pure vector mode, the main score is the vector score
|
||||
vector_score = result.score
|
||||
structural_score = None
|
||||
elif mode == "structural":
|
||||
# In structural mode, no vector score
|
||||
vector_score = None
|
||||
structural_score = result.score
|
||||
else:
|
||||
# Fusion mode - try to extract individual scores
|
||||
if "vector" in source_scores:
|
||||
vector_score = source_scores["vector"]
|
||||
elif metadata.get("fusion_method") == "simple_weighted":
|
||||
# From weighted fusion
|
||||
vector_score = source_scores.get("vector")
|
||||
|
||||
# Structural score approximation (from exact/fuzzy FTS)
|
||||
fts_scores = []
|
||||
if "exact" in source_scores:
|
||||
fts_scores.append(source_scores["exact"])
|
||||
if "fuzzy" in source_scores:
|
||||
fts_scores.append(source_scores["fuzzy"])
|
||||
if "splade" in source_scores:
|
||||
fts_scores.append(source_scores["splade"])
|
||||
|
||||
if fts_scores:
|
||||
structural_score = max(fts_scores)
|
||||
|
||||
# Build snippet
|
||||
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
|
||||
if len(snippet) > 500:
|
||||
snippet = snippet[:500] + "..."
|
||||
|
||||
# Generate match reason if requested
|
||||
match_reason = None
|
||||
if include_match_reason:
|
||||
match_reason = _generate_match_reason(
|
||||
query=query,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
snippet=snippet,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
)
|
||||
|
||||
semantic_result = SemanticResult(
|
||||
symbol_name=symbol_name or Path(result.path).stem,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path=result.path,
|
||||
line=start_line or 1,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
fusion_score=fusion_score,
|
||||
snippet=snippet,
|
||||
match_reason=match_reason,
|
||||
)
|
||||
|
||||
semantic_results.append(semantic_result)
|
||||
|
||||
# Sort by fusion_score descending
|
||||
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
|
||||
|
||||
return semantic_results
|
||||
|
||||
|
||||
def _generate_match_reason(
|
||||
query: str,
|
||||
symbol_name: Optional[str],
|
||||
symbol_kind: Optional[str],
|
||||
snippet: str,
|
||||
vector_score: Optional[float],
|
||||
structural_score: Optional[float],
|
||||
) -> str:
|
||||
"""Generate human-readable match reason heuristically.
|
||||
|
||||
This is a simple heuristic-based approach, not LLM-powered.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
symbol_name: Symbol name if available
|
||||
symbol_kind: Symbol kind if available
|
||||
snippet: Code snippet
|
||||
vector_score: Vector similarity score
|
||||
structural_score: Structural match score
|
||||
|
||||
Returns:
|
||||
Human-readable explanation string
|
||||
"""
|
||||
reasons = []
|
||||
|
||||
# Check for direct name match
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
if symbol_name:
|
||||
name_lower = symbol_name.lower()
|
||||
# Direct substring match
|
||||
if query_lower in name_lower or name_lower in query_lower:
|
||||
reasons.append(f"Symbol name '{symbol_name}' matches query")
|
||||
# Word overlap
|
||||
name_words = set(_split_camel_case(symbol_name).lower().split())
|
||||
overlap = query_words & name_words
|
||||
if overlap and not reasons:
|
||||
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
|
||||
|
||||
# Check snippet for keyword matches
|
||||
snippet_lower = snippet.lower()
|
||||
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
|
||||
if matching_words and len(reasons) < 2:
|
||||
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
|
||||
|
||||
# Add score-based reasoning
|
||||
if vector_score is not None and vector_score > 0.7:
|
||||
reasons.append("High semantic similarity")
|
||||
elif vector_score is not None and vector_score > 0.5:
|
||||
reasons.append("Moderate semantic similarity")
|
||||
|
||||
if structural_score is not None and structural_score > 0.8:
|
||||
reasons.append("Strong structural match")
|
||||
|
||||
# Symbol kind context
|
||||
if symbol_kind and len(reasons) < 3:
|
||||
reasons.append(f"Matched {symbol_kind}")
|
||||
|
||||
if not reasons:
|
||||
reasons.append("Partial relevance based on content analysis")
|
||||
|
||||
return "; ".join(reasons[:3])
|
||||
|
||||
|
||||
def _split_camel_case(name: str) -> str:
|
||||
"""Split camelCase and PascalCase to words.
|
||||
|
||||
Args:
|
||||
name: Symbol name in camelCase or PascalCase
|
||||
|
||||
Returns:
|
||||
Space-separated words
|
||||
"""
|
||||
import re
|
||||
|
||||
# Insert space before uppercase letters
|
||||
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
# Insert space before uppercase followed by lowercase
|
||||
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
|
||||
# Replace underscores with spaces
|
||||
result = result.replace("_", " ")
|
||||
|
||||
return result
|
||||
146
codex-lens/build/lib/codexlens/api/symbols.py
Normal file
146
codex-lens/build/lib/codexlens/api/symbols.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""workspace_symbols API implementation.
|
||||
|
||||
This module provides the workspace_symbols() function for searching
|
||||
symbols across the entire workspace with prefix matching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import SymbolInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def workspace_symbols(
|
||||
project_root: str,
|
||||
query: str,
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
file_pattern: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[SymbolInfo]:
|
||||
"""Search for symbols across the entire workspace.
|
||||
|
||||
Uses prefix matching for efficient searching.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
query: Search query (prefix match)
|
||||
kind_filter: Optional list of symbol kinds to include
|
||||
(e.g., ["class", "function"])
|
||||
file_pattern: Optional glob pattern to filter by file path
|
||||
(e.g., "*.py", "src/**/*.ts")
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of SymbolInfo sorted by score
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search with prefix matching
|
||||
# If kind_filter has multiple kinds, we need to search for each
|
||||
all_results: List[Symbol] = []
|
||||
|
||||
if kind_filter and len(kind_filter) > 0:
|
||||
# Search for each kind separately
|
||||
for kind in kind_filter:
|
||||
results = global_index.search(
|
||||
name=query,
|
||||
kind=kind,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
all_results.extend(results)
|
||||
else:
|
||||
# Search without kind filter
|
||||
all_results = global_index.search(
|
||||
name=query,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
|
||||
|
||||
# Apply file pattern filter if specified
|
||||
if file_pattern:
|
||||
all_results = [
|
||||
sym for sym in all_results
|
||||
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
|
||||
]
|
||||
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
|
||||
|
||||
# Convert to SymbolInfo and sort by relevance
|
||||
symbols = [
|
||||
SymbolInfo(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
container=None, # Could extract from parent
|
||||
score=_calculate_score(sym.name, query)
|
||||
)
|
||||
for sym in all_results
|
||||
]
|
||||
|
||||
# Sort by score (exact matches first)
|
||||
symbols.sort(key=lambda s: s.score, reverse=True)
|
||||
|
||||
return symbols[:limit]
|
||||
|
||||
|
||||
def _calculate_score(symbol_name: str, query: str) -> float:
|
||||
"""Calculate relevance score for a symbol match.
|
||||
|
||||
Scoring:
|
||||
- Exact match: 1.0
|
||||
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
|
||||
- Case-insensitive match: 0.6
|
||||
|
||||
Args:
|
||||
symbol_name: The matched symbol name
|
||||
query: The search query
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
if symbol_name == query:
|
||||
return 1.0
|
||||
|
||||
if symbol_name.lower() == query.lower():
|
||||
return 0.9
|
||||
|
||||
if symbol_name.startswith(query):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.8 + 0.2 * ratio
|
||||
|
||||
if symbol_name.lower().startswith(query.lower()):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.6 + 0.2 * ratio
|
||||
|
||||
return 0.5
|
||||
153
codex-lens/build/lib/codexlens/api/utils.py
Normal file
153
codex-lens/build/lib/codexlens/api/utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Utility functions for the codexlens API.
|
||||
|
||||
This module provides helper functions for:
|
||||
- Project resolution
|
||||
- Relationship type normalization
|
||||
- Result ranking by proximity
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, TypeVar, Callable
|
||||
|
||||
from .models import DefinitionResult
|
||||
|
||||
|
||||
# Type variable for generic ranking
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def resolve_project(project_root: str) -> Path:
|
||||
"""Resolve and validate project root path.
|
||||
|
||||
Args:
|
||||
project_root: Path to project root (relative or absolute)
|
||||
|
||||
Returns:
|
||||
Resolved absolute Path
|
||||
|
||||
Raises:
|
||||
ValueError: If path does not exist or is not a directory
|
||||
"""
|
||||
path = Path(project_root).resolve()
|
||||
if not path.exists():
|
||||
raise ValueError(f"Project root does not exist: {path}")
|
||||
if not path.is_dir():
|
||||
raise ValueError(f"Project root is not a directory: {path}")
|
||||
return path
|
||||
|
||||
|
||||
# Relationship type normalization mapping
|
||||
_RELATIONSHIP_NORMALIZATION = {
|
||||
# Plural to singular
|
||||
"calls": "call",
|
||||
"imports": "import",
|
||||
"inherits": "inheritance",
|
||||
"uses": "use",
|
||||
# Already normalized (passthrough)
|
||||
"call": "call",
|
||||
"import": "import",
|
||||
"inheritance": "inheritance",
|
||||
"use": "use",
|
||||
"type_annotation": "type_annotation",
|
||||
}
|
||||
|
||||
|
||||
def normalize_relationship_type(relationship: str) -> str:
|
||||
"""Normalize relationship type to canonical form.
|
||||
|
||||
Converts plural forms and variations to standard singular forms:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
- 'uses' -> 'use'
|
||||
|
||||
Args:
|
||||
relationship: Raw relationship type string
|
||||
|
||||
Returns:
|
||||
Normalized relationship type
|
||||
|
||||
Examples:
|
||||
>>> normalize_relationship_type('calls')
|
||||
'call'
|
||||
>>> normalize_relationship_type('inherits')
|
||||
'inheritance'
|
||||
>>> normalize_relationship_type('call')
|
||||
'call'
|
||||
"""
|
||||
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
|
||||
|
||||
|
||||
def rank_by_proximity(
|
||||
results: List[DefinitionResult],
|
||||
file_context: Optional[str] = None
|
||||
) -> List[DefinitionResult]:
|
||||
"""Rank results by file path proximity to context.
|
||||
|
||||
V1 Implementation: Uses path-based proximity scoring.
|
||||
|
||||
Scoring algorithm:
|
||||
1. Same directory: highest score (100)
|
||||
2. Otherwise: length of common path prefix
|
||||
|
||||
Args:
|
||||
results: List of definition results to rank
|
||||
file_context: Reference file path for proximity calculation.
|
||||
If None, returns results unchanged.
|
||||
|
||||
Returns:
|
||||
Results sorted by proximity score (highest first)
|
||||
|
||||
Examples:
|
||||
>>> results = [
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/b/c.py", line=1, end_line=10),
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/x/y.py", line=1, end_line=10),
|
||||
... ]
|
||||
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
|
||||
>>> ranked[0].file_path
|
||||
'/a/b/c.py'
|
||||
"""
|
||||
if not file_context or not results:
|
||||
return results
|
||||
|
||||
def proximity_score(result: DefinitionResult) -> int:
|
||||
"""Calculate proximity score for a result."""
|
||||
result_dir = os.path.dirname(result.file_path)
|
||||
context_dir = os.path.dirname(file_context)
|
||||
|
||||
# Same directory gets highest score
|
||||
if result_dir == context_dir:
|
||||
return 100
|
||||
|
||||
# Otherwise, score by common path prefix length
|
||||
try:
|
||||
common = os.path.commonpath([result.file_path, file_context])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
return sorted(results, key=proximity_score, reverse=True)
|
||||
|
||||
|
||||
def rank_by_score(
|
||||
results: List[T],
|
||||
score_fn: Callable[[T], float],
|
||||
reverse: bool = True
|
||||
) -> List[T]:
|
||||
"""Generic ranking function by custom score.
|
||||
|
||||
Args:
|
||||
results: List of items to rank
|
||||
score_fn: Function to extract score from item
|
||||
reverse: If True, highest scores first (default)
|
||||
|
||||
Returns:
|
||||
Sorted list
|
||||
"""
|
||||
return sorted(results, key=score_fn, reverse=reverse)
|
||||
27
codex-lens/build/lib/codexlens/cli/__init__.py
Normal file
27
codex-lens/build/lib/codexlens/cli/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""CLI package for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Force UTF-8 encoding for Windows console
|
||||
# This ensures Chinese characters display correctly instead of GBK garbled text
|
||||
if sys.platform == "win32":
|
||||
# Set environment variable for Python I/O encoding
|
||||
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
|
||||
|
||||
# Reconfigure stdout/stderr to use UTF-8 if possible
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
# Fallback: some environments don't support reconfigure
|
||||
pass
|
||||
|
||||
from .commands import app
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
4494
codex-lens/build/lib/codexlens/cli/commands.py
Normal file
4494
codex-lens/build/lib/codexlens/cli/commands.py
Normal file
File diff suppressed because it is too large
Load Diff
2001
codex-lens/build/lib/codexlens/cli/embedding_manager.py
Normal file
2001
codex-lens/build/lib/codexlens/cli/embedding_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
1026
codex-lens/build/lib/codexlens/cli/model_manager.py
Normal file
1026
codex-lens/build/lib/codexlens/cli/model_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
135
codex-lens/build/lib/codexlens/cli/output.py
Normal file
135
codex-lens/build/lib/codexlens/cli/output.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Rich and JSON output helpers for CodexLens CLI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping, Sequence
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
|
||||
# Force UTF-8 encoding for Windows console to properly display Chinese text
|
||||
# Use force_terminal=True and legacy_windows=False to avoid GBK encoding issues
|
||||
console = Console(force_terminal=True, legacy_windows=False)
|
||||
|
||||
|
||||
def _to_jsonable(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump()
|
||||
if is_dataclass(value):
|
||||
return asdict(value)
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
if isinstance(value, Mapping):
|
||||
return {k: _to_jsonable(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [_to_jsonable(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
|
||||
"""Print JSON output with optional additional fields.
|
||||
|
||||
Args:
|
||||
success: Whether the operation succeeded
|
||||
result: Result data (used when success=True)
|
||||
error: Error message (used when success=False)
|
||||
**kwargs: Additional fields to include in the payload (e.g., code, details)
|
||||
"""
|
||||
payload: dict[str, Any] = {"success": success}
|
||||
if success:
|
||||
payload["result"] = _to_jsonable(result)
|
||||
else:
|
||||
payload["error"] = error or "Unknown error"
|
||||
# Include additional error details if provided
|
||||
for key, value in kwargs.items():
|
||||
payload[key] = _to_jsonable(value)
|
||||
console.print_json(json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_search_results(
|
||||
results: Sequence[SearchResult], *, title: str = "Search Results", verbose: bool = False
|
||||
) -> None:
|
||||
"""Render search results with optional source tags in verbose mode.
|
||||
|
||||
Args:
|
||||
results: Search results to display
|
||||
title: Table title
|
||||
verbose: If True, show search source tags ([E], [F], [V]) and fusion scores
|
||||
"""
|
||||
table = Table(title=title, show_lines=False)
|
||||
|
||||
if verbose:
|
||||
# Verbose mode: show source tags
|
||||
table.add_column("Source", style="dim", width=6, justify="center")
|
||||
|
||||
table.add_column("Path", style="cyan", no_wrap=True)
|
||||
table.add_column("Score", style="magenta", justify="right")
|
||||
table.add_column("Excerpt", style="white")
|
||||
|
||||
for res in results:
|
||||
excerpt = res.excerpt or ""
|
||||
score_str = f"{res.score:.3f}"
|
||||
|
||||
if verbose:
|
||||
# Extract search source tag if available
|
||||
source = getattr(res, "search_source", None)
|
||||
source_tag = ""
|
||||
if source == "exact":
|
||||
source_tag = "[E]"
|
||||
elif source == "fuzzy":
|
||||
source_tag = "[F]"
|
||||
elif source == "vector":
|
||||
source_tag = "[V]"
|
||||
elif source == "fusion":
|
||||
source_tag = "[RRF]"
|
||||
table.add_row(source_tag, res.path, score_str, excerpt)
|
||||
else:
|
||||
table.add_row(res.path, score_str, excerpt)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def render_symbols(symbols: Sequence[Symbol], *, title: str = "Symbols") -> None:
|
||||
table = Table(title=title)
|
||||
table.add_column("Name", style="green")
|
||||
table.add_column("Kind", style="yellow")
|
||||
table.add_column("Range", style="white", justify="right")
|
||||
|
||||
for sym in symbols:
|
||||
start, end = sym.range
|
||||
table.add_row(sym.name, sym.kind, f"{start}-{end}")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def render_status(stats: Mapping[str, Any]) -> None:
|
||||
table = Table(title="Index Status")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="white")
|
||||
|
||||
for key, value in stats.items():
|
||||
if isinstance(value, Mapping):
|
||||
value_text = ", ".join(f"{k}:{v}" for k, v in value.items())
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_text = ", ".join(str(v) for v in value)
|
||||
else:
|
||||
value_text = str(value)
|
||||
table.add_row(str(key), value_text)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) -> None:
|
||||
header = Text.assemble(("File: ", "bold"), (path, "cyan"), (" Language: ", "bold"), (language, "green"))
|
||||
console.print(header)
|
||||
render_symbols(list(symbols), title="Discovered Symbols")
|
||||
|
||||
692
codex-lens/build/lib/codexlens/config.py
Normal file
692
codex-lens/build/lib/codexlens/config.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""Configuration system for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .errors import ConfigError
|
||||
|
||||
|
||||
# Workspace-local directory name
|
||||
WORKSPACE_DIR_NAME = ".codexlens"
|
||||
|
||||
# Settings file name
|
||||
SETTINGS_FILE_NAME = "settings.json"
|
||||
|
||||
# SPLADE index database name (centralized storage)
|
||||
SPLADE_DB_NAME = "_splade.db"
|
||||
|
||||
# Dense vector storage names (centralized storage)
|
||||
VECTORS_HNSW_NAME = "_vectors.hnsw"
|
||||
VECTORS_META_DB_NAME = "_vectors_meta.db"
|
||||
BINARY_VECTORS_MMAP_NAME = "_binary_vectors.mmap"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_global_dir() -> Path:
|
||||
"""Get global CodexLens data directory."""
|
||||
env_override = os.getenv("CODEXLENS_DATA_DIR")
|
||||
if env_override:
|
||||
return Path(env_override).expanduser().resolve()
|
||||
return (Path.home() / ".codexlens").resolve()
|
||||
|
||||
|
||||
def find_workspace_root(start_path: Path) -> Optional[Path]:
|
||||
"""Find the workspace root by looking for .codexlens directory.
|
||||
|
||||
Searches from start_path upward to find an existing .codexlens directory.
|
||||
Returns None if not found.
|
||||
"""
|
||||
current = start_path.resolve()
|
||||
|
||||
# Search up to filesystem root
|
||||
while current != current.parent:
|
||||
workspace_dir = current / WORKSPACE_DIR_NAME
|
||||
if workspace_dir.is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
|
||||
# Check root as well
|
||||
workspace_dir = current / WORKSPACE_DIR_NAME
|
||||
if workspace_dir.is_dir():
|
||||
return current
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Runtime configuration for CodexLens.
|
||||
|
||||
- data_dir: Base directory for all persistent CodexLens data.
|
||||
- venv_path: Optional virtualenv used for language tooling.
|
||||
- supported_languages: Language IDs and their associated file extensions.
|
||||
- parsing_rules: Per-language parsing and chunking hints.
|
||||
"""
|
||||
|
||||
data_dir: Path = field(default_factory=_default_global_dir)
|
||||
venv_path: Path = field(default_factory=lambda: _default_global_dir() / "venv")
|
||||
supported_languages: Dict[str, Dict[str, Any]] = field(
|
||||
default_factory=lambda: {
|
||||
# Source code languages (category: "code")
|
||||
"python": {"extensions": [".py"], "tree_sitter_language": "python", "category": "code"},
|
||||
"javascript": {"extensions": [".js", ".jsx"], "tree_sitter_language": "javascript", "category": "code"},
|
||||
"typescript": {"extensions": [".ts", ".tsx"], "tree_sitter_language": "typescript", "category": "code"},
|
||||
"java": {"extensions": [".java"], "tree_sitter_language": "java", "category": "code"},
|
||||
"go": {"extensions": [".go"], "tree_sitter_language": "go", "category": "code"},
|
||||
"zig": {"extensions": [".zig"], "tree_sitter_language": "zig", "category": "code"},
|
||||
"objective-c": {"extensions": [".m", ".mm"], "tree_sitter_language": "objc", "category": "code"},
|
||||
"c": {"extensions": [".c", ".h"], "tree_sitter_language": "c", "category": "code"},
|
||||
"cpp": {"extensions": [".cc", ".cpp", ".hpp", ".cxx"], "tree_sitter_language": "cpp", "category": "code"},
|
||||
"rust": {"extensions": [".rs"], "tree_sitter_language": "rust", "category": "code"},
|
||||
}
|
||||
)
|
||||
parsing_rules: Dict[str, Dict[str, Any]] = field(
|
||||
default_factory=lambda: {
|
||||
"default": {
|
||||
"max_chunk_chars": 4000,
|
||||
"max_chunk_lines": 200,
|
||||
"overlap_lines": 20,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
llm_enabled: bool = False
|
||||
llm_tool: str = "gemini"
|
||||
llm_timeout_ms: int = 300000
|
||||
llm_batch_size: int = 5
|
||||
|
||||
# Hybrid chunker configuration
|
||||
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
|
||||
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
|
||||
|
||||
# Embedding configuration
|
||||
embedding_backend: str = "fastembed" # "fastembed" (local) or "litellm" (API)
|
||||
embedding_model: str = "code" # For fastembed: profile (fast/code/multilingual/balanced)
|
||||
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||
|
||||
# SPLADE sparse retrieval configuration
|
||||
enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead)
|
||||
splade_model: str = "naver/splade-cocondenser-ensembledistil"
|
||||
splade_threshold: float = 0.01 # Min weight to store in index
|
||||
splade_onnx_path: Optional[str] = None # Custom ONNX model path
|
||||
|
||||
# FTS fallback (disabled by default, available via --use-fts)
|
||||
use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled)
|
||||
|
||||
# Indexing/search optimizations
|
||||
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
|
||||
|
||||
# Graph expansion (search-time, uses precomputed neighbors)
|
||||
enable_graph_expansion: bool = False
|
||||
graph_expansion_depth: int = 2
|
||||
|
||||
# Optional search reranking (disabled by default)
|
||||
enable_reranking: bool = False
|
||||
reranking_top_k: int = 50
|
||||
symbol_boost_factor: float = 1.5
|
||||
|
||||
# Optional cross-encoder reranking (second stage; requires optional reranker deps)
|
||||
enable_cross_encoder_rerank: bool = False
|
||||
reranker_backend: str = "onnx"
|
||||
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
reranker_top_k: int = 50
|
||||
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
|
||||
reranker_chunk_type_weights: Optional[Dict[str, float]] = None # Weights for chunk types: {"code": 1.0, "docstring": 0.7}
|
||||
reranker_test_file_penalty: float = 0.0 # Penalty for test files (0.0-1.0, e.g., 0.2 = 20% reduction)
|
||||
|
||||
# Chunk stripping configuration (for semantic embedding)
|
||||
chunk_strip_comments: bool = True # Strip comments from code chunks
|
||||
chunk_strip_docstrings: bool = True # Strip docstrings from code chunks
|
||||
|
||||
# Cascade search configuration (two-stage retrieval)
|
||||
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
||||
cascade_coarse_k: int = 100 # Number of coarse candidates from first stage
|
||||
cascade_fine_k: int = 10 # Number of final results after reranking
|
||||
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
|
||||
|
||||
# Staged cascade search configuration (4-stage pipeline)
|
||||
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
|
||||
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
|
||||
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
|
||||
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
|
||||
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
|
||||
|
||||
# RRF fusion configuration
|
||||
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
|
||||
rrf_k: int = 60 # RRF constant (default 60)
|
||||
|
||||
# Category-based filtering to separate code/doc results
|
||||
enable_category_filter: bool = True # Enable code/doc result separation
|
||||
|
||||
# Multi-endpoint configuration for litellm backend
|
||||
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
|
||||
embedding_pool_enabled: bool = False # Enable high availability pool for embeddings
|
||||
embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
|
||||
embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
|
||||
|
||||
# Reranker multi-endpoint configuration
|
||||
reranker_pool_enabled: bool = False # Enable high availability pool for reranker
|
||||
reranker_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
|
||||
reranker_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
|
||||
|
||||
# API concurrency settings
|
||||
api_max_workers: int = 4 # Max concurrent API calls for embedding/reranking
|
||||
api_batch_size: int = 8 # Batch size for API requests
|
||||
api_batch_size_dynamic: bool = False # Enable dynamic batch size calculation
|
||||
api_batch_size_utilization_factor: float = 0.8 # Use 80% of model token capacity
|
||||
api_batch_size_max: int = 2048 # Absolute upper limit for batch size
|
||||
chars_per_token_estimate: int = 4 # Characters per token estimation ratio
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
try:
|
||||
self.data_dir = self.data_dir.expanduser().resolve()
|
||||
self.venv_path = self.venv_path.expanduser().resolve()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as exc:
|
||||
raise ConfigError(
|
||||
f"Permission denied initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||
f"[{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ConfigError(
|
||||
f"Filesystem error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||
f"[{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise ConfigError(
|
||||
f"Unexpected error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||
f"[{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
|
||||
@cached_property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Directory for transient caches."""
|
||||
return self.data_dir / "cache"
|
||||
|
||||
@cached_property
|
||||
def index_dir(self) -> Path:
|
||||
"""Directory where index artifacts are stored."""
|
||||
return self.data_dir / "index"
|
||||
|
||||
@cached_property
|
||||
def db_path(self) -> Path:
|
||||
"""Default SQLite index path."""
|
||||
return self.index_dir / "codexlens.db"
|
||||
|
||||
def ensure_runtime_dirs(self) -> None:
|
||||
"""Create standard runtime directories if missing."""
|
||||
for directory in (self.cache_dir, self.index_dir):
|
||||
try:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as exc:
|
||||
raise ConfigError(
|
||||
f"Permission denied creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ConfigError(
|
||||
f"Filesystem error creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise ConfigError(
|
||||
f"Unexpected error creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
|
||||
def language_for_path(self, path: str | Path) -> str | None:
|
||||
"""Infer a supported language ID from a file path."""
|
||||
extension = Path(path).suffix.lower()
|
||||
for language_id, spec in self.supported_languages.items():
|
||||
extensions: List[str] = spec.get("extensions", [])
|
||||
if extension in extensions:
|
||||
return language_id
|
||||
return None
|
||||
|
||||
def category_for_path(self, path: str | Path) -> str | None:
|
||||
"""Get file category ('code' or 'doc') from a file path."""
|
||||
language = self.language_for_path(path)
|
||||
if language is None:
|
||||
return None
|
||||
spec = self.supported_languages.get(language, {})
|
||||
return spec.get("category")
|
||||
|
||||
def rules_for_language(self, language_id: str) -> Dict[str, Any]:
|
||||
"""Get parsing rules for a specific language, falling back to defaults."""
|
||||
return {**self.parsing_rules.get("default", {}), **self.parsing_rules.get(language_id, {})}
|
||||
|
||||
@cached_property
|
||||
def settings_path(self) -> Path:
|
||||
"""Path to the settings file."""
|
||||
return self.data_dir / SETTINGS_FILE_NAME
|
||||
|
||||
def save_settings(self) -> None:
|
||||
"""Save embedding and other settings to file."""
|
||||
embedding_config = {
|
||||
"backend": self.embedding_backend,
|
||||
"model": self.embedding_model,
|
||||
"use_gpu": self.embedding_use_gpu,
|
||||
"pool_enabled": self.embedding_pool_enabled,
|
||||
"strategy": self.embedding_strategy,
|
||||
"cooldown": self.embedding_cooldown,
|
||||
}
|
||||
# Include multi-endpoint config if present
|
||||
if self.embedding_endpoints:
|
||||
embedding_config["endpoints"] = self.embedding_endpoints
|
||||
|
||||
settings = {
|
||||
"embedding": embedding_config,
|
||||
"llm": {
|
||||
"enabled": self.llm_enabled,
|
||||
"tool": self.llm_tool,
|
||||
"timeout_ms": self.llm_timeout_ms,
|
||||
"batch_size": self.llm_batch_size,
|
||||
},
|
||||
"reranker": {
|
||||
"enabled": self.enable_cross_encoder_rerank,
|
||||
"backend": self.reranker_backend,
|
||||
"model": self.reranker_model,
|
||||
"top_k": self.reranker_top_k,
|
||||
"max_input_tokens": self.reranker_max_input_tokens,
|
||||
"pool_enabled": self.reranker_pool_enabled,
|
||||
"strategy": self.reranker_strategy,
|
||||
"cooldown": self.reranker_cooldown,
|
||||
},
|
||||
"cascade": {
|
||||
"strategy": self.cascade_strategy,
|
||||
"coarse_k": self.cascade_coarse_k,
|
||||
"fine_k": self.cascade_fine_k,
|
||||
},
|
||||
"api": {
|
||||
"max_workers": self.api_max_workers,
|
||||
"batch_size": self.api_batch_size,
|
||||
"batch_size_dynamic": self.api_batch_size_dynamic,
|
||||
"batch_size_utilization_factor": self.api_batch_size_utilization_factor,
|
||||
"batch_size_max": self.api_batch_size_max,
|
||||
"chars_per_token_estimate": self.chars_per_token_estimate,
|
||||
},
|
||||
}
|
||||
with open(self.settings_path, "w", encoding="utf-8") as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
def load_settings(self) -> None:
|
||||
"""Load settings from file if exists."""
|
||||
if not self.settings_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.settings_path, "r", encoding="utf-8") as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Load embedding settings
|
||||
embedding = settings.get("embedding", {})
|
||||
if "backend" in embedding:
|
||||
backend = embedding["backend"]
|
||||
# Support 'api' as alias for 'litellm'
|
||||
if backend == "api":
|
||||
backend = "litellm"
|
||||
if backend in {"fastembed", "litellm"}:
|
||||
self.embedding_backend = backend
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid embedding backend in %s: %r (expected 'fastembed' or 'litellm')",
|
||||
self.settings_path,
|
||||
embedding["backend"],
|
||||
)
|
||||
if "model" in embedding:
|
||||
self.embedding_model = embedding["model"]
|
||||
if "use_gpu" in embedding:
|
||||
self.embedding_use_gpu = embedding["use_gpu"]
|
||||
|
||||
# Load multi-endpoint configuration
|
||||
if "endpoints" in embedding:
|
||||
self.embedding_endpoints = embedding["endpoints"]
|
||||
if "pool_enabled" in embedding:
|
||||
self.embedding_pool_enabled = embedding["pool_enabled"]
|
||||
if "strategy" in embedding:
|
||||
self.embedding_strategy = embedding["strategy"]
|
||||
if "cooldown" in embedding:
|
||||
self.embedding_cooldown = embedding["cooldown"]
|
||||
|
||||
# Load LLM settings
|
||||
llm = settings.get("llm", {})
|
||||
if "enabled" in llm:
|
||||
self.llm_enabled = llm["enabled"]
|
||||
if "tool" in llm:
|
||||
self.llm_tool = llm["tool"]
|
||||
if "timeout_ms" in llm:
|
||||
self.llm_timeout_ms = llm["timeout_ms"]
|
||||
if "batch_size" in llm:
|
||||
self.llm_batch_size = llm["batch_size"]
|
||||
|
||||
# Load reranker settings
|
||||
reranker = settings.get("reranker", {})
|
||||
if "enabled" in reranker:
|
||||
self.enable_cross_encoder_rerank = reranker["enabled"]
|
||||
if "backend" in reranker:
|
||||
backend = reranker["backend"]
|
||||
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
|
||||
self.reranker_backend = backend
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid reranker backend in %s: %r (expected 'fastembed', 'onnx', 'api', 'litellm', or 'legacy')",
|
||||
self.settings_path,
|
||||
backend,
|
||||
)
|
||||
if "model" in reranker:
|
||||
self.reranker_model = reranker["model"]
|
||||
if "top_k" in reranker:
|
||||
self.reranker_top_k = reranker["top_k"]
|
||||
if "max_input_tokens" in reranker:
|
||||
self.reranker_max_input_tokens = reranker["max_input_tokens"]
|
||||
if "pool_enabled" in reranker:
|
||||
self.reranker_pool_enabled = reranker["pool_enabled"]
|
||||
if "strategy" in reranker:
|
||||
self.reranker_strategy = reranker["strategy"]
|
||||
if "cooldown" in reranker:
|
||||
self.reranker_cooldown = reranker["cooldown"]
|
||||
|
||||
# Load cascade settings
|
||||
cascade = settings.get("cascade", {})
|
||||
if "strategy" in cascade:
|
||||
strategy = cascade["strategy"]
|
||||
if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}:
|
||||
self.cascade_strategy = strategy
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')",
|
||||
self.settings_path,
|
||||
strategy,
|
||||
)
|
||||
if "coarse_k" in cascade:
|
||||
self.cascade_coarse_k = cascade["coarse_k"]
|
||||
if "fine_k" in cascade:
|
||||
self.cascade_fine_k = cascade["fine_k"]
|
||||
|
||||
# Load API settings
|
||||
api = settings.get("api", {})
|
||||
if "max_workers" in api:
|
||||
self.api_max_workers = api["max_workers"]
|
||||
if "batch_size" in api:
|
||||
self.api_batch_size = api["batch_size"]
|
||||
if "batch_size_dynamic" in api:
|
||||
self.api_batch_size_dynamic = api["batch_size_dynamic"]
|
||||
if "batch_size_utilization_factor" in api:
|
||||
self.api_batch_size_utilization_factor = api["batch_size_utilization_factor"]
|
||||
if "batch_size_max" in api:
|
||||
self.api_batch_size_max = api["batch_size_max"]
|
||||
if "chars_per_token_estimate" in api:
|
||||
self.chars_per_token_estimate = api["chars_per_token_estimate"]
|
||||
except Exception as exc:
|
||||
log.warning(
|
||||
"Failed to load settings from %s (%s): %s",
|
||||
self.settings_path,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Apply .env overrides (highest priority)
|
||||
self._apply_env_overrides()
|
||||
|
||||
def _apply_env_overrides(self) -> None:
|
||||
"""Apply environment variable overrides from .env file.
|
||||
|
||||
Priority: default → settings.json → .env (highest)
|
||||
|
||||
Supported variables (with or without CODEXLENS_ prefix):
|
||||
EMBEDDING_MODEL: Override embedding model/profile
|
||||
EMBEDDING_BACKEND: Override embedding backend (fastembed/litellm)
|
||||
EMBEDDING_POOL_ENABLED: Enable embedding high availability pool
|
||||
EMBEDDING_STRATEGY: Load balance strategy for embedding
|
||||
EMBEDDING_COOLDOWN: Rate limit cooldown for embedding
|
||||
RERANKER_MODEL: Override reranker model
|
||||
RERANKER_BACKEND: Override reranker backend
|
||||
RERANKER_ENABLED: Override reranker enabled state (true/false)
|
||||
RERANKER_POOL_ENABLED: Enable reranker high availability pool
|
||||
RERANKER_STRATEGY: Load balance strategy for reranker
|
||||
RERANKER_COOLDOWN: Rate limit cooldown for reranker
|
||||
"""
|
||||
from .env_config import load_global_env
|
||||
|
||||
env_vars = load_global_env()
|
||||
if not env_vars:
|
||||
return
|
||||
|
||||
def get_env(key: str) -> str | None:
|
||||
"""Get env var with or without CODEXLENS_ prefix."""
|
||||
# Check prefixed version first (Dashboard format), then unprefixed
|
||||
return env_vars.get(f"CODEXLENS_{key}") or env_vars.get(key)
|
||||
|
||||
# Embedding overrides
|
||||
embedding_model = get_env("EMBEDDING_MODEL")
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
log.debug("Overriding embedding_model from .env: %s", self.embedding_model)
|
||||
|
||||
embedding_backend = get_env("EMBEDDING_BACKEND")
|
||||
if embedding_backend:
|
||||
backend = embedding_backend.lower()
|
||||
# Support 'api' as alias for 'litellm'
|
||||
if backend == "api":
|
||||
backend = "litellm"
|
||||
if backend in {"fastembed", "litellm"}:
|
||||
self.embedding_backend = backend
|
||||
log.debug("Overriding embedding_backend from .env: %s", backend)
|
||||
else:
|
||||
log.warning("Invalid EMBEDDING_BACKEND in .env: %r", embedding_backend)
|
||||
|
||||
embedding_pool = get_env("EMBEDDING_POOL_ENABLED")
|
||||
if embedding_pool:
|
||||
value = embedding_pool.lower()
|
||||
self.embedding_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding embedding_pool_enabled from .env: %s", self.embedding_pool_enabled)
|
||||
|
||||
embedding_strategy = get_env("EMBEDDING_STRATEGY")
|
||||
if embedding_strategy:
|
||||
strategy = embedding_strategy.lower()
|
||||
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||
self.embedding_strategy = strategy
|
||||
log.debug("Overriding embedding_strategy from .env: %s", strategy)
|
||||
else:
|
||||
log.warning("Invalid EMBEDDING_STRATEGY in .env: %r", embedding_strategy)
|
||||
|
||||
embedding_cooldown = get_env("EMBEDDING_COOLDOWN")
|
||||
if embedding_cooldown:
|
||||
try:
|
||||
self.embedding_cooldown = float(embedding_cooldown)
|
||||
log.debug("Overriding embedding_cooldown from .env: %s", self.embedding_cooldown)
|
||||
except ValueError:
|
||||
log.warning("Invalid EMBEDDING_COOLDOWN in .env: %r", embedding_cooldown)
|
||||
|
||||
# Reranker overrides
|
||||
reranker_model = get_env("RERANKER_MODEL")
|
||||
if reranker_model:
|
||||
self.reranker_model = reranker_model
|
||||
log.debug("Overriding reranker_model from .env: %s", self.reranker_model)
|
||||
|
||||
reranker_backend = get_env("RERANKER_BACKEND")
|
||||
if reranker_backend:
|
||||
backend = reranker_backend.lower()
|
||||
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
|
||||
self.reranker_backend = backend
|
||||
log.debug("Overriding reranker_backend from .env: %s", backend)
|
||||
else:
|
||||
log.warning("Invalid RERANKER_BACKEND in .env: %r", reranker_backend)
|
||||
|
||||
reranker_enabled = get_env("RERANKER_ENABLED")
|
||||
if reranker_enabled:
|
||||
value = reranker_enabled.lower()
|
||||
self.enable_cross_encoder_rerank = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding reranker_enabled from .env: %s", self.enable_cross_encoder_rerank)
|
||||
|
||||
reranker_pool = get_env("RERANKER_POOL_ENABLED")
|
||||
if reranker_pool:
|
||||
value = reranker_pool.lower()
|
||||
self.reranker_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding reranker_pool_enabled from .env: %s", self.reranker_pool_enabled)
|
||||
|
||||
reranker_strategy = get_env("RERANKER_STRATEGY")
|
||||
if reranker_strategy:
|
||||
strategy = reranker_strategy.lower()
|
||||
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||
self.reranker_strategy = strategy
|
||||
log.debug("Overriding reranker_strategy from .env: %s", strategy)
|
||||
else:
|
||||
log.warning("Invalid RERANKER_STRATEGY in .env: %r", reranker_strategy)
|
||||
|
||||
reranker_cooldown = get_env("RERANKER_COOLDOWN")
|
||||
if reranker_cooldown:
|
||||
try:
|
||||
self.reranker_cooldown = float(reranker_cooldown)
|
||||
log.debug("Overriding reranker_cooldown from .env: %s", self.reranker_cooldown)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_COOLDOWN in .env: %r", reranker_cooldown)
|
||||
|
||||
reranker_max_tokens = get_env("RERANKER_MAX_INPUT_TOKENS")
|
||||
if reranker_max_tokens:
|
||||
try:
|
||||
self.reranker_max_input_tokens = int(reranker_max_tokens)
|
||||
log.debug("Overriding reranker_max_input_tokens from .env: %s", self.reranker_max_input_tokens)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
|
||||
|
||||
# Reranker tuning from environment
|
||||
test_penalty = get_env("RERANKER_TEST_FILE_PENALTY")
|
||||
if test_penalty:
|
||||
try:
|
||||
self.reranker_test_file_penalty = float(test_penalty)
|
||||
log.debug("Overriding reranker_test_file_penalty from .env: %s", self.reranker_test_file_penalty)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_TEST_FILE_PENALTY in .env: %r", test_penalty)
|
||||
|
||||
docstring_weight = get_env("RERANKER_DOCSTRING_WEIGHT")
|
||||
if docstring_weight:
|
||||
try:
|
||||
weight = float(docstring_weight)
|
||||
self.reranker_chunk_type_weights = {"code": 1.0, "docstring": weight}
|
||||
log.debug("Overriding reranker docstring weight from .env: %s", weight)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_DOCSTRING_WEIGHT in .env: %r", docstring_weight)
|
||||
|
||||
# Chunk stripping from environment
|
||||
strip_comments = get_env("CHUNK_STRIP_COMMENTS")
|
||||
if strip_comments:
|
||||
self.chunk_strip_comments = strip_comments.lower() in ("true", "1", "yes")
|
||||
log.debug("Overriding chunk_strip_comments from .env: %s", self.chunk_strip_comments)
|
||||
|
||||
strip_docstrings = get_env("CHUNK_STRIP_DOCSTRINGS")
|
||||
if strip_docstrings:
|
||||
self.chunk_strip_docstrings = strip_docstrings.lower() in ("true", "1", "yes")
|
||||
log.debug("Overriding chunk_strip_docstrings from .env: %s", self.chunk_strip_docstrings)
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "Config":
|
||||
"""Load config with settings from file."""
|
||||
config = cls()
|
||||
config.load_settings()
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceConfig:
|
||||
"""Workspace-local configuration for CodexLens.
|
||||
|
||||
Stores index data in project/.codexlens/ directory.
|
||||
"""
|
||||
|
||||
workspace_root: Path
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.workspace_root = Path(self.workspace_root).resolve()
|
||||
|
||||
@property
|
||||
def codexlens_dir(self) -> Path:
|
||||
"""The .codexlens directory in workspace root."""
|
||||
return self.workspace_root / WORKSPACE_DIR_NAME
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""SQLite index path for this workspace."""
|
||||
return self.codexlens_dir / "index.db"
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Cache directory for this workspace."""
|
||||
return self.codexlens_dir / "cache"
|
||||
|
||||
@property
|
||||
def env_path(self) -> Path:
|
||||
"""Path to workspace .env file."""
|
||||
return self.codexlens_dir / ".env"
|
||||
|
||||
def load_env(self, *, override: bool = False) -> int:
|
||||
"""Load .env file and apply to os.environ.
|
||||
|
||||
Args:
|
||||
override: If True, override existing environment variables
|
||||
|
||||
Returns:
|
||||
Number of variables applied
|
||||
"""
|
||||
from .env_config import apply_workspace_env
|
||||
return apply_workspace_env(self.workspace_root, override=override)
|
||||
|
||||
def get_api_config(self, prefix: str) -> dict:
|
||||
"""Get API configuration from environment.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||
|
||||
Returns:
|
||||
Dictionary with api_key, api_base, model, etc.
|
||||
"""
|
||||
from .env_config import get_api_config
|
||||
return get_api_config(prefix, workspace_root=self.workspace_root)
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Create the .codexlens directory structure."""
|
||||
try:
|
||||
self.codexlens_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create .gitignore to exclude cache but keep index
|
||||
gitignore_path = self.codexlens_dir / ".gitignore"
|
||||
if not gitignore_path.exists():
|
||||
gitignore_path.write_text(
|
||||
"# CodexLens workspace data\n"
|
||||
"cache/\n"
|
||||
"*.log\n"
|
||||
".env\n" # Exclude .env from git
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if workspace is already initialized."""
|
||||
return self.codexlens_dir.is_dir() and self.db_path.exists()
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path: Path) -> Optional["WorkspaceConfig"]:
|
||||
"""Create WorkspaceConfig from a path by finding workspace root.
|
||||
|
||||
Returns None if no workspace found.
|
||||
"""
|
||||
root = find_workspace_root(path)
|
||||
if root is None:
|
||||
return None
|
||||
return cls(workspace_root=root)
|
||||
|
||||
@classmethod
|
||||
def create_at(cls, path: Path) -> "WorkspaceConfig":
|
||||
"""Create a new workspace at the given path."""
|
||||
config = cls(workspace_root=path)
|
||||
config.initialize()
|
||||
return config
|
||||
128
codex-lens/build/lib/codexlens/entities.py
Normal file
128
codex-lens/build/lib/codexlens/entities.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Pydantic entity models for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class Symbol(BaseModel):
|
||||
"""A code symbol discovered in a file."""
|
||||
|
||||
name: str = Field(..., min_length=1)
|
||||
kind: str = Field(..., min_length=1)
|
||||
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
||||
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
|
||||
|
||||
@field_validator("range")
|
||||
@classmethod
|
||||
def validate_range(cls, value: Tuple[int, int]) -> Tuple[int, int]:
|
||||
if len(value) != 2:
|
||||
raise ValueError("range must be a (start_line, end_line) tuple")
|
||||
start_line, end_line = value
|
||||
if start_line < 1 or end_line < 1:
|
||||
raise ValueError("range lines must be >= 1")
|
||||
if end_line < start_line:
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
return value
|
||||
|
||||
|
||||
class SemanticChunk(BaseModel):
|
||||
"""A semantically meaningful chunk of content, optionally embedded."""
|
||||
|
||||
content: str = Field(..., min_length=1)
|
||||
embedding: Optional[List[float]] = Field(default=None, description="Vector embedding for semantic search")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
id: Optional[int] = Field(default=None, description="Database row ID")
|
||||
file_path: Optional[str] = Field(default=None, description="Source file path")
|
||||
|
||||
@field_validator("embedding")
|
||||
@classmethod
|
||||
def validate_embedding(cls, value: Optional[List[float]]) -> Optional[List[float]]:
|
||||
if value is None:
|
||||
return value
|
||||
if not value:
|
||||
raise ValueError("embedding cannot be empty when provided")
|
||||
norm = math.sqrt(sum(x * x for x in value))
|
||||
epsilon = 1e-10
|
||||
if norm < epsilon:
|
||||
raise ValueError("embedding cannot be a zero vector")
|
||||
return value
|
||||
|
||||
|
||||
class IndexedFile(BaseModel):
|
||||
"""An indexed source file with symbols and optional semantic chunks."""
|
||||
|
||||
path: str = Field(..., min_length=1)
|
||||
language: str = Field(..., min_length=1)
|
||||
symbols: List[Symbol] = Field(default_factory=list)
|
||||
chunks: List[SemanticChunk] = Field(default_factory=list)
|
||||
relationships: List["CodeRelationship"] = Field(default_factory=list)
|
||||
|
||||
@field_validator("path", "language")
|
||||
@classmethod
|
||||
def strip_and_validate_nonempty(cls, value: str) -> str:
|
||||
cleaned = value.strip()
|
||||
if not cleaned:
|
||||
raise ValueError("value cannot be blank")
|
||||
return cleaned
|
||||
|
||||
|
||||
class RelationshipType(str, Enum):
|
||||
"""Types of code relationships."""
|
||||
CALL = "calls"
|
||||
INHERITS = "inherits"
|
||||
IMPORTS = "imports"
|
||||
|
||||
|
||||
class CodeRelationship(BaseModel):
|
||||
"""A relationship between code symbols (e.g., function calls, inheritance)."""
|
||||
|
||||
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
|
||||
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
|
||||
relationship_type: RelationshipType = Field(..., description="Type of relationship (call, inherits, etc.)")
|
||||
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
|
||||
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
|
||||
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
|
||||
|
||||
|
||||
class AdditionalLocation(BaseModel):
|
||||
"""A pointer to another location where a similar result was found.
|
||||
|
||||
Used for grouping search results with similar scores and content,
|
||||
where the primary result is stored in SearchResult and secondary
|
||||
locations are stored in this model.
|
||||
"""
|
||||
|
||||
path: str = Field(..., min_length=1)
|
||||
score: float = Field(..., ge=0.0)
|
||||
start_line: Optional[int] = Field(default=None, description="Start line of the result (1-based)")
|
||||
end_line: Optional[int] = Field(default=None, description="End line of the result (1-based)")
|
||||
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol")
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""A unified search result for lexical or semantic search."""
|
||||
|
||||
path: str = Field(..., min_length=1)
|
||||
score: float = Field(..., ge=0.0)
|
||||
excerpt: Optional[str] = None
|
||||
content: Optional[str] = Field(default=None, description="Full content of matched code block")
|
||||
symbol: Optional[Symbol] = None
|
||||
chunk: Optional[SemanticChunk] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Additional context for complete code blocks
|
||||
start_line: Optional[int] = Field(default=None, description="Start line of code block (1-based)")
|
||||
end_line: Optional[int] = Field(default=None, description="End line of code block (1-based)")
|
||||
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol/function/class")
|
||||
symbol_kind: Optional[str] = Field(default=None, description="Kind of symbol (function/class/method)")
|
||||
|
||||
# Field for grouping similar results
|
||||
additional_locations: List["AdditionalLocation"] = Field(
|
||||
default_factory=list,
|
||||
description="Other locations for grouped results with similar scores and content."
|
||||
)
|
||||
304
codex-lens/build/lib/codexlens/env_config.py
Normal file
304
codex-lens/build/lib/codexlens/env_config.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Environment configuration loader for CodexLens.
|
||||
|
||||
Loads .env files from workspace .codexlens directory with fallback to project root.
|
||||
Provides unified access to API configurations.
|
||||
|
||||
Priority order:
|
||||
1. Environment variables (already set)
|
||||
2. .codexlens/.env (workspace-local)
|
||||
3. .env (project root)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Supported environment variables with descriptions
|
||||
ENV_VARS = {
|
||||
# Reranker configuration (overrides settings.json)
|
||||
"RERANKER_MODEL": "Reranker model name (overrides settings.json)",
|
||||
"RERANKER_BACKEND": "Reranker backend: fastembed, onnx, api, litellm, legacy",
|
||||
"RERANKER_ENABLED": "Enable reranker: true/false",
|
||||
"RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)",
|
||||
"RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)",
|
||||
"RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina",
|
||||
"RERANKER_POOL_ENABLED": "Enable reranker high availability pool: true/false",
|
||||
"RERANKER_STRATEGY": "Reranker load balance strategy: round_robin, latency_aware, weighted_random",
|
||||
"RERANKER_COOLDOWN": "Reranker rate limit cooldown in seconds",
|
||||
# Embedding configuration (overrides settings.json)
|
||||
"EMBEDDING_MODEL": "Embedding model/profile name (overrides settings.json)",
|
||||
"EMBEDDING_BACKEND": "Embedding backend: fastembed, litellm",
|
||||
"EMBEDDING_API_KEY": "API key for embedding service",
|
||||
"EMBEDDING_API_BASE": "Base URL for embedding API",
|
||||
"EMBEDDING_POOL_ENABLED": "Enable embedding high availability pool: true/false",
|
||||
"EMBEDDING_STRATEGY": "Embedding load balance strategy: round_robin, latency_aware, weighted_random",
|
||||
"EMBEDDING_COOLDOWN": "Embedding rate limit cooldown in seconds",
|
||||
# LiteLLM configuration
|
||||
"LITELLM_API_KEY": "API key for LiteLLM",
|
||||
"LITELLM_API_BASE": "Base URL for LiteLLM",
|
||||
"LITELLM_MODEL": "LiteLLM model name",
|
||||
# General configuration
|
||||
"CODEXLENS_DATA_DIR": "Custom data directory path",
|
||||
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
|
||||
# Chunking configuration
|
||||
"CHUNK_STRIP_COMMENTS": "Strip comments from code chunks for embedding: true/false (default: true)",
|
||||
"CHUNK_STRIP_DOCSTRINGS": "Strip docstrings from code chunks for embedding: true/false (default: true)",
|
||||
# Reranker tuning
|
||||
"RERANKER_TEST_FILE_PENALTY": "Penalty for test files in reranking: 0.0-1.0 (default: 0.0)",
|
||||
"RERANKER_DOCSTRING_WEIGHT": "Weight for docstring chunks in reranking: 0.0-1.0 (default: 1.0)",
|
||||
}
|
||||
|
||||
|
||||
def _parse_env_line(line: str) -> tuple[str, str] | None:
|
||||
"""Parse a single .env line, returning (key, value) or None."""
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith("#"):
|
||||
return None
|
||||
|
||||
# Handle export prefix
|
||||
if line.startswith("export "):
|
||||
line = line[7:].strip()
|
||||
|
||||
# Split on first =
|
||||
if "=" not in line:
|
||||
return None
|
||||
|
||||
key, _, value = line.partition("=")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Remove surrounding quotes
|
||||
if len(value) >= 2:
|
||||
if (value.startswith('"') and value.endswith('"')) or \
|
||||
(value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1]
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
def load_env_file(env_path: Path) -> Dict[str, str]:
|
||||
"""Load environment variables from a .env file.
|
||||
|
||||
Args:
|
||||
env_path: Path to .env file
|
||||
|
||||
Returns:
|
||||
Dictionary of environment variables
|
||||
"""
|
||||
if not env_path.is_file():
|
||||
return {}
|
||||
|
||||
env_vars: Dict[str, str] = {}
|
||||
|
||||
try:
|
||||
content = env_path.read_text(encoding="utf-8")
|
||||
for line in content.splitlines():
|
||||
result = _parse_env_line(line)
|
||||
if result:
|
||||
key, value = result
|
||||
env_vars[key] = value
|
||||
except Exception as exc:
|
||||
log.warning("Failed to load .env file %s: %s", env_path, exc)
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def _get_global_data_dir() -> Path:
|
||||
"""Get global CodexLens data directory."""
|
||||
env_override = os.environ.get("CODEXLENS_DATA_DIR")
|
||||
if env_override:
|
||||
return Path(env_override).expanduser().resolve()
|
||||
return (Path.home() / ".codexlens").resolve()
|
||||
|
||||
|
||||
def load_global_env() -> Dict[str, str]:
|
||||
"""Load environment variables from global ~/.codexlens/.env file.
|
||||
|
||||
Returns:
|
||||
Dictionary of environment variables from global config
|
||||
"""
|
||||
global_env_path = _get_global_data_dir() / ".env"
|
||||
if global_env_path.is_file():
|
||||
env_vars = load_env_file(global_env_path)
|
||||
log.debug("Loaded %d vars from global %s", len(env_vars), global_env_path)
|
||||
return env_vars
|
||||
return {}
|
||||
|
||||
|
||||
def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]:
|
||||
"""Load environment variables from workspace .env files.
|
||||
|
||||
Priority (later overrides earlier):
|
||||
1. Global ~/.codexlens/.env (lowest priority)
|
||||
2. Project root .env
|
||||
3. .codexlens/.env (highest priority)
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace root directory. If None, uses current directory.
|
||||
|
||||
Returns:
|
||||
Merged dictionary of environment variables
|
||||
"""
|
||||
if workspace_root is None:
|
||||
workspace_root = Path.cwd()
|
||||
|
||||
workspace_root = Path(workspace_root).resolve()
|
||||
|
||||
env_vars: Dict[str, str] = {}
|
||||
|
||||
# Load from global ~/.codexlens/.env (lowest priority)
|
||||
global_vars = load_global_env()
|
||||
if global_vars:
|
||||
env_vars.update(global_vars)
|
||||
|
||||
# Load from project root .env (medium priority)
|
||||
root_env = workspace_root / ".env"
|
||||
if root_env.is_file():
|
||||
loaded = load_env_file(root_env)
|
||||
env_vars.update(loaded)
|
||||
log.debug("Loaded %d vars from %s", len(loaded), root_env)
|
||||
|
||||
# Load from .codexlens/.env (highest priority)
|
||||
codexlens_env = workspace_root / ".codexlens" / ".env"
|
||||
if codexlens_env.is_file():
|
||||
loaded = load_env_file(codexlens_env)
|
||||
env_vars.update(loaded)
|
||||
log.debug("Loaded %d vars from %s", len(loaded), codexlens_env)
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int:
|
||||
"""Load .env files and apply to os.environ.
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace root directory
|
||||
override: If True, override existing environment variables
|
||||
|
||||
Returns:
|
||||
Number of variables applied
|
||||
"""
|
||||
env_vars = load_workspace_env(workspace_root)
|
||||
applied = 0
|
||||
|
||||
for key, value in env_vars.items():
|
||||
if override or key not in os.environ:
|
||||
os.environ[key] = value
|
||||
applied += 1
|
||||
log.debug("Applied env var: %s", key)
|
||||
|
||||
return applied
|
||||
|
||||
|
||||
def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback.
|
||||
|
||||
Priority:
|
||||
1. os.environ (already set)
|
||||
2. .codexlens/.env
|
||||
3. .env
|
||||
4. default value
|
||||
|
||||
Args:
|
||||
key: Environment variable name
|
||||
default: Default value if not found
|
||||
workspace_root: Workspace root for .env file lookup
|
||||
|
||||
Returns:
|
||||
Value or default
|
||||
"""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Load from .env files
|
||||
env_vars = load_workspace_env(workspace_root)
|
||||
if key in env_vars:
|
||||
return env_vars[key]
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_api_config(
|
||||
prefix: str,
|
||||
*,
|
||||
workspace_root: Path | None = None,
|
||||
defaults: Dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get API configuration from environment.
|
||||
|
||||
Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||
workspace_root: Workspace root for .env file lookup
|
||||
defaults: Default values
|
||||
|
||||
Returns:
|
||||
Dictionary with api_key, api_base, model, etc.
|
||||
"""
|
||||
defaults = defaults or {}
|
||||
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
# Standard API config fields
|
||||
field_mapping = {
|
||||
"api_key": f"{prefix}_API_KEY",
|
||||
"api_base": f"{prefix}_API_BASE",
|
||||
"model": f"{prefix}_MODEL",
|
||||
"provider": f"{prefix}_PROVIDER",
|
||||
"timeout": f"{prefix}_TIMEOUT",
|
||||
}
|
||||
|
||||
for field, env_key in field_mapping.items():
|
||||
value = get_env(env_key, workspace_root=workspace_root)
|
||||
if value is not None:
|
||||
# Type conversion for specific fields
|
||||
if field == "timeout":
|
||||
try:
|
||||
config[field] = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
config[field] = value
|
||||
elif field in defaults:
|
||||
config[field] = defaults[field]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def generate_env_example() -> str:
|
||||
"""Generate .env.example content with all supported variables.
|
||||
|
||||
Returns:
|
||||
String content for .env.example file
|
||||
"""
|
||||
lines = [
|
||||
"# CodexLens Environment Configuration",
|
||||
"# Copy this file to .codexlens/.env and fill in your values",
|
||||
"",
|
||||
]
|
||||
|
||||
# Group by prefix
|
||||
groups: Dict[str, list] = {}
|
||||
for key, desc in ENV_VARS.items():
|
||||
prefix = key.split("_")[0]
|
||||
if prefix not in groups:
|
||||
groups[prefix] = []
|
||||
groups[prefix].append((key, desc))
|
||||
|
||||
for prefix, items in groups.items():
|
||||
lines.append(f"# {prefix} Configuration")
|
||||
for key, desc in items:
|
||||
lines.append(f"# {desc}")
|
||||
lines.append(f"# {key}=")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
59
codex-lens/build/lib/codexlens/errors.py
Normal file
59
codex-lens/build/lib/codexlens/errors.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""CodexLens exception hierarchy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class CodexLensError(Exception):
|
||||
"""Base class for all CodexLens errors."""
|
||||
|
||||
|
||||
class ConfigError(CodexLensError):
|
||||
"""Raised when configuration is invalid or cannot be loaded."""
|
||||
|
||||
|
||||
class ParseError(CodexLensError):
|
||||
"""Raised when parsing or indexing a file fails."""
|
||||
|
||||
|
||||
class StorageError(CodexLensError):
|
||||
"""Raised when reading/writing index storage fails.
|
||||
|
||||
Attributes:
|
||||
message: Human-readable error description
|
||||
db_path: Path to the database file (if applicable)
|
||||
operation: The operation that failed (e.g., 'query', 'initialize', 'migrate')
|
||||
details: Additional context for debugging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
db_path: str | None = None,
|
||||
operation: str | None = None,
|
||||
details: dict | None = None
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.db_path = db_path
|
||||
self.operation = operation
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [self.message]
|
||||
if self.db_path:
|
||||
parts.append(f"[db: {self.db_path}]")
|
||||
if self.operation:
|
||||
parts.append(f"[op: {self.operation}]")
|
||||
if self.details:
|
||||
detail_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
|
||||
parts.append(f"[{detail_str}]")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
class SearchError(CodexLensError):
|
||||
"""Raised when a search operation fails."""
|
||||
|
||||
|
||||
class IndexNotFoundError(CodexLensError):
|
||||
"""Raised when a project's index cannot be found."""
|
||||
|
||||
28
codex-lens/build/lib/codexlens/hybrid_search/__init__.py
Normal file
28
codex-lens/build/lib/codexlens/hybrid_search/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Hybrid Search data structures for CodexLens.
|
||||
|
||||
This module provides core data structures for hybrid search:
|
||||
- CodeSymbolNode: Graph node representing a code symbol
|
||||
- CodeAssociationGraph: Graph of code relationships
|
||||
- SearchResultCluster: Clustered search results
|
||||
- Range: Position range in source files
|
||||
- CallHierarchyItem: LSP call hierarchy item
|
||||
|
||||
Note: The search engine is in codexlens.search.hybrid_search
|
||||
LSP-based expansion is in codexlens.lsp module
|
||||
"""
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeAssociationGraph,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
SearchResultCluster,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallHierarchyItem",
|
||||
"CodeAssociationGraph",
|
||||
"CodeSymbolNode",
|
||||
"Range",
|
||||
"SearchResultCluster",
|
||||
]
|
||||
602
codex-lens/build/lib/codexlens/hybrid_search/data_structures.py
Normal file
602
codex-lens/build/lib/codexlens/hybrid_search/data_structures.py
Normal file
@@ -0,0 +1,602 @@
|
||||
"""Core data structures for the hybrid search system.
|
||||
|
||||
This module defines the fundamental data structures used throughout the
|
||||
hybrid search pipeline, including code symbol representations, association
|
||||
graphs, and clustered search results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import networkx as nx
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
"""Position range within a source file.
|
||||
|
||||
Attributes:
|
||||
start_line: Starting line number (0-based).
|
||||
start_character: Starting character offset within the line.
|
||||
end_line: Ending line number (0-based).
|
||||
end_character: Ending character offset within the line.
|
||||
"""
|
||||
|
||||
start_line: int
|
||||
start_character: int
|
||||
end_line: int
|
||||
end_character: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate range values."""
|
||||
if self.start_line < 0:
|
||||
raise ValueError("start_line must be >= 0")
|
||||
if self.start_character < 0:
|
||||
raise ValueError("start_character must be >= 0")
|
||||
if self.end_line < 0:
|
||||
raise ValueError("end_line must be >= 0")
|
||||
if self.end_character < 0:
|
||||
raise ValueError("end_character must be >= 0")
|
||||
if self.end_line < self.start_line:
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
if self.end_line == self.start_line and self.end_character < self.start_character:
|
||||
raise ValueError("end_character must be >= start_character on the same line")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"start": {"line": self.start_line, "character": self.start_character},
|
||||
"end": {"line": self.end_line, "character": self.end_character},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Range:
|
||||
"""Create Range from dictionary representation."""
|
||||
return cls(
|
||||
start_line=data["start"]["line"],
|
||||
start_character=data["start"]["character"],
|
||||
end_line=data["end"]["line"],
|
||||
end_character=data["end"]["character"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
|
||||
"""Create Range from LSP Range object.
|
||||
|
||||
LSP Range format:
|
||||
{"start": {"line": int, "character": int},
|
||||
"end": {"line": int, "character": int}}
|
||||
"""
|
||||
return cls(
|
||||
start_line=lsp_range["start"]["line"],
|
||||
start_character=lsp_range["start"]["character"],
|
||||
end_line=lsp_range["end"]["line"],
|
||||
end_character=lsp_range["end"]["character"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallHierarchyItem:
|
||||
"""LSP CallHierarchyItem for representing callers/callees.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name (function, method, class name).
|
||||
kind: Symbol kind (function, method, class, etc.).
|
||||
file_path: Absolute file path where the symbol is defined.
|
||||
range: Position range in the source file.
|
||||
detail: Optional additional detail about the symbol.
|
||||
"""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
detail: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
}
|
||||
if self.detail:
|
||||
result["detail"] = self.detail
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from dictionary representation."""
|
||||
return cls(
|
||||
name=data["name"],
|
||||
kind=data["kind"],
|
||||
file_path=data["file_path"],
|
||||
range=Range.from_dict(data["range"]),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeSymbolNode:
|
||||
"""Graph node representing a code symbol.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier in format 'file_path:name:line'.
|
||||
name: Symbol name (function, class, variable name).
|
||||
kind: Symbol kind (function, class, method, variable, etc.).
|
||||
file_path: Absolute file path where symbol is defined.
|
||||
range: Start/end position in the source file.
|
||||
embedding: Optional vector embedding for semantic search.
|
||||
raw_code: Raw source code of the symbol.
|
||||
docstring: Documentation string (if available).
|
||||
score: Ranking score (used during reranking).
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
embedding: Optional[List[float]] = None
|
||||
raw_code: str = ""
|
||||
docstring: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate required fields."""
|
||||
if not self.id:
|
||||
raise ValueError("id cannot be empty")
|
||||
if not self.name:
|
||||
raise ValueError("name cannot be empty")
|
||||
if not self.kind:
|
||||
raise ValueError("kind cannot be empty")
|
||||
if not self.file_path:
|
||||
raise ValueError("file_path cannot be empty")
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on unique ID."""
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on unique ID."""
|
||||
if not isinstance(other, CodeSymbolNode):
|
||||
return False
|
||||
return self.id == other.id
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
"score": self.score,
|
||||
}
|
||||
if self.raw_code:
|
||||
result["raw_code"] = self.raw_code
|
||||
if self.docstring:
|
||||
result["docstring"] = self.docstring
|
||||
# Exclude embedding from serialization (too large for JSON responses)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
|
||||
"""Create CodeSymbolNode from dictionary representation."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
kind=data["kind"],
|
||||
file_path=data["file_path"],
|
||||
range=Range.from_dict(data["range"]),
|
||||
embedding=data.get("embedding"),
|
||||
raw_code=data.get("raw_code", ""),
|
||||
docstring=data.get("docstring", ""),
|
||||
score=data.get("score", 0.0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp_location(
|
||||
cls,
|
||||
uri: str,
|
||||
name: str,
|
||||
kind: str,
|
||||
lsp_range: Dict[str, Any],
|
||||
raw_code: str = "",
|
||||
docstring: str = "",
|
||||
) -> CodeSymbolNode:
|
||||
"""Create CodeSymbolNode from LSP location data.
|
||||
|
||||
Args:
|
||||
uri: File URI (file:// prefix will be stripped).
|
||||
name: Symbol name.
|
||||
kind: Symbol kind.
|
||||
lsp_range: LSP Range object.
|
||||
raw_code: Optional raw source code.
|
||||
docstring: Optional documentation string.
|
||||
|
||||
Returns:
|
||||
New CodeSymbolNode instance.
|
||||
"""
|
||||
# Strip file:// prefix if present
|
||||
file_path = uri
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
# Handle Windows paths (file:///C:/...)
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
range_obj = Range.from_lsp_range(lsp_range)
|
||||
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
|
||||
|
||||
return cls(
|
||||
id=symbol_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=range_obj,
|
||||
raw_code=raw_code,
|
||||
docstring=docstring,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_id(cls, file_path: str, name: str, line: int) -> str:
|
||||
"""Generate a unique symbol ID.
|
||||
|
||||
Args:
|
||||
file_path: Absolute file path.
|
||||
name: Symbol name.
|
||||
line: Start line number.
|
||||
|
||||
Returns:
|
||||
Unique ID string in format 'file_path:name:line'.
|
||||
"""
|
||||
return f"{file_path}:{name}:{line}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeAssociationGraph:
|
||||
"""Graph of code relationships between symbols.
|
||||
|
||||
This graph represents the association between code symbols discovered
|
||||
through LSP queries (references, call hierarchy, etc.).
|
||||
|
||||
Attributes:
|
||||
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
|
||||
edges: List of (from_id, to_id, relationship_type) tuples.
|
||||
relationship_type: 'calls', 'references', 'inherits', 'imports'.
|
||||
"""
|
||||
|
||||
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
|
||||
edges: List[Tuple[str, str, str]] = field(default_factory=list)
|
||||
|
||||
def add_node(self, node: CodeSymbolNode) -> None:
|
||||
"""Add a node to the graph.
|
||||
|
||||
Args:
|
||||
node: CodeSymbolNode to add. If a node with the same ID exists,
|
||||
it will be replaced.
|
||||
"""
|
||||
self.nodes[node.id] = node
|
||||
|
||||
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||
"""Add an edge to the graph.
|
||||
|
||||
Args:
|
||||
from_id: Source node ID.
|
||||
to_id: Target node ID.
|
||||
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
|
||||
|
||||
Raises:
|
||||
ValueError: If from_id or to_id not in graph nodes.
|
||||
"""
|
||||
if from_id not in self.nodes:
|
||||
raise ValueError(f"Source node '{from_id}' not found in graph")
|
||||
if to_id not in self.nodes:
|
||||
raise ValueError(f"Target node '{to_id}' not found in graph")
|
||||
|
||||
edge = (from_id, to_id, rel_type)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||
"""Add an edge without validating node existence.
|
||||
|
||||
Use this method during bulk graph construction where nodes may be
|
||||
added after edges, or when performance is critical.
|
||||
|
||||
Args:
|
||||
from_id: Source node ID.
|
||||
to_id: Target node ID.
|
||||
rel_type: Relationship type.
|
||||
"""
|
||||
edge = (from_id, to_id, rel_type)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
|
||||
"""Get a node by ID.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to look up.
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode if found, None otherwise.
|
||||
"""
|
||||
return self.nodes.get(node_id)
|
||||
|
||||
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||
"""Get neighboring nodes connected by outgoing edges.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to find neighbors for.
|
||||
rel_type: Optional filter by relationship type.
|
||||
|
||||
Returns:
|
||||
List of neighboring CodeSymbolNode objects.
|
||||
"""
|
||||
neighbors = []
|
||||
for from_id, to_id, edge_rel in self.edges:
|
||||
if from_id == node_id:
|
||||
if rel_type is None or edge_rel == rel_type:
|
||||
node = self.nodes.get(to_id)
|
||||
if node:
|
||||
neighbors.append(node)
|
||||
return neighbors
|
||||
|
||||
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||
"""Get nodes connected by incoming edges.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to find incoming connections for.
|
||||
rel_type: Optional filter by relationship type.
|
||||
|
||||
Returns:
|
||||
List of CodeSymbolNode objects with edges pointing to node_id.
|
||||
"""
|
||||
incoming = []
|
||||
for from_id, to_id, edge_rel in self.edges:
|
||||
if to_id == node_id:
|
||||
if rel_type is None or edge_rel == rel_type:
|
||||
node = self.nodes.get(from_id)
|
||||
if node:
|
||||
incoming.append(node)
|
||||
return incoming
|
||||
|
||||
def to_networkx(self) -> "nx.DiGraph":
|
||||
"""Convert to NetworkX DiGraph for graph algorithms.
|
||||
|
||||
Returns:
|
||||
NetworkX directed graph with nodes and edges.
|
||||
|
||||
Raises:
|
||||
ImportError: If networkx is not installed.
|
||||
"""
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"networkx is required for graph algorithms. "
|
||||
"Install with: pip install networkx"
|
||||
)
|
||||
|
||||
graph = nx.DiGraph()
|
||||
|
||||
# Add nodes with attributes
|
||||
for node_id, node in self.nodes.items():
|
||||
graph.add_node(
|
||||
node_id,
|
||||
name=node.name,
|
||||
kind=node.kind,
|
||||
file_path=node.file_path,
|
||||
score=node.score,
|
||||
)
|
||||
|
||||
# Add edges with relationship type
|
||||
for from_id, to_id, rel_type in self.edges:
|
||||
graph.add_edge(from_id, to_id, relationship=rel_type)
|
||||
|
||||
return graph
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'nodes' and 'edges' keys.
|
||||
"""
|
||||
return {
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"edges": [
|
||||
{"from": from_id, "to": to_id, "relationship": rel_type}
|
||||
for from_id, to_id, rel_type in self.edges
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
|
||||
"""Create CodeAssociationGraph from dictionary representation.
|
||||
|
||||
Args:
|
||||
data: Dictionary with 'nodes' and 'edges' keys.
|
||||
|
||||
Returns:
|
||||
New CodeAssociationGraph instance.
|
||||
"""
|
||||
graph = cls()
|
||||
|
||||
# Load nodes
|
||||
for node_id, node_data in data.get("nodes", {}).items():
|
||||
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
|
||||
|
||||
# Load edges
|
||||
for edge_data in data.get("edges", []):
|
||||
graph.edges.append((
|
||||
edge_data["from"],
|
||||
edge_data["to"],
|
||||
edge_data["relationship"],
|
||||
))
|
||||
|
||||
return graph
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of nodes in the graph."""
|
||||
return len(self.nodes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResultCluster:
|
||||
"""Clustered search result containing related code symbols.
|
||||
|
||||
Search results are grouped into clusters based on graph community
|
||||
detection or embedding similarity. Each cluster represents a
|
||||
conceptually related group of code symbols.
|
||||
|
||||
Attributes:
|
||||
cluster_id: Unique cluster identifier.
|
||||
score: Cluster relevance score (max of symbol scores).
|
||||
title: Human-readable cluster title/summary.
|
||||
symbols: List of CodeSymbolNode in this cluster.
|
||||
metadata: Additional cluster metadata.
|
||||
"""
|
||||
|
||||
cluster_id: str
|
||||
score: float
|
||||
title: str
|
||||
symbols: List[CodeSymbolNode] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate cluster fields."""
|
||||
if not self.cluster_id:
|
||||
raise ValueError("cluster_id cannot be empty")
|
||||
if self.score < 0:
|
||||
raise ValueError("score must be >= 0")
|
||||
|
||||
def add_symbol(self, symbol: CodeSymbolNode) -> None:
|
||||
"""Add a symbol to the cluster.
|
||||
|
||||
Args:
|
||||
symbol: CodeSymbolNode to add.
|
||||
"""
|
||||
self.symbols.append(symbol)
|
||||
|
||||
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
|
||||
"""Get top N symbols by score.
|
||||
|
||||
Args:
|
||||
n: Number of symbols to return.
|
||||
|
||||
Returns:
|
||||
List of top N CodeSymbolNode objects sorted by score descending.
|
||||
"""
|
||||
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
|
||||
return sorted_symbols[:n]
|
||||
|
||||
def update_score(self) -> None:
|
||||
"""Update cluster score to max of symbol scores."""
|
||||
if self.symbols:
|
||||
self.score = max(s.score for s in self.symbols)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the cluster.
|
||||
"""
|
||||
return {
|
||||
"cluster_id": self.cluster_id,
|
||||
"score": self.score,
|
||||
"title": self.title,
|
||||
"symbols": [s.to_dict() for s in self.symbols],
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
|
||||
"""Create SearchResultCluster from dictionary representation.
|
||||
|
||||
Args:
|
||||
data: Dictionary with cluster data.
|
||||
|
||||
Returns:
|
||||
New SearchResultCluster instance.
|
||||
"""
|
||||
return cls(
|
||||
cluster_id=data["cluster_id"],
|
||||
score=data["score"],
|
||||
title=data["title"],
|
||||
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of symbols in the cluster."""
|
||||
return len(self.symbols)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallHierarchyItem:
|
||||
"""LSP CallHierarchyItem for representing callers/callees.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name (function, method, etc.).
|
||||
kind: Symbol kind (function, method, etc.).
|
||||
file_path: Absolute file path.
|
||||
range: Position range in the file.
|
||||
detail: Optional additional detail (e.g., signature).
|
||||
"""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
detail: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
}
|
||||
if self.detail:
|
||||
result["detail"] = self.detail
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from dictionary representation."""
|
||||
return cls(
|
||||
name=data.get("name", "unknown"),
|
||||
kind=data.get("kind", "unknown"),
|
||||
file_path=data.get("file_path", data.get("uri", "")),
|
||||
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from LSP response format.
|
||||
|
||||
LSP uses 0-based line numbers and 'character' instead of 'char'.
|
||||
"""
|
||||
uri = data.get("uri", data.get("file_path", ""))
|
||||
# Strip file:// prefix
|
||||
file_path = uri
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
return cls(
|
||||
name=data.get("name", "unknown"),
|
||||
kind=str(data.get("kind", "unknown")),
|
||||
file_path=file_path,
|
||||
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
26
codex-lens/build/lib/codexlens/indexing/__init__.py
Normal file
26
codex-lens/build/lib/codexlens/indexing/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Code indexing and symbol extraction."""
|
||||
from codexlens.indexing.symbol_extractor import SymbolExtractor
|
||||
from codexlens.indexing.embedding import (
|
||||
BinaryEmbeddingBackend,
|
||||
DenseEmbeddingBackend,
|
||||
CascadeEmbeddingBackend,
|
||||
get_cascade_embedder,
|
||||
binarize_embedding,
|
||||
pack_binary_embedding,
|
||||
unpack_binary_embedding,
|
||||
hamming_distance,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SymbolExtractor",
|
||||
# Cascade embedding backends
|
||||
"BinaryEmbeddingBackend",
|
||||
"DenseEmbeddingBackend",
|
||||
"CascadeEmbeddingBackend",
|
||||
"get_cascade_embedder",
|
||||
# Utility functions
|
||||
"binarize_embedding",
|
||||
"pack_binary_embedding",
|
||||
"unpack_binary_embedding",
|
||||
"hamming_distance",
|
||||
]
|
||||
582
codex-lens/build/lib/codexlens/indexing/embedding.py
Normal file
582
codex-lens/build/lib/codexlens/indexing/embedding.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Multi-type embedding backends for cascade retrieval.
|
||||
|
||||
This module provides embedding backends optimized for cascade retrieval:
|
||||
1. BinaryEmbeddingBackend - Fast coarse filtering with binary vectors
|
||||
2. DenseEmbeddingBackend - High-precision dense vectors for reranking
|
||||
3. CascadeEmbeddingBackend - Combined binary + dense for two-stage retrieval
|
||||
|
||||
Cascade retrieval workflow:
|
||||
1. Binary search (fast, ~32 bytes/vector) -> top-K candidates
|
||||
2. Dense rerank (precise, ~8KB/vector) -> final results
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from codexlens.semantic.base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utility Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def binarize_embedding(embedding: np.ndarray) -> np.ndarray:
|
||||
"""Convert float embedding to binary vector.
|
||||
|
||||
Applies sign-based quantization: values > 0 become 1, values <= 0 become 0.
|
||||
|
||||
Args:
|
||||
embedding: Float32 embedding of any dimension
|
||||
|
||||
Returns:
|
||||
Binary vector (uint8 with values 0 or 1) of same dimension
|
||||
"""
|
||||
return (embedding > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def pack_binary_embedding(binary_vector: np.ndarray) -> bytes:
|
||||
"""Pack binary vector into compact bytes format.
|
||||
|
||||
Packs 8 binary values into each byte for storage efficiency.
|
||||
For a 256-dim binary vector, output is 32 bytes.
|
||||
|
||||
Args:
|
||||
binary_vector: Binary vector (uint8 with values 0 or 1)
|
||||
|
||||
Returns:
|
||||
Packed bytes (length = ceil(dim / 8))
|
||||
"""
|
||||
# Ensure vector length is multiple of 8 by padding if needed
|
||||
dim = len(binary_vector)
|
||||
padded_dim = ((dim + 7) // 8) * 8
|
||||
if padded_dim > dim:
|
||||
padded = np.zeros(padded_dim, dtype=np.uint8)
|
||||
padded[:dim] = binary_vector
|
||||
binary_vector = padded
|
||||
|
||||
# Pack 8 bits per byte
|
||||
packed = np.packbits(binary_vector)
|
||||
return packed.tobytes()
|
||||
|
||||
|
||||
def unpack_binary_embedding(packed_bytes: bytes, dim: int = 256) -> np.ndarray:
|
||||
"""Unpack bytes back to binary vector.
|
||||
|
||||
Args:
|
||||
packed_bytes: Packed binary data
|
||||
dim: Original vector dimension (default: 256)
|
||||
|
||||
Returns:
|
||||
Binary vector (uint8 with values 0 or 1)
|
||||
"""
|
||||
unpacked = np.unpackbits(np.frombuffer(packed_bytes, dtype=np.uint8))
|
||||
return unpacked[:dim]
|
||||
|
||||
|
||||
def hamming_distance(a: bytes, b: bytes) -> int:
|
||||
"""Compute Hamming distance between two packed binary vectors.
|
||||
|
||||
Uses XOR and popcount for efficient distance computation.
|
||||
|
||||
Args:
|
||||
a: First packed binary vector
|
||||
b: Second packed binary vector
|
||||
|
||||
Returns:
|
||||
Hamming distance (number of differing bits)
|
||||
"""
|
||||
a_arr = np.frombuffer(a, dtype=np.uint8)
|
||||
b_arr = np.frombuffer(b, dtype=np.uint8)
|
||||
xor = np.bitwise_xor(a_arr, b_arr)
|
||||
return int(np.unpackbits(xor).sum())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Binary Embedding Backend
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BinaryEmbeddingBackend(BaseEmbedder):
|
||||
"""Generate 256-dimensional binary embeddings for fast coarse retrieval.
|
||||
|
||||
Uses a lightweight embedding model and applies sign-based quantization
|
||||
to produce compact binary vectors (32 bytes per embedding).
|
||||
|
||||
Suitable for:
|
||||
- First-stage candidate retrieval
|
||||
- Hamming distance-based similarity search
|
||||
- Memory-constrained environments
|
||||
|
||||
Model: sentence-transformers/all-MiniLM-L6-v2 (384 dim) -> quantized to 256 bits
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, fast
|
||||
BINARY_DIM = 256
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
) -> None:
|
||||
"""Initialize binary embedding backend.
|
||||
|
||||
Args:
|
||||
model_name: Base embedding model name. Defaults to BAAI/bge-small-en-v1.5
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
"""
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Semantic search dependencies not available. "
|
||||
"Install with: pip install codexlens[semantic]"
|
||||
)
|
||||
|
||||
self._model_name = model_name or self.DEFAULT_MODEL
|
||||
self._use_gpu = use_gpu
|
||||
self._model = None
|
||||
|
||||
# Projection matrix for dimension reduction (lazily initialized)
|
||||
self._projection_matrix: Optional[np.ndarray] = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return binary embedding dimension (256)."""
|
||||
return self.BINARY_DIM
|
||||
|
||||
@property
|
||||
def packed_bytes(self) -> int:
|
||||
"""Return packed bytes size (32 bytes for 256 bits)."""
|
||||
return self.BINARY_DIM // 8
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from codexlens.semantic.gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
|
||||
try:
|
||||
self._model = TextEmbedding(
|
||||
model_name=self._model_name,
|
||||
providers=providers,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for older fastembed versions
|
||||
self._model = TextEmbedding(model_name=self._model_name)
|
||||
|
||||
logger.debug(f"BinaryEmbeddingBackend loaded model: {self._model_name}")
|
||||
|
||||
def _get_projection_matrix(self, input_dim: int) -> np.ndarray:
|
||||
"""Get or create projection matrix for dimension reduction.
|
||||
|
||||
Uses random projection with fixed seed for reproducibility.
|
||||
|
||||
Args:
|
||||
input_dim: Input embedding dimension from base model
|
||||
|
||||
Returns:
|
||||
Projection matrix of shape (input_dim, BINARY_DIM)
|
||||
"""
|
||||
if self._projection_matrix is not None:
|
||||
return self._projection_matrix
|
||||
|
||||
# Fixed seed for reproducibility across sessions
|
||||
rng = np.random.RandomState(42)
|
||||
# Gaussian random projection
|
||||
self._projection_matrix = rng.randn(input_dim, self.BINARY_DIM).astype(np.float32)
|
||||
# Normalize columns for consistent scale
|
||||
norms = np.linalg.norm(self._projection_matrix, axis=0, keepdims=True)
|
||||
self._projection_matrix /= (norms + 1e-8)
|
||||
|
||||
return self._projection_matrix
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate binary embeddings as numpy array.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Binary embeddings of shape (n_texts, 256) with values 0 or 1
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Get base float embeddings
|
||||
float_embeddings = np.array(list(self._model.embed(texts)))
|
||||
input_dim = float_embeddings.shape[1]
|
||||
|
||||
# Project to target dimension if needed
|
||||
if input_dim != self.BINARY_DIM:
|
||||
projection = self._get_projection_matrix(input_dim)
|
||||
float_embeddings = float_embeddings @ projection
|
||||
|
||||
# Binarize
|
||||
return binarize_embedding(float_embeddings)
|
||||
|
||||
def embed_packed(self, texts: str | Iterable[str]) -> List[bytes]:
|
||||
"""Generate packed binary embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
List of packed bytes (32 bytes each for 256-dim)
|
||||
"""
|
||||
binary = self.embed_to_numpy(texts)
|
||||
return [pack_binary_embedding(vec) for vec in binary]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dense Embedding Backend
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DenseEmbeddingBackend(BaseEmbedder):
|
||||
"""Generate high-dimensional dense embeddings for precise reranking.
|
||||
|
||||
Uses large embedding models to produce 2048-dimensional float32 vectors
|
||||
for maximum retrieval quality.
|
||||
|
||||
Suitable for:
|
||||
- Second-stage reranking
|
||||
- High-precision similarity search
|
||||
- Quality-critical applications
|
||||
|
||||
Model: BAAI/bge-large-en-v1.5 (1024 dim) with optional expansion
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, use small for testing
|
||||
TARGET_DIM = 768 # Reduced target for faster testing
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
expand_dim: bool = True,
|
||||
) -> None:
|
||||
"""Initialize dense embedding backend.
|
||||
|
||||
Args:
|
||||
model_name: Dense embedding model name. Defaults to BAAI/bge-large-en-v1.5
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
expand_dim: If True, expand embeddings to TARGET_DIM using learned expansion
|
||||
"""
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Semantic search dependencies not available. "
|
||||
"Install with: pip install codexlens[semantic]"
|
||||
)
|
||||
|
||||
self._model_name = model_name or self.DEFAULT_MODEL
|
||||
self._use_gpu = use_gpu
|
||||
self._expand_dim = expand_dim
|
||||
self._model = None
|
||||
self._native_dim: Optional[int] = None
|
||||
|
||||
# Expansion matrix for dimension expansion (lazily initialized)
|
||||
self._expansion_matrix: Optional[np.ndarray] = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimension.
|
||||
|
||||
Returns TARGET_DIM if expand_dim is True, otherwise native model dimension.
|
||||
"""
|
||||
if self._expand_dim:
|
||||
return self.TARGET_DIM
|
||||
# Return cached native dim or estimate based on model
|
||||
if self._native_dim is not None:
|
||||
return self._native_dim
|
||||
# Model dimension estimates
|
||||
model_dims = {
|
||||
"BAAI/bge-large-en-v1.5": 1024,
|
||||
"BAAI/bge-base-en-v1.5": 768,
|
||||
"BAAI/bge-small-en-v1.5": 384,
|
||||
"intfloat/multilingual-e5-large": 1024,
|
||||
}
|
||||
return model_dims.get(self._model_name, 1024)
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit."""
|
||||
return 512 # Conservative default for large models
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from codexlens.semantic.gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
|
||||
try:
|
||||
self._model = TextEmbedding(
|
||||
model_name=self._model_name,
|
||||
providers=providers,
|
||||
)
|
||||
except TypeError:
|
||||
self._model = TextEmbedding(model_name=self._model_name)
|
||||
|
||||
logger.debug(f"DenseEmbeddingBackend loaded model: {self._model_name}")
|
||||
|
||||
def _get_expansion_matrix(self, input_dim: int) -> np.ndarray:
|
||||
"""Get or create expansion matrix for dimension expansion.
|
||||
|
||||
Uses random orthogonal projection for information-preserving expansion.
|
||||
|
||||
Args:
|
||||
input_dim: Input embedding dimension from base model
|
||||
|
||||
Returns:
|
||||
Expansion matrix of shape (input_dim, TARGET_DIM)
|
||||
"""
|
||||
if self._expansion_matrix is not None:
|
||||
return self._expansion_matrix
|
||||
|
||||
# Fixed seed for reproducibility
|
||||
rng = np.random.RandomState(123)
|
||||
|
||||
# Create semi-orthogonal expansion matrix
|
||||
# First input_dim columns form identity-like structure
|
||||
self._expansion_matrix = np.zeros((input_dim, self.TARGET_DIM), dtype=np.float32)
|
||||
|
||||
# Copy original dimensions
|
||||
copy_dim = min(input_dim, self.TARGET_DIM)
|
||||
self._expansion_matrix[:copy_dim, :copy_dim] = np.eye(copy_dim, dtype=np.float32)
|
||||
|
||||
# Fill remaining with random projections
|
||||
if self.TARGET_DIM > input_dim:
|
||||
random_part = rng.randn(input_dim, self.TARGET_DIM - input_dim).astype(np.float32)
|
||||
# Normalize
|
||||
norms = np.linalg.norm(random_part, axis=0, keepdims=True)
|
||||
random_part /= (norms + 1e-8)
|
||||
self._expansion_matrix[:, input_dim:] = random_part
|
||||
|
||||
return self._expansion_matrix
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate dense embeddings as numpy array.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Dense embeddings of shape (n_texts, TARGET_DIM) as float32
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Get base float embeddings
|
||||
float_embeddings = np.array(list(self._model.embed(texts)), dtype=np.float32)
|
||||
self._native_dim = float_embeddings.shape[1]
|
||||
|
||||
# Expand to target dimension if needed
|
||||
if self._expand_dim and self._native_dim < self.TARGET_DIM:
|
||||
expansion = self._get_expansion_matrix(self._native_dim)
|
||||
float_embeddings = float_embeddings @ expansion
|
||||
|
||||
return float_embeddings
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cascade Embedding Backend
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CascadeEmbeddingBackend(BaseEmbedder):
|
||||
"""Combined binary + dense embedding backend for cascade retrieval.
|
||||
|
||||
Generates both binary (for fast coarse filtering) and dense (for precise
|
||||
reranking) embeddings in a single pass, optimized for two-stage retrieval.
|
||||
|
||||
Cascade workflow:
|
||||
1. encode_cascade() returns (binary_embeddings, dense_embeddings)
|
||||
2. Binary search: Use Hamming distance on binary vectors -> top-K candidates
|
||||
3. Dense rerank: Use cosine similarity on dense vectors -> final results
|
||||
|
||||
Memory efficiency:
|
||||
- Binary: 32 bytes per vector (256 bits)
|
||||
- Dense: 8192 bytes per vector (2048 x float32)
|
||||
- Total: ~8KB per document for full cascade support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
binary_model: Optional[str] = None,
|
||||
dense_model: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
) -> None:
|
||||
"""Initialize cascade embedding backend.
|
||||
|
||||
Args:
|
||||
binary_model: Model for binary embeddings. Defaults to BAAI/bge-small-en-v1.5
|
||||
dense_model: Model for dense embeddings. Defaults to BAAI/bge-large-en-v1.5
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
"""
|
||||
self._binary_backend = BinaryEmbeddingBackend(
|
||||
model_name=binary_model,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
self._dense_backend = DenseEmbeddingBackend(
|
||||
model_name=dense_model,
|
||||
use_gpu=use_gpu,
|
||||
expand_dim=True,
|
||||
)
|
||||
self._use_gpu = use_gpu
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model names for both backends."""
|
||||
return f"cascade({self._binary_backend.model_name}, {self._dense_backend.model_name})"
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return dense embedding dimension (for compatibility)."""
|
||||
return self._dense_backend.embedding_dim
|
||||
|
||||
@property
|
||||
def binary_dim(self) -> int:
|
||||
"""Return binary embedding dimension."""
|
||||
return self._binary_backend.embedding_dim
|
||||
|
||||
@property
|
||||
def dense_dim(self) -> int:
|
||||
"""Return dense embedding dimension."""
|
||||
return self._dense_backend.embedding_dim
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate dense embeddings (for BaseEmbedder compatibility).
|
||||
|
||||
For cascade embeddings, use encode_cascade() instead.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Dense embeddings of shape (n_texts, dense_dim)
|
||||
"""
|
||||
return self._dense_backend.embed_to_numpy(texts)
|
||||
|
||||
def encode_cascade(
|
||||
self,
|
||||
texts: str | Iterable[str],
|
||||
batch_size: int = 32,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Generate both binary and dense embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- binary_embeddings: Shape (n_texts, 256), uint8 values 0/1
|
||||
- dense_embeddings: Shape (n_texts, 2048), float32
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
binary_embeddings = self._binary_backend.embed_to_numpy(texts)
|
||||
dense_embeddings = self._dense_backend.embed_to_numpy(texts)
|
||||
|
||||
return binary_embeddings, dense_embeddings
|
||||
|
||||
def encode_binary(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate only binary embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Binary embeddings of shape (n_texts, 256)
|
||||
"""
|
||||
return self._binary_backend.embed_to_numpy(texts)
|
||||
|
||||
def encode_dense(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate only dense embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Dense embeddings of shape (n_texts, 2048)
|
||||
"""
|
||||
return self._dense_backend.embed_to_numpy(texts)
|
||||
|
||||
def encode_binary_packed(self, texts: str | Iterable[str]) -> List[bytes]:
|
||||
"""Generate packed binary embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
List of packed bytes (32 bytes each)
|
||||
"""
|
||||
return self._binary_backend.embed_packed(texts)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Factory Function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_cascade_embedder(
|
||||
binary_model: Optional[str] = None,
|
||||
dense_model: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
) -> CascadeEmbeddingBackend:
|
||||
"""Factory function to create a cascade embedder.
|
||||
|
||||
Args:
|
||||
binary_model: Model for binary embeddings (default: BAAI/bge-small-en-v1.5)
|
||||
dense_model: Model for dense embeddings (default: BAAI/bge-large-en-v1.5)
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
|
||||
Returns:
|
||||
Configured CascadeEmbeddingBackend instance
|
||||
|
||||
Example:
|
||||
>>> embedder = get_cascade_embedder()
|
||||
>>> binary, dense = embedder.encode_cascade(["hello world"])
|
||||
>>> binary.shape # (1, 256)
|
||||
>>> dense.shape # (1, 2048)
|
||||
"""
|
||||
return CascadeEmbeddingBackend(
|
||||
binary_model=binary_model,
|
||||
dense_model=dense_model,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
277
codex-lens/build/lib/codexlens/indexing/symbol_extractor.py
Normal file
277
codex-lens/build/lib/codexlens/indexing/symbol_extractor.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Symbol and relationship extraction from source code."""
|
||||
import re
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
except Exception: # pragma: no cover - optional dependency / platform variance
|
||||
TreeSitterSymbolParser = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class SymbolExtractor:
|
||||
"""Extract symbols and relationships from source code using regex patterns."""
|
||||
|
||||
# Pattern definitions for different languages
|
||||
PATTERNS = {
|
||||
'python': {
|
||||
'function': r'^(?:async\s+)?def\s+(\w+)\s*\(',
|
||||
'class': r'^class\s+(\w+)\s*[:\(]',
|
||||
'import': r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)',
|
||||
'call': r'(?<![.\w])(\w+)\s*\(',
|
||||
},
|
||||
'typescript': {
|
||||
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*[<\(]',
|
||||
'class': r'(?:export\s+)?class\s+(\w+)',
|
||||
'import': r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
|
||||
'call': r'(?<![.\w])(\w+)\s*[<\(]',
|
||||
},
|
||||
'javascript': {
|
||||
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(',
|
||||
'class': r'(?:export\s+)?class\s+(\w+)',
|
||||
'import': r"(?:import|require)\s*\(?['\"]([^'\"]+)['\"]",
|
||||
'call': r'(?<![.\w])(\w+)\s*\(',
|
||||
}
|
||||
}
|
||||
|
||||
LANGUAGE_MAP = {
|
||||
'.py': 'python',
|
||||
'.ts': 'typescript',
|
||||
'.tsx': 'typescript',
|
||||
'.js': 'javascript',
|
||||
'.jsx': 'javascript',
|
||||
}
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.db_conn: Optional[sqlite3.Connection] = None
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Connect to database and ensure schema exists."""
|
||||
self.db_conn = sqlite3.connect(str(self.db_path))
|
||||
self._ensure_tables()
|
||||
|
||||
def __enter__(self) -> "SymbolExtractor":
|
||||
"""Context manager entry: connect to database."""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Context manager exit: close database connection."""
|
||||
self.close()
|
||||
|
||||
def _ensure_tables(self) -> None:
|
||||
"""Create symbols and relationships tables if they don't exist."""
|
||||
if not self.db_conn:
|
||||
return
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
# Create symbols table with qualified_name
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS symbols (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
qualified_name TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL,
|
||||
UNIQUE(file_path, name, start_line)
|
||||
)
|
||||
''')
|
||||
|
||||
# Create relationships table with target_symbol_fqn
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS symbol_relationships (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_symbol_fqn TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
line INTEGER,
|
||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols(id) ON DELETE CASCADE
|
||||
)
|
||||
''')
|
||||
|
||||
# Create performance indexes
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_path)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_source ON symbol_relationships(source_symbol_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_target ON symbol_relationships(target_symbol_fqn)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_type ON symbol_relationships(relationship_type)')
|
||||
|
||||
self.db_conn.commit()
|
||||
|
||||
def extract_from_file(self, file_path: Path, content: str) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Extract symbols and relationships from file content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file
|
||||
content: File content as string
|
||||
|
||||
Returns:
|
||||
Tuple of (symbols, relationships) where:
|
||||
- symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
|
||||
- relationships: List of relationship dicts with source_scope, target, type, file_path, line
|
||||
"""
|
||||
ext = file_path.suffix.lower()
|
||||
lang = self.LANGUAGE_MAP.get(ext)
|
||||
|
||||
if not lang or lang not in self.PATTERNS:
|
||||
return [], []
|
||||
|
||||
patterns = self.PATTERNS[lang]
|
||||
symbols = []
|
||||
relationships: List[Dict] = []
|
||||
lines = content.split('\n')
|
||||
|
||||
current_scope = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
# Extract function/class definitions
|
||||
for kind in ['function', 'class']:
|
||||
if kind in patterns:
|
||||
match = re.search(patterns[kind], line)
|
||||
if match:
|
||||
name = match.group(1)
|
||||
qualified_name = f"{file_path.stem}.{name}"
|
||||
symbols.append({
|
||||
'qualified_name': qualified_name,
|
||||
'name': name,
|
||||
'kind': kind,
|
||||
'file_path': str(file_path),
|
||||
'start_line': line_num,
|
||||
'end_line': line_num, # Simplified - would need proper parsing for actual end
|
||||
})
|
||||
current_scope = name
|
||||
|
||||
if TreeSitterSymbolParser is not None:
|
||||
try:
|
||||
ts_parser = TreeSitterSymbolParser(lang, file_path)
|
||||
if ts_parser.is_available():
|
||||
indexed = ts_parser.parse(content, file_path)
|
||||
if indexed is not None and indexed.relationships:
|
||||
relationships = [
|
||||
{
|
||||
"source_scope": r.source_symbol,
|
||||
"target": r.target_symbol,
|
||||
"type": r.relationship_type.value,
|
||||
"file_path": str(file_path),
|
||||
"line": r.source_line,
|
||||
}
|
||||
for r in indexed.relationships
|
||||
]
|
||||
except Exception:
|
||||
relationships = []
|
||||
|
||||
# Regex fallback for relationships (when tree-sitter is unavailable)
|
||||
if not relationships:
|
||||
current_scope = None
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
for kind in ['function', 'class']:
|
||||
if kind in patterns:
|
||||
match = re.search(patterns[kind], line)
|
||||
if match:
|
||||
current_scope = match.group(1)
|
||||
|
||||
# Extract imports
|
||||
if 'import' in patterns:
|
||||
match = re.search(patterns['import'], line)
|
||||
if match:
|
||||
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
if import_target and current_scope:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': import_target.strip(),
|
||||
'type': 'imports',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
|
||||
# Extract function calls (simplified)
|
||||
if 'call' in patterns and current_scope:
|
||||
for match in re.finditer(patterns['call'], line):
|
||||
call_name = match.group(1)
|
||||
# Skip common keywords and the current function
|
||||
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': call_name,
|
||||
'type': 'calls',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
|
||||
return symbols, relationships
|
||||
|
||||
def save_symbols(self, symbols: List[Dict]) -> Dict[str, int]:
|
||||
"""Save symbols to database and return name->id mapping.
|
||||
|
||||
Args:
|
||||
symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol name to database id
|
||||
"""
|
||||
if not self.db_conn or not symbols:
|
||||
return {}
|
||||
|
||||
cursor = self.db_conn.cursor()
|
||||
name_to_id = {}
|
||||
|
||||
for sym in symbols:
|
||||
try:
|
||||
cursor.execute('''
|
||||
INSERT OR IGNORE INTO symbols
|
||||
(qualified_name, name, kind, file_path, start_line, end_line)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (sym['qualified_name'], sym['name'], sym['kind'],
|
||||
sym['file_path'], sym['start_line'], sym['end_line']))
|
||||
|
||||
# Get the id
|
||||
cursor.execute('''
|
||||
SELECT id FROM symbols
|
||||
WHERE file_path = ? AND name = ? AND start_line = ?
|
||||
''', (sym['file_path'], sym['name'], sym['start_line']))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
name_to_id[sym['name']] = row[0]
|
||||
except sqlite3.Error:
|
||||
continue
|
||||
|
||||
self.db_conn.commit()
|
||||
return name_to_id
|
||||
|
||||
def save_relationships(self, relationships: List[Dict], name_to_id: Dict[str, int]) -> None:
|
||||
"""Save relationships to database.
|
||||
|
||||
Args:
|
||||
relationships: List of relationship dicts with source_scope, target, type, file_path, line
|
||||
name_to_id: Dictionary mapping symbol names to database ids
|
||||
"""
|
||||
if not self.db_conn or not relationships:
|
||||
return
|
||||
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
for rel in relationships:
|
||||
source_id = name_to_id.get(rel['source_scope'])
|
||||
if source_id:
|
||||
try:
|
||||
cursor.execute('''
|
||||
INSERT INTO symbol_relationships
|
||||
(source_symbol_id, target_symbol_fqn, relationship_type, file_path, line)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (source_id, rel['target'], rel['type'], rel['file_path'], rel['line']))
|
||||
except sqlite3.Error:
|
||||
continue
|
||||
|
||||
self.db_conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self.db_conn:
|
||||
self.db_conn.close()
|
||||
self.db_conn = None
|
||||
34
codex-lens/build/lib/codexlens/lsp/__init__.py
Normal file
34
codex-lens/build/lib/codexlens/lsp/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""LSP module for real-time language server integration.
|
||||
|
||||
This module provides:
|
||||
- LspBridge: HTTP bridge to VSCode language servers
|
||||
- LspGraphBuilder: Build code association graphs via LSP
|
||||
- Location: Position in a source file
|
||||
|
||||
Example:
|
||||
>>> from codexlens.lsp import LspBridge, LspGraphBuilder
|
||||
>>>
|
||||
>>> async with LspBridge() as bridge:
|
||||
... refs = await bridge.get_references(symbol)
|
||||
... graph = await LspGraphBuilder().build_from_seeds(seeds, bridge)
|
||||
"""
|
||||
|
||||
from codexlens.lsp.lsp_bridge import (
|
||||
CacheEntry,
|
||||
Location,
|
||||
LspBridge,
|
||||
)
|
||||
from codexlens.lsp.lsp_graph_builder import (
|
||||
LspGraphBuilder,
|
||||
)
|
||||
|
||||
# Alias for backward compatibility
|
||||
GraphBuilder = LspGraphBuilder
|
||||
|
||||
__all__ = [
|
||||
"CacheEntry",
|
||||
"GraphBuilder",
|
||||
"Location",
|
||||
"LspBridge",
|
||||
"LspGraphBuilder",
|
||||
]
|
||||
551
codex-lens/build/lib/codexlens/lsp/handlers.py
Normal file
551
codex-lens/build/lib/codexlens/lsp/handlers.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""LSP request handlers for codex-lens.
|
||||
|
||||
This module contains handlers for LSP requests:
|
||||
- textDocument/definition
|
||||
- textDocument/completion
|
||||
- workspace/symbol
|
||||
- textDocument/didSave
|
||||
- textDocument/hover
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
from codexlens.lsp.server import server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Symbol kind mapping from codex-lens to LSP
|
||||
SYMBOL_KIND_MAP = {
|
||||
"class": lsp.SymbolKind.Class,
|
||||
"function": lsp.SymbolKind.Function,
|
||||
"method": lsp.SymbolKind.Method,
|
||||
"variable": lsp.SymbolKind.Variable,
|
||||
"constant": lsp.SymbolKind.Constant,
|
||||
"property": lsp.SymbolKind.Property,
|
||||
"field": lsp.SymbolKind.Field,
|
||||
"interface": lsp.SymbolKind.Interface,
|
||||
"module": lsp.SymbolKind.Module,
|
||||
"namespace": lsp.SymbolKind.Namespace,
|
||||
"package": lsp.SymbolKind.Package,
|
||||
"enum": lsp.SymbolKind.Enum,
|
||||
"enum_member": lsp.SymbolKind.EnumMember,
|
||||
"struct": lsp.SymbolKind.Struct,
|
||||
"type": lsp.SymbolKind.TypeParameter,
|
||||
"type_alias": lsp.SymbolKind.TypeParameter,
|
||||
}
|
||||
|
||||
# Completion kind mapping from codex-lens to LSP
|
||||
COMPLETION_KIND_MAP = {
|
||||
"class": lsp.CompletionItemKind.Class,
|
||||
"function": lsp.CompletionItemKind.Function,
|
||||
"method": lsp.CompletionItemKind.Method,
|
||||
"variable": lsp.CompletionItemKind.Variable,
|
||||
"constant": lsp.CompletionItemKind.Constant,
|
||||
"property": lsp.CompletionItemKind.Property,
|
||||
"field": lsp.CompletionItemKind.Field,
|
||||
"interface": lsp.CompletionItemKind.Interface,
|
||||
"module": lsp.CompletionItemKind.Module,
|
||||
"enum": lsp.CompletionItemKind.Enum,
|
||||
"enum_member": lsp.CompletionItemKind.EnumMember,
|
||||
"struct": lsp.CompletionItemKind.Struct,
|
||||
"type": lsp.CompletionItemKind.TypeParameter,
|
||||
"type_alias": lsp.CompletionItemKind.TypeParameter,
|
||||
}
|
||||
|
||||
|
||||
def _path_to_uri(path: Union[str, Path]) -> str:
|
||||
"""Convert a file path to a URI.
|
||||
|
||||
Args:
|
||||
path: File path (string or Path object)
|
||||
|
||||
Returns:
|
||||
File URI string
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
# Handle Windows paths
|
||||
if path_str.startswith("/"):
|
||||
return f"file://{quote(path_str)}"
|
||||
else:
|
||||
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
|
||||
|
||||
|
||||
def _uri_to_path(uri: str) -> Path:
|
||||
"""Convert a URI to a file path.
|
||||
|
||||
Args:
|
||||
uri: File URI string
|
||||
|
||||
Returns:
|
||||
Path object
|
||||
"""
|
||||
path = uri.replace("file:///", "").replace("file://", "")
|
||||
return Path(unquote(path))
|
||||
|
||||
|
||||
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
|
||||
"""Extract the word at the given position in the document.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Word at position, or None if no word found
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return None
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
return None
|
||||
|
||||
# Find word boundaries
|
||||
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
|
||||
for match in word_pattern.finditer(line_text):
|
||||
if match.start() <= character <= match.end():
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
|
||||
"""Extract the incomplete word prefix at the given position.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Prefix string (may be empty)
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return ""
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
character = len(line_text)
|
||||
|
||||
# Extract text before cursor
|
||||
before_cursor = line_text[:character]
|
||||
|
||||
# Find the start of the current word
|
||||
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
|
||||
"""Convert a codex-lens Symbol to an LSP Location.
|
||||
|
||||
Args:
|
||||
symbol: codex-lens Symbol object
|
||||
|
||||
Returns:
|
||||
LSP Location, or None if symbol has no file
|
||||
"""
|
||||
if not symbol.file:
|
||||
return None
|
||||
|
||||
# LSP uses 0-based lines, codex-lens uses 1-based
|
||||
start_line = max(0, symbol.range[0] - 1)
|
||||
end_line = max(0, symbol.range[1] - 1)
|
||||
|
||||
return lsp.Location(
|
||||
uri=_path_to_uri(symbol.file),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(line=start_line, character=0),
|
||||
end=lsp.Position(line=end_line, character=0),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
|
||||
"""Map codex-lens symbol kind to LSP SymbolKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP SymbolKind
|
||||
"""
|
||||
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
|
||||
|
||||
|
||||
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
|
||||
"""Map codex-lens symbol kind to LSP CompletionItemKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP CompletionItemKind
|
||||
"""
|
||||
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LSP Request Handlers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
|
||||
def lsp_definition(
|
||||
params: lsp.DefinitionParams,
|
||||
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
|
||||
"""Handle textDocument/definition request.
|
||||
|
||||
Finds the definition of the symbol at the cursor position.
|
||||
"""
|
||||
if not server.global_index:
|
||||
logger.debug("No global index available for definition lookup")
|
||||
return None
|
||||
|
||||
# Get document
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
# Get word at position
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
logger.debug("No word found at position")
|
||||
return None
|
||||
|
||||
logger.debug("Looking up definition for: %s", word)
|
||||
|
||||
# Search for exact symbol match
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=False, # Exact match preferred
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
if not exact_matches:
|
||||
# Fall back to prefix search
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=True,
|
||||
)
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
if not exact_matches:
|
||||
logger.debug("No definition found for: %s", word)
|
||||
return None
|
||||
|
||||
# Convert to LSP locations
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
if len(locations) == 1:
|
||||
return locations[0]
|
||||
elif locations:
|
||||
return locations
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error looking up definition: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
|
||||
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
|
||||
"""Handle textDocument/references request.
|
||||
|
||||
Finds all references to the symbol at the cursor position using
|
||||
the code_relationships table for accurate call-site tracking.
|
||||
Falls back to same-name symbol search if search_engine is unavailable.
|
||||
"""
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Finding references for: %s", word)
|
||||
|
||||
try:
|
||||
# Try using search_engine.search_references() for accurate reference tracking
|
||||
if server.search_engine and server.workspace_root:
|
||||
references = server.search_engine.search_references(
|
||||
symbol_name=word,
|
||||
source_path=server.workspace_root,
|
||||
limit=200,
|
||||
)
|
||||
|
||||
if references:
|
||||
locations = []
|
||||
for ref in references:
|
||||
locations.append(
|
||||
lsp.Location(
|
||||
uri=_path_to_uri(ref.file_path),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column,
|
||||
),
|
||||
end=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column + len(word),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return locations if locations else None
|
||||
|
||||
# Fallback: search for symbols with same name using global_index
|
||||
if server.global_index:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=100,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact matches
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
return locations if locations else None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error finding references: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
|
||||
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
|
||||
"""Handle textDocument/completion request.
|
||||
|
||||
Provides code completion suggestions based on indexed symbols.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
prefix = _get_prefix_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not prefix or len(prefix) < 2:
|
||||
# Require at least 2 characters for completion
|
||||
return None
|
||||
|
||||
logger.debug("Completing prefix: %s", prefix)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=prefix,
|
||||
limit=50,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
# Convert to completion items
|
||||
items = []
|
||||
seen_names = set()
|
||||
|
||||
for sym in symbols:
|
||||
if sym.name in seen_names:
|
||||
continue
|
||||
seen_names.add(sym.name)
|
||||
|
||||
items.append(
|
||||
lsp.CompletionItem(
|
||||
label=sym.name,
|
||||
kind=_symbol_kind_to_completion_kind(sym.kind),
|
||||
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
|
||||
sort_text=sym.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return lsp.CompletionList(
|
||||
is_incomplete=len(symbols) >= 50,
|
||||
items=items,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting completions: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
|
||||
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
|
||||
"""Handle textDocument/hover request.
|
||||
|
||||
Provides hover information for the symbol at the cursor position
|
||||
using HoverProvider for rich symbol information including
|
||||
signature, documentation, and location.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Hover for: %s", word)
|
||||
|
||||
try:
|
||||
# Use HoverProvider for rich symbol information
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
provider = HoverProvider(server.global_index, server.registry)
|
||||
info = provider.get_hover_info(word)
|
||||
|
||||
if not info:
|
||||
return None
|
||||
|
||||
# Format as markdown with signature and location
|
||||
content = provider.format_hover_markdown(info)
|
||||
|
||||
return lsp.Hover(
|
||||
contents=lsp.MarkupContent(
|
||||
kind=lsp.MarkupKind.Markdown,
|
||||
value=content,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting hover info: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.WORKSPACE_SYMBOL)
|
||||
def lsp_workspace_symbol(
|
||||
params: lsp.WorkspaceSymbolParams,
|
||||
) -> Optional[List[lsp.SymbolInformation]]:
|
||||
"""Handle workspace/symbol request.
|
||||
|
||||
Searches for symbols across the workspace.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
query = params.query
|
||||
if not query or len(query) < 2:
|
||||
return None
|
||||
|
||||
logger.debug("Workspace symbol search: %s", query)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=query,
|
||||
limit=100,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
result = []
|
||||
for sym in symbols:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
result.append(
|
||||
lsp.SymbolInformation(
|
||||
name=sym.name,
|
||||
kind=_symbol_kind_to_lsp(sym.kind),
|
||||
location=loc,
|
||||
container_name=Path(sym.file).parent.name if sym.file else None,
|
||||
)
|
||||
)
|
||||
|
||||
return result if result else None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error searching workspace symbols: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
|
||||
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didSave notification.
|
||||
|
||||
Triggers incremental re-indexing of the saved file.
|
||||
Note: Full incremental indexing requires WatcherManager integration,
|
||||
which is planned for Phase 2.
|
||||
"""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.info("File saved: %s", file_path)
|
||||
|
||||
# Phase 1: Just log the save event
|
||||
# Phase 2 will integrate with WatcherManager for incremental indexing
|
||||
# if server.watcher_manager:
|
||||
# server.watcher_manager.trigger_reindex(file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
|
||||
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didOpen notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File opened: %s", file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
|
||||
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didClose notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File closed: %s", file_path)
|
||||
834
codex-lens/build/lib/codexlens/lsp/lsp_bridge.py
Normal file
834
codex-lens/build/lib/codexlens/lsp/lsp_bridge.py
Normal file
@@ -0,0 +1,834 @@
|
||||
"""LspBridge service for real-time LSP communication with caching.
|
||||
|
||||
This module provides a bridge to communicate with language servers either via:
|
||||
1. Standalone LSP Manager (direct subprocess communication - default)
|
||||
2. VSCode Bridge extension (HTTP-based, legacy mode)
|
||||
|
||||
Features:
|
||||
- Direct communication with language servers (no VSCode dependency)
|
||||
- Cache with TTL and file modification time invalidation
|
||||
- Graceful error handling with empty results on failure
|
||||
- Support for definition, references, hover, and call hierarchy
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
|
||||
# Check for optional dependencies
|
||||
try:
|
||||
import aiohttp
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Location:
|
||||
"""A location in a source file (LSP response format)."""
|
||||
|
||||
file_path: str
|
||||
line: int
|
||||
character: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"line": self.line,
|
||||
"character": self.character,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_lsp_response(cls, data: Dict[str, Any]) -> "Location":
|
||||
"""Create Location from LSP response format.
|
||||
|
||||
Handles both direct format and VSCode URI format.
|
||||
"""
|
||||
# Handle VSCode URI format (file:///path/to/file)
|
||||
uri = data.get("uri", data.get("file_path", ""))
|
||||
if uri.startswith("file:///"):
|
||||
# Windows: file:///C:/path -> C:/path
|
||||
# Unix: file:///path -> /path
|
||||
file_path = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||
elif uri.startswith("file://"):
|
||||
file_path = uri[7:]
|
||||
else:
|
||||
file_path = uri
|
||||
|
||||
# Get position from range or direct fields
|
||||
if "range" in data:
|
||||
range_data = data["range"]
|
||||
start = range_data.get("start", {})
|
||||
line = start.get("line", 0) + 1 # LSP is 0-based, convert to 1-based
|
||||
character = start.get("character", 0) + 1
|
||||
else:
|
||||
line = data.get("line", 1)
|
||||
character = data.get("character", 1)
|
||||
|
||||
return cls(file_path=file_path, line=line, character=character)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached LSP response with expiration metadata.
|
||||
|
||||
Attributes:
|
||||
data: The cached response data
|
||||
file_mtime: File modification time when cached (for invalidation)
|
||||
cached_at: Unix timestamp when entry was cached
|
||||
"""
|
||||
|
||||
data: Any
|
||||
file_mtime: float
|
||||
cached_at: float
|
||||
|
||||
|
||||
class LspBridge:
|
||||
"""Bridge for real-time LSP communication with language servers.
|
||||
|
||||
By default, uses StandaloneLspManager to directly spawn and communicate
|
||||
with language servers via JSON-RPC over stdio. No VSCode dependency required.
|
||||
|
||||
For legacy mode, can use VSCode Bridge HTTP server (set use_vscode_bridge=True).
|
||||
|
||||
Features:
|
||||
- Direct language server communication (default)
|
||||
- Response caching with TTL and file modification invalidation
|
||||
- Timeout handling
|
||||
- Graceful error handling returning empty results
|
||||
|
||||
Example:
|
||||
# Default: standalone mode (no VSCode needed)
|
||||
async with LspBridge() as bridge:
|
||||
refs = await bridge.get_references(symbol)
|
||||
definition = await bridge.get_definition(symbol)
|
||||
|
||||
# Legacy: VSCode Bridge mode
|
||||
async with LspBridge(use_vscode_bridge=True) as bridge:
|
||||
refs = await bridge.get_references(symbol)
|
||||
"""
|
||||
|
||||
DEFAULT_BRIDGE_URL = "http://127.0.0.1:3457"
|
||||
DEFAULT_TIMEOUT = 30.0 # seconds (increased for standalone mode)
|
||||
DEFAULT_CACHE_TTL = 300 # 5 minutes
|
||||
DEFAULT_MAX_CACHE_SIZE = 1000 # Maximum cache entries
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bridge_url: str = DEFAULT_BRIDGE_URL,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
cache_ttl: int = DEFAULT_CACHE_TTL,
|
||||
max_cache_size: int = DEFAULT_MAX_CACHE_SIZE,
|
||||
use_vscode_bridge: bool = False,
|
||||
workspace_root: Optional[str] = None,
|
||||
config_file: Optional[str] = None,
|
||||
):
|
||||
"""Initialize LspBridge.
|
||||
|
||||
Args:
|
||||
bridge_url: URL of the VSCode Bridge HTTP server (legacy mode only)
|
||||
timeout: Request timeout in seconds
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
max_cache_size: Maximum number of cache entries (LRU eviction)
|
||||
use_vscode_bridge: If True, use VSCode Bridge HTTP mode (requires aiohttp)
|
||||
workspace_root: Root directory for standalone LSP manager
|
||||
config_file: Path to lsp-servers.json configuration file
|
||||
"""
|
||||
self.bridge_url = bridge_url
|
||||
self.timeout = timeout
|
||||
self.cache_ttl = cache_ttl
|
||||
self.max_cache_size = max_cache_size
|
||||
self.use_vscode_bridge = use_vscode_bridge
|
||||
self.workspace_root = workspace_root
|
||||
self.config_file = config_file
|
||||
|
||||
self.cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
|
||||
# VSCode Bridge mode (legacy)
|
||||
self._session: Optional["aiohttp.ClientSession"] = None
|
||||
|
||||
# Standalone mode (default)
|
||||
self._manager: Optional["StandaloneLspManager"] = None
|
||||
self._manager_started = False
|
||||
|
||||
# Validate dependencies
|
||||
if use_vscode_bridge and not HAS_AIOHTTP:
|
||||
raise ImportError(
|
||||
"aiohttp is required for VSCode Bridge mode: pip install aiohttp"
|
||||
)
|
||||
|
||||
async def _ensure_manager(self) -> "StandaloneLspManager":
|
||||
"""Ensure standalone LSP manager is started."""
|
||||
if self._manager is None:
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
self._manager = StandaloneLspManager(
|
||||
workspace_root=self.workspace_root,
|
||||
config_file=self.config_file,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if not self._manager_started:
|
||||
await self._manager.start()
|
||||
self._manager_started = True
|
||||
|
||||
return self._manager
|
||||
|
||||
async def _get_session(self) -> "aiohttp.ClientSession":
|
||||
"""Get or create the aiohttp session (VSCode Bridge mode only)."""
|
||||
if not HAS_AIOHTTP:
|
||||
raise ImportError("aiohttp required for VSCode Bridge mode")
|
||||
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connections and cleanup resources."""
|
||||
# Close VSCode Bridge session
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
# Stop standalone manager
|
||||
if self._manager and self._manager_started:
|
||||
await self._manager.stop()
|
||||
self._manager_started = False
|
||||
|
||||
def _get_file_mtime(self, file_path: str) -> float:
|
||||
"""Get file modification time, or 0 if file doesn't exist."""
|
||||
try:
|
||||
return os.path.getmtime(file_path)
|
||||
except OSError:
|
||||
return 0.0
|
||||
|
||||
def _is_cached(self, cache_key: str, file_path: str) -> bool:
|
||||
"""Check if cache entry is valid.
|
||||
|
||||
Cache is invalid if:
|
||||
- Entry doesn't exist
|
||||
- TTL has expired
|
||||
- File has been modified since caching
|
||||
|
||||
Args:
|
||||
cache_key: The cache key to check
|
||||
file_path: Path to source file for mtime check
|
||||
|
||||
Returns:
|
||||
True if cache is valid and can be used
|
||||
"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
entry = self.cache[cache_key]
|
||||
now = time.time()
|
||||
|
||||
# Check TTL
|
||||
if now - entry.cached_at > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return False
|
||||
|
||||
# Check file modification time
|
||||
current_mtime = self._get_file_mtime(file_path)
|
||||
if current_mtime != entry.file_mtime:
|
||||
del self.cache[cache_key]
|
||||
return False
|
||||
|
||||
# Move to end on access (LRU behavior)
|
||||
self.cache.move_to_end(cache_key)
|
||||
return True
|
||||
|
||||
def _cache(self, key: str, file_path: str, data: Any) -> None:
|
||||
"""Store data in cache with LRU eviction.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
file_path: Path to source file (for mtime tracking)
|
||||
data: Data to cache
|
||||
"""
|
||||
# Remove oldest entries if at capacity
|
||||
while len(self.cache) >= self.max_cache_size:
|
||||
self.cache.popitem(last=False) # Remove oldest (FIFO order)
|
||||
|
||||
# Move to end if key exists (update access order)
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
self.cache[key] = CacheEntry(
|
||||
data=data,
|
||||
file_mtime=self._get_file_mtime(file_path),
|
||||
cached_at=time.time(),
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
self.cache.clear()
|
||||
|
||||
async def _request_vscode_bridge(self, action: str, params: Dict[str, Any]) -> Any:
|
||||
"""Make HTTP request to VSCode Bridge (legacy mode).
|
||||
|
||||
Args:
|
||||
action: The endpoint/action name (e.g., "get_definition")
|
||||
params: Request parameters
|
||||
|
||||
Returns:
|
||||
Response data on success, None on failure
|
||||
"""
|
||||
url = f"{self.bridge_url}/{action}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=params) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
if data.get("success") is False:
|
||||
return None
|
||||
|
||||
return data.get("result")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_references(self, symbol: CodeSymbolNode) -> List[Location]:
|
||||
"""Get all references to a symbol via real-time LSP.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to find references for
|
||||
|
||||
Returns:
|
||||
List of Location objects where the symbol is referenced.
|
||||
Returns empty list on error or timeout.
|
||||
"""
|
||||
cache_key = f"refs:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
locations: List[Location] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_references", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
# Don't cache on connection error (result is None)
|
||||
if result is None:
|
||||
return locations
|
||||
|
||||
if isinstance(result, list):
|
||||
for item in result:
|
||||
try:
|
||||
locations.append(Location.from_lsp_response(item))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_references(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
for item in result:
|
||||
try:
|
||||
locations.append(Location.from_lsp_response(item))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
|
||||
self._cache(cache_key, symbol.file_path, locations)
|
||||
return locations
|
||||
|
||||
async def get_definition(self, symbol: CodeSymbolNode) -> Optional[Location]:
|
||||
"""Get symbol definition location.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to find definition for
|
||||
|
||||
Returns:
|
||||
Location of the definition, or None if not found
|
||||
"""
|
||||
cache_key = f"def:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
location: Optional[Location] = None
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_definition", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result:
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
try:
|
||||
location = Location.from_lsp_response(result[0])
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
elif isinstance(result, dict):
|
||||
try:
|
||||
location = Location.from_lsp_response(result)
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_definition(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
if result:
|
||||
try:
|
||||
location = Location.from_lsp_response(result)
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
|
||||
self._cache(cache_key, symbol.file_path, location)
|
||||
return location
|
||||
|
||||
async def get_call_hierarchy(self, symbol: CodeSymbolNode) -> List[CallHierarchyItem]:
|
||||
"""Get incoming/outgoing calls for a symbol.
|
||||
|
||||
If call hierarchy is not supported by the language server,
|
||||
falls back to using references.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to get call hierarchy for
|
||||
|
||||
Returns:
|
||||
List of CallHierarchyItem representing callers/callees.
|
||||
Returns empty list on error or if not supported.
|
||||
"""
|
||||
cache_key = f"calls:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
items: List[CallHierarchyItem] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_call_hierarchy", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result is None:
|
||||
# Fallback: use references
|
||||
refs = await self.get_references(symbol)
|
||||
for ref in refs:
|
||||
items.append(CallHierarchyItem(
|
||||
name=f"caller@{ref.line}",
|
||||
kind="reference",
|
||||
file_path=ref.file_path,
|
||||
range=Range(
|
||||
start_line=ref.line,
|
||||
start_character=ref.character,
|
||||
end_line=ref.line,
|
||||
end_character=ref.character,
|
||||
),
|
||||
detail="Inferred from reference",
|
||||
))
|
||||
elif isinstance(result, list):
|
||||
for item in result:
|
||||
try:
|
||||
range_data = item.get("range", {})
|
||||
start = range_data.get("start", {})
|
||||
end = range_data.get("end", {})
|
||||
|
||||
items.append(CallHierarchyItem(
|
||||
name=item.get("name", "unknown"),
|
||||
kind=item.get("kind", "unknown"),
|
||||
file_path=item.get("file_path", item.get("uri", "")),
|
||||
range=Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
),
|
||||
detail=item.get("detail"),
|
||||
))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
|
||||
# Try to get call hierarchy items
|
||||
hierarchy_items = await manager.get_call_hierarchy_items(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
if hierarchy_items:
|
||||
# Get incoming calls for each item
|
||||
for h_item in hierarchy_items:
|
||||
incoming = await manager.get_incoming_calls(h_item)
|
||||
for call in incoming:
|
||||
from_item = call.get("from", {})
|
||||
range_data = from_item.get("range", {})
|
||||
start = range_data.get("start", {})
|
||||
end = range_data.get("end", {})
|
||||
|
||||
# Parse URI
|
||||
uri = from_item.get("uri", "")
|
||||
if uri.startswith("file:///"):
|
||||
fp = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||
elif uri.startswith("file://"):
|
||||
fp = uri[7:]
|
||||
else:
|
||||
fp = uri
|
||||
|
||||
items.append(CallHierarchyItem(
|
||||
name=from_item.get("name", "unknown"),
|
||||
kind=str(from_item.get("kind", "unknown")),
|
||||
file_path=fp,
|
||||
range=Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
),
|
||||
detail=from_item.get("detail"),
|
||||
))
|
||||
else:
|
||||
# Fallback: use references
|
||||
refs = await self.get_references(symbol)
|
||||
for ref in refs:
|
||||
items.append(CallHierarchyItem(
|
||||
name=f"caller@{ref.line}",
|
||||
kind="reference",
|
||||
file_path=ref.file_path,
|
||||
range=Range(
|
||||
start_line=ref.line,
|
||||
start_character=ref.character,
|
||||
end_line=ref.line,
|
||||
end_character=ref.character,
|
||||
),
|
||||
detail="Inferred from reference",
|
||||
))
|
||||
|
||||
self._cache(cache_key, symbol.file_path, items)
|
||||
return items
|
||||
|
||||
async def get_document_symbols(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
"""Get all symbols in a document (batch operation).
|
||||
|
||||
This is more efficient than individual hover queries when processing
|
||||
multiple locations in the same file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file
|
||||
|
||||
Returns:
|
||||
List of symbol dictionaries with name, kind, range, etc.
|
||||
Returns empty list on error or timeout.
|
||||
"""
|
||||
cache_key = f"symbols:{file_path}"
|
||||
|
||||
if self._is_cached(cache_key, file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
symbols: List[Dict[str, Any]] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_document_symbols", {
|
||||
"file_path": file_path,
|
||||
})
|
||||
|
||||
if isinstance(result, list):
|
||||
symbols = self._flatten_document_symbols(result)
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_document_symbols(file_path)
|
||||
|
||||
if result:
|
||||
symbols = self._flatten_document_symbols(result)
|
||||
|
||||
self._cache(cache_key, file_path, symbols)
|
||||
return symbols
|
||||
|
||||
def _flatten_document_symbols(
|
||||
self, symbols: List[Dict[str, Any]], parent_name: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Flatten nested document symbols into a flat list.
|
||||
|
||||
Document symbols can be nested (e.g., methods inside classes).
|
||||
This flattens them for easier lookup by line number.
|
||||
|
||||
Args:
|
||||
symbols: List of symbol dictionaries (may be nested)
|
||||
parent_name: Name of parent symbol for qualification
|
||||
|
||||
Returns:
|
||||
Flat list of all symbols with their ranges
|
||||
"""
|
||||
flat: List[Dict[str, Any]] = []
|
||||
|
||||
for sym in symbols:
|
||||
# Add the symbol itself
|
||||
symbol_entry = {
|
||||
"name": sym.get("name", "unknown"),
|
||||
"kind": self._symbol_kind_to_string(sym.get("kind", 0)),
|
||||
"range": sym.get("range", sym.get("location", {}).get("range", {})),
|
||||
"selection_range": sym.get("selectionRange", {}),
|
||||
"detail": sym.get("detail", ""),
|
||||
"parent": parent_name,
|
||||
}
|
||||
flat.append(symbol_entry)
|
||||
|
||||
# Recursively process children
|
||||
children = sym.get("children", [])
|
||||
if children:
|
||||
qualified_name = sym.get("name", "")
|
||||
if parent_name:
|
||||
qualified_name = f"{parent_name}.{qualified_name}"
|
||||
flat.extend(self._flatten_document_symbols(children, qualified_name))
|
||||
|
||||
return flat
|
||||
|
||||
def _symbol_kind_to_string(self, kind: int) -> str:
|
||||
"""Convert LSP SymbolKind integer to string.
|
||||
|
||||
Args:
|
||||
kind: LSP SymbolKind enum value
|
||||
|
||||
Returns:
|
||||
Human-readable string representation
|
||||
"""
|
||||
# LSP SymbolKind enum (1-indexed)
|
||||
kinds = {
|
||||
1: "file",
|
||||
2: "module",
|
||||
3: "namespace",
|
||||
4: "package",
|
||||
5: "class",
|
||||
6: "method",
|
||||
7: "property",
|
||||
8: "field",
|
||||
9: "constructor",
|
||||
10: "enum",
|
||||
11: "interface",
|
||||
12: "function",
|
||||
13: "variable",
|
||||
14: "constant",
|
||||
15: "string",
|
||||
16: "number",
|
||||
17: "boolean",
|
||||
18: "array",
|
||||
19: "object",
|
||||
20: "key",
|
||||
21: "null",
|
||||
22: "enum_member",
|
||||
23: "struct",
|
||||
24: "event",
|
||||
25: "operator",
|
||||
26: "type_parameter",
|
||||
}
|
||||
return kinds.get(kind, "unknown")
|
||||
|
||||
async def get_hover(self, symbol: CodeSymbolNode) -> Optional[str]:
|
||||
"""Get hover documentation for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to get hover info for
|
||||
|
||||
Returns:
|
||||
Hover documentation as string, or None if not available
|
||||
"""
|
||||
cache_key = f"hover:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
hover_text: Optional[str] = None
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_hover", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result:
|
||||
hover_text = self._parse_hover_result(result)
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
hover_text = await manager.get_hover(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
self._cache(cache_key, symbol.file_path, hover_text)
|
||||
return hover_text
|
||||
|
||||
def _parse_hover_result(self, result: Any) -> Optional[str]:
|
||||
"""Parse hover result into string."""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
elif isinstance(result, list):
|
||||
parts = []
|
||||
for item in result:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
value = item.get("value", item.get("contents", ""))
|
||||
if value:
|
||||
parts.append(str(value))
|
||||
return "\n\n".join(parts) if parts else None
|
||||
elif isinstance(result, dict):
|
||||
contents = result.get("contents", result.get("value", ""))
|
||||
if isinstance(contents, str):
|
||||
return contents
|
||||
elif isinstance(contents, list):
|
||||
parts = []
|
||||
for c in contents:
|
||||
if isinstance(c, str):
|
||||
parts.append(c)
|
||||
elif isinstance(c, dict):
|
||||
parts.append(str(c.get("value", "")))
|
||||
return "\n\n".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> "LspBridge":
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit - close connections."""
|
||||
await self.close()
|
||||
|
||||
|
||||
# Simple test
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
async def test_lsp_bridge():
|
||||
"""Simple test of LspBridge functionality."""
|
||||
print("Testing LspBridge (Standalone Mode)...")
|
||||
print(f"Timeout: {LspBridge.DEFAULT_TIMEOUT}s")
|
||||
print(f"Cache TTL: {LspBridge.DEFAULT_CACHE_TTL}s")
|
||||
print()
|
||||
|
||||
# Create a test symbol pointing to this file
|
||||
test_file = os.path.abspath(__file__)
|
||||
test_symbol = CodeSymbolNode(
|
||||
id=f"{test_file}:LspBridge:96",
|
||||
name="LspBridge",
|
||||
kind="class",
|
||||
file_path=test_file,
|
||||
range=Range(
|
||||
start_line=96,
|
||||
start_character=1,
|
||||
end_line=200,
|
||||
end_character=1,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"Test symbol: {test_symbol.name} in {os.path.basename(test_symbol.file_path)}")
|
||||
print()
|
||||
|
||||
# Use standalone mode (default)
|
||||
async with LspBridge(
|
||||
workspace_root=str(Path(__file__).parent.parent.parent.parent),
|
||||
) as bridge:
|
||||
print("1. Testing get_document_symbols...")
|
||||
try:
|
||||
symbols = await bridge.get_document_symbols(test_file)
|
||||
print(f" Found {len(symbols)} symbols")
|
||||
for sym in symbols[:5]:
|
||||
print(f" - {sym.get('name')} ({sym.get('kind')})")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("2. Testing get_definition...")
|
||||
try:
|
||||
definition = await bridge.get_definition(test_symbol)
|
||||
if definition:
|
||||
print(f" Definition: {os.path.basename(definition.file_path)}:{definition.line}")
|
||||
else:
|
||||
print(" No definition found")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("3. Testing get_references...")
|
||||
try:
|
||||
refs = await bridge.get_references(test_symbol)
|
||||
print(f" Found {len(refs)} references")
|
||||
for ref in refs[:3]:
|
||||
print(f" - {os.path.basename(ref.file_path)}:{ref.line}")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("4. Testing get_hover...")
|
||||
try:
|
||||
hover = await bridge.get_hover(test_symbol)
|
||||
if hover:
|
||||
print(f" Hover: {hover[:100]}...")
|
||||
else:
|
||||
print(" No hover info found")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("5. Testing get_call_hierarchy...")
|
||||
try:
|
||||
calls = await bridge.get_call_hierarchy(test_symbol)
|
||||
print(f" Found {len(calls)} call hierarchy items")
|
||||
for call in calls[:3]:
|
||||
print(f" - {call.name} in {os.path.basename(call.file_path)}")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("6. Testing cache...")
|
||||
print(f" Cache entries: {len(bridge.cache)}")
|
||||
for key in list(bridge.cache.keys())[:5]:
|
||||
print(f" - {key}")
|
||||
|
||||
print()
|
||||
print("Test complete!")
|
||||
|
||||
# Run the test
|
||||
# Note: On Windows, use default ProactorEventLoop (supports subprocess creation)
|
||||
|
||||
asyncio.run(test_lsp_bridge())
|
||||
375
codex-lens/build/lib/codexlens/lsp/lsp_graph_builder.py
Normal file
375
codex-lens/build/lib/codexlens/lsp/lsp_graph_builder.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Graph builder for code association graphs via LSP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeAssociationGraph,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
)
|
||||
from codexlens.lsp.lsp_bridge import (
|
||||
Location,
|
||||
LspBridge,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LspGraphBuilder:
|
||||
"""Builds code association graph by expanding from seed symbols using LSP."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: int = 2,
|
||||
max_nodes: int = 100,
|
||||
max_concurrent: int = 10,
|
||||
):
|
||||
"""Initialize GraphBuilder.
|
||||
|
||||
Args:
|
||||
max_depth: Maximum depth for BFS expansion from seeds.
|
||||
max_nodes: Maximum number of nodes in the graph.
|
||||
max_concurrent: Maximum concurrent LSP requests.
|
||||
"""
|
||||
self.max_depth = max_depth
|
||||
self.max_nodes = max_nodes
|
||||
self.max_concurrent = max_concurrent
|
||||
# Cache for document symbols per file (avoids per-location hover queries)
|
||||
self._document_symbols_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
async def build_from_seeds(
|
||||
self,
|
||||
seeds: List[CodeSymbolNode],
|
||||
lsp_bridge: LspBridge,
|
||||
) -> CodeAssociationGraph:
|
||||
"""Build association graph by BFS expansion from seeds.
|
||||
|
||||
For each seed:
|
||||
1. Get references via LSP
|
||||
2. Get call hierarchy via LSP
|
||||
3. Add nodes and edges to graph
|
||||
4. Continue expanding until max_depth or max_nodes reached
|
||||
|
||||
Args:
|
||||
seeds: Initial seed symbols to expand from.
|
||||
lsp_bridge: LSP bridge for querying language servers.
|
||||
|
||||
Returns:
|
||||
CodeAssociationGraph with expanded nodes and relationships.
|
||||
"""
|
||||
graph = CodeAssociationGraph()
|
||||
visited: Set[str] = set()
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
# Initialize queue with seeds at depth 0
|
||||
queue: List[Tuple[CodeSymbolNode, int]] = [(s, 0) for s in seeds]
|
||||
|
||||
# Add seed nodes to graph
|
||||
for seed in seeds:
|
||||
graph.add_node(seed)
|
||||
|
||||
# BFS expansion
|
||||
while queue and len(graph.nodes) < self.max_nodes:
|
||||
# Take a batch of nodes from queue
|
||||
batch_size = min(self.max_concurrent, len(queue))
|
||||
batch = queue[:batch_size]
|
||||
queue = queue[batch_size:]
|
||||
|
||||
# Expand nodes in parallel
|
||||
tasks = [
|
||||
self._expand_node(
|
||||
node, depth, graph, lsp_bridge, visited, semaphore
|
||||
)
|
||||
for node, depth in batch
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results and add new nodes to queue
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning("Error expanding node: %s", result)
|
||||
continue
|
||||
if result:
|
||||
# Add new nodes to queue if not at max depth
|
||||
for new_node, new_depth in result:
|
||||
if (
|
||||
new_depth <= self.max_depth
|
||||
and len(graph.nodes) < self.max_nodes
|
||||
):
|
||||
queue.append((new_node, new_depth))
|
||||
|
||||
return graph
|
||||
|
||||
async def _expand_node(
|
||||
self,
|
||||
node: CodeSymbolNode,
|
||||
depth: int,
|
||||
graph: CodeAssociationGraph,
|
||||
lsp_bridge: LspBridge,
|
||||
visited: Set[str],
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> List[Tuple[CodeSymbolNode, int]]:
|
||||
"""Expand a single node, return new nodes to process.
|
||||
|
||||
Args:
|
||||
node: Node to expand.
|
||||
depth: Current depth in BFS.
|
||||
graph: Graph to add nodes and edges to.
|
||||
lsp_bridge: LSP bridge for queries.
|
||||
visited: Set of visited node IDs.
|
||||
semaphore: Semaphore for concurrency control.
|
||||
|
||||
Returns:
|
||||
List of (new_node, new_depth) tuples to add to queue.
|
||||
"""
|
||||
# Skip if already visited or at max depth
|
||||
if node.id in visited:
|
||||
return []
|
||||
if depth > self.max_depth:
|
||||
return []
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
return []
|
||||
|
||||
visited.add(node.id)
|
||||
new_nodes: List[Tuple[CodeSymbolNode, int]] = []
|
||||
|
||||
async with semaphore:
|
||||
# Get relationships in parallel
|
||||
try:
|
||||
refs_task = lsp_bridge.get_references(node)
|
||||
calls_task = lsp_bridge.get_call_hierarchy(node)
|
||||
|
||||
refs, calls = await asyncio.gather(
|
||||
refs_task, calls_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle reference results
|
||||
if isinstance(refs, Exception):
|
||||
logger.debug(
|
||||
"Failed to get references for %s: %s", node.id, refs
|
||||
)
|
||||
refs = []
|
||||
|
||||
# Handle call hierarchy results
|
||||
if isinstance(calls, Exception):
|
||||
logger.debug(
|
||||
"Failed to get call hierarchy for %s: %s",
|
||||
node.id,
|
||||
calls,
|
||||
)
|
||||
calls = []
|
||||
|
||||
# Process references
|
||||
for ref in refs:
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
break
|
||||
|
||||
ref_node = await self._location_to_node(ref, lsp_bridge)
|
||||
if ref_node and ref_node.id != node.id:
|
||||
if ref_node.id not in graph.nodes:
|
||||
graph.add_node(ref_node)
|
||||
new_nodes.append((ref_node, depth + 1))
|
||||
# Use add_edge since both nodes should exist now
|
||||
graph.add_edge(node.id, ref_node.id, "references")
|
||||
|
||||
# Process call hierarchy (incoming calls)
|
||||
for call in calls:
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
break
|
||||
|
||||
call_node = await self._call_hierarchy_to_node(
|
||||
call, lsp_bridge
|
||||
)
|
||||
if call_node and call_node.id != node.id:
|
||||
if call_node.id not in graph.nodes:
|
||||
graph.add_node(call_node)
|
||||
new_nodes.append((call_node, depth + 1))
|
||||
# Incoming call: call_node calls node
|
||||
graph.add_edge(call_node.id, node.id, "calls")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error during node expansion for %s: %s", node.id, e
|
||||
)
|
||||
|
||||
return new_nodes
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the document symbols cache.
|
||||
|
||||
Call this between searches to free memory and ensure fresh data.
|
||||
"""
|
||||
self._document_symbols_cache.clear()
|
||||
|
||||
async def _get_symbol_at_location(
|
||||
self,
|
||||
file_path: str,
|
||||
line: int,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Find symbol at location using cached document symbols.
|
||||
|
||||
This is much more efficient than individual hover queries because
|
||||
document symbols are fetched once per file and cached.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file.
|
||||
line: Line number (1-based).
|
||||
lsp_bridge: LSP bridge for fetching document symbols.
|
||||
|
||||
Returns:
|
||||
Symbol dictionary with name, kind, range, etc., or None if not found.
|
||||
"""
|
||||
# Get or fetch document symbols for this file
|
||||
if file_path not in self._document_symbols_cache:
|
||||
symbols = await lsp_bridge.get_document_symbols(file_path)
|
||||
self._document_symbols_cache[file_path] = symbols
|
||||
|
||||
symbols = self._document_symbols_cache[file_path]
|
||||
|
||||
# Find symbol containing this line (best match = smallest range)
|
||||
best_match: Optional[Dict[str, Any]] = None
|
||||
best_range_size = float("inf")
|
||||
|
||||
for symbol in symbols:
|
||||
sym_range = symbol.get("range", {})
|
||||
start = sym_range.get("start", {})
|
||||
end = sym_range.get("end", {})
|
||||
|
||||
# LSP ranges are 0-based, our line is 1-based
|
||||
start_line = start.get("line", 0) + 1
|
||||
end_line = end.get("line", 0) + 1
|
||||
|
||||
if start_line <= line <= end_line:
|
||||
range_size = end_line - start_line
|
||||
if range_size < best_range_size:
|
||||
best_match = symbol
|
||||
best_range_size = range_size
|
||||
|
||||
return best_match
|
||||
|
||||
async def _location_to_node(
|
||||
self,
|
||||
location: Location,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[CodeSymbolNode]:
|
||||
"""Convert LSP location to CodeSymbolNode.
|
||||
|
||||
Uses cached document symbols instead of individual hover queries
|
||||
for better performance.
|
||||
|
||||
Args:
|
||||
location: LSP location to convert.
|
||||
lsp_bridge: LSP bridge for additional queries.
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode or None if conversion fails.
|
||||
"""
|
||||
try:
|
||||
file_path = location.file_path
|
||||
start_line = location.line
|
||||
|
||||
# Try to find symbol info from cached document symbols (fast)
|
||||
symbol_info = await self._get_symbol_at_location(
|
||||
file_path, start_line, lsp_bridge
|
||||
)
|
||||
|
||||
if symbol_info:
|
||||
name = symbol_info.get("name", f"symbol_L{start_line}")
|
||||
kind = symbol_info.get("kind", "unknown")
|
||||
|
||||
# Extract range from symbol if available
|
||||
sym_range = symbol_info.get("range", {})
|
||||
start = sym_range.get("start", {})
|
||||
end = sym_range.get("end", {})
|
||||
|
||||
location_range = Range(
|
||||
start_line=start.get("line", start_line - 1) + 1,
|
||||
start_character=start.get("character", location.character - 1) + 1,
|
||||
end_line=end.get("line", start_line - 1) + 1,
|
||||
end_character=end.get("character", location.character - 1) + 1,
|
||||
)
|
||||
else:
|
||||
# Fallback to basic node without symbol info
|
||||
name = f"symbol_L{start_line}"
|
||||
kind = "unknown"
|
||||
location_range = Range(
|
||||
start_line=location.line,
|
||||
start_character=location.character,
|
||||
end_line=location.line,
|
||||
end_character=location.character,
|
||||
)
|
||||
|
||||
node_id = self._create_node_id(file_path, name, start_line)
|
||||
|
||||
return CodeSymbolNode(
|
||||
id=node_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=location_range,
|
||||
docstring="", # Skip hover for performance
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to convert location to node: %s", e)
|
||||
return None
|
||||
|
||||
async def _call_hierarchy_to_node(
|
||||
self,
|
||||
call_item: CallHierarchyItem,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[CodeSymbolNode]:
|
||||
"""Convert CallHierarchyItem to CodeSymbolNode.
|
||||
|
||||
Args:
|
||||
call_item: Call hierarchy item to convert.
|
||||
lsp_bridge: LSP bridge (unused, kept for API consistency).
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode or None if conversion fails.
|
||||
"""
|
||||
try:
|
||||
file_path = call_item.file_path
|
||||
name = call_item.name
|
||||
start_line = call_item.range.start_line
|
||||
# CallHierarchyItem.kind is already a string
|
||||
kind = call_item.kind
|
||||
|
||||
node_id = self._create_node_id(file_path, name, start_line)
|
||||
|
||||
return CodeSymbolNode(
|
||||
id=node_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=call_item.range,
|
||||
docstring=call_item.detail or "",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to convert call hierarchy item to node: %s", e
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_node_id(
|
||||
self, file_path: str, name: str, line: int
|
||||
) -> str:
|
||||
"""Create unique node ID.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
name: Symbol name.
|
||||
line: Line number (0-based).
|
||||
|
||||
Returns:
|
||||
Unique node ID string.
|
||||
"""
|
||||
return f"{file_path}:{name}:{line}"
|
||||
177
codex-lens/build/lib/codexlens/lsp/providers.py
Normal file
177
codex-lens/build/lib/codexlens/lsp/providers.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""LSP feature providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol."""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: tuple # (start_line, end_line)
|
||||
|
||||
|
||||
class HoverProvider:
|
||||
"""Provides hover information for symbols."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
registry: Optional["RegistryStore"] = None,
|
||||
) -> None:
|
||||
"""Initialize hover provider.
|
||||
|
||||
Args:
|
||||
global_index: Global symbol index for lookups
|
||||
registry: Optional registry store for index path resolution
|
||||
"""
|
||||
self.global_index = global_index
|
||||
self.registry = registry
|
||||
|
||||
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
|
||||
"""Get hover information for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to look up
|
||||
|
||||
Returns:
|
||||
HoverInfo or None if symbol not found
|
||||
"""
|
||||
# Look up symbol in global index using exact match
|
||||
symbols = self.global_index.search(
|
||||
name=symbol_name,
|
||||
limit=1,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == symbol_name]
|
||||
|
||||
if not exact_matches:
|
||||
return None
|
||||
|
||||
symbol = exact_matches[0]
|
||||
|
||||
# Extract signature from source file
|
||||
signature = self._extract_signature(symbol)
|
||||
|
||||
# Symbol uses 'file' attribute and 'range' tuple
|
||||
file_path = symbol.file or ""
|
||||
start_line, end_line = symbol.range
|
||||
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=signature,
|
||||
documentation=None, # Symbol doesn't have docstring field
|
||||
file_path=file_path,
|
||||
line_range=(start_line, end_line),
|
||||
)
|
||||
|
||||
def _extract_signature(self, symbol) -> str:
|
||||
"""Extract function/class signature from source file.
|
||||
|
||||
Args:
|
||||
symbol: Symbol object with file and range information
|
||||
|
||||
Returns:
|
||||
Extracted signature string or fallback kind + name
|
||||
"""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Extract signature lines (first line of definition + continuation)
|
||||
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
|
||||
if start_line >= len(lines) or start_line < 0:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
signature_lines = []
|
||||
first_line = lines[start_line]
|
||||
signature_lines.append(first_line)
|
||||
|
||||
# Continue if multiline signature (no closing paren + colon yet)
|
||||
# Look for patterns like "def func(", "class Foo(", etc.
|
||||
i = start_line + 1
|
||||
max_lines = min(start_line + 5, len(lines))
|
||||
while i < max_lines:
|
||||
line = signature_lines[-1]
|
||||
# Stop if we see closing pattern
|
||||
if "):" in line or line.rstrip().endswith(":"):
|
||||
break
|
||||
signature_lines.append(lines[i])
|
||||
i += 1
|
||||
|
||||
return "\n".join(signature_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
def format_hover_markdown(self, info: HoverInfo) -> str:
|
||||
"""Format hover info as Markdown.
|
||||
|
||||
Args:
|
||||
info: HoverInfo object to format
|
||||
|
||||
Returns:
|
||||
Markdown-formatted hover content
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Detect language for code fence based on file extension
|
||||
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
|
||||
lang_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".jsx": "javascript",
|
||||
".java": "java",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".cs": "csharp",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
}
|
||||
lang = lang_map.get(ext, "")
|
||||
|
||||
# Code block with signature
|
||||
parts.append(f"```{lang}\n{info.signature}\n```")
|
||||
|
||||
# Documentation if available
|
||||
if info.documentation:
|
||||
parts.append(f"\n---\n\n{info.documentation}")
|
||||
|
||||
# Location info
|
||||
file_name = Path(info.file_path).name if info.file_path else "unknown"
|
||||
parts.append(
|
||||
f"\n---\n\n*{info.kind}* defined in "
|
||||
f"`{file_name}` "
|
||||
f"(line {info.line_range[0]})"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
263
codex-lens/build/lib/codexlens/lsp/server.py
Normal file
263
codex-lens/build/lib/codexlens/lsp/server.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""codex-lens LSP Server implementation using pygls.
|
||||
|
||||
This module provides the main Language Server class and entry point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
from pygls.lsp.server import LanguageServer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodexLensLanguageServer(LanguageServer):
|
||||
"""Language Server for codex-lens code indexing.
|
||||
|
||||
Provides IDE features using codex-lens symbol index:
|
||||
- Go to Definition
|
||||
- Find References
|
||||
- Code Completion
|
||||
- Hover Information
|
||||
- Workspace Symbol Search
|
||||
|
||||
Attributes:
|
||||
registry: Global project registry for path lookups
|
||||
mapper: Path mapper for source/index conversions
|
||||
global_index: Project-wide symbol index
|
||||
search_engine: Chain search engine for symbol search
|
||||
workspace_root: Current workspace root path
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="codexlens-lsp", version="0.1.0")
|
||||
|
||||
self.registry: Optional[RegistryStore] = None
|
||||
self.mapper: Optional[PathMapper] = None
|
||||
self.global_index: Optional[GlobalSymbolIndex] = None
|
||||
self.search_engine: Optional[ChainSearchEngine] = None
|
||||
self.workspace_root: Optional[Path] = None
|
||||
self._config: Optional[Config] = None
|
||||
|
||||
def initialize_components(self, workspace_root: Path) -> bool:
|
||||
"""Initialize codex-lens components for the workspace.
|
||||
|
||||
Args:
|
||||
workspace_root: Root path of the workspace
|
||||
|
||||
Returns:
|
||||
True if initialization succeeded, False otherwise
|
||||
"""
|
||||
self.workspace_root = workspace_root.resolve()
|
||||
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
|
||||
|
||||
try:
|
||||
# Initialize registry
|
||||
self.registry = RegistryStore()
|
||||
self.registry.initialize()
|
||||
|
||||
# Initialize path mapper
|
||||
self.mapper = PathMapper()
|
||||
|
||||
# Try to find project in registry
|
||||
project_info = self.registry.find_by_source_path(str(self.workspace_root))
|
||||
|
||||
if project_info:
|
||||
project_id = int(project_info["id"])
|
||||
index_root = Path(project_info["index_root"])
|
||||
|
||||
# Initialize global symbol index
|
||||
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||
self.global_index = GlobalSymbolIndex(global_db, project_id)
|
||||
self.global_index.initialize()
|
||||
|
||||
# Initialize search engine
|
||||
self._config = Config()
|
||||
self.search_engine = ChainSearchEngine(
|
||||
registry=self.registry,
|
||||
mapper=self.mapper,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Workspace not indexed by codex-lens: %s. "
|
||||
"Run 'codexlens index %s' to index first.",
|
||||
self.workspace_root,
|
||||
self.workspace_root,
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize codex-lens: %s", exc)
|
||||
return False
|
||||
|
||||
def shutdown_components(self) -> None:
|
||||
"""Clean up codex-lens components."""
|
||||
if self.global_index:
|
||||
try:
|
||||
self.global_index.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing global index: %s", exc)
|
||||
self.global_index = None
|
||||
|
||||
if self.search_engine:
|
||||
try:
|
||||
self.search_engine.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing search engine: %s", exc)
|
||||
self.search_engine = None
|
||||
|
||||
if self.registry:
|
||||
try:
|
||||
self.registry.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing registry: %s", exc)
|
||||
self.registry = None
|
||||
|
||||
|
||||
# Create server instance
|
||||
server = CodexLensLanguageServer()
|
||||
|
||||
|
||||
@server.feature(lsp.INITIALIZE)
|
||||
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
|
||||
"""Handle LSP initialize request."""
|
||||
logger.info("LSP initialize request received")
|
||||
|
||||
# Get workspace root
|
||||
workspace_root: Optional[Path] = None
|
||||
if params.root_uri:
|
||||
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
|
||||
elif params.root_path:
|
||||
workspace_root = Path(params.root_path)
|
||||
|
||||
if workspace_root:
|
||||
server.initialize_components(workspace_root)
|
||||
|
||||
# Declare server capabilities
|
||||
return lsp.InitializeResult(
|
||||
capabilities=lsp.ServerCapabilities(
|
||||
text_document_sync=lsp.TextDocumentSyncOptions(
|
||||
open_close=True,
|
||||
change=lsp.TextDocumentSyncKind.Incremental,
|
||||
save=lsp.SaveOptions(include_text=False),
|
||||
),
|
||||
definition_provider=True,
|
||||
references_provider=True,
|
||||
completion_provider=lsp.CompletionOptions(
|
||||
trigger_characters=[".", ":"],
|
||||
resolve_provider=False,
|
||||
),
|
||||
hover_provider=True,
|
||||
workspace_symbol_provider=True,
|
||||
),
|
||||
server_info=lsp.ServerInfo(
|
||||
name="codexlens-lsp",
|
||||
version="0.1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@server.feature(lsp.SHUTDOWN)
|
||||
def lsp_shutdown(params: None) -> None:
|
||||
"""Handle LSP shutdown request."""
|
||||
logger.info("LSP shutdown request received")
|
||||
server.shutdown_components()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Entry point for codexlens-lsp command.
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success)
|
||||
"""
|
||||
# Import handlers to register them with the server
|
||||
# This must be done before starting the server
|
||||
import codexlens.lsp.handlers # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="codex-lens Language Server",
|
||||
prog="codexlens-lsp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stdio",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use stdio for communication (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tcp",
|
||||
action="store_true",
|
||||
help="Use TCP for communication",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default="127.0.0.1",
|
||||
help="TCP host (default: 127.0.0.1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=2087,
|
||||
help="TCP port (default: 2087)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Log level (default: INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
help="Log file path (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
log_handlers = []
|
||||
if args.log_file:
|
||||
log_handlers.append(logging.FileHandler(args.log_file))
|
||||
else:
|
||||
log_handlers.append(logging.StreamHandler(sys.stderr))
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
logger.info("Starting codexlens-lsp server")
|
||||
|
||||
if args.tcp:
|
||||
logger.info("Starting TCP server on %s:%d", args.host, args.port)
|
||||
server.start_tcp(args.host, args.port)
|
||||
else:
|
||||
logger.info("Starting stdio server")
|
||||
server.start_io()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
1159
codex-lens/build/lib/codexlens/lsp/standalone_manager.py
Normal file
1159
codex-lens/build/lib/codexlens/lsp/standalone_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
20
codex-lens/build/lib/codexlens/mcp/__init__.py
Normal file
20
codex-lens/build/lib/codexlens/mcp/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Model Context Protocol implementation for Claude Code integration."""
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"MCPContext",
|
||||
"SymbolInfo",
|
||||
"ReferenceInfo",
|
||||
"RelatedSymbol",
|
||||
"MCPProvider",
|
||||
"HookManager",
|
||||
"create_context_for_prompt",
|
||||
]
|
||||
170
codex-lens/build/lib/codexlens/mcp/hooks.py
Normal file
170
codex-lens/build/lib/codexlens/mcp/hooks.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Hook interfaces for Claude Code integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import MCPContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Manages hook registration and execution."""
|
||||
|
||||
def __init__(self, mcp_provider: "MCPProvider") -> None:
|
||||
self.mcp_provider = mcp_provider
|
||||
self._pre_hooks: Dict[str, Callable] = {}
|
||||
self._post_hooks: Dict[str, Callable] = {}
|
||||
|
||||
# Register default hooks
|
||||
self._register_default_hooks()
|
||||
|
||||
def _register_default_hooks(self) -> None:
|
||||
"""Register built-in hooks."""
|
||||
self._pre_hooks["explain"] = self._pre_explain_hook
|
||||
self._pre_hooks["refactor"] = self._pre_refactor_hook
|
||||
self._pre_hooks["document"] = self._pre_document_hook
|
||||
|
||||
def execute_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> Optional[MCPContext]:
|
||||
"""Execute pre-tool hook to gather context.
|
||||
|
||||
Args:
|
||||
action: The action being performed (e.g., "explain", "refactor")
|
||||
params: Parameters for the action
|
||||
|
||||
Returns:
|
||||
MCPContext to inject into prompt, or None
|
||||
"""
|
||||
hook = self._pre_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
logger.debug(f"No pre-hook for action: {action}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return hook(params)
|
||||
except Exception as e:
|
||||
logger.error(f"Pre-hook failed for {action}: {e}")
|
||||
return None
|
||||
|
||||
def execute_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
result: Any,
|
||||
) -> None:
|
||||
"""Execute post-tool hook for proactive caching.
|
||||
|
||||
Args:
|
||||
action: The action that was performed
|
||||
result: Result of the action
|
||||
"""
|
||||
hook = self._post_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
return
|
||||
|
||||
try:
|
||||
hook(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Post-hook failed for {action}: {e}")
|
||||
|
||||
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'explain' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="symbol_explanation",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'refactor' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="refactor_context",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
max_references=20,
|
||||
)
|
||||
|
||||
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'document' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
file_path = params.get("file_path")
|
||||
|
||||
if symbol_name:
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="documentation_context",
|
||||
include_references=False,
|
||||
include_related=True,
|
||||
)
|
||||
elif file_path:
|
||||
return self.mcp_provider.build_context_for_file(
|
||||
Path(file_path),
|
||||
context_type="file_documentation",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def register_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
|
||||
) -> None:
|
||||
"""Register a custom pre-tool hook."""
|
||||
self._pre_hooks[action] = hook
|
||||
|
||||
def register_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Any], None],
|
||||
) -> None:
|
||||
"""Register a custom post-tool hook."""
|
||||
self._post_hooks[action] = hook
|
||||
|
||||
|
||||
def create_context_for_prompt(
|
||||
mcp_provider: "MCPProvider",
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Create context string for prompt injection.
|
||||
|
||||
This is the main entry point for Claude Code hook integration.
|
||||
|
||||
Args:
|
||||
mcp_provider: The MCP provider instance
|
||||
action: Action being performed
|
||||
params: Action parameters
|
||||
|
||||
Returns:
|
||||
Formatted context string for prompt injection
|
||||
"""
|
||||
manager = HookManager(mcp_provider)
|
||||
context = manager.execute_pre_hook(action, params)
|
||||
|
||||
if context:
|
||||
return context.to_prompt_injection()
|
||||
|
||||
return ""
|
||||
202
codex-lens/build/lib/codexlens/mcp/provider.py
Normal file
202
codex-lens/build/lib/codexlens/mcp/provider.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""MCP context provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPProvider:
|
||||
"""Builds MCP context objects from codex-lens data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
search_engine: "ChainSearchEngine",
|
||||
registry: "RegistryStore",
|
||||
) -> None:
|
||||
self.global_index = global_index
|
||||
self.search_engine = search_engine
|
||||
self.registry = registry
|
||||
|
||||
def build_context(
|
||||
self,
|
||||
symbol_name: str,
|
||||
context_type: str = "symbol_explanation",
|
||||
include_references: bool = True,
|
||||
include_related: bool = True,
|
||||
max_references: int = 10,
|
||||
) -> Optional[MCPContext]:
|
||||
"""Build comprehensive context for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to contextualize
|
||||
context_type: Type of context being requested
|
||||
include_references: Whether to include reference locations
|
||||
include_related: Whether to include related symbols
|
||||
max_references: Maximum number of references to include
|
||||
|
||||
Returns:
|
||||
MCPContext object or None if symbol not found
|
||||
"""
|
||||
# Look up symbol
|
||||
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
|
||||
|
||||
if not symbols:
|
||||
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
|
||||
return None
|
||||
|
||||
symbol = symbols[0]
|
||||
|
||||
# Build SymbolInfo
|
||||
symbol_info = SymbolInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
file_path=symbol.file or "",
|
||||
line_start=symbol.range[0],
|
||||
line_end=symbol.range[1],
|
||||
signature=None, # Symbol entity doesn't have signature
|
||||
documentation=None, # Symbol entity doesn't have docstring
|
||||
)
|
||||
|
||||
# Extract definition source code
|
||||
definition = self._extract_definition(symbol)
|
||||
|
||||
# Get references
|
||||
references = []
|
||||
if include_references:
|
||||
refs = self.search_engine.search_references(
|
||||
symbol_name,
|
||||
limit=max_references,
|
||||
)
|
||||
references = [
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
context=r.context,
|
||||
relationship_type=r.relationship_type,
|
||||
)
|
||||
for r in refs
|
||||
]
|
||||
|
||||
# Get related symbols
|
||||
related_symbols = []
|
||||
if include_related:
|
||||
related_symbols = self._get_related_symbols(symbol)
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
symbol=symbol_info,
|
||||
definition=definition,
|
||||
references=references,
|
||||
related_symbols=related_symbols,
|
||||
metadata={
|
||||
"source": "codex-lens",
|
||||
},
|
||||
)
|
||||
|
||||
def _extract_definition(self, symbol) -> Optional[str]:
|
||||
"""Extract source code for symbol definition."""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return None
|
||||
|
||||
content = file_path.read_text(encoding='utf-8', errors='ignore')
|
||||
lines = content.split("\n")
|
||||
|
||||
start = symbol.range[0] - 1
|
||||
end = symbol.range[1]
|
||||
|
||||
if start >= len(lines):
|
||||
return None
|
||||
|
||||
return "\n".join(lines[start:end])
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract definition: {e}")
|
||||
return None
|
||||
|
||||
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
|
||||
"""Get symbols related to the given symbol."""
|
||||
related = []
|
||||
|
||||
try:
|
||||
# Search for symbols that might be related by name patterns
|
||||
# This is a simplified implementation - could be enhanced with relationship data
|
||||
|
||||
# Look for imports/callers via reference search
|
||||
refs = self.search_engine.search_references(symbol.name, limit=20)
|
||||
|
||||
seen_names = set()
|
||||
for ref in refs:
|
||||
# Extract potential symbol name from context
|
||||
if ref.relationship_type and ref.relationship_type not in seen_names:
|
||||
related.append(RelatedSymbol(
|
||||
name=f"{Path(ref.file_path).stem}",
|
||||
kind="module",
|
||||
relationship=ref.relationship_type,
|
||||
file_path=ref.file_path,
|
||||
))
|
||||
seen_names.add(ref.relationship_type)
|
||||
if len(related) >= 10:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get related symbols: {e}")
|
||||
|
||||
return related
|
||||
|
||||
def build_context_for_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
context_type: str = "file_overview",
|
||||
) -> MCPContext:
|
||||
"""Build context for an entire file."""
|
||||
# Try to get symbols by searching with file path
|
||||
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
|
||||
symbols = []
|
||||
|
||||
# Search for common symbols that might be in this file
|
||||
# This is a simplified approach - a full implementation would query by file path
|
||||
try:
|
||||
# Use the global index to search for symbols from this file
|
||||
file_str = str(file_path.resolve())
|
||||
# Get all symbols and filter by file path (not efficient but works)
|
||||
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
|
||||
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get file symbols: {e}")
|
||||
|
||||
related = [
|
||||
RelatedSymbol(
|
||||
name=s.name,
|
||||
kind=s.kind,
|
||||
relationship="defines",
|
||||
)
|
||||
for s in symbols
|
||||
]
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
related_symbols=related,
|
||||
metadata={
|
||||
"file_path": str(file_path),
|
||||
"symbol_count": len(symbols),
|
||||
},
|
||||
)
|
||||
113
codex-lens/build/lib/codexlens/mcp/schema.py
Normal file
113
codex-lens/build/lib/codexlens/mcp/schema.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""MCP data models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Information about a code symbol."""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line_start: int
|
||||
line_end: int
|
||||
signature: Optional[str] = None
|
||||
documentation: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceInfo:
|
||||
"""Information about a symbol reference."""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelatedSymbol:
|
||||
"""Related symbol (import, call target, etc.)."""
|
||||
name: str
|
||||
kind: str
|
||||
relationship: str # "imports", "calls", "inherits", "uses"
|
||||
file_path: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPContext:
|
||||
"""Model Context Protocol context object.
|
||||
|
||||
This is the structured context that gets injected into
|
||||
LLM prompts to provide code understanding.
|
||||
"""
|
||||
version: str = "1.0"
|
||||
context_type: str = "code_context"
|
||||
symbol: Optional[SymbolInfo] = None
|
||||
definition: Optional[str] = None
|
||||
references: List[ReferenceInfo] = field(default_factory=list)
|
||||
related_symbols: List[RelatedSymbol] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result = {
|
||||
"version": self.version,
|
||||
"context_type": self.context_type,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
if self.symbol:
|
||||
result["symbol"] = self.symbol.to_dict()
|
||||
if self.definition:
|
||||
result["definition"] = self.definition
|
||||
if self.references:
|
||||
result["references"] = [r.to_dict() for r in self.references]
|
||||
if self.related_symbols:
|
||||
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
|
||||
|
||||
return result
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""Serialize to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
def to_prompt_injection(self) -> str:
|
||||
"""Format for injection into LLM prompt."""
|
||||
parts = ["<code_context>"]
|
||||
|
||||
if self.symbol:
|
||||
parts.append(f"## Symbol: {self.symbol.name}")
|
||||
parts.append(f"Type: {self.symbol.kind}")
|
||||
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
|
||||
|
||||
if self.definition:
|
||||
parts.append("\n## Definition")
|
||||
parts.append(f"```\n{self.definition}\n```")
|
||||
|
||||
if self.references:
|
||||
parts.append(f"\n## References ({len(self.references)} found)")
|
||||
for ref in self.references[:5]: # Limit to 5
|
||||
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
|
||||
parts.append(f" ```\n {ref.context}\n ```")
|
||||
|
||||
if self.related_symbols:
|
||||
parts.append("\n## Related Symbols")
|
||||
for sym in self.related_symbols[:10]: # Limit to 10
|
||||
parts.append(f"- {sym.name} ({sym.relationship})")
|
||||
|
||||
parts.append("</code_context>")
|
||||
return "\n".join(parts)
|
||||
8
codex-lens/build/lib/codexlens/parsers/__init__.py
Normal file
8
codex-lens/build/lib/codexlens/parsers/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Parsers for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .factory import ParserFactory
|
||||
|
||||
__all__ = ["ParserFactory"]
|
||||
|
||||
202
codex-lens/build/lib/codexlens/parsers/encoding.py
Normal file
202
codex-lens/build/lib/codexlens/parsers/encoding.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Optional encoding detection module for CodexLens.
|
||||
|
||||
Provides automatic encoding detection with graceful fallback to UTF-8.
|
||||
Install with: pip install codexlens[encoding]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Feature flag for encoding detection availability
|
||||
ENCODING_DETECTION_AVAILABLE = False
|
||||
_import_error: Optional[str] = None
|
||||
|
||||
|
||||
def _detect_chardet_backend() -> Tuple[bool, Optional[str]]:
|
||||
"""Detect if chardet or charset-normalizer is available."""
|
||||
try:
|
||||
import chardet
|
||||
return True, None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from charset_normalizer import from_bytes
|
||||
return True, None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return False, "chardet not available. Install with: pip install codexlens[encoding]"
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
ENCODING_DETECTION_AVAILABLE, _import_error = _detect_chardet_backend()
|
||||
|
||||
|
||||
def check_encoding_available() -> Tuple[bool, Optional[str]]:
|
||||
"""Check if encoding detection dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (available, error_message)
|
||||
"""
|
||||
return ENCODING_DETECTION_AVAILABLE, _import_error
|
||||
|
||||
|
||||
def detect_encoding(content_bytes: bytes, confidence_threshold: float = 0.7) -> str:
|
||||
"""Detect encoding from file content bytes.
|
||||
|
||||
Uses chardet or charset-normalizer with configurable confidence threshold.
|
||||
Falls back to UTF-8 if confidence is too low or detection unavailable.
|
||||
|
||||
Args:
|
||||
content_bytes: Raw file content as bytes
|
||||
confidence_threshold: Minimum confidence (0.0-1.0) to accept detection
|
||||
|
||||
Returns:
|
||||
Detected encoding name (e.g., 'utf-8', 'iso-8859-1', 'gbk')
|
||||
Returns 'utf-8' as fallback if detection fails or confidence too low
|
||||
"""
|
||||
if not ENCODING_DETECTION_AVAILABLE:
|
||||
log.debug("Encoding detection not available, using UTF-8 fallback")
|
||||
return "utf-8"
|
||||
|
||||
if not content_bytes:
|
||||
return "utf-8"
|
||||
|
||||
try:
|
||||
# Try chardet first
|
||||
try:
|
||||
import chardet
|
||||
result = chardet.detect(content_bytes)
|
||||
encoding = result.get("encoding")
|
||||
confidence = result.get("confidence", 0.0)
|
||||
|
||||
if encoding and confidence >= confidence_threshold:
|
||||
log.debug(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
|
||||
# Normalize encoding name: replace underscores with hyphens
|
||||
return encoding.lower().replace('_', '-')
|
||||
else:
|
||||
log.debug(
|
||||
f"Low confidence encoding detection: {encoding} "
|
||||
f"(confidence: {confidence:.2f}), using UTF-8 fallback"
|
||||
)
|
||||
return "utf-8"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback to charset-normalizer
|
||||
try:
|
||||
from charset_normalizer import from_bytes
|
||||
results = from_bytes(content_bytes)
|
||||
if results:
|
||||
best = results.best()
|
||||
if best and best.encoding:
|
||||
log.debug(f"Detected encoding via charset-normalizer: {best.encoding}")
|
||||
# Normalize encoding name: replace underscores with hyphens
|
||||
return best.encoding.lower().replace('_', '-')
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"Encoding detection failed: {e}, using UTF-8 fallback")
|
||||
|
||||
return "utf-8"
|
||||
|
||||
|
||||
def read_file_safe(
|
||||
path: Path | str,
|
||||
confidence_threshold: float = 0.7,
|
||||
max_detection_bytes: int = 100_000
|
||||
) -> Tuple[str, str]:
|
||||
"""Read file with automatic encoding detection and safe decoding.
|
||||
|
||||
Reads file bytes, detects encoding, and decodes with error replacement
|
||||
to preserve file structure even with encoding issues.
|
||||
|
||||
Args:
|
||||
path: Path to file to read
|
||||
confidence_threshold: Minimum confidence for encoding detection
|
||||
max_detection_bytes: Maximum bytes to use for encoding detection (default 100KB)
|
||||
|
||||
Returns:
|
||||
Tuple of (content, detected_encoding)
|
||||
- content: Decoded file content (with <20> for unmappable bytes)
|
||||
- detected_encoding: Detected encoding name
|
||||
|
||||
Raises:
|
||||
OSError: If file cannot be read
|
||||
IsADirectoryError: If path is a directory
|
||||
"""
|
||||
file_path = Path(path) if isinstance(path, str) else path
|
||||
|
||||
# Read file bytes
|
||||
try:
|
||||
content_bytes = file_path.read_bytes()
|
||||
except Exception as e:
|
||||
log.error(f"Failed to read file {file_path}: {e}")
|
||||
raise
|
||||
|
||||
# Detect encoding from first N bytes for performance
|
||||
detection_sample = content_bytes[:max_detection_bytes] if len(content_bytes) > max_detection_bytes else content_bytes
|
||||
encoding = detect_encoding(detection_sample, confidence_threshold)
|
||||
|
||||
# Decode with error replacement to preserve structure
|
||||
try:
|
||||
content = content_bytes.decode(encoding, errors='replace')
|
||||
log.debug(f"Successfully decoded {file_path} using {encoding}")
|
||||
return content, encoding
|
||||
except Exception as e:
|
||||
# Final fallback to UTF-8 with replacement
|
||||
log.warning(f"Failed to decode {file_path} with {encoding}, using UTF-8: {e}")
|
||||
content = content_bytes.decode('utf-8', errors='replace')
|
||||
return content, 'utf-8'
|
||||
|
||||
|
||||
def is_binary_file(path: Path | str, sample_size: int = 8192) -> bool:
|
||||
"""Check if file is likely binary by sampling first bytes.
|
||||
|
||||
Uses heuristic: if >30% of sample bytes are null or non-text, consider binary.
|
||||
|
||||
Args:
|
||||
path: Path to file to check
|
||||
sample_size: Number of bytes to sample (default 8KB)
|
||||
|
||||
Returns:
|
||||
True if file appears to be binary, False otherwise
|
||||
"""
|
||||
file_path = Path(path) if isinstance(path, str) else path
|
||||
|
||||
try:
|
||||
with file_path.open('rb') as f:
|
||||
sample = f.read(sample_size)
|
||||
|
||||
if not sample:
|
||||
return False
|
||||
|
||||
# Count null bytes and non-printable characters
|
||||
null_count = sample.count(b'\x00')
|
||||
non_text_count = sum(1 for byte in sample if byte < 0x20 and byte not in (0x09, 0x0a, 0x0d))
|
||||
|
||||
# If >30% null bytes or >50% non-text, consider binary
|
||||
null_ratio = null_count / len(sample)
|
||||
non_text_ratio = non_text_count / len(sample)
|
||||
|
||||
return null_ratio > 0.3 or non_text_ratio > 0.5
|
||||
|
||||
except Exception as e:
|
||||
log.debug(f"Binary check failed for {file_path}: {e}, assuming text")
|
||||
return False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ENCODING_DETECTION_AVAILABLE",
|
||||
"check_encoding_available",
|
||||
"detect_encoding",
|
||||
"read_file_safe",
|
||||
"is_binary_file",
|
||||
]
|
||||
385
codex-lens/build/lib/codexlens/parsers/factory.py
Normal file
385
codex-lens/build/lib/codexlens/parsers/factory.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Parser factory for CodexLens.
|
||||
|
||||
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
|
||||
available. Regex fallbacks are retained to preserve the existing parser
|
||||
interface and behavior in minimal environments.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
class Parser(Protocol):
|
||||
def parse(self, text: str, path: Path) -> IndexedFile: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleRegexParser:
|
||||
language_id: str
|
||||
|
||||
def parse(self, text: str, path: Path) -> IndexedFile:
|
||||
# Try tree-sitter first for supported languages
|
||||
if self.language_id in {"python", "javascript", "typescript"}:
|
||||
ts_parser = TreeSitterSymbolParser(self.language_id, path)
|
||||
if ts_parser.is_available():
|
||||
indexed = ts_parser.parse(text, path)
|
||||
if indexed is not None:
|
||||
return indexed
|
||||
|
||||
# Fallback to regex parsing
|
||||
if self.language_id == "python":
|
||||
symbols = _parse_python_symbols_regex(text)
|
||||
relationships = _parse_python_relationships_regex(text, path)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
symbols = _parse_js_ts_symbols_regex(text)
|
||||
relationships = _parse_js_ts_relationships_regex(text, path)
|
||||
elif self.language_id == "java":
|
||||
symbols = _parse_java_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "go":
|
||||
symbols = _parse_go_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "markdown":
|
||||
symbols = _parse_markdown_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "text":
|
||||
symbols = _parse_text_symbols(text)
|
||||
relationships = []
|
||||
else:
|
||||
symbols = _parse_generic_symbols(text)
|
||||
relationships = []
|
||||
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
relationships=relationships,
|
||||
)
|
||||
|
||||
|
||||
class ParserFactory:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
self._parsers: Dict[str, Parser] = {}
|
||||
|
||||
def get_parser(self, language_id: str) -> Parser:
|
||||
if language_id not in self._parsers:
|
||||
self._parsers[language_id] = SimpleRegexParser(language_id)
|
||||
return self._parsers[language_id]
|
||||
|
||||
|
||||
# Regex-based fallback parsers
|
||||
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
|
||||
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
|
||||
|
||||
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
|
||||
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||
|
||||
|
||||
|
||||
|
||||
def _parse_python_symbols(text: str) -> List[Symbol]:
|
||||
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
|
||||
ts_parser = TreeSitterSymbolParser("python")
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_python_symbols_regex(text)
|
||||
|
||||
|
||||
def _parse_js_ts_symbols(
|
||||
text: str,
|
||||
language_id: str = "javascript",
|
||||
path: Optional[Path] = None,
|
||||
) -> List[Symbol]:
|
||||
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
|
||||
ts_parser = TreeSitterSymbolParser(language_id, path)
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_js_ts_symbols_regex(text)
|
||||
|
||||
|
||||
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
current_class_indent: Optional[int] = None
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _PY_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_class_indent = len(line) - len(line.lstrip(" "))
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
def_match = _PY_DEF_RE.match(line)
|
||||
if def_match:
|
||||
indent = len(line) - len(line.lstrip(" "))
|
||||
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
|
||||
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
|
||||
continue
|
||||
if current_class_indent is not None:
|
||||
indent = len(line) - len(line.lstrip(" "))
|
||||
if line.strip() and indent <= current_class_indent:
|
||||
current_class_indent = None
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||
relationships: List[CodeRelationship] = []
|
||||
current_scope: str | None = None
|
||||
source_file = str(path.resolve())
|
||||
|
||||
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _PY_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_scope = class_match.group(1)
|
||||
continue
|
||||
|
||||
def_match = _PY_DEF_RE.match(line)
|
||||
if def_match:
|
||||
current_scope = def_match.group(1)
|
||||
continue
|
||||
|
||||
if current_scope is None:
|
||||
continue
|
||||
|
||||
import_match = _PY_IMPORT_RE.search(line)
|
||||
if import_match:
|
||||
import_target = import_match.group(1) or import_match.group(2)
|
||||
if import_target:
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=import_target.strip(),
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
for call_match in _PY_CALL_RE.finditer(line):
|
||||
call_name = call_match.group(1)
|
||||
if call_name in {
|
||||
"if",
|
||||
"for",
|
||||
"while",
|
||||
"return",
|
||||
"print",
|
||||
"len",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"list",
|
||||
"dict",
|
||||
"set",
|
||||
"tuple",
|
||||
current_scope,
|
||||
}:
|
||||
continue
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=call_name,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
|
||||
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
|
||||
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
|
||||
_JS_ARROW_RE = re.compile(
|
||||
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
|
||||
)
|
||||
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
|
||||
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
|
||||
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||
|
||||
|
||||
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
in_class = False
|
||||
class_brace_depth = 0
|
||||
brace_depth = 0
|
||||
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
brace_depth += line.count("{") - line.count("}")
|
||||
|
||||
class_match = _JS_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
in_class = True
|
||||
class_brace_depth = brace_depth
|
||||
continue
|
||||
|
||||
if in_class and brace_depth < class_brace_depth:
|
||||
in_class = False
|
||||
|
||||
func_match = _JS_FUNC_RE.match(line)
|
||||
if func_match:
|
||||
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
|
||||
continue
|
||||
|
||||
arrow_match = _JS_ARROW_RE.match(line)
|
||||
if arrow_match:
|
||||
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
|
||||
continue
|
||||
|
||||
if in_class:
|
||||
method_match = _JS_METHOD_RE.match(line)
|
||||
if method_match:
|
||||
name = method_match.group(1)
|
||||
if name != "constructor":
|
||||
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
|
||||
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||
relationships: List[CodeRelationship] = []
|
||||
current_scope: str | None = None
|
||||
source_file = str(path.resolve())
|
||||
|
||||
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _JS_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_scope = class_match.group(1)
|
||||
continue
|
||||
|
||||
func_match = _JS_FUNC_RE.match(line)
|
||||
if func_match:
|
||||
current_scope = func_match.group(1)
|
||||
continue
|
||||
|
||||
arrow_match = _JS_ARROW_RE.match(line)
|
||||
if arrow_match:
|
||||
current_scope = arrow_match.group(1)
|
||||
continue
|
||||
|
||||
if current_scope is None:
|
||||
continue
|
||||
|
||||
import_match = _JS_IMPORT_RE.search(line)
|
||||
if import_match:
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=import_match.group(1),
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
for call_match in _JS_CALL_RE.finditer(line):
|
||||
call_name = call_match.group(1)
|
||||
if call_name in {current_scope}:
|
||||
continue
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=call_name,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
|
||||
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
|
||||
_JAVA_METHOD_RE = re.compile(
|
||||
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
|
||||
)
|
||||
|
||||
|
||||
def _parse_java_symbols(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _JAVA_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
method_match = _JAVA_METHOD_RE.match(line)
|
||||
if method_match:
|
||||
symbols.append(Symbol(name=method_match.group(1), kind="method", range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
_GO_FUNC_RE = re.compile(r"^\s*func\s+(?:\([^)]+\)\s+)?([A-Za-z_]\w*)\s*\(")
|
||||
_GO_TYPE_RE = re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+(?:struct|interface)\b")
|
||||
|
||||
|
||||
def _parse_go_symbols(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
type_match = _GO_TYPE_RE.match(line)
|
||||
if type_match:
|
||||
symbols.append(Symbol(name=type_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
func_match = _GO_FUNC_RE.match(line)
|
||||
if func_match:
|
||||
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
_GENERIC_DEF_RE = re.compile(r"^\s*(?:def|function|func)\s+([A-Za-z_]\w*)\b")
|
||||
_GENERIC_CLASS_RE = re.compile(r"^\s*(?:class|struct|interface)\s+([A-Za-z_]\w*)\b")
|
||||
|
||||
|
||||
def _parse_generic_symbols(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _GENERIC_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
def_match = _GENERIC_DEF_RE.match(line)
|
||||
if def_match:
|
||||
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
# Markdown heading regex: # Heading, ## Heading, etc.
|
||||
_MD_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$")
|
||||
|
||||
|
||||
def _parse_markdown_symbols(text: str) -> List[Symbol]:
|
||||
"""Parse Markdown headings as symbols.
|
||||
|
||||
Extracts # headings as 'section' symbols with heading level as kind suffix.
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
heading_match = _MD_HEADING_RE.match(line)
|
||||
if heading_match:
|
||||
level = len(heading_match.group(1))
|
||||
title = heading_match.group(2).strip()
|
||||
# Use 'section' kind with level indicator
|
||||
kind = f"h{level}"
|
||||
symbols.append(Symbol(name=title, kind=kind, range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_text_symbols(text: str) -> List[Symbol]:
|
||||
"""Parse plain text files - no symbols, just index content."""
|
||||
# Text files don't have structured symbols, return empty list
|
||||
# The file content will still be indexed for FTS search
|
||||
return []
|
||||
98
codex-lens/build/lib/codexlens/parsers/tokenizer.py
Normal file
98
codex-lens/build/lib/codexlens/parsers/tokenizer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Token counting utilities for CodexLens.
|
||||
|
||||
Provides accurate token counting using tiktoken with character count fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
TIKTOKEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
TIKTOKEN_AVAILABLE = False
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Token counter with tiktoken primary and character count fallback."""
|
||||
|
||||
def __init__(self, encoding_name: str = "cl100k_base") -> None:
|
||||
"""Initialize tokenizer.
|
||||
|
||||
Args:
|
||||
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
|
||||
"""
|
||||
self._encoding: Optional[object] = None
|
||||
self._encoding_name = encoding_name
|
||||
|
||||
if TIKTOKEN_AVAILABLE:
|
||||
try:
|
||||
self._encoding = tiktoken.get_encoding(encoding_name)
|
||||
except Exception:
|
||||
# Fallback to character counting if encoding fails
|
||||
self._encoding = None
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in text.
|
||||
|
||||
Uses tiktoken if available, otherwise falls back to character count / 4.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if self._encoding is not None:
|
||||
try:
|
||||
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
# Fall through to character count fallback
|
||||
pass
|
||||
|
||||
# Fallback: rough estimate using character count
|
||||
# Average of ~4 characters per token for English text
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
def is_using_tiktoken(self) -> bool:
|
||||
"""Check if tiktoken is being used.
|
||||
|
||||
Returns:
|
||||
True if tiktoken is available and initialized
|
||||
"""
|
||||
return self._encoding is not None
|
||||
|
||||
|
||||
# Global default tokenizer instance
|
||||
_default_tokenizer: Optional[Tokenizer] = None
|
||||
|
||||
|
||||
def get_default_tokenizer() -> Tokenizer:
|
||||
"""Get the global default tokenizer instance.
|
||||
|
||||
Returns:
|
||||
Shared Tokenizer instance
|
||||
"""
|
||||
global _default_tokenizer
|
||||
if _default_tokenizer is None:
|
||||
_default_tokenizer = Tokenizer()
|
||||
return _default_tokenizer
|
||||
|
||||
|
||||
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
|
||||
"""Count tokens in text using default or provided tokenizer.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
tokenizer: Optional tokenizer instance (uses default if None)
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_default_tokenizer()
|
||||
return tokenizer.count_tokens(text)
|
||||
809
codex-lens/build/lib/codexlens/parsers/treesitter_parser.py
Normal file
809
codex-lens/build/lib/codexlens/parsers/treesitter_parser.py
Normal file
@@ -0,0 +1,809 @@
|
||||
"""Tree-sitter based parser for CodexLens.
|
||||
|
||||
Provides precise AST-level parsing via tree-sitter.
|
||||
|
||||
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
|
||||
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
|
||||
return `None`; callers should use a regex-based fallback such as
|
||||
`codexlens.parsers.factory.SimpleRegexParser`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Language as TreeSitterLanguage
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
from tree_sitter import Parser as TreeSitterParser
|
||||
TREE_SITTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TreeSitterLanguage = None # type: ignore[assignment]
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TreeSitterParser = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
class TreeSitterSymbolParser:
|
||||
"""Parser using tree-sitter for AST-level symbol extraction."""
|
||||
|
||||
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
|
||||
"""Initialize tree-sitter parser for a language.
|
||||
|
||||
Args:
|
||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||
path: Optional file path for language variant detection (e.g., .tsx)
|
||||
"""
|
||||
self.language_id = language_id
|
||||
self.path = path
|
||||
self._parser: Optional[object] = None
|
||||
self._language: Optional[TreeSitterLanguage] = None
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
|
||||
if TREE_SITTER_AVAILABLE:
|
||||
self._initialize_parser()
|
||||
|
||||
def _initialize_parser(self) -> None:
|
||||
"""Initialize tree-sitter parser and language."""
|
||||
if TreeSitterParser is None or TreeSitterLanguage is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Load language grammar
|
||||
if self.language_id == "python":
|
||||
import tree_sitter_python
|
||||
self._language = TreeSitterLanguage(tree_sitter_python.language())
|
||||
elif self.language_id == "javascript":
|
||||
import tree_sitter_javascript
|
||||
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
|
||||
elif self.language_id == "typescript":
|
||||
import tree_sitter_typescript
|
||||
# Detect TSX files by extension
|
||||
if self.path is not None and self.path.suffix.lower() == ".tsx":
|
||||
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
|
||||
else:
|
||||
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
|
||||
else:
|
||||
return
|
||||
|
||||
# Create parser
|
||||
self._parser = TreeSitterParser()
|
||||
if hasattr(self._parser, "set_language"):
|
||||
self._parser.set_language(self._language) # type: ignore[attr-defined]
|
||||
else:
|
||||
self._parser.language = self._language # type: ignore[assignment]
|
||||
|
||||
except Exception:
|
||||
# Gracefully handle missing language bindings
|
||||
self._parser = None
|
||||
self._language = None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if tree-sitter parser is available.
|
||||
|
||||
Returns:
|
||||
True if parser is initialized and ready
|
||||
"""
|
||||
return self._parser is not None and self._language is not None
|
||||
|
||||
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
|
||||
if not self.is_available() or self._parser is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
return source_bytes, tree.root_node
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
|
||||
"""Parse source code and extract symbols without creating IndexedFile.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
|
||||
Returns:
|
||||
List of symbols if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
parsed = self._parse_tree(text)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
source_bytes, root = parsed
|
||||
try:
|
||||
return self._extract_symbols(source_bytes, root)
|
||||
except Exception:
|
||||
# Gracefully handle extraction errors
|
||||
return None
|
||||
|
||||
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
|
||||
"""Parse source code and extract symbols.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
path: File path
|
||||
|
||||
Returns:
|
||||
IndexedFile if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
parsed = self._parse_tree(text)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
source_bytes, root = parsed
|
||||
try:
|
||||
symbols = self._extract_symbols(source_bytes, root)
|
||||
relationships = self._extract_relationships(source_bytes, root, path)
|
||||
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
relationships=relationships,
|
||||
)
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return None
|
||||
|
||||
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of extracted symbols
|
||||
"""
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_symbols(source_bytes, root)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_symbols(source_bytes, root)
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_relationships(source_bytes, root, path)
|
||||
if self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_relationships(source_bytes, root, path)
|
||||
return []
|
||||
|
||||
def _extract_python_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
source_file = str(path.resolve())
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
scope_stack: List[str] = []
|
||||
alias_stack: List[Dict[str, str]] = [{}]
|
||||
|
||||
def record_import(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_call(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
base = target_symbol.split(".", 1)[0]
|
||||
if base in {"self", "cls"}:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.INHERITS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def visit(node: TreeSitterNode) -> None:
|
||||
pushed_scope = False
|
||||
pushed_aliases = False
|
||||
|
||||
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type == "class_definition" and pushed_scope:
|
||||
superclasses = node.child_by_field_name("superclasses")
|
||||
if superclasses is not None:
|
||||
for child in superclasses.children:
|
||||
dotted = self._python_expression_to_dotted(source_bytes, child)
|
||||
if not dotted:
|
||||
continue
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_inherits(resolved, self._node_start_line(node))
|
||||
|
||||
if node.type in {"import_statement", "import_from_statement"}:
|
||||
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
|
||||
if updates:
|
||||
alias_stack[-1].update(updates)
|
||||
for target_symbol in imported_targets:
|
||||
record_import(target_symbol, self._node_start_line(node))
|
||||
|
||||
if node.type == "call":
|
||||
fn_node = node.child_by_field_name("function")
|
||||
if fn_node is not None:
|
||||
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_call(resolved, self._node_start_line(node))
|
||||
|
||||
for child in node.children:
|
||||
visit(child)
|
||||
|
||||
if pushed_aliases:
|
||||
alias_stack.pop()
|
||||
if pushed_scope:
|
||||
scope_stack.pop()
|
||||
|
||||
visit(root)
|
||||
return relationships
|
||||
|
||||
def _extract_js_ts_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
source_file = str(path.resolve())
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
scope_stack: List[str] = []
|
||||
alias_stack: List[Dict[str, str]] = [{}]
|
||||
|
||||
def record_import(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_call(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
base = target_symbol.split(".", 1)[0]
|
||||
if base in {"this", "super"}:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.INHERITS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def visit(node: TreeSitterNode) -> None:
|
||||
pushed_scope = False
|
||||
pushed_aliases = False
|
||||
|
||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if pushed_scope:
|
||||
superclass = node.child_by_field_name("superclass")
|
||||
if superclass is not None:
|
||||
dotted = self._js_expression_to_dotted(source_bytes, superclass)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_inherits(resolved, self._node_start_line(node))
|
||||
|
||||
if node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is not None
|
||||
and value_node is not None
|
||||
and name_node.type in {"identifier", "property_identifier"}
|
||||
and value_node.type == "arrow_function"
|
||||
):
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type == "method_definition" and self._has_class_ancestor(node):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name and scope_name != "constructor":
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type in {"import_declaration", "import_statement"}:
|
||||
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
|
||||
if updates:
|
||||
alias_stack[-1].update(updates)
|
||||
for target_symbol in imported_targets:
|
||||
record_import(target_symbol, self._node_start_line(node))
|
||||
|
||||
# Best-effort support for CommonJS require() imports:
|
||||
# const fs = require("fs")
|
||||
if node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is not None
|
||||
and value_node is not None
|
||||
and name_node.type == "identifier"
|
||||
and value_node.type == "call_expression"
|
||||
):
|
||||
callee = value_node.child_by_field_name("function")
|
||||
args = value_node.child_by_field_name("arguments")
|
||||
if (
|
||||
callee is not None
|
||||
and self._node_text(source_bytes, callee).strip() == "require"
|
||||
and args is not None
|
||||
):
|
||||
module_name = self._js_first_string_argument(source_bytes, args)
|
||||
if module_name:
|
||||
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
|
||||
record_import(module_name, self._node_start_line(node))
|
||||
|
||||
if node.type == "call_expression":
|
||||
fn_node = node.child_by_field_name("function")
|
||||
if fn_node is not None:
|
||||
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_call(resolved, self._node_start_line(node))
|
||||
|
||||
for child in node.children:
|
||||
visit(child)
|
||||
|
||||
if pushed_aliases:
|
||||
alias_stack.pop()
|
||||
if pushed_scope:
|
||||
scope_stack.pop()
|
||||
|
||||
visit(root)
|
||||
return relationships
|
||||
|
||||
def _node_start_line(self, node: TreeSitterNode) -> int:
|
||||
return node.start_point[0] + 1
|
||||
|
||||
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
|
||||
dotted = (dotted or "").strip()
|
||||
if not dotted:
|
||||
return ""
|
||||
|
||||
base, sep, rest = dotted.partition(".")
|
||||
resolved_base = aliases.get(base, base)
|
||||
if not rest:
|
||||
return resolved_base
|
||||
if resolved_base and rest:
|
||||
return f"{resolved_base}.{rest}"
|
||||
return resolved_base
|
||||
|
||||
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
if node.type in {"identifier", "dotted_name"}:
|
||||
return self._node_text(source_bytes, node).strip()
|
||||
if node.type == "attribute":
|
||||
obj = node.child_by_field_name("object")
|
||||
attr = node.child_by_field_name("attribute")
|
||||
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
|
||||
if obj_text and attr_text:
|
||||
return f"{obj_text}.{attr_text}"
|
||||
return obj_text or attr_text
|
||||
return ""
|
||||
|
||||
def _python_import_aliases_and_targets(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
node: TreeSitterNode,
|
||||
) -> tuple[Dict[str, str], List[str]]:
|
||||
aliases: Dict[str, str] = {}
|
||||
targets: List[str] = []
|
||||
|
||||
if node.type == "import_statement":
|
||||
for child in node.children:
|
||||
if child.type == "aliased_import":
|
||||
name_node = child.child_by_field_name("name")
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
module_name = self._node_text(source_bytes, name_node).strip()
|
||||
if not module_name:
|
||||
continue
|
||||
bound_name = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else module_name.split(".", 1)[0]
|
||||
)
|
||||
if bound_name:
|
||||
aliases[bound_name] = module_name
|
||||
targets.append(module_name)
|
||||
elif child.type == "dotted_name":
|
||||
module_name = self._node_text(source_bytes, child).strip()
|
||||
if not module_name:
|
||||
continue
|
||||
bound_name = module_name.split(".", 1)[0]
|
||||
if bound_name:
|
||||
aliases[bound_name] = bound_name
|
||||
targets.append(module_name)
|
||||
|
||||
if node.type == "import_from_statement":
|
||||
module_name = ""
|
||||
module_node = node.child_by_field_name("module_name")
|
||||
if module_node is None:
|
||||
for child in node.children:
|
||||
if child.type == "dotted_name":
|
||||
module_node = child
|
||||
break
|
||||
if module_node is not None:
|
||||
module_name = self._node_text(source_bytes, module_node).strip()
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "aliased_import":
|
||||
name_node = child.child_by_field_name("name")
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
imported_name = self._node_text(source_bytes, name_node).strip()
|
||||
if not imported_name or imported_name == "*":
|
||||
continue
|
||||
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||
bound_name = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else imported_name
|
||||
)
|
||||
if bound_name:
|
||||
aliases[bound_name] = target
|
||||
targets.append(target)
|
||||
elif child.type == "identifier":
|
||||
imported_name = self._node_text(source_bytes, child).strip()
|
||||
if not imported_name or imported_name in {"from", "import", "*"}:
|
||||
continue
|
||||
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||
aliases[imported_name] = target
|
||||
targets.append(target)
|
||||
|
||||
return aliases, targets
|
||||
|
||||
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
if node.type in {"this", "super"}:
|
||||
return node.type
|
||||
if node.type in {"identifier", "property_identifier"}:
|
||||
return self._node_text(source_bytes, node).strip()
|
||||
if node.type == "member_expression":
|
||||
obj = node.child_by_field_name("object")
|
||||
prop = node.child_by_field_name("property")
|
||||
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
|
||||
if obj_text and prop_text:
|
||||
return f"{obj_text}.{prop_text}"
|
||||
return obj_text or prop_text
|
||||
return ""
|
||||
|
||||
def _js_import_aliases_and_targets(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
node: TreeSitterNode,
|
||||
) -> tuple[Dict[str, str], List[str]]:
|
||||
aliases: Dict[str, str] = {}
|
||||
targets: List[str] = []
|
||||
|
||||
module_name = ""
|
||||
source_node = node.child_by_field_name("source")
|
||||
if source_node is not None:
|
||||
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
|
||||
if module_name:
|
||||
targets.append(module_name)
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "import_clause":
|
||||
for clause_child in child.children:
|
||||
if clause_child.type == "identifier":
|
||||
# Default import: import React from "react"
|
||||
local = self._node_text(source_bytes, clause_child).strip()
|
||||
if local and module_name:
|
||||
aliases[local] = module_name
|
||||
if clause_child.type == "namespace_import":
|
||||
# Namespace import: import * as fs from "fs"
|
||||
name_node = clause_child.child_by_field_name("name")
|
||||
if name_node is not None and module_name:
|
||||
local = self._node_text(source_bytes, name_node).strip()
|
||||
if local:
|
||||
aliases[local] = module_name
|
||||
if clause_child.type == "named_imports":
|
||||
for spec in clause_child.children:
|
||||
if spec.type != "import_specifier":
|
||||
continue
|
||||
name_node = spec.child_by_field_name("name")
|
||||
alias_node = spec.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
imported = self._node_text(source_bytes, name_node).strip()
|
||||
if not imported:
|
||||
continue
|
||||
local = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else imported
|
||||
)
|
||||
if local and module_name:
|
||||
aliases[local] = f"{module_name}.{imported}"
|
||||
targets.append(f"{module_name}.{imported}")
|
||||
|
||||
return aliases, targets
|
||||
|
||||
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
|
||||
for child in args_node.children:
|
||||
if child.type == "string":
|
||||
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
|
||||
return ""
|
||||
|
||||
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract Python symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of Python symbols (classes, functions, methods)
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind=self._python_function_kind(node),
|
||||
range=self._node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract JavaScript/TypeScript symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of JS/TS symbols (classes, functions, methods)
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is None
|
||||
or value_node is None
|
||||
or name_node.type not in {"identifier", "property_identifier"}
|
||||
or value_node.type != "arrow_function"
|
||||
):
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type == "method_definition" and self._has_class_ancestor(node):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name == "constructor":
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=name,
|
||||
kind="method",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
def _python_function_kind(self, node: TreeSitterNode) -> str:
|
||||
"""Determine if Python function is a method or standalone function.
|
||||
|
||||
Args:
|
||||
node: Function definition node
|
||||
|
||||
Returns:
|
||||
'method' if inside a class, 'function' otherwise
|
||||
"""
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"function_definition", "async_function_definition"}:
|
||||
return "function"
|
||||
if parent.type == "class_definition":
|
||||
return "method"
|
||||
parent = parent.parent
|
||||
return "function"
|
||||
|
||||
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
|
||||
"""Check if node has a class ancestor.
|
||||
|
||||
Args:
|
||||
node: AST node to check
|
||||
|
||||
Returns:
|
||||
True if node is inside a class
|
||||
"""
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"class_declaration", "class"}:
|
||||
return True
|
||||
parent = parent.parent
|
||||
return False
|
||||
|
||||
def _iter_nodes(self, root: TreeSitterNode):
|
||||
"""Iterate over all nodes in AST.
|
||||
|
||||
Args:
|
||||
root: Root node to start iteration
|
||||
|
||||
Yields:
|
||||
AST nodes in depth-first order
|
||||
"""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
"""Extract text for a node.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
Text content of node
|
||||
"""
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
|
||||
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
|
||||
"""Get line range for a node.
|
||||
|
||||
Args:
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
(start_line, end_line) tuple, 1-based inclusive
|
||||
"""
|
||||
start_line = node.start_point[0] + 1
|
||||
end_line = node.end_point[0] + 1
|
||||
return (start_line, max(start_line, end_line))
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in text.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
return self._tokenizer.count_tokens(text)
|
||||
53
codex-lens/build/lib/codexlens/search/__init__.py
Normal file
53
codex-lens/build/lib/codexlens/search/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from .chain_search import (
|
||||
ChainSearchEngine,
|
||||
SearchOptions,
|
||||
SearchStats,
|
||||
ChainSearchResult,
|
||||
quick_search,
|
||||
)
|
||||
|
||||
# Clustering availability flag (lazy import pattern)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
_clustering_import_error: str | None = None
|
||||
|
||||
try:
|
||||
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
|
||||
from .clustering import check_clustering_available
|
||||
CLUSTERING_AVAILABLE = _clustering_flag
|
||||
except ImportError as e:
|
||||
_clustering_import_error = str(e)
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Fallback when clustering module not loadable."""
|
||||
return False, _clustering_import_error
|
||||
|
||||
|
||||
# Clustering module exports (conditional)
|
||||
try:
|
||||
from .clustering import (
|
||||
BaseClusteringStrategy,
|
||||
ClusteringConfig,
|
||||
ClusteringStrategyFactory,
|
||||
get_strategy,
|
||||
)
|
||||
_clustering_exports = [
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
]
|
||||
except ImportError:
|
||||
_clustering_exports = []
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChainSearchEngine",
|
||||
"SearchOptions",
|
||||
"SearchStats",
|
||||
"ChainSearchResult",
|
||||
"quick_search",
|
||||
# Clustering
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
*_clustering_exports,
|
||||
]
|
||||
@@ -0,0 +1,21 @@
|
||||
"""Association tree module for LSP-based code relationship discovery.
|
||||
|
||||
This module provides components for building and processing call association trees
|
||||
using Language Server Protocol (LSP) call hierarchy capabilities.
|
||||
"""
|
||||
|
||||
from .builder import AssociationTreeBuilder
|
||||
from .data_structures import (
|
||||
CallTree,
|
||||
TreeNode,
|
||||
UniqueNode,
|
||||
)
|
||||
from .deduplicator import ResultDeduplicator
|
||||
|
||||
__all__ = [
|
||||
"AssociationTreeBuilder",
|
||||
"CallTree",
|
||||
"TreeNode",
|
||||
"UniqueNode",
|
||||
"ResultDeduplicator",
|
||||
]
|
||||
@@ -0,0 +1,450 @@
|
||||
"""Association tree builder using LSP call hierarchy.
|
||||
|
||||
Builds call relationship trees by recursively expanding from seed locations
|
||||
using Language Server Protocol (LSP) call hierarchy capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
from .data_structures import CallTree, TreeNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssociationTreeBuilder:
|
||||
"""Builds association trees from seed locations using LSP call hierarchy.
|
||||
|
||||
Uses depth-first recursive expansion to build a tree of code relationships
|
||||
starting from seed locations (typically from vector search results).
|
||||
|
||||
Strategy:
|
||||
- Start from seed locations (vector search results)
|
||||
- For each seed, get call hierarchy items via LSP
|
||||
- Recursively expand incoming calls (callers) if expand_callers=True
|
||||
- Recursively expand outgoing calls (callees) if expand_callees=True
|
||||
- Track visited nodes to prevent cycles
|
||||
- Stop at max_depth or when no more relations found
|
||||
|
||||
Attributes:
|
||||
lsp_manager: StandaloneLspManager for LSP communication
|
||||
visited: Set of visited node IDs to prevent cycles
|
||||
timeout: Timeout for individual LSP requests (seconds)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lsp_manager: StandaloneLspManager,
|
||||
timeout: float = 5.0,
|
||||
analysis_wait: float = 2.0,
|
||||
):
|
||||
"""Initialize AssociationTreeBuilder.
|
||||
|
||||
Args:
|
||||
lsp_manager: StandaloneLspManager instance for LSP communication
|
||||
timeout: Timeout for individual LSP requests in seconds
|
||||
analysis_wait: Time to wait for LSP analysis on first file (seconds)
|
||||
"""
|
||||
self.lsp_manager = lsp_manager
|
||||
self.timeout = timeout
|
||||
self.analysis_wait = analysis_wait
|
||||
self.visited: Set[str] = set()
|
||||
self._analyzed_files: Set[str] = set() # Track files already analyzed
|
||||
|
||||
async def build_tree(
|
||||
self,
|
||||
seed_file_path: str,
|
||||
seed_line: int,
|
||||
seed_character: int = 1,
|
||||
max_depth: int = 5,
|
||||
expand_callers: bool = True,
|
||||
expand_callees: bool = True,
|
||||
) -> CallTree:
|
||||
"""Build call tree from a single seed location.
|
||||
|
||||
Args:
|
||||
seed_file_path: Path to the seed file
|
||||
seed_line: Line number of the seed symbol (1-based)
|
||||
seed_character: Character position (1-based, default 1)
|
||||
max_depth: Maximum recursion depth (default 5)
|
||||
expand_callers: Whether to expand incoming calls (callers)
|
||||
expand_callees: Whether to expand outgoing calls (callees)
|
||||
|
||||
Returns:
|
||||
CallTree containing all discovered nodes and relationships
|
||||
"""
|
||||
tree = CallTree()
|
||||
self.visited.clear()
|
||||
|
||||
# Determine wait time - only wait for analysis on first encounter of file
|
||||
wait_time = 0.0
|
||||
if seed_file_path not in self._analyzed_files:
|
||||
wait_time = self.analysis_wait
|
||||
self._analyzed_files.add(seed_file_path)
|
||||
|
||||
# Get call hierarchy items for the seed position
|
||||
try:
|
||||
hierarchy_items = await asyncio.wait_for(
|
||||
self.lsp_manager.get_call_hierarchy_items(
|
||||
file_path=seed_file_path,
|
||||
line=seed_line,
|
||||
character=seed_character,
|
||||
wait_for_analysis=wait_time,
|
||||
),
|
||||
timeout=self.timeout + wait_time,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Timeout getting call hierarchy items for %s:%d",
|
||||
seed_file_path,
|
||||
seed_line,
|
||||
)
|
||||
return tree
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting call hierarchy items for %s:%d: %s",
|
||||
seed_file_path,
|
||||
seed_line,
|
||||
e,
|
||||
)
|
||||
return tree
|
||||
|
||||
if not hierarchy_items:
|
||||
logger.debug(
|
||||
"No call hierarchy items found for %s:%d",
|
||||
seed_file_path,
|
||||
seed_line,
|
||||
)
|
||||
return tree
|
||||
|
||||
# Create root nodes from hierarchy items
|
||||
for item_dict in hierarchy_items:
|
||||
# Convert LSP dict to CallHierarchyItem
|
||||
item = self._dict_to_call_hierarchy_item(item_dict)
|
||||
if not item:
|
||||
continue
|
||||
|
||||
root_node = TreeNode(
|
||||
item=item,
|
||||
depth=0,
|
||||
path_from_root=[self._create_node_id(item)],
|
||||
)
|
||||
tree.roots.append(root_node)
|
||||
tree.add_node(root_node)
|
||||
|
||||
# Mark as visited
|
||||
self.visited.add(root_node.node_id)
|
||||
|
||||
# Recursively expand the tree
|
||||
await self._expand_node(
|
||||
node=root_node,
|
||||
node_dict=item_dict,
|
||||
tree=tree,
|
||||
current_depth=0,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
|
||||
tree.depth_reached = max_depth
|
||||
return tree
|
||||
|
||||
async def _expand_node(
|
||||
self,
|
||||
node: TreeNode,
|
||||
node_dict: Dict,
|
||||
tree: CallTree,
|
||||
current_depth: int,
|
||||
max_depth: int,
|
||||
expand_callers: bool,
|
||||
expand_callees: bool,
|
||||
) -> None:
|
||||
"""Recursively expand a node by fetching its callers and callees.
|
||||
|
||||
Args:
|
||||
node: TreeNode to expand
|
||||
node_dict: LSP CallHierarchyItem dict (for LSP requests)
|
||||
tree: CallTree to add discovered nodes to
|
||||
current_depth: Current recursion depth
|
||||
max_depth: Maximum allowed depth
|
||||
expand_callers: Whether to expand incoming calls
|
||||
expand_callees: Whether to expand outgoing calls
|
||||
"""
|
||||
# Stop if max depth reached
|
||||
if current_depth >= max_depth:
|
||||
return
|
||||
|
||||
# Prepare tasks for parallel expansion
|
||||
tasks = []
|
||||
|
||||
if expand_callers:
|
||||
tasks.append(
|
||||
self._expand_incoming_calls(
|
||||
node=node,
|
||||
node_dict=node_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
)
|
||||
|
||||
if expand_callees:
|
||||
tasks.append(
|
||||
self._expand_outgoing_calls(
|
||||
node=node,
|
||||
node_dict=node_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute expansions in parallel
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _expand_incoming_calls(
|
||||
self,
|
||||
node: TreeNode,
|
||||
node_dict: Dict,
|
||||
tree: CallTree,
|
||||
current_depth: int,
|
||||
max_depth: int,
|
||||
expand_callers: bool,
|
||||
expand_callees: bool,
|
||||
) -> None:
|
||||
"""Expand incoming calls (callers) for a node.
|
||||
|
||||
Args:
|
||||
node: TreeNode being expanded
|
||||
node_dict: LSP dict for the node
|
||||
tree: CallTree to add nodes to
|
||||
current_depth: Current depth
|
||||
max_depth: Maximum depth
|
||||
expand_callers: Whether to continue expanding callers
|
||||
expand_callees: Whether to expand callees
|
||||
"""
|
||||
try:
|
||||
incoming_calls = await asyncio.wait_for(
|
||||
self.lsp_manager.get_incoming_calls(item=node_dict),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Timeout getting incoming calls for %s", node.node_id)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Error getting incoming calls for %s: %s", node.node_id, e)
|
||||
return
|
||||
|
||||
if not incoming_calls:
|
||||
return
|
||||
|
||||
# Process each incoming call
|
||||
for call_dict in incoming_calls:
|
||||
caller_dict = call_dict.get("from")
|
||||
if not caller_dict:
|
||||
continue
|
||||
|
||||
# Convert to CallHierarchyItem
|
||||
caller_item = self._dict_to_call_hierarchy_item(caller_dict)
|
||||
if not caller_item:
|
||||
continue
|
||||
|
||||
caller_id = self._create_node_id(caller_item)
|
||||
|
||||
# Check for cycles
|
||||
if caller_id in self.visited:
|
||||
# Create cycle marker node
|
||||
cycle_node = TreeNode(
|
||||
item=caller_item,
|
||||
depth=current_depth + 1,
|
||||
is_cycle=True,
|
||||
path_from_root=node.path_from_root + [caller_id],
|
||||
)
|
||||
node.parents.append(cycle_node)
|
||||
continue
|
||||
|
||||
# Create new caller node
|
||||
caller_node = TreeNode(
|
||||
item=caller_item,
|
||||
depth=current_depth + 1,
|
||||
path_from_root=node.path_from_root + [caller_id],
|
||||
)
|
||||
|
||||
# Add to tree
|
||||
tree.add_node(caller_node)
|
||||
tree.add_edge(caller_node, node)
|
||||
|
||||
# Update relationships
|
||||
node.parents.append(caller_node)
|
||||
caller_node.children.append(node)
|
||||
|
||||
# Mark as visited
|
||||
self.visited.add(caller_id)
|
||||
|
||||
# Recursively expand the caller
|
||||
await self._expand_node(
|
||||
node=caller_node,
|
||||
node_dict=caller_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth + 1,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
|
||||
async def _expand_outgoing_calls(
|
||||
self,
|
||||
node: TreeNode,
|
||||
node_dict: Dict,
|
||||
tree: CallTree,
|
||||
current_depth: int,
|
||||
max_depth: int,
|
||||
expand_callers: bool,
|
||||
expand_callees: bool,
|
||||
) -> None:
|
||||
"""Expand outgoing calls (callees) for a node.
|
||||
|
||||
Args:
|
||||
node: TreeNode being expanded
|
||||
node_dict: LSP dict for the node
|
||||
tree: CallTree to add nodes to
|
||||
current_depth: Current depth
|
||||
max_depth: Maximum depth
|
||||
expand_callers: Whether to expand callers
|
||||
expand_callees: Whether to continue expanding callees
|
||||
"""
|
||||
try:
|
||||
outgoing_calls = await asyncio.wait_for(
|
||||
self.lsp_manager.get_outgoing_calls(item=node_dict),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Timeout getting outgoing calls for %s", node.node_id)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Error getting outgoing calls for %s: %s", node.node_id, e)
|
||||
return
|
||||
|
||||
if not outgoing_calls:
|
||||
return
|
||||
|
||||
# Process each outgoing call
|
||||
for call_dict in outgoing_calls:
|
||||
callee_dict = call_dict.get("to")
|
||||
if not callee_dict:
|
||||
continue
|
||||
|
||||
# Convert to CallHierarchyItem
|
||||
callee_item = self._dict_to_call_hierarchy_item(callee_dict)
|
||||
if not callee_item:
|
||||
continue
|
||||
|
||||
callee_id = self._create_node_id(callee_item)
|
||||
|
||||
# Check for cycles
|
||||
if callee_id in self.visited:
|
||||
# Create cycle marker node
|
||||
cycle_node = TreeNode(
|
||||
item=callee_item,
|
||||
depth=current_depth + 1,
|
||||
is_cycle=True,
|
||||
path_from_root=node.path_from_root + [callee_id],
|
||||
)
|
||||
node.children.append(cycle_node)
|
||||
continue
|
||||
|
||||
# Create new callee node
|
||||
callee_node = TreeNode(
|
||||
item=callee_item,
|
||||
depth=current_depth + 1,
|
||||
path_from_root=node.path_from_root + [callee_id],
|
||||
)
|
||||
|
||||
# Add to tree
|
||||
tree.add_node(callee_node)
|
||||
tree.add_edge(node, callee_node)
|
||||
|
||||
# Update relationships
|
||||
node.children.append(callee_node)
|
||||
callee_node.parents.append(node)
|
||||
|
||||
# Mark as visited
|
||||
self.visited.add(callee_id)
|
||||
|
||||
# Recursively expand the callee
|
||||
await self._expand_node(
|
||||
node=callee_node,
|
||||
node_dict=callee_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth + 1,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
|
||||
def _dict_to_call_hierarchy_item(
|
||||
self, item_dict: Dict
|
||||
) -> Optional[CallHierarchyItem]:
|
||||
"""Convert LSP dict to CallHierarchyItem.
|
||||
|
||||
Args:
|
||||
item_dict: LSP CallHierarchyItem dictionary
|
||||
|
||||
Returns:
|
||||
CallHierarchyItem or None if conversion fails
|
||||
"""
|
||||
try:
|
||||
# Extract URI and convert to file path
|
||||
uri = item_dict.get("uri", "")
|
||||
file_path = uri.replace("file:///", "").replace("file://", "")
|
||||
|
||||
# Handle Windows paths (file:///C:/...)
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
# Extract range
|
||||
range_dict = item_dict.get("range", {})
|
||||
start = range_dict.get("start", {})
|
||||
end = range_dict.get("end", {})
|
||||
|
||||
# Create Range (convert from 0-based to 1-based)
|
||||
item_range = Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
)
|
||||
|
||||
return CallHierarchyItem(
|
||||
name=item_dict.get("name", "unknown"),
|
||||
kind=str(item_dict.get("kind", "unknown")),
|
||||
file_path=file_path,
|
||||
range=item_range,
|
||||
detail=item_dict.get("detail"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to convert dict to CallHierarchyItem: %s", e)
|
||||
return None
|
||||
|
||||
def _create_node_id(self, item: CallHierarchyItem) -> str:
|
||||
"""Create unique node ID from CallHierarchyItem.
|
||||
|
||||
Args:
|
||||
item: CallHierarchyItem
|
||||
|
||||
Returns:
|
||||
Unique node ID string
|
||||
"""
|
||||
return f"{item.file_path}:{item.name}:{item.range.start_line}"
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Data structures for association tree building.
|
||||
|
||||
Defines the core data classes for representing call hierarchy trees and
|
||||
deduplicated results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class TreeNode:
|
||||
"""Node in the call association tree.
|
||||
|
||||
Represents a single function/method in the tree, including its position
|
||||
in the hierarchy and relationships.
|
||||
|
||||
Attributes:
|
||||
item: LSP CallHierarchyItem containing symbol information
|
||||
depth: Distance from the root node (seed) - 0 for roots
|
||||
children: List of child nodes (functions called by this node)
|
||||
parents: List of parent nodes (functions that call this node)
|
||||
is_cycle: Whether this node creates a circular reference
|
||||
path_from_root: Path (list of node IDs) from root to this node
|
||||
"""
|
||||
|
||||
item: CallHierarchyItem
|
||||
depth: int = 0
|
||||
children: List[TreeNode] = field(default_factory=list)
|
||||
parents: List[TreeNode] = field(default_factory=list)
|
||||
is_cycle: bool = False
|
||||
path_from_root: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
"""Unique identifier for this node."""
|
||||
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on node ID."""
|
||||
return hash(self.node_id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on node ID."""
|
||||
if not isinstance(other, TreeNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the node."""
|
||||
cycle_marker = " [CYCLE]" if self.is_cycle else ""
|
||||
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallTree:
|
||||
"""Complete call tree structure built from seeds.
|
||||
|
||||
Contains all nodes discovered through recursive expansion and
|
||||
the relationships between them.
|
||||
|
||||
Attributes:
|
||||
roots: List of root nodes (seed symbols)
|
||||
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
|
||||
node_list: Flat list of all nodes in tree order
|
||||
edges: List of (from_node_id, to_node_id) tuples representing calls
|
||||
depth_reached: Maximum depth achieved in expansion
|
||||
"""
|
||||
|
||||
roots: List[TreeNode] = field(default_factory=list)
|
||||
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
|
||||
node_list: List[TreeNode] = field(default_factory=list)
|
||||
edges: List[tuple[str, str]] = field(default_factory=list)
|
||||
depth_reached: int = 0
|
||||
|
||||
def add_node(self, node: TreeNode) -> None:
|
||||
"""Add a node to the tree.
|
||||
|
||||
Args:
|
||||
node: TreeNode to add
|
||||
"""
|
||||
if node.node_id not in self.all_nodes:
|
||||
self.all_nodes[node.node_id] = node
|
||||
self.node_list.append(node)
|
||||
|
||||
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
|
||||
"""Add an edge between two nodes.
|
||||
|
||||
Args:
|
||||
from_node: Source node
|
||||
to_node: Target node
|
||||
"""
|
||||
edge = (from_node.node_id, to_node.node_id)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[TreeNode]:
|
||||
"""Get a node by ID.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
TreeNode if found, None otherwise
|
||||
"""
|
||||
return self.all_nodes.get(node_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return total number of nodes in tree."""
|
||||
return len(self.all_nodes)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the tree."""
|
||||
return (
|
||||
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
|
||||
f"depth={self.depth_reached})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniqueNode:
|
||||
"""Deduplicated unique code symbol from the tree.
|
||||
|
||||
Represents a single unique code location that may appear multiple times
|
||||
in the tree under different contexts. Contains aggregated information
|
||||
about all occurrences.
|
||||
|
||||
Attributes:
|
||||
file_path: Absolute path to the file
|
||||
name: Symbol name (function, method, class, etc.)
|
||||
kind: Symbol kind (function, method, class, etc.)
|
||||
range: Code range in the file
|
||||
min_depth: Minimum depth at which this node appears in the tree
|
||||
occurrences: Number of times this node appears in the tree
|
||||
paths: List of paths from roots to this node
|
||||
context_nodes: Related nodes from the tree
|
||||
score: Composite relevance score (higher is better)
|
||||
"""
|
||||
|
||||
file_path: str
|
||||
name: str
|
||||
kind: str
|
||||
range: Range
|
||||
min_depth: int = 0
|
||||
occurrences: int = 1
|
||||
paths: List[List[str]] = field(default_factory=list)
|
||||
context_nodes: List[str] = field(default_factory=list)
|
||||
score: float = 0.0
|
||||
|
||||
@property
|
||||
def node_key(self) -> tuple[str, int, int]:
|
||||
"""Unique key for deduplication.
|
||||
|
||||
Uses (file_path, start_line, end_line) as the unique identifier
|
||||
for this symbol across all occurrences.
|
||||
"""
|
||||
return (
|
||||
self.file_path,
|
||||
self.range.start_line,
|
||||
self.range.end_line,
|
||||
)
|
||||
|
||||
def add_path(self, path: List[str]) -> None:
|
||||
"""Add a path from root to this node.
|
||||
|
||||
Args:
|
||||
path: List of node IDs from root to this node
|
||||
"""
|
||||
if path not in self.paths:
|
||||
self.paths.append(path)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on node key."""
|
||||
return hash(self.node_key)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on node key."""
|
||||
if not isinstance(other, UniqueNode):
|
||||
return False
|
||||
return self.node_key == other.node_key
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the unique node."""
|
||||
return (
|
||||
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
|
||||
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
|
||||
)
|
||||
@@ -0,0 +1,301 @@
|
||||
"""Result deduplication for association tree nodes.
|
||||
|
||||
Provides functionality to extract unique nodes from a call tree and assign
|
||||
relevance scores based on various factors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .data_structures import (
|
||||
CallTree,
|
||||
TreeNode,
|
||||
UniqueNode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Symbol kind weights for scoring (higher = more relevant)
|
||||
KIND_WEIGHTS: Dict[str, float] = {
|
||||
# Functions and methods are primary targets
|
||||
"function": 1.0,
|
||||
"method": 1.0,
|
||||
"12": 1.0, # LSP SymbolKind.Function
|
||||
"6": 1.0, # LSP SymbolKind.Method
|
||||
# Classes are important but secondary
|
||||
"class": 0.8,
|
||||
"5": 0.8, # LSP SymbolKind.Class
|
||||
# Interfaces and types
|
||||
"interface": 0.7,
|
||||
"11": 0.7, # LSP SymbolKind.Interface
|
||||
"type": 0.6,
|
||||
# Constructors
|
||||
"constructor": 0.9,
|
||||
"9": 0.9, # LSP SymbolKind.Constructor
|
||||
# Variables and constants
|
||||
"variable": 0.4,
|
||||
"13": 0.4, # LSP SymbolKind.Variable
|
||||
"constant": 0.5,
|
||||
"14": 0.5, # LSP SymbolKind.Constant
|
||||
# Default for unknown kinds
|
||||
"unknown": 0.3,
|
||||
}
|
||||
|
||||
|
||||
class ResultDeduplicator:
|
||||
"""Extracts and scores unique nodes from call trees.
|
||||
|
||||
Processes a CallTree to extract unique code locations, merging duplicates
|
||||
and assigning relevance scores based on:
|
||||
- Depth: Shallower nodes (closer to seeds) score higher
|
||||
- Frequency: Nodes appearing multiple times score higher
|
||||
- Kind: Function/method > class > variable
|
||||
|
||||
Attributes:
|
||||
depth_weight: Weight for depth factor in scoring (default 0.4)
|
||||
frequency_weight: Weight for frequency factor (default 0.3)
|
||||
kind_weight: Weight for symbol kind factor (default 0.3)
|
||||
max_depth_penalty: Maximum depth before full penalty applied
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth_weight: float = 0.4,
|
||||
frequency_weight: float = 0.3,
|
||||
kind_weight: float = 0.3,
|
||||
max_depth_penalty: int = 10,
|
||||
):
|
||||
"""Initialize ResultDeduplicator.
|
||||
|
||||
Args:
|
||||
depth_weight: Weight for depth factor (0.0-1.0)
|
||||
frequency_weight: Weight for frequency factor (0.0-1.0)
|
||||
kind_weight: Weight for symbol kind factor (0.0-1.0)
|
||||
max_depth_penalty: Depth at which score becomes 0 for depth factor
|
||||
"""
|
||||
self.depth_weight = depth_weight
|
||||
self.frequency_weight = frequency_weight
|
||||
self.kind_weight = kind_weight
|
||||
self.max_depth_penalty = max_depth_penalty
|
||||
|
||||
def deduplicate(
|
||||
self,
|
||||
tree: CallTree,
|
||||
max_results: Optional[int] = None,
|
||||
) -> List[UniqueNode]:
|
||||
"""Extract unique nodes from the call tree.
|
||||
|
||||
Traverses the tree, groups nodes by their unique key (file_path,
|
||||
start_line, end_line), and merges duplicate occurrences.
|
||||
|
||||
Args:
|
||||
tree: CallTree to process
|
||||
max_results: Maximum number of results to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of UniqueNode objects, sorted by score descending
|
||||
"""
|
||||
if not tree.node_list:
|
||||
return []
|
||||
|
||||
# Group nodes by unique key
|
||||
unique_map: Dict[tuple, UniqueNode] = {}
|
||||
|
||||
for node in tree.node_list:
|
||||
if node.is_cycle:
|
||||
# Skip cycle markers - they point to already-counted nodes
|
||||
continue
|
||||
|
||||
key = self._get_node_key(node)
|
||||
|
||||
if key in unique_map:
|
||||
# Update existing unique node
|
||||
unique_node = unique_map[key]
|
||||
unique_node.occurrences += 1
|
||||
unique_node.min_depth = min(unique_node.min_depth, node.depth)
|
||||
unique_node.add_path(node.path_from_root)
|
||||
|
||||
# Collect context from relationships
|
||||
for parent in node.parents:
|
||||
if not parent.is_cycle:
|
||||
unique_node.context_nodes.append(parent.node_id)
|
||||
for child in node.children:
|
||||
if not child.is_cycle:
|
||||
unique_node.context_nodes.append(child.node_id)
|
||||
else:
|
||||
# Create new unique node
|
||||
unique_node = UniqueNode(
|
||||
file_path=node.item.file_path,
|
||||
name=node.item.name,
|
||||
kind=node.item.kind,
|
||||
range=node.item.range,
|
||||
min_depth=node.depth,
|
||||
occurrences=1,
|
||||
paths=[node.path_from_root.copy()],
|
||||
context_nodes=[],
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
# Collect initial context
|
||||
for parent in node.parents:
|
||||
if not parent.is_cycle:
|
||||
unique_node.context_nodes.append(parent.node_id)
|
||||
for child in node.children:
|
||||
if not child.is_cycle:
|
||||
unique_node.context_nodes.append(child.node_id)
|
||||
|
||||
unique_map[key] = unique_node
|
||||
|
||||
# Calculate scores for all unique nodes
|
||||
unique_nodes = list(unique_map.values())
|
||||
|
||||
# Find max frequency for normalization
|
||||
max_frequency = max((n.occurrences for n in unique_nodes), default=1)
|
||||
|
||||
for node in unique_nodes:
|
||||
node.score = self._score_node(node, max_frequency)
|
||||
|
||||
# Sort by score descending
|
||||
unique_nodes.sort(key=lambda n: n.score, reverse=True)
|
||||
|
||||
# Apply max_results limit
|
||||
if max_results is not None and max_results > 0:
|
||||
unique_nodes = unique_nodes[:max_results]
|
||||
|
||||
logger.debug(
|
||||
"Deduplicated %d tree nodes to %d unique nodes",
|
||||
len(tree.node_list),
|
||||
len(unique_nodes),
|
||||
)
|
||||
|
||||
return unique_nodes
|
||||
|
||||
def _score_node(
|
||||
self,
|
||||
node: UniqueNode,
|
||||
max_frequency: int,
|
||||
) -> float:
|
||||
"""Calculate composite score for a unique node.
|
||||
|
||||
Score = depth_weight * depth_score +
|
||||
frequency_weight * frequency_score +
|
||||
kind_weight * kind_score
|
||||
|
||||
Args:
|
||||
node: UniqueNode to score
|
||||
max_frequency: Maximum occurrence count for normalization
|
||||
|
||||
Returns:
|
||||
Composite score between 0.0 and 1.0
|
||||
"""
|
||||
# Depth score: closer to root = higher score
|
||||
# Score of 1.0 at depth 0, decreasing to 0.0 at max_depth_penalty
|
||||
depth_score = max(
|
||||
0.0,
|
||||
1.0 - (node.min_depth / self.max_depth_penalty),
|
||||
)
|
||||
|
||||
# Frequency score: more occurrences = higher score
|
||||
frequency_score = node.occurrences / max_frequency if max_frequency > 0 else 0.0
|
||||
|
||||
# Kind score: function/method > class > variable
|
||||
kind_str = str(node.kind).lower()
|
||||
kind_score = KIND_WEIGHTS.get(kind_str, KIND_WEIGHTS["unknown"])
|
||||
|
||||
# Composite score
|
||||
score = (
|
||||
self.depth_weight * depth_score
|
||||
+ self.frequency_weight * frequency_score
|
||||
+ self.kind_weight * kind_score
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
def _get_node_key(self, node: TreeNode) -> tuple:
|
||||
"""Get unique key for a tree node.
|
||||
|
||||
Uses (file_path, start_line, end_line) as the unique identifier.
|
||||
|
||||
Args:
|
||||
node: TreeNode
|
||||
|
||||
Returns:
|
||||
Tuple key for deduplication
|
||||
"""
|
||||
return (
|
||||
node.item.file_path,
|
||||
node.item.range.start_line,
|
||||
node.item.range.end_line,
|
||||
)
|
||||
|
||||
def filter_by_kind(
|
||||
self,
|
||||
nodes: List[UniqueNode],
|
||||
kinds: List[str],
|
||||
) -> List[UniqueNode]:
|
||||
"""Filter unique nodes by symbol kind.
|
||||
|
||||
Args:
|
||||
nodes: List of UniqueNode to filter
|
||||
kinds: List of allowed kinds (e.g., ["function", "method"])
|
||||
|
||||
Returns:
|
||||
Filtered list of UniqueNode
|
||||
"""
|
||||
kinds_lower = [k.lower() for k in kinds]
|
||||
return [
|
||||
node
|
||||
for node in nodes
|
||||
if str(node.kind).lower() in kinds_lower
|
||||
]
|
||||
|
||||
def filter_by_file(
|
||||
self,
|
||||
nodes: List[UniqueNode],
|
||||
file_patterns: List[str],
|
||||
) -> List[UniqueNode]:
|
||||
"""Filter unique nodes by file path patterns.
|
||||
|
||||
Args:
|
||||
nodes: List of UniqueNode to filter
|
||||
file_patterns: List of path substrings to match
|
||||
|
||||
Returns:
|
||||
Filtered list of UniqueNode
|
||||
"""
|
||||
return [
|
||||
node
|
||||
for node in nodes
|
||||
if any(pattern in node.file_path for pattern in file_patterns)
|
||||
]
|
||||
|
||||
def to_dict_list(self, nodes: List[UniqueNode]) -> List[Dict]:
|
||||
"""Convert list of UniqueNode to JSON-serializable dicts.
|
||||
|
||||
Args:
|
||||
nodes: List of UniqueNode
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"file_path": node.file_path,
|
||||
"name": node.name,
|
||||
"kind": node.kind,
|
||||
"range": {
|
||||
"start_line": node.range.start_line,
|
||||
"start_character": node.range.start_character,
|
||||
"end_line": node.range.end_line,
|
||||
"end_character": node.range.end_character,
|
||||
},
|
||||
"min_depth": node.min_depth,
|
||||
"occurrences": node.occurrences,
|
||||
"path_count": len(node.paths),
|
||||
"score": round(node.score, 4),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
277
codex-lens/build/lib/codexlens/search/binary_searcher.py
Normal file
277
codex-lens/build/lib/codexlens/search/binary_searcher.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Binary vector searcher for cascade search.
|
||||
|
||||
This module provides fast binary vector search using Hamming distance
|
||||
for the first stage of cascade search (coarse filtering).
|
||||
|
||||
Supports two loading modes:
|
||||
1. Memory-mapped file (preferred): Low memory footprint, OS-managed paging
|
||||
2. Database loading (fallback): Loads all vectors into RAM
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pre-computed popcount lookup table for vectorized Hamming distance
|
||||
# Each byte value (0-255) maps to its bit count
|
||||
_POPCOUNT_TABLE = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
|
||||
|
||||
|
||||
class BinarySearcher:
|
||||
"""Fast binary vector search using Hamming distance.
|
||||
|
||||
This class implements the first stage of cascade search:
|
||||
fast, approximate retrieval using binary vectors and Hamming distance.
|
||||
|
||||
The binary vectors are derived from dense embeddings by thresholding:
|
||||
binary[i] = 1 if dense[i] > 0 else 0
|
||||
|
||||
Hamming distance between two binary vectors counts the number of
|
||||
differing bits, which can be computed very efficiently using XOR
|
||||
and population count.
|
||||
|
||||
Supports two loading modes:
|
||||
- Memory-mapped file (preferred): Uses np.memmap for minimal RAM usage
|
||||
- Database (fallback): Loads all vectors into memory from SQLite
|
||||
"""
|
||||
|
||||
def __init__(self, index_root_or_meta_path: Path) -> None:
|
||||
"""Initialize BinarySearcher.
|
||||
|
||||
Args:
|
||||
index_root_or_meta_path: Either:
|
||||
- Path to index root directory (containing _binary_vectors.mmap)
|
||||
- Path to _vectors_meta.db (legacy mode, loads from DB)
|
||||
"""
|
||||
path = Path(index_root_or_meta_path)
|
||||
|
||||
# Determine if this is an index root or a specific DB path
|
||||
if path.suffix == '.db':
|
||||
# Legacy mode: specific DB path
|
||||
self.index_root = path.parent
|
||||
self.meta_store_path = path
|
||||
else:
|
||||
# New mode: index root directory
|
||||
self.index_root = path
|
||||
self.meta_store_path = path / "_vectors_meta.db"
|
||||
|
||||
self._chunk_ids: Optional[np.ndarray] = None
|
||||
self._binary_matrix: Optional[np.ndarray] = None
|
||||
self._is_memmap = False
|
||||
self._loaded = False
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load binary vectors using memory-mapped file or database fallback.
|
||||
|
||||
Tries to load from memory-mapped file first (preferred for large indexes),
|
||||
falls back to database loading if mmap file doesn't exist.
|
||||
|
||||
Returns:
|
||||
True if vectors were loaded successfully.
|
||||
"""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
# Try memory-mapped file first (preferred)
|
||||
mmap_path = self.index_root / "_binary_vectors.mmap"
|
||||
meta_path = mmap_path.with_suffix('.meta.json')
|
||||
|
||||
if mmap_path.exists() and meta_path.exists():
|
||||
try:
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
shape = tuple(meta['shape'])
|
||||
self._chunk_ids = np.array(meta['chunk_ids'], dtype=np.int64)
|
||||
|
||||
# Memory-map the binary matrix (read-only)
|
||||
self._binary_matrix = np.memmap(
|
||||
str(mmap_path),
|
||||
dtype=np.uint8,
|
||||
mode='r',
|
||||
shape=shape
|
||||
)
|
||||
self._is_memmap = True
|
||||
self._loaded = True
|
||||
|
||||
logger.info(
|
||||
"Memory-mapped %d binary vectors (%d bytes each)",
|
||||
len(self._chunk_ids), shape[1]
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load mmap binary vectors, falling back to DB: %s", e)
|
||||
|
||||
# Fallback: load from database
|
||||
return self._load_from_db()
|
||||
|
||||
def _load_from_db(self) -> bool:
|
||||
"""Load binary vectors from database (legacy/fallback mode).
|
||||
|
||||
Returns:
|
||||
True if vectors were loaded successfully.
|
||||
"""
|
||||
try:
|
||||
from codexlens.storage.vector_meta_store import VectorMetadataStore
|
||||
|
||||
with VectorMetadataStore(self.meta_store_path) as store:
|
||||
rows = store.get_all_binary_vectors()
|
||||
|
||||
if not rows:
|
||||
logger.warning("No binary vectors found in %s", self.meta_store_path)
|
||||
return False
|
||||
|
||||
# Convert to numpy arrays for fast computation
|
||||
self._chunk_ids = np.array([r[0] for r in rows], dtype=np.int64)
|
||||
|
||||
# Unpack bytes to numpy array
|
||||
binary_arrays = []
|
||||
for _, vec_bytes in rows:
|
||||
arr = np.frombuffer(vec_bytes, dtype=np.uint8)
|
||||
binary_arrays.append(arr)
|
||||
|
||||
self._binary_matrix = np.vstack(binary_arrays)
|
||||
self._is_memmap = False
|
||||
self._loaded = True
|
||||
|
||||
logger.info(
|
||||
"Loaded %d binary vectors from DB (%d bytes each)",
|
||||
len(self._chunk_ids), self._binary_matrix.shape[1]
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load binary vectors: %s", e)
|
||||
return False
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_vector: np.ndarray,
|
||||
top_k: int = 100
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Search for similar vectors using Hamming distance.
|
||||
|
||||
Args:
|
||||
query_vector: Dense query vector (will be binarized).
|
||||
top_k: Number of top results to return.
|
||||
|
||||
Returns:
|
||||
List of (chunk_id, hamming_distance) tuples sorted by distance.
|
||||
"""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
# Binarize query vector
|
||||
query_binary = (query_vector > 0).astype(np.uint8)
|
||||
query_packed = np.packbits(query_binary)
|
||||
|
||||
# Compute Hamming distances using XOR and popcount
|
||||
# XOR gives 1 for differing bits
|
||||
xor_result = np.bitwise_xor(self._binary_matrix, query_packed)
|
||||
|
||||
# Vectorized popcount using lookup table (orders of magnitude faster)
|
||||
# Sum the bit counts for each byte across all columns
|
||||
distances = np.sum(_POPCOUNT_TABLE[xor_result], axis=1, dtype=np.int32)
|
||||
|
||||
# Get top-k with smallest distances
|
||||
if top_k >= len(distances):
|
||||
top_indices = np.argsort(distances)
|
||||
else:
|
||||
# Partial sort for efficiency
|
||||
top_indices = np.argpartition(distances, top_k)[:top_k]
|
||||
top_indices = top_indices[np.argsort(distances[top_indices])]
|
||||
|
||||
results = [
|
||||
(int(self._chunk_ids[i]), int(distances[i]))
|
||||
for i in top_indices
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
def search_with_rerank(
|
||||
self,
|
||||
query_dense: np.ndarray,
|
||||
dense_vectors: np.ndarray,
|
||||
dense_chunk_ids: np.ndarray,
|
||||
top_k: int = 10,
|
||||
candidates: int = 100
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Two-stage cascade search: binary filter + dense rerank.
|
||||
|
||||
Args:
|
||||
query_dense: Dense query vector.
|
||||
dense_vectors: Dense vectors for reranking (from HNSW or stored).
|
||||
dense_chunk_ids: Chunk IDs corresponding to dense_vectors.
|
||||
top_k: Final number of results.
|
||||
candidates: Number of candidates from binary search.
|
||||
|
||||
Returns:
|
||||
List of (chunk_id, cosine_similarity) tuples.
|
||||
"""
|
||||
# Stage 1: Binary filtering
|
||||
binary_results = self.search(query_dense, top_k=candidates)
|
||||
if not binary_results:
|
||||
return []
|
||||
|
||||
candidate_ids = {r[0] for r in binary_results}
|
||||
|
||||
# Stage 2: Dense reranking
|
||||
# Find indices of candidates in dense_vectors
|
||||
candidate_mask = np.isin(dense_chunk_ids, list(candidate_ids))
|
||||
candidate_indices = np.where(candidate_mask)[0]
|
||||
|
||||
if len(candidate_indices) == 0:
|
||||
# Fallback: return binary results with normalized distance
|
||||
max_dist = max(r[1] for r in binary_results) if binary_results else 1
|
||||
return [(r[0], 1.0 - r[1] / max_dist) for r in binary_results[:top_k]]
|
||||
|
||||
# Compute cosine similarities for candidates
|
||||
candidate_vectors = dense_vectors[candidate_indices]
|
||||
candidate_ids_array = dense_chunk_ids[candidate_indices]
|
||||
|
||||
# Normalize vectors
|
||||
query_norm = query_dense / (np.linalg.norm(query_dense) + 1e-8)
|
||||
cand_norms = candidate_vectors / (
|
||||
np.linalg.norm(candidate_vectors, axis=1, keepdims=True) + 1e-8
|
||||
)
|
||||
|
||||
# Cosine similarities
|
||||
similarities = np.dot(cand_norms, query_norm)
|
||||
|
||||
# Sort by similarity (descending)
|
||||
sorted_indices = np.argsort(-similarities)[:top_k]
|
||||
|
||||
results = [
|
||||
(int(candidate_ids_array[i]), float(similarities[i]))
|
||||
for i in sorted_indices
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def vector_count(self) -> int:
|
||||
"""Get number of loaded binary vectors."""
|
||||
return len(self._chunk_ids) if self._chunk_ids is not None else 0
|
||||
|
||||
@property
|
||||
def is_memmap(self) -> bool:
|
||||
"""Check if using memory-mapped file (vs in-memory array)."""
|
||||
return self._is_memmap
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear loaded vectors from memory."""
|
||||
# For memmap, just delete the reference (OS will handle cleanup)
|
||||
if self._is_memmap and self._binary_matrix is not None:
|
||||
del self._binary_matrix
|
||||
self._chunk_ids = None
|
||||
self._binary_matrix = None
|
||||
self._is_memmap = False
|
||||
self._loaded = False
|
||||
3268
codex-lens/build/lib/codexlens/search/chain_search.py
Normal file
3268
codex-lens/build/lib/codexlens/search/chain_search.py
Normal file
File diff suppressed because it is too large
Load Diff
124
codex-lens/build/lib/codexlens/search/clustering/__init__.py
Normal file
124
codex-lens/build/lib/codexlens/search/clustering/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Clustering strategies for the staged hybrid search pipeline.
|
||||
|
||||
This module provides extensible clustering infrastructure for grouping
|
||||
similar search results and selecting representative results.
|
||||
|
||||
Install with: pip install codexlens[clustering]
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import (
|
||||
... CLUSTERING_AVAILABLE,
|
||||
... ClusteringConfig,
|
||||
... get_strategy,
|
||||
... )
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy with fallback
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
>>>
|
||||
>>> # Or explicitly use a specific strategy
|
||||
>>> if CLUSTERING_AVAILABLE:
|
||||
... from codexlens.search.clustering import HDBSCANStrategy
|
||||
... strategy = HDBSCANStrategy(config)
|
||||
... representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Always export base classes and factory (no heavy dependencies)
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .factory import (
|
||||
ClusteringStrategyFactory,
|
||||
check_clustering_strategy_available,
|
||||
get_strategy,
|
||||
)
|
||||
from .noop_strategy import NoOpStrategy
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
|
||||
# Feature flag for clustering availability (hdbscan + sklearn)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
HDBSCAN_AVAILABLE = False
|
||||
DBSCAN_AVAILABLE = False
|
||||
_import_error: str | None = None
|
||||
|
||||
|
||||
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
|
||||
"""Detect if clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
|
||||
"""
|
||||
hdbscan_ok = False
|
||||
dbscan_ok = False
|
||||
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
hdbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
dbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
all_ok = hdbscan_ok and dbscan_ok
|
||||
error = None
|
||||
if not all_ok:
|
||||
missing = []
|
||||
if not hdbscan_ok:
|
||||
missing.append("hdbscan")
|
||||
if not dbscan_ok:
|
||||
missing.append("scikit-learn")
|
||||
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
|
||||
|
||||
return all_ok, hdbscan_ok, dbscan_ok, error
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
|
||||
_detect_clustering_available()
|
||||
)
|
||||
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Check if all clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
return CLUSTERING_AVAILABLE, _import_error
|
||||
|
||||
|
||||
# Conditionally export strategy implementations
|
||||
__all__ = [
|
||||
# Feature flags
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"HDBSCAN_AVAILABLE",
|
||||
"DBSCAN_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
# Base classes
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
# Factory
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
"check_clustering_strategy_available",
|
||||
# Always-available strategies
|
||||
"NoOpStrategy",
|
||||
"FrequencyStrategy",
|
||||
"FrequencyConfig",
|
||||
]
|
||||
|
||||
# Conditionally add strategy classes to __all__ and module namespace
|
||||
if HDBSCAN_AVAILABLE:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
|
||||
__all__.append("HDBSCANStrategy")
|
||||
|
||||
if DBSCAN_AVAILABLE:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
|
||||
__all__.append("DBSCANStrategy")
|
||||
153
codex-lens/build/lib/codexlens/search/clustering/base.py
Normal file
153
codex-lens/build/lib/codexlens/search/clustering/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Base classes for clustering strategies in the hybrid search pipeline.
|
||||
|
||||
This module defines the abstract base class for clustering strategies used
|
||||
in the staged hybrid search pipeline. Strategies cluster search results
|
||||
based on their embeddings and select representative results from each cluster.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusteringConfig:
|
||||
"""Configuration parameters for clustering strategies.
|
||||
|
||||
Attributes:
|
||||
min_cluster_size: Minimum number of results to form a cluster.
|
||||
HDBSCAN default is 5, but for search results 2-3 is often better.
|
||||
min_samples: Number of samples in a neighborhood for a point to be
|
||||
considered a core point. Lower values allow more clusters.
|
||||
metric: Distance metric for clustering. Common options:
|
||||
- 'euclidean': Standard L2 distance
|
||||
- 'cosine': Cosine distance (1 - cosine_similarity)
|
||||
- 'manhattan': L1 distance
|
||||
cluster_selection_epsilon: Distance threshold for cluster selection.
|
||||
Results within this distance may be merged into the same cluster.
|
||||
allow_single_cluster: If True, allow all results to form one cluster.
|
||||
Useful when results are very similar.
|
||||
prediction_data: If True, generate prediction data for new points.
|
||||
"""
|
||||
|
||||
min_cluster_size: int = 3
|
||||
min_samples: int = 2
|
||||
metric: str = "cosine"
|
||||
cluster_selection_epsilon: float = 0.0
|
||||
allow_single_cluster: bool = True
|
||||
prediction_data: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
if self.min_cluster_size < 2:
|
||||
raise ValueError("min_cluster_size must be >= 2")
|
||||
if self.min_samples < 1:
|
||||
raise ValueError("min_samples must be >= 1")
|
||||
if self.metric not in ("euclidean", "cosine", "manhattan"):
|
||||
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
|
||||
if self.cluster_selection_epsilon < 0:
|
||||
raise ValueError("cluster_selection_epsilon must be >= 0")
|
||||
|
||||
|
||||
class BaseClusteringStrategy(ABC):
|
||||
"""Abstract base class for clustering strategies.
|
||||
|
||||
Clustering strategies are used in the staged hybrid search pipeline to
|
||||
group similar search results and select representative results from each
|
||||
cluster, reducing redundancy while maintaining diversity.
|
||||
|
||||
Subclasses must implement:
|
||||
- cluster(): Group results into clusters based on embeddings
|
||||
- select_representatives(): Choose best result(s) from each cluster
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize the clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config = config or ClusteringConfig()
|
||||
|
||||
@abstractmethod
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results based on their embeddings.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
Used for additional metadata during clustering.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Results not assigned to any cluster
|
||||
(noise points) should be returned as single-element clusters.
|
||||
|
||||
Example:
|
||||
>>> strategy = HDBSCANStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
|
||||
>>> # Result indices 0, 2, 5 are in cluster 0
|
||||
>>> # Result indices 1, 3 are in cluster 1
|
||||
>>> # Result index 4 is a noise point (singleton cluster)
|
||||
>>> # Result indices 6, 7, 8 are in cluster 2
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
This method chooses the best result(s) from each cluster to include
|
||||
in the final search results. The selection can be based on:
|
||||
- Highest score within cluster
|
||||
- Closest to cluster centroid
|
||||
- Custom selection logic
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings array for centroid-based selection.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one or more per cluster,
|
||||
ordered by relevance (highest score first).
|
||||
|
||||
Example:
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns best result from each cluster
|
||||
"""
|
||||
...
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
@@ -0,0 +1,197 @@
|
||||
"""DBSCAN-based clustering strategy for search results.
|
||||
|
||||
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the fallback clustering strategy when HDBSCAN is not available.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class DBSCANStrategy(BaseClusteringStrategy):
|
||||
"""DBSCAN-based clustering strategy.
|
||||
|
||||
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
|
||||
DBSCAN requires an explicit eps parameter, which is auto-computed from the
|
||||
distance distribution if not provided.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = DBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
# Default eps percentile for auto-computation
|
||||
DEFAULT_EPS_PERCENTILE: float = 15.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
eps: Optional[float] = None,
|
||||
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
|
||||
) -> None:
|
||||
"""Initialize DBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
|
||||
from the distance distribution.
|
||||
eps_percentile: Percentile of pairwise distances to use for
|
||||
auto-computing eps. Default is 15th percentile.
|
||||
|
||||
Raises:
|
||||
ImportError: If sklearn is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.eps = eps
|
||||
self.eps_percentile = eps_percentile
|
||||
|
||||
# Validate sklearn is available
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"scikit-learn package is required for DBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def _compute_eps(self, embeddings: "np.ndarray") -> float:
|
||||
"""Auto-compute eps from pairwise distance distribution.
|
||||
|
||||
Uses the specified percentile of pairwise distances as eps,
|
||||
which typically captures local density well.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
|
||||
Returns:
|
||||
Computed eps value.
|
||||
"""
|
||||
import numpy as np
|
||||
from sklearn.metrics import pairwise_distances
|
||||
|
||||
# Compute pairwise distances
|
||||
distances = pairwise_distances(embeddings, metric=self.config.metric)
|
||||
|
||||
# Get upper triangle (excluding diagonal)
|
||||
upper_tri = distances[np.triu_indices_from(distances, k=1)]
|
||||
|
||||
if len(upper_tri) == 0:
|
||||
# Only one point, return a default small eps
|
||||
return 0.1
|
||||
|
||||
# Use percentile of distances as eps
|
||||
eps = float(np.percentile(upper_tri, self.eps_percentile))
|
||||
|
||||
# Ensure eps is positive
|
||||
return max(eps, 1e-6)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using DBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
from sklearn.cluster import DBSCAN
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: single result
|
||||
if n_results == 1:
|
||||
return [[0]]
|
||||
|
||||
# Determine eps value
|
||||
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
|
||||
|
||||
# Configure DBSCAN clusterer
|
||||
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
|
||||
clusterer = DBSCAN(
|
||||
eps=eps,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
202
codex-lens/build/lib/codexlens/search/clustering/factory.py
Normal file
202
codex-lens/build/lib/codexlens/search/clustering/factory.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Factory for creating clustering strategies.
|
||||
|
||||
Provides a unified interface for instantiating different clustering backends
|
||||
with automatic fallback chain: hdbscan -> dbscan -> noop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .noop_strategy import NoOpStrategy
|
||||
|
||||
|
||||
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific clustering strategy can be used.
|
||||
|
||||
Args:
|
||||
strategy: Strategy name to check. Options:
|
||||
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
|
||||
- "dbscan": DBSCAN clustering (requires sklearn)
|
||||
- "frequency": Frequency-based clustering (always available)
|
||||
- "noop": No-op strategy (always available)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
if strategy == "hdbscan":
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"hdbscan package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "dbscan":
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"scikit-learn package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "frequency":
|
||||
# Frequency strategy is always available (no external deps)
|
||||
return True, None
|
||||
|
||||
if strategy == "noop":
|
||||
return True, None
|
||||
|
||||
return False, (
|
||||
f"Invalid clustering strategy: {strategy}. "
|
||||
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
|
||||
)
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy: str = "hdbscan",
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
*,
|
||||
fallback: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Factory function to create clustering strategy with fallback chain.
|
||||
|
||||
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
|
||||
|
||||
Args:
|
||||
strategy: Clustering strategy to use. Options:
|
||||
- "hdbscan": HDBSCAN clustering (default, recommended)
|
||||
- "dbscan": DBSCAN clustering (fallback)
|
||||
- "frequency": Frequency-based clustering (groups by symbol occurrence)
|
||||
- "noop": No-op strategy (returns all results ungrouped)
|
||||
- "auto": Try hdbscan, then dbscan, then noop
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
For frequency strategy, pass FrequencyConfig for full control.
|
||||
fallback: If True (default), automatically fall back to next strategy
|
||||
in the chain when primary is unavailable. If False, raise ImportError
|
||||
when requested strategy is unavailable.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
For DBSCANStrategy: eps, eps_percentile
|
||||
For FrequencyStrategy: group_by, min_frequency, etc.
|
||||
|
||||
Returns:
|
||||
BaseClusteringStrategy: Configured clustering strategy instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If strategy is not recognized.
|
||||
ImportError: If required dependencies are not installed and fallback=False.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
|
||||
>>> strategy = get_strategy("hdbscan", config)
|
||||
>>> # Use frequency-based strategy
|
||||
>>> from codexlens.search.clustering import FrequencyConfig
|
||||
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = get_strategy("frequency", freq_config)
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
# Handle "auto" - try strategies in order
|
||||
if strategy == "auto":
|
||||
return _get_best_available_strategy(config, **kwargs)
|
||||
|
||||
if strategy == "hdbscan":
|
||||
ok, err = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
if fallback:
|
||||
# Try dbscan fallback
|
||||
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok_dbscan:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
# Final fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "dbscan":
|
||||
ok, err = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
if fallback:
|
||||
# Fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "frequency":
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
|
||||
if config is None or not isinstance(config, FrequencyConfig):
|
||||
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
|
||||
else:
|
||||
freq_config = config
|
||||
return FrequencyStrategy(freq_config)
|
||||
|
||||
if strategy == "noop":
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown clustering strategy: {strategy}. "
|
||||
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
|
||||
)
|
||||
|
||||
|
||||
def _get_best_available_strategy(
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Get the best available clustering strategy.
|
||||
|
||||
Tries strategies in order: hdbscan -> dbscan -> noop
|
||||
|
||||
Args:
|
||||
config: Clustering configuration.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
|
||||
Returns:
|
||||
Best available clustering strategy instance.
|
||||
"""
|
||||
# Try HDBSCAN first
|
||||
ok, _ = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
# Try DBSCAN second
|
||||
ok, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
# Fallback to NoOp
|
||||
return NoOpStrategy(config)
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
ClusteringStrategyFactory = type(
|
||||
"ClusteringStrategyFactory",
|
||||
(),
|
||||
{
|
||||
"get_strategy": staticmethod(get_strategy),
|
||||
"check_available": staticmethod(check_clustering_strategy_available),
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol/method name and prunes based on
|
||||
occurrence frequency. High-frequency symbols (frequently referenced methods)
|
||||
are considered more important and retained, while low-frequency results
|
||||
(potentially noise) can be filtered out.
|
||||
|
||||
Use cases:
|
||||
- Prioritize commonly called methods/functions
|
||||
- Filter out one-off results that may be less relevant
|
||||
- Deduplicate results pointing to the same symbol from different locations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrequencyConfig(ClusteringConfig):
|
||||
"""Configuration for frequency-based clustering strategy.
|
||||
|
||||
Attributes:
|
||||
group_by: Field to group results by for frequency counting.
|
||||
- 'symbol': Group by symbol_name (default, for method/function dedup)
|
||||
- 'file': Group by file path
|
||||
- 'symbol_kind': Group by symbol type (function, class, etc.)
|
||||
min_frequency: Minimum occurrence count to keep a result.
|
||||
Results appearing less than this are considered noise and pruned.
|
||||
max_representatives_per_group: Maximum results to keep per symbol group.
|
||||
frequency_weight: How much to boost score based on frequency.
|
||||
Final score = original_score * (1 + frequency_weight * log(frequency))
|
||||
keep_mode: How to handle low-frequency results.
|
||||
- 'filter': Remove results below min_frequency
|
||||
- 'demote': Keep but lower their score ranking
|
||||
"""
|
||||
|
||||
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
|
||||
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
|
||||
max_representatives_per_group: int = 3
|
||||
frequency_weight: float = 0.1 # Boost factor for frequency
|
||||
keep_mode: Literal["filter", "demote"] = "demote"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
# Skip parent validation since we don't use HDBSCAN params
|
||||
if self.min_frequency < 1:
|
||||
raise ValueError("min_frequency must be >= 1")
|
||||
if self.max_representatives_per_group < 1:
|
||||
raise ValueError("max_representatives_per_group must be >= 1")
|
||||
if self.frequency_weight < 0:
|
||||
raise ValueError("frequency_weight must be >= 0")
|
||||
if self.group_by not in ("symbol", "file", "symbol_kind"):
|
||||
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
|
||||
if self.keep_mode not in ("filter", "demote"):
|
||||
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
|
||||
|
||||
|
||||
class FrequencyStrategy(BaseClusteringStrategy):
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol name (or file/kind) and:
|
||||
1. Counts how many times each symbol appears in results
|
||||
2. Higher frequency = more important (frequently referenced method)
|
||||
3. Filters or demotes low-frequency results
|
||||
4. Selects top representatives from each frequency group
|
||||
|
||||
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
|
||||
- Does NOT require embeddings (works with metadata only)
|
||||
- Is very fast (O(n) complexity)
|
||||
- Is deterministic (no random initialization)
|
||||
- Works well for symbol-level deduplication
|
||||
|
||||
Example:
|
||||
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = FrequencyStrategy(config)
|
||||
>>> # Results with symbol "authenticate" appearing 5 times
|
||||
>>> # will be prioritized over "helper_func" appearing once
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
|
||||
"""Initialize the frequency strategy.
|
||||
|
||||
Args:
|
||||
config: Frequency configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config: FrequencyConfig = config or FrequencyConfig()
|
||||
|
||||
def _get_group_key(self, result: "SearchResult") -> str:
|
||||
"""Extract grouping key from a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult to extract key from.
|
||||
|
||||
Returns:
|
||||
String key for grouping (symbol name, file path, or kind).
|
||||
"""
|
||||
if self.config.group_by == "symbol":
|
||||
# Use symbol_name if available, otherwise fall back to file:line
|
||||
symbol = getattr(result, "symbol_name", None)
|
||||
if symbol:
|
||||
return str(symbol)
|
||||
# Fallback: use file path + start_line as pseudo-symbol
|
||||
start_line = getattr(result, "start_line", 0) or 0
|
||||
return f"{result.path}:{start_line}"
|
||||
|
||||
elif self.config.group_by == "file":
|
||||
return str(result.path)
|
||||
|
||||
elif self.config.group_by == "symbol_kind":
|
||||
kind = getattr(result, "symbol_kind", None)
|
||||
return str(kind) if kind else "unknown"
|
||||
|
||||
return str(result.path) # Default fallback
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Group search results by frequency of occurrence.
|
||||
|
||||
Note: This method ignores embeddings and groups by metadata only.
|
||||
The embeddings parameter is kept for interface compatibility.
|
||||
|
||||
Args:
|
||||
embeddings: Ignored (kept for interface compatibility).
|
||||
results: List of SearchResult objects to cluster.
|
||||
|
||||
Returns:
|
||||
List of clusters (groups), where each cluster contains indices
|
||||
of results with the same grouping key. Clusters are ordered by
|
||||
frequency (highest frequency first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Group results by key
|
||||
groups: Dict[str, List[int]] = defaultdict(list)
|
||||
for idx, result in enumerate(results):
|
||||
key = self._get_group_key(result)
|
||||
groups[key].append(idx)
|
||||
|
||||
# Sort groups by frequency (descending) then by key (for stability)
|
||||
sorted_groups = sorted(
|
||||
groups.items(),
|
||||
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
|
||||
)
|
||||
|
||||
# Convert to list of clusters
|
||||
clusters = [indices for _, indices in sorted_groups]
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results based on frequency and score.
|
||||
|
||||
For each frequency group:
|
||||
1. If frequency < min_frequency: filter or demote based on keep_mode
|
||||
2. Sort by score within group
|
||||
3. Apply frequency boost to scores
|
||||
4. Select top N representatives
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (used for tie-breaking if provided).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, ordered by
|
||||
frequency-adjusted score (highest first).
|
||||
"""
|
||||
import math
|
||||
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
demoted: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
frequency = len(cluster_indices)
|
||||
|
||||
# Get results in this cluster, sorted by score
|
||||
cluster_results = [results[i] for i in cluster_indices]
|
||||
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
|
||||
# Check frequency threshold
|
||||
if frequency < self.config.min_frequency:
|
||||
if self.config.keep_mode == "filter":
|
||||
# Skip low-frequency results entirely
|
||||
continue
|
||||
else: # demote mode
|
||||
# Keep but add to demoted list (lower priority)
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
demoted.append(result)
|
||||
continue
|
||||
|
||||
# Apply frequency boost and select top representatives
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
# Calculate frequency-boosted score
|
||||
original_score = getattr(result, "score", 0.0)
|
||||
# log(frequency + 1) to handle frequency=1 case smoothly
|
||||
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
|
||||
boosted_score = original_score * frequency_boost
|
||||
|
||||
# Create new result with boosted score and frequency metadata
|
||||
# Note: SearchResult might be immutable, so we preserve original
|
||||
# and track boosted score in metadata
|
||||
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
|
||||
result.metadata["frequency"] = frequency
|
||||
result.metadata["frequency_boosted_score"] = boosted_score
|
||||
|
||||
representatives.append(result)
|
||||
|
||||
# Sort representatives by boosted score (or original score as fallback)
|
||||
def get_sort_score(r: "SearchResult") -> float:
|
||||
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
|
||||
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
|
||||
return getattr(r, "score", 0.0)
|
||||
|
||||
representatives.sort(key=get_sort_score, reverse=True)
|
||||
|
||||
# Add demoted results at the end
|
||||
if demoted:
|
||||
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
representatives.extend(demoted)
|
||||
|
||||
return representatives
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array (may be ignored for frequency-based clustering).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
@@ -0,0 +1,153 @@
|
||||
"""HDBSCAN-based clustering strategy for search results.
|
||||
|
||||
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the primary clustering strategy for grouping similar search results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class HDBSCANStrategy(BaseClusteringStrategy):
|
||||
"""HDBSCAN-based clustering strategy.
|
||||
|
||||
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
|
||||
HDBSCAN is preferred over DBSCAN because it:
|
||||
- Automatically determines the number of clusters
|
||||
- Handles varying density clusters well
|
||||
- Identifies noise points (outliers) effectively
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = HDBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize HDBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
|
||||
Raises:
|
||||
ImportError: If hdbscan package is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
# Validate hdbscan is available
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"hdbscan package is required for HDBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using HDBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
import hdbscan
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: fewer results than min_cluster_size
|
||||
if n_results < self.config.min_cluster_size:
|
||||
# Return each result as its own singleton cluster
|
||||
return [[i] for i in range(n_results)]
|
||||
|
||||
# Configure HDBSCAN clusterer
|
||||
clusterer = hdbscan.HDBSCAN(
|
||||
min_cluster_size=self.config.min_cluster_size,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
|
||||
allow_single_cluster=self.config.allow_single_cluster,
|
||||
prediction_data=self.config.prediction_data,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
@@ -0,0 +1,83 @@
|
||||
"""No-op clustering strategy for search results.
|
||||
|
||||
NoOpStrategy returns all results ungrouped when clustering dependencies
|
||||
are not available or clustering is disabled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class NoOpStrategy(BaseClusteringStrategy):
|
||||
"""No-op clustering strategy that returns all results ungrouped.
|
||||
|
||||
This strategy is used as a final fallback when no clustering dependencies
|
||||
are available, or when clustering is explicitly disabled. Each result
|
||||
is treated as its own singleton cluster.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import NoOpStrategy
|
||||
>>> strategy = NoOpStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns all results sorted by score
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize NoOp clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Ignored for NoOpStrategy
|
||||
but accepted for interface compatibility.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Return each result as its own singleton cluster.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
Not used but accepted for interface compatibility.
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of singleton clusters, one per result.
|
||||
"""
|
||||
return [[i] for i in range(len(results))]
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Return all results sorted by score.
|
||||
|
||||
Since each cluster is a singleton, this effectively returns all
|
||||
results sorted by score descending.
|
||||
|
||||
Args:
|
||||
clusters: List of singleton clusters.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used).
|
||||
|
||||
Returns:
|
||||
All SearchResult objects sorted by score (highest first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Return all results sorted by score
|
||||
return sorted(results, key=lambda r: r.score, reverse=True)
|
||||
171
codex-lens/build/lib/codexlens/search/enrichment.py
Normal file
171
codex-lens/build/lib/codexlens/search/enrichment.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# codex-lens/src/codexlens/search/enrichment.py
|
||||
"""Relationship enrichment for search results."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.graph_expander import GraphExpander
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
|
||||
class RelationshipEnricher:
|
||||
"""Enriches search results with code graph relationships."""
|
||||
|
||||
def __init__(self, index_path: Path):
|
||||
"""Initialize with path to index database.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db SQLite database
|
||||
"""
|
||||
self.index_path = index_path
|
||||
self.db_conn: Optional[sqlite3.Connection] = None
|
||||
self._connect()
|
||||
|
||||
def _connect(self) -> None:
|
||||
"""Establish read-only database connection."""
|
||||
if self.index_path.exists():
|
||||
self.db_conn = sqlite3.connect(
|
||||
f"file:{self.index_path}?mode=ro",
|
||||
uri=True,
|
||||
check_same_thread=False
|
||||
)
|
||||
self.db_conn.row_factory = sqlite3.Row
|
||||
|
||||
def enrich(self, results: List[Dict[str, Any]], limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Add relationship data to search results.
|
||||
|
||||
Args:
|
||||
results: List of search result dictionaries
|
||||
limit: Maximum number of results to enrich
|
||||
|
||||
Returns:
|
||||
Results with relationships field added
|
||||
"""
|
||||
if not self.db_conn:
|
||||
return results
|
||||
|
||||
for result in results[:limit]:
|
||||
file_path = result.get('file') or result.get('path')
|
||||
symbol_name = result.get('symbol')
|
||||
result['relationships'] = self._find_relationships(file_path, symbol_name)
|
||||
return results
|
||||
|
||||
def _find_relationships(self, file_path: Optional[str], symbol_name: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""Query relationships for a symbol.
|
||||
|
||||
Args:
|
||||
file_path: Path to file containing the symbol
|
||||
symbol_name: Name of the symbol
|
||||
|
||||
Returns:
|
||||
List of relationship dictionaries with type, direction, target/source, file, line
|
||||
"""
|
||||
if not self.db_conn or not symbol_name:
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
try:
|
||||
# Find symbol ID(s) by name and optionally file
|
||||
if file_path:
|
||||
cursor.execute(
|
||||
'SELECT id FROM symbols WHERE name = ? AND file_path = ?',
|
||||
(symbol_name, file_path)
|
||||
)
|
||||
else:
|
||||
cursor.execute('SELECT id FROM symbols WHERE name = ?', (symbol_name,))
|
||||
|
||||
symbol_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not symbol_ids:
|
||||
return []
|
||||
|
||||
# Query outgoing relationships (symbol is source)
|
||||
placeholders = ','.join('?' * len(symbol_ids))
|
||||
cursor.execute(f'''
|
||||
SELECT sr.relationship_type, sr.target_symbol_fqn, sr.file_path, sr.line
|
||||
FROM symbol_relationships sr
|
||||
WHERE sr.source_symbol_id IN ({placeholders})
|
||||
''', symbol_ids)
|
||||
|
||||
for row in cursor.fetchall():
|
||||
relationships.append({
|
||||
'type': row[0],
|
||||
'direction': 'outgoing',
|
||||
'target': row[1],
|
||||
'file': row[2],
|
||||
'line': row[3],
|
||||
})
|
||||
|
||||
# Query incoming relationships (symbol is target)
|
||||
# Match against symbol name or qualified name patterns
|
||||
cursor.execute('''
|
||||
SELECT sr.relationship_type, s.name AS source_name, sr.file_path, sr.line
|
||||
FROM symbol_relationships sr
|
||||
JOIN symbols s ON sr.source_symbol_id = s.id
|
||||
WHERE sr.target_symbol_fqn = ? OR sr.target_symbol_fqn LIKE ?
|
||||
''', (symbol_name, f'%.{symbol_name}'))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
rel_type = row[0]
|
||||
# Convert to incoming type
|
||||
incoming_type = self._to_incoming_type(rel_type)
|
||||
relationships.append({
|
||||
'type': incoming_type,
|
||||
'direction': 'incoming',
|
||||
'source': row[1],
|
||||
'file': row[2],
|
||||
'line': row[3],
|
||||
})
|
||||
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
return relationships
|
||||
|
||||
def _to_incoming_type(self, outgoing_type: str) -> str:
|
||||
"""Convert outgoing relationship type to incoming type.
|
||||
|
||||
Args:
|
||||
outgoing_type: The outgoing relationship type (e.g., 'calls', 'imports')
|
||||
|
||||
Returns:
|
||||
Corresponding incoming type (e.g., 'called_by', 'imported_by')
|
||||
"""
|
||||
type_map = {
|
||||
'calls': 'called_by',
|
||||
'imports': 'imported_by',
|
||||
'extends': 'extended_by',
|
||||
}
|
||||
return type_map.get(outgoing_type, f'{outgoing_type}_by')
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self.db_conn:
|
||||
self.db_conn.close()
|
||||
self.db_conn = None
|
||||
|
||||
def __enter__(self) -> 'RelationshipEnricher':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class SearchEnrichmentPipeline:
|
||||
"""Search post-processing pipeline (optional enrichments)."""
|
||||
|
||||
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||
self._config = config
|
||||
self._graph_expander = GraphExpander(mapper, config=config)
|
||||
|
||||
def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]:
|
||||
"""Expand base results with related symbols when enabled in config."""
|
||||
if self._config is None or not getattr(self._config, "enable_graph_expansion", False):
|
||||
return []
|
||||
|
||||
depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2)
|
||||
return self._graph_expander.expand(results, depth=depth)
|
||||
264
codex-lens/build/lib/codexlens/search/graph_expander.py
Normal file
264
codex-lens/build/lib/codexlens/search/graph_expander.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Graph expansion for search results using precomputed neighbors.
|
||||
|
||||
Expands top search results with related symbol definitions by traversing
|
||||
precomputed N-hop neighbors stored in the per-directory index databases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]:
|
||||
return (result.path, result.symbol_name, result.start_line, result.end_line)
|
||||
|
||||
|
||||
def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]:
|
||||
if content is None:
|
||||
return None
|
||||
if start_line is None or end_line is None:
|
||||
return None
|
||||
if start_line < 1 or end_line < start_line:
|
||||
return None
|
||||
|
||||
lines = content.splitlines()
|
||||
start_idx = max(0, start_line - 1)
|
||||
end_idx = min(len(lines), end_line)
|
||||
if start_idx >= len(lines):
|
||||
return None
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
|
||||
class GraphExpander:
|
||||
"""Expands SearchResult lists with related symbols from the code graph."""
|
||||
|
||||
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||
self._mapper = mapper
|
||||
self._config = config
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
def expand(
|
||||
self,
|
||||
results: Sequence[SearchResult],
|
||||
*,
|
||||
depth: Optional[int] = None,
|
||||
max_expand: int = 10,
|
||||
max_related: int = 50,
|
||||
) -> List[SearchResult]:
|
||||
"""Expand top results with related symbols.
|
||||
|
||||
Args:
|
||||
results: Base ranked results.
|
||||
depth: Maximum relationship depth to include (defaults to Config or 2).
|
||||
max_expand: Only expand the top-N base results to bound cost.
|
||||
max_related: Maximum related results to return.
|
||||
|
||||
Returns:
|
||||
A list of related SearchResult objects with relationship_depth metadata.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2
|
||||
max_depth = int(depth if depth is not None else configured_depth)
|
||||
if max_depth <= 0:
|
||||
return []
|
||||
max_depth = min(max_depth, 2)
|
||||
|
||||
expand_count = max(0, int(max_expand))
|
||||
related_limit = max(0, int(max_related))
|
||||
if expand_count == 0 or related_limit == 0:
|
||||
return []
|
||||
|
||||
seen = {_result_key(r) for r in results}
|
||||
related_results: List[SearchResult] = []
|
||||
conn_cache: Dict[Path, sqlite3.Connection] = {}
|
||||
|
||||
try:
|
||||
for base in list(results)[:expand_count]:
|
||||
if len(related_results) >= related_limit:
|
||||
break
|
||||
|
||||
if not base.symbol_name or not base.path:
|
||||
continue
|
||||
|
||||
index_path = self._mapper.source_to_index_db(Path(base.path).parent)
|
||||
conn = conn_cache.get(index_path)
|
||||
if conn is None:
|
||||
conn = self._connect_readonly(index_path)
|
||||
if conn is None:
|
||||
continue
|
||||
conn_cache[index_path] = conn
|
||||
|
||||
source_ids = self._resolve_source_symbol_ids(
|
||||
conn,
|
||||
file_path=base.path,
|
||||
symbol_name=base.symbol_name,
|
||||
symbol_kind=base.symbol_kind,
|
||||
)
|
||||
if not source_ids:
|
||||
continue
|
||||
|
||||
for source_id in source_ids:
|
||||
neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit)
|
||||
for neighbor_id, rel_depth in neighbors:
|
||||
if len(related_results) >= related_limit:
|
||||
break
|
||||
row = self._get_symbol_details(conn, neighbor_id)
|
||||
if row is None:
|
||||
continue
|
||||
|
||||
path = str(row["full_path"])
|
||||
symbol_name = str(row["name"])
|
||||
symbol_kind = str(row["kind"])
|
||||
start_line = int(row["start_line"]) if row["start_line"] is not None else None
|
||||
end_line = int(row["end_line"]) if row["end_line"] is not None else None
|
||||
content_block = _slice_content_block(
|
||||
str(row["content"]) if row["content"] is not None else "",
|
||||
start_line,
|
||||
end_line,
|
||||
)
|
||||
|
||||
score = float(base.score) * (0.5 ** int(rel_depth))
|
||||
candidate = SearchResult(
|
||||
path=path,
|
||||
score=max(0.0, score),
|
||||
excerpt=None,
|
||||
content=content_block,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
metadata={"relationship_depth": int(rel_depth)},
|
||||
)
|
||||
|
||||
key = _result_key(candidate)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
related_results.append(candidate)
|
||||
|
||||
finally:
|
||||
for conn in conn_cache.values():
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return related_results
|
||||
|
||||
def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]:
|
||||
try:
|
||||
if not index_path.exists() or index_path.stat().st_size == 0:
|
||||
return None
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
except Exception as exc:
|
||||
self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc)
|
||||
return None
|
||||
|
||||
def _resolve_source_symbol_ids(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
*,
|
||||
file_path: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str],
|
||||
) -> List[int]:
|
||||
try:
|
||||
if symbol_kind:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.id
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE f.full_path = ? AND s.name = ? AND s.kind = ?
|
||||
""",
|
||||
(file_path, symbol_name, symbol_kind),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.id
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE f.full_path = ? AND s.name = ?
|
||||
""",
|
||||
(file_path, symbol_name),
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
ids: List[int] = []
|
||||
for row in rows:
|
||||
try:
|
||||
ids.append(int(row["id"]))
|
||||
except Exception:
|
||||
continue
|
||||
return ids
|
||||
|
||||
def _get_neighbors(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
source_symbol_id: int,
|
||||
*,
|
||||
max_depth: int,
|
||||
limit: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT neighbor_symbol_id, relationship_depth
|
||||
FROM graph_neighbors
|
||||
WHERE source_symbol_id = ? AND relationship_depth <= ?
|
||||
ORDER BY relationship_depth ASC, neighbor_symbol_id ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(int(source_symbol_id), int(max_depth), int(limit)),
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
neighbors: List[Tuple[int, int]] = []
|
||||
for row in rows:
|
||||
try:
|
||||
neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"])))
|
||||
except Exception:
|
||||
continue
|
||||
return neighbors
|
||||
|
||||
def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]:
|
||||
try:
|
||||
return conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.id,
|
||||
s.name,
|
||||
s.kind,
|
||||
s.start_line,
|
||||
s.end_line,
|
||||
f.full_path,
|
||||
f.content
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE s.id = ?
|
||||
""",
|
||||
(int(symbol_id),),
|
||||
).fetchone()
|
||||
except sqlite3.Error:
|
||||
return None
|
||||
|
||||
1409
codex-lens/build/lib/codexlens/search/hybrid_search.py
Normal file
1409
codex-lens/build/lib/codexlens/search/hybrid_search.py
Normal file
File diff suppressed because it is too large
Load Diff
242
codex-lens/build/lib/codexlens/search/query_parser.py
Normal file
242
codex-lens/build/lib/codexlens/search/query_parser.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Query preprocessing for CodexLens search.
|
||||
|
||||
Provides query expansion for better identifier matching:
|
||||
- CamelCase splitting: UserAuth → User OR Auth
|
||||
- snake_case splitting: user_auth → user OR auth
|
||||
- Preserves original query for exact matching
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Set, List
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryParser:
|
||||
"""Parser for preprocessing search queries before FTS5 execution.
|
||||
|
||||
Expands identifier-style queries (CamelCase, snake_case) into OR queries
|
||||
to improve recall when searching for code symbols.
|
||||
|
||||
Example transformations:
|
||||
- 'UserAuth' → 'UserAuth OR User OR Auth'
|
||||
- 'user_auth' → 'user_auth OR user OR auth'
|
||||
- 'getUserData' → 'getUserData OR get OR User OR Data'
|
||||
"""
|
||||
|
||||
# Patterns for identifier splitting
|
||||
CAMEL_CASE_PATTERN = re.compile(r'([a-z])([A-Z])')
|
||||
SNAKE_CASE_PATTERN = re.compile(r'_+')
|
||||
KEBAB_CASE_PATTERN = re.compile(r'-+')
|
||||
|
||||
# Minimum token length to include in expansion (avoid noise from single chars)
|
||||
MIN_TOKEN_LENGTH = 2
|
||||
|
||||
# All-caps acronyms pattern (e.g., HTTP, SQL, API)
|
||||
ALL_CAPS_PATTERN = re.compile(r'^[A-Z]{2,}$')
|
||||
|
||||
def __init__(self, enable: bool = True, min_token_length: int = 2):
|
||||
"""Initialize query parser.
|
||||
|
||||
Args:
|
||||
enable: Whether to enable query preprocessing
|
||||
min_token_length: Minimum token length to include in expansion
|
||||
"""
|
||||
self.enable = enable
|
||||
self.min_token_length = min_token_length
|
||||
|
||||
def preprocess_query(self, query: str) -> str:
|
||||
"""Preprocess query with identifier expansion.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
|
||||
Returns:
|
||||
Expanded query with OR operator connecting original and split tokens
|
||||
|
||||
Example:
|
||||
>>> parser = QueryParser()
|
||||
>>> parser.preprocess_query('UserAuth')
|
||||
'UserAuth OR User OR Auth'
|
||||
>>> parser.preprocess_query('get_user_data')
|
||||
'get_user_data OR get OR user OR data'
|
||||
"""
|
||||
if not self.enable:
|
||||
return query
|
||||
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return query
|
||||
|
||||
# Extract tokens from query (handle multiple words/terms)
|
||||
# For simple queries, just process the whole thing
|
||||
# For complex FTS5 queries with operators, preserve structure
|
||||
if self._is_simple_query(query):
|
||||
return self._expand_simple_query(query)
|
||||
else:
|
||||
# Complex query with FTS5 operators, don't expand
|
||||
log.debug(f"Skipping expansion for complex FTS5 query: {query}")
|
||||
return query
|
||||
|
||||
def _is_simple_query(self, query: str) -> bool:
|
||||
"""Check if query is simple (no FTS5 operators).
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
True if query is simple (safe to expand), False otherwise
|
||||
"""
|
||||
# Check for FTS5 operators that indicate complex query
|
||||
fts5_operators = ['OR', 'AND', 'NOT', 'NEAR', '*', '^', '"']
|
||||
return not any(op in query for op in fts5_operators)
|
||||
|
||||
def _expand_simple_query(self, query: str) -> str:
|
||||
"""Expand a simple query with identifier splitting.
|
||||
|
||||
Args:
|
||||
query: Simple search query
|
||||
|
||||
Returns:
|
||||
Expanded query with OR operators
|
||||
"""
|
||||
tokens: Set[str] = set()
|
||||
|
||||
# Always include original query
|
||||
tokens.add(query)
|
||||
|
||||
# Split on whitespace first
|
||||
words = query.split()
|
||||
|
||||
for word in words:
|
||||
# Extract tokens from this word
|
||||
word_tokens = self._extract_tokens(word)
|
||||
tokens.update(word_tokens)
|
||||
|
||||
# Filter out short tokens and duplicates
|
||||
filtered_tokens = [
|
||||
t for t in tokens
|
||||
if len(t) >= self.min_token_length
|
||||
]
|
||||
|
||||
# Remove duplicates while preserving original query first
|
||||
unique_tokens: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
|
||||
# Always put original query first
|
||||
if query not in seen and len(query) >= self.min_token_length:
|
||||
unique_tokens.append(query)
|
||||
seen.add(query)
|
||||
|
||||
# Add other tokens
|
||||
for token in filtered_tokens:
|
||||
if token not in seen:
|
||||
unique_tokens.append(token)
|
||||
seen.add(token)
|
||||
|
||||
# Join with OR operator (only if we have multiple tokens)
|
||||
if len(unique_tokens) > 1:
|
||||
expanded = ' OR '.join(unique_tokens)
|
||||
log.debug(f"Expanded query: '{query}' → '{expanded}'")
|
||||
return expanded
|
||||
else:
|
||||
return query
|
||||
|
||||
def _extract_tokens(self, word: str) -> Set[str]:
|
||||
"""Extract tokens from a single word using various splitting strategies.
|
||||
|
||||
Args:
|
||||
word: Single word/identifier to split
|
||||
|
||||
Returns:
|
||||
Set of extracted tokens
|
||||
"""
|
||||
tokens: Set[str] = set()
|
||||
|
||||
# Add original word
|
||||
tokens.add(word)
|
||||
|
||||
# Handle all-caps acronyms (don't split)
|
||||
if self.ALL_CAPS_PATTERN.match(word):
|
||||
return tokens
|
||||
|
||||
# CamelCase splitting
|
||||
camel_tokens = self._split_camel_case(word)
|
||||
tokens.update(camel_tokens)
|
||||
|
||||
# snake_case splitting
|
||||
snake_tokens = self._split_snake_case(word)
|
||||
tokens.update(snake_tokens)
|
||||
|
||||
# kebab-case splitting
|
||||
kebab_tokens = self._split_kebab_case(word)
|
||||
tokens.update(kebab_tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _split_camel_case(self, word: str) -> List[str]:
|
||||
"""Split CamelCase identifier into tokens.
|
||||
|
||||
Args:
|
||||
word: CamelCase identifier (e.g., 'getUserData')
|
||||
|
||||
Returns:
|
||||
List of tokens (e.g., ['get', 'User', 'Data'])
|
||||
"""
|
||||
# Insert space before uppercase letters preceded by lowercase
|
||||
spaced = self.CAMEL_CASE_PATTERN.sub(r'\1 \2', word)
|
||||
# Split on spaces and filter empty
|
||||
return [t for t in spaced.split() if t]
|
||||
|
||||
def _split_snake_case(self, word: str) -> List[str]:
|
||||
"""Split snake_case identifier into tokens.
|
||||
|
||||
Args:
|
||||
word: snake_case identifier (e.g., 'get_user_data')
|
||||
|
||||
Returns:
|
||||
List of tokens (e.g., ['get', 'user', 'data'])
|
||||
"""
|
||||
# Split on underscores
|
||||
return [t for t in self.SNAKE_CASE_PATTERN.split(word) if t]
|
||||
|
||||
def _split_kebab_case(self, word: str) -> List[str]:
|
||||
"""Split kebab-case identifier into tokens.
|
||||
|
||||
Args:
|
||||
word: kebab-case identifier (e.g., 'get-user-data')
|
||||
|
||||
Returns:
|
||||
List of tokens (e.g., ['get', 'user', 'data'])
|
||||
"""
|
||||
# Split on hyphens
|
||||
return [t for t in self.KEBAB_CASE_PATTERN.split(word) if t]
|
||||
|
||||
|
||||
# Global default parser instance
|
||||
_default_parser = QueryParser(enable=True)
|
||||
|
||||
|
||||
def preprocess_query(query: str, enable: bool = True) -> str:
|
||||
"""Convenience function for query preprocessing.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
enable: Whether to enable preprocessing
|
||||
|
||||
Returns:
|
||||
Preprocessed query with identifier expansion
|
||||
"""
|
||||
if not enable:
|
||||
return query
|
||||
|
||||
return _default_parser.preprocess_query(query)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QueryParser",
|
||||
"preprocess_query",
|
||||
]
|
||||
942
codex-lens/build/lib/codexlens/search/ranking.py
Normal file
942
codex-lens/build/lib/codexlens/search/ranking.py
Normal file
@@ -0,0 +1,942 @@
|
||||
"""Ranking algorithms for hybrid search result fusion.
|
||||
|
||||
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
|
||||
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import math
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
|
||||
# Default RRF weights for SPLADE-based hybrid search
|
||||
DEFAULT_WEIGHTS = {
|
||||
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
}
|
||||
|
||||
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
||||
FTS_FALLBACK_WEIGHTS = {
|
||||
"exact": 0.25,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
}
|
||||
|
||||
|
||||
class QueryIntent(str, Enum):
|
||||
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
def normalize_weights(weights: Dict[str, float | None]) -> Dict[str, float | None]:
|
||||
"""Normalize weights to sum to 1.0 (best-effort)."""
|
||||
total = sum(float(v) for v in weights.values() if v is not None)
|
||||
|
||||
# NaN total: do not attempt to normalize (division would propagate NaNs).
|
||||
if math.isnan(total):
|
||||
return dict(weights)
|
||||
|
||||
# Infinite total: do not attempt to normalize (division yields 0 or NaN).
|
||||
if not math.isfinite(total):
|
||||
return dict(weights)
|
||||
|
||||
# Zero/negative total: do not attempt to normalize (invalid denominator).
|
||||
if total <= 0:
|
||||
return dict(weights)
|
||||
|
||||
return {k: (float(v) / total if v is not None else None) for k, v in weights.items()}
|
||||
|
||||
|
||||
def detect_query_intent(query: str) -> QueryIntent:
|
||||
"""Detect whether a query is code-like, natural-language, or mixed.
|
||||
|
||||
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
|
||||
"""
|
||||
trimmed = (query or "").strip()
|
||||
if not trimmed:
|
||||
return QueryIntent.MIXED
|
||||
|
||||
lower = trimmed.lower()
|
||||
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
|
||||
|
||||
has_code_signals = bool(
|
||||
re.search(r"(::|->|\.)", trimmed)
|
||||
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
|
||||
or re.search(r"\b\w+_\w+\b", trimmed)
|
||||
or re.search(
|
||||
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
|
||||
lower,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
has_natural_signals = bool(
|
||||
word_count > 5
|
||||
or "?" in trimmed
|
||||
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
|
||||
or re.search(
|
||||
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
|
||||
trimmed,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
|
||||
if has_code_signals and has_natural_signals:
|
||||
return QueryIntent.MIXED
|
||||
if has_code_signals:
|
||||
return QueryIntent.KEYWORD
|
||||
if has_natural_signals:
|
||||
return QueryIntent.SEMANTIC
|
||||
return QueryIntent.MIXED
|
||||
|
||||
|
||||
def adjust_weights_by_intent(
|
||||
intent: QueryIntent,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Adjust RRF weights based on query intent."""
|
||||
# Check if using SPLADE or FTS mode
|
||||
use_splade = "splade" in base_weights
|
||||
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
if use_splade:
|
||||
target = {"splade": 0.6, "vector": 0.4}
|
||||
else:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
if use_splade:
|
||||
target = {"splade": 0.3, "vector": 0.7}
|
||||
else:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
else:
|
||||
target = dict(base_weights)
|
||||
|
||||
# Filter to active backends
|
||||
keys = list(base_weights.keys())
|
||||
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||
return normalize_weights(filtered)
|
||||
|
||||
|
||||
def get_rrf_weights(
|
||||
query: str,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Compute adaptive RRF weights from query intent."""
|
||||
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
|
||||
|
||||
|
||||
# File extensions to category mapping for fast lookup
|
||||
_EXT_TO_CATEGORY: Dict[str, str] = {
|
||||
# Code extensions
|
||||
".py": "code", ".js": "code", ".jsx": "code", ".ts": "code", ".tsx": "code",
|
||||
".java": "code", ".go": "code", ".zig": "code", ".m": "code", ".mm": "code",
|
||||
".c": "code", ".h": "code", ".cc": "code", ".cpp": "code", ".hpp": "code", ".cxx": "code",
|
||||
".rs": "code",
|
||||
# Doc extensions
|
||||
".md": "doc", ".mdx": "doc", ".txt": "doc", ".rst": "doc",
|
||||
}
|
||||
|
||||
|
||||
def get_file_category(path: str) -> Optional[str]:
|
||||
"""Get file category ('code' or 'doc') from path extension.
|
||||
|
||||
Args:
|
||||
path: File path string
|
||||
|
||||
Returns:
|
||||
'code', 'doc', or None if unknown
|
||||
"""
|
||||
ext = Path(path).suffix.lower()
|
||||
return _EXT_TO_CATEGORY.get(ext)
|
||||
|
||||
|
||||
def filter_results_by_category(
|
||||
results: List[SearchResult],
|
||||
intent: QueryIntent,
|
||||
allow_mixed: bool = True,
|
||||
) -> List[SearchResult]:
|
||||
"""Filter results by category based on query intent.
|
||||
|
||||
Strategy:
|
||||
- KEYWORD (code intent): Only return code files
|
||||
- SEMANTIC (doc intent): Prefer docs, but allow code if allow_mixed=True
|
||||
- MIXED: Return all results
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
intent: Query intent from detect_query_intent()
|
||||
allow_mixed: If True, SEMANTIC intent includes code files with lower priority
|
||||
|
||||
Returns:
|
||||
Filtered and re-ranked list of SearchResult objects
|
||||
"""
|
||||
if not results or intent == QueryIntent.MIXED:
|
||||
return results
|
||||
|
||||
code_results = []
|
||||
doc_results = []
|
||||
unknown_results = []
|
||||
|
||||
for r in results:
|
||||
category = get_file_category(r.path)
|
||||
if category == "code":
|
||||
code_results.append(r)
|
||||
elif category == "doc":
|
||||
doc_results.append(r)
|
||||
else:
|
||||
unknown_results.append(r)
|
||||
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
# Code intent: return only code files + unknown (might be code)
|
||||
filtered = code_results + unknown_results
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
if allow_mixed:
|
||||
# Semantic intent with mixed: docs first, then code
|
||||
filtered = doc_results + code_results + unknown_results
|
||||
else:
|
||||
# Semantic intent strict: only docs
|
||||
filtered = doc_results + unknown_results
|
||||
else:
|
||||
filtered = results
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def simple_weighted_fusion(
|
||||
results_map: Dict[str, List[SearchResult]],
|
||||
weights: Dict[str, float] = None,
|
||||
) -> List[SearchResult]:
|
||||
"""Combine search results using simple weighted sum of normalized scores.
|
||||
|
||||
This is an alternative to RRF that preserves score magnitude information.
|
||||
Scores are min-max normalized per source before weighted combination.
|
||||
|
||||
Formula: score(d) = Σ weight_source * normalized_score_source(d)
|
||||
|
||||
Args:
|
||||
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by fused score (descending)
|
||||
|
||||
Examples:
|
||||
>>> fts_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
||||
>>> vector_results = [SearchResult(path="b.py", score=0.85, excerpt="...")]
|
||||
>>> results_map = {'exact': fts_results, 'vector': vector_results}
|
||||
>>> fused = simple_weighted_fusion(results_map)
|
||||
"""
|
||||
if not results_map:
|
||||
return []
|
||||
|
||||
# Default equal weights if not provided
|
||||
if weights is None:
|
||||
num_sources = len(results_map)
|
||||
weights = {source: 1.0 / num_sources for source in results_map}
|
||||
|
||||
# Normalize weights to sum to 1.0
|
||||
weight_sum = sum(weights.values())
|
||||
if not math.isclose(weight_sum, 1.0, abs_tol=0.01) and weight_sum > 0:
|
||||
weights = {source: w / weight_sum for source, w in weights.items()}
|
||||
|
||||
# Compute min-max normalization parameters per source
|
||||
source_stats: Dict[str, tuple] = {}
|
||||
for source_name, results in results_map.items():
|
||||
if not results:
|
||||
continue
|
||||
scores = [r.score for r in results]
|
||||
min_s, max_s = min(scores), max(scores)
|
||||
source_stats[source_name] = (min_s, max_s)
|
||||
|
||||
def normalize_score(score: float, source: str) -> float:
|
||||
"""Normalize score to [0, 1] range using min-max scaling."""
|
||||
if source not in source_stats:
|
||||
return 0.0
|
||||
min_s, max_s = source_stats[source]
|
||||
if max_s == min_s:
|
||||
return 1.0 if score >= min_s else 0.0
|
||||
return (score - min_s) / (max_s - min_s)
|
||||
|
||||
# Build unified result set with weighted scores
|
||||
path_to_result: Dict[str, SearchResult] = {}
|
||||
path_to_fusion_score: Dict[str, float] = {}
|
||||
path_to_source_scores: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
for source_name, results in results_map.items():
|
||||
weight = weights.get(source_name, 0.0)
|
||||
if weight == 0:
|
||||
continue
|
||||
|
||||
for result in results:
|
||||
path = result.path
|
||||
normalized = normalize_score(result.score, source_name)
|
||||
contribution = weight * normalized
|
||||
|
||||
if path not in path_to_fusion_score:
|
||||
path_to_fusion_score[path] = 0.0
|
||||
path_to_result[path] = result
|
||||
path_to_source_scores[path] = {}
|
||||
|
||||
path_to_fusion_score[path] += contribution
|
||||
path_to_source_scores[path][source_name] = normalized
|
||||
|
||||
# Create final results with fusion scores
|
||||
fused_results = []
|
||||
for path, base_result in path_to_result.items():
|
||||
fusion_score = path_to_fusion_score[path]
|
||||
|
||||
fused_result = SearchResult(
|
||||
path=base_result.path,
|
||||
score=fusion_score,
|
||||
excerpt=base_result.excerpt,
|
||||
content=base_result.content,
|
||||
symbol=base_result.symbol,
|
||||
chunk=base_result.chunk,
|
||||
metadata={
|
||||
**base_result.metadata,
|
||||
"fusion_method": "simple_weighted",
|
||||
"fusion_score": fusion_score,
|
||||
"original_score": base_result.score,
|
||||
"source_scores": path_to_source_scores[path],
|
||||
},
|
||||
start_line=base_result.start_line,
|
||||
end_line=base_result.end_line,
|
||||
symbol_name=base_result.symbol_name,
|
||||
symbol_kind=base_result.symbol_kind,
|
||||
)
|
||||
fused_results.append(fused_result)
|
||||
|
||||
fused_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return fused_results
|
||||
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
results_map: Dict[str, List[SearchResult]],
|
||||
weights: Dict[str, float] = None,
|
||||
k: int = 60,
|
||||
) -> List[SearchResult]:
|
||||
"""Combine search results from multiple sources using Reciprocal Rank Fusion.
|
||||
|
||||
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
|
||||
|
||||
Supports three-way fusion with FTS, Vector, and SPLADE sources.
|
||||
|
||||
Args:
|
||||
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||
Or: {'splade': 0.4, 'vector': 0.6}
|
||||
k: Constant to avoid division by zero and control rank influence (default 60)
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by fused score (descending)
|
||||
|
||||
Examples:
|
||||
>>> exact_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
||||
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
|
||||
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
|
||||
>>> fused = reciprocal_rank_fusion(results_map)
|
||||
|
||||
# Three-way fusion with SPLADE
|
||||
>>> results_map = {
|
||||
... 'exact': exact_results,
|
||||
... 'vector': vector_results,
|
||||
... 'splade': splade_results
|
||||
... }
|
||||
>>> fused = reciprocal_rank_fusion(results_map, k=60)
|
||||
"""
|
||||
if not results_map:
|
||||
return []
|
||||
|
||||
# Default equal weights if not provided
|
||||
if weights is None:
|
||||
num_sources = len(results_map)
|
||||
weights = {source: 1.0 / num_sources for source in results_map}
|
||||
|
||||
# Validate weights sum to 1.0
|
||||
weight_sum = sum(weights.values())
|
||||
if not math.isclose(weight_sum, 1.0, abs_tol=0.01):
|
||||
# Normalize weights to sum to 1.0
|
||||
weights = {source: w / weight_sum for source, w in weights.items()}
|
||||
|
||||
# Build unified result set with RRF scores
|
||||
path_to_result: Dict[str, SearchResult] = {}
|
||||
path_to_fusion_score: Dict[str, float] = {}
|
||||
path_to_source_ranks: Dict[str, Dict[str, int]] = {}
|
||||
|
||||
for source_name, results in results_map.items():
|
||||
weight = weights.get(source_name, 0.0)
|
||||
if weight == 0:
|
||||
continue
|
||||
|
||||
for rank, result in enumerate(results, start=1):
|
||||
path = result.path
|
||||
rrf_contribution = weight / (k + rank)
|
||||
|
||||
# Initialize or accumulate fusion score
|
||||
if path not in path_to_fusion_score:
|
||||
path_to_fusion_score[path] = 0.0
|
||||
path_to_result[path] = result
|
||||
path_to_source_ranks[path] = {}
|
||||
|
||||
path_to_fusion_score[path] += rrf_contribution
|
||||
path_to_source_ranks[path][source_name] = rank
|
||||
|
||||
# Create final results with fusion scores
|
||||
fused_results = []
|
||||
for path, base_result in path_to_result.items():
|
||||
fusion_score = path_to_fusion_score[path]
|
||||
|
||||
# Create new SearchResult with fusion_score in metadata
|
||||
fused_result = SearchResult(
|
||||
path=base_result.path,
|
||||
score=fusion_score,
|
||||
excerpt=base_result.excerpt,
|
||||
content=base_result.content,
|
||||
symbol=base_result.symbol,
|
||||
chunk=base_result.chunk,
|
||||
metadata={
|
||||
**base_result.metadata,
|
||||
"fusion_method": "rrf",
|
||||
"fusion_score": fusion_score,
|
||||
"original_score": base_result.score,
|
||||
"rrf_k": k,
|
||||
"source_ranks": path_to_source_ranks[path],
|
||||
},
|
||||
start_line=base_result.start_line,
|
||||
end_line=base_result.end_line,
|
||||
symbol_name=base_result.symbol_name,
|
||||
symbol_kind=base_result.symbol_kind,
|
||||
)
|
||||
fused_results.append(fused_result)
|
||||
|
||||
# Sort by fusion score descending
|
||||
fused_results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return fused_results
|
||||
|
||||
|
||||
def apply_symbol_boost(
|
||||
results: List[SearchResult],
|
||||
boost_factor: float = 1.5,
|
||||
) -> List[SearchResult]:
|
||||
"""Boost fused scores for results that include an explicit symbol match.
|
||||
|
||||
The boost is multiplicative on the current result.score (typically the RRF fusion score).
|
||||
When boosted, the original score is preserved in metadata["original_fusion_score"] and
|
||||
metadata["boosted"] is set to True.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if boost_factor <= 1.0:
|
||||
# Still return new objects to follow immutable transformation pattern.
|
||||
return [
|
||||
SearchResult(
|
||||
path=r.path,
|
||||
score=r.score,
|
||||
excerpt=r.excerpt,
|
||||
content=r.content,
|
||||
symbol=r.symbol,
|
||||
chunk=r.chunk,
|
||||
metadata={**r.metadata},
|
||||
start_line=r.start_line,
|
||||
end_line=r.end_line,
|
||||
symbol_name=r.symbol_name,
|
||||
symbol_kind=r.symbol_kind,
|
||||
additional_locations=list(r.additional_locations),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
boosted_results: List[SearchResult] = []
|
||||
for result in results:
|
||||
has_symbol = bool(result.symbol_name)
|
||||
original_score = float(result.score)
|
||||
boosted_score = original_score * boost_factor if has_symbol else original_score
|
||||
|
||||
metadata = {**result.metadata}
|
||||
if has_symbol:
|
||||
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
|
||||
metadata["boosted"] = True
|
||||
metadata["symbol_boost_factor"] = boost_factor
|
||||
|
||||
boosted_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=boosted_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata=metadata,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
boosted_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return boosted_results
|
||||
|
||||
|
||||
def rerank_results(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
embedder: Any,
|
||||
top_k: int = 50,
|
||||
) -> List[SearchResult]:
|
||||
"""Re-rank results with embedding cosine similarity, combined with current score.
|
||||
|
||||
Combined score formula:
|
||||
0.5 * rrf_score + 0.5 * cosine_similarity
|
||||
|
||||
If embedder is None or embedding fails, returns results as-is.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if embedder is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
||||
# Defensive: handle mismatched lengths and zero vectors.
|
||||
n = min(len(vec_a), len(vec_b))
|
||||
if n == 0:
|
||||
return 0.0
|
||||
dot = 0.0
|
||||
norm_a = 0.0
|
||||
norm_b = 0.0
|
||||
for i in range(n):
|
||||
a = float(vec_a[i])
|
||||
b = float(vec_b[i])
|
||||
dot += a * b
|
||||
norm_a += a * a
|
||||
norm_b += b * b
|
||||
if norm_a <= 0.0 or norm_b <= 0.0:
|
||||
return 0.0
|
||||
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
|
||||
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
|
||||
return max(0.0, min(1.0, sim))
|
||||
|
||||
def text_for_embedding(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
# Fallback: stable, non-empty text.
|
||||
return r.symbol_name or r.path
|
||||
|
||||
try:
|
||||
if hasattr(embedder, "embed_single"):
|
||||
query_vec = embedder.embed_single(query)
|
||||
else:
|
||||
query_vec = embedder.embed(query)[0]
|
||||
|
||||
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
|
||||
doc_vecs = embedder.embed(doc_texts)
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
rrf_score = float(result.score)
|
||||
sim = cosine_similarity(query_vec, doc_vecs[idx])
|
||||
combined_score = 0.5 * rrf_score + 0.5 * sim
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"rrf_score": rrf_score,
|
||||
"cosine_similarity": sim,
|
||||
"reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Preserve remaining results without re-ranking, but keep immutability.
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def cross_encoder_rerank(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
reranker: Any,
|
||||
top_k: int = 50,
|
||||
batch_size: int = 32,
|
||||
chunk_type_weights: Optional[Dict[str, float]] = None,
|
||||
test_file_penalty: float = 0.0,
|
||||
) -> List[SearchResult]:
|
||||
"""Second-stage reranking using a cross-encoder model.
|
||||
|
||||
This function is dependency-agnostic: callers can pass any object that exposes
|
||||
a compatible `score_pairs(pairs, batch_size=...)` method.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
results: List of search results to rerank
|
||||
reranker: Cross-encoder model with score_pairs or predict method
|
||||
top_k: Number of top results to rerank
|
||||
batch_size: Batch size for reranking
|
||||
chunk_type_weights: Optional weights for different chunk types.
|
||||
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
|
||||
test_file_penalty: Penalty applied to test files (0.0-1.0).
|
||||
Example: 0.2 means test files get 20% score reduction
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if reranker is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def text_for_pair(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
return r.symbol_name or r.path
|
||||
|
||||
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
|
||||
|
||||
try:
|
||||
if hasattr(reranker, "score_pairs"):
|
||||
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
|
||||
elif hasattr(reranker, "predict"):
|
||||
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
|
||||
else:
|
||||
return results
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
if not raw_scores or len(raw_scores) != rerank_count:
|
||||
return results
|
||||
|
||||
scores = [float(s) for s in raw_scores]
|
||||
min_s = min(scores)
|
||||
max_s = max(scores)
|
||||
|
||||
def sigmoid(x: float) -> float:
|
||||
# Clamp to keep exp() stable.
|
||||
x = max(-50.0, min(50.0, x))
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
if 0.0 <= min_s and max_s <= 1.0:
|
||||
probs = scores
|
||||
else:
|
||||
probs = [sigmoid(s) for s in scores]
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
# Helper to detect test files
|
||||
def is_test_file(path: str) -> bool:
|
||||
if not path:
|
||||
return False
|
||||
basename = path.split("/")[-1].split("\\")[-1]
|
||||
return (
|
||||
basename.startswith("test_") or
|
||||
basename.endswith("_test.py") or
|
||||
basename.endswith(".test.ts") or
|
||||
basename.endswith(".test.js") or
|
||||
basename.endswith(".spec.ts") or
|
||||
basename.endswith(".spec.js") or
|
||||
"/tests/" in path or
|
||||
"\\tests\\" in path or
|
||||
"/test/" in path or
|
||||
"\\test\\" in path
|
||||
)
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
prev_score = float(result.score)
|
||||
ce_score = scores[idx]
|
||||
ce_prob = probs[idx]
|
||||
|
||||
# Base combined score
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob
|
||||
|
||||
# Apply chunk_type weight adjustment
|
||||
if chunk_type_weights:
|
||||
chunk_type = None
|
||||
if result.chunk and hasattr(result.chunk, "metadata"):
|
||||
chunk_type = result.chunk.metadata.get("chunk_type")
|
||||
elif result.metadata:
|
||||
chunk_type = result.metadata.get("chunk_type")
|
||||
|
||||
if chunk_type and chunk_type in chunk_type_weights:
|
||||
weight = chunk_type_weights[chunk_type]
|
||||
# Apply weight to CE contribution only
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
|
||||
|
||||
# Apply test file penalty
|
||||
if test_file_penalty > 0 and is_test_file(result.path):
|
||||
combined_score = combined_score * (1.0 - test_file_penalty)
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"pre_cross_encoder_score": prev_score,
|
||||
"cross_encoder_score": ce_score,
|
||||
"cross_encoder_prob": ce_prob,
|
||||
"cross_encoder_reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def normalize_bm25_score(score: float) -> float:
|
||||
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
||||
|
||||
SQLite FTS5 returns negative BM25 scores (more negative = better match).
|
||||
Uses sigmoid transformation for normalization.
|
||||
|
||||
Args:
|
||||
score: Raw BM25 score from SQLite (typically negative)
|
||||
|
||||
Returns:
|
||||
Normalized score in range [0, 1]
|
||||
|
||||
Examples:
|
||||
>>> normalize_bm25_score(-10.5) # Good match
|
||||
0.85
|
||||
>>> normalize_bm25_score(-1.2) # Weak match
|
||||
0.62
|
||||
"""
|
||||
# Take absolute value (BM25 is negative in SQLite)
|
||||
abs_score = abs(score)
|
||||
|
||||
# Sigmoid transformation: 1 / (1 + e^(-x))
|
||||
# Scale factor of 0.1 maps typical BM25 range (-20 to 0) to (0, 1)
|
||||
normalized = 1.0 / (1.0 + math.exp(-abs_score * 0.1))
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def tag_search_source(results: List[SearchResult], source: str) -> List[SearchResult]:
|
||||
"""Tag search results with their source for RRF tracking.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
source: Source identifier ('exact', 'fuzzy', 'vector')
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects with 'search_source' in metadata
|
||||
"""
|
||||
tagged_results = []
|
||||
for result in results:
|
||||
tagged_result = SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata, "search_source": source},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
)
|
||||
tagged_results.append(tagged_result)
|
||||
|
||||
return tagged_results
|
||||
|
||||
|
||||
def group_similar_results(
|
||||
results: List[SearchResult],
|
||||
score_threshold_abs: float = 0.01,
|
||||
content_field: str = "excerpt"
|
||||
) -> List[SearchResult]:
|
||||
"""Group search results by content and score similarity.
|
||||
|
||||
Groups results that have similar content and similar scores into a single
|
||||
representative result, with other locations stored in additional_locations.
|
||||
|
||||
Algorithm:
|
||||
1. Group results by content (using excerpt or content field)
|
||||
2. Within each content group, create subgroups based on score similarity
|
||||
3. Select highest-scoring result as representative for each subgroup
|
||||
4. Store other results in subgroup as additional_locations
|
||||
|
||||
Args:
|
||||
results: A list of SearchResult objects (typically sorted by score)
|
||||
score_threshold_abs: Absolute score difference to consider results similar.
|
||||
Results with |score_a - score_b| <= threshold are grouped.
|
||||
Default 0.01 is suitable for RRF fusion scores.
|
||||
content_field: The field to use for content grouping ('excerpt' or 'content')
|
||||
|
||||
Returns:
|
||||
A new list of SearchResult objects where similar items are grouped.
|
||||
The list is sorted by score descending.
|
||||
|
||||
Examples:
|
||||
>>> results = [SearchResult(path="a.py", score=0.5, excerpt="def foo()"),
|
||||
... SearchResult(path="b.py", score=0.5, excerpt="def foo()")]
|
||||
>>> grouped = group_similar_results(results)
|
||||
>>> len(grouped) # Two results merged into one
|
||||
1
|
||||
>>> len(grouped[0].additional_locations) # One additional location
|
||||
1
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Group results by content
|
||||
content_map: Dict[str, List[SearchResult]] = {}
|
||||
unidentifiable_results: List[SearchResult] = []
|
||||
|
||||
for r in results:
|
||||
key = getattr(r, content_field, None)
|
||||
if key and key.strip():
|
||||
content_map.setdefault(key, []).append(r)
|
||||
else:
|
||||
# Results without content can't be grouped by content
|
||||
unidentifiable_results.append(r)
|
||||
|
||||
final_results: List[SearchResult] = []
|
||||
|
||||
# Process each content group
|
||||
for content_group in content_map.values():
|
||||
# Sort by score descending within group
|
||||
content_group.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
while content_group:
|
||||
# Take highest scoring as representative
|
||||
representative = content_group.pop(0)
|
||||
others_in_group = []
|
||||
remaining_for_next_pass = []
|
||||
|
||||
# Find results with similar scores
|
||||
for item in content_group:
|
||||
if abs(representative.score - item.score) <= score_threshold_abs:
|
||||
others_in_group.append(item)
|
||||
else:
|
||||
remaining_for_next_pass.append(item)
|
||||
|
||||
# Create grouped result with additional locations
|
||||
if others_in_group:
|
||||
# Build new result with additional_locations populated
|
||||
grouped_result = SearchResult(
|
||||
path=representative.path,
|
||||
score=representative.score,
|
||||
excerpt=representative.excerpt,
|
||||
content=representative.content,
|
||||
symbol=representative.symbol,
|
||||
chunk=representative.chunk,
|
||||
metadata={
|
||||
**representative.metadata,
|
||||
"grouped_count": len(others_in_group) + 1,
|
||||
},
|
||||
start_line=representative.start_line,
|
||||
end_line=representative.end_line,
|
||||
symbol_name=representative.symbol_name,
|
||||
symbol_kind=representative.symbol_kind,
|
||||
additional_locations=[
|
||||
AdditionalLocation(
|
||||
path=other.path,
|
||||
score=other.score,
|
||||
start_line=other.start_line,
|
||||
end_line=other.end_line,
|
||||
symbol_name=other.symbol_name,
|
||||
) for other in others_in_group
|
||||
],
|
||||
)
|
||||
final_results.append(grouped_result)
|
||||
else:
|
||||
final_results.append(representative)
|
||||
|
||||
content_group = remaining_for_next_pass
|
||||
|
||||
# Add ungroupable results
|
||||
final_results.extend(unidentifiable_results)
|
||||
|
||||
# Sort final results by score descending
|
||||
final_results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return final_results
|
||||
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Optional semantic search module for CodexLens.
|
||||
|
||||
Install with: pip install codexlens[semantic]
|
||||
Uses fastembed (ONNX-based, lightweight ~200MB)
|
||||
|
||||
GPU Acceleration:
|
||||
- Automatic GPU detection and usage when available
|
||||
- Supports CUDA (NVIDIA), TensorRT, DirectML (Windows), ROCm (AMD), CoreML (Apple)
|
||||
- Install GPU support: pip install onnxruntime-gpu (NVIDIA) or onnxruntime-directml (Windows)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
SEMANTIC_AVAILABLE = False
|
||||
SEMANTIC_BACKEND: str | None = None
|
||||
GPU_AVAILABLE = False
|
||||
LITELLM_AVAILABLE = False
|
||||
_import_error: str | None = None
|
||||
|
||||
|
||||
def _detect_backend() -> tuple[bool, str | None, bool, str | None]:
|
||||
"""Detect if fastembed and GPU are available."""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as e:
|
||||
return False, None, False, f"numpy not available: {e}"
|
||||
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
return False, None, False, "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||
|
||||
# Check GPU availability
|
||||
gpu_available = False
|
||||
try:
|
||||
from .gpu_support import is_gpu_available
|
||||
gpu_available = is_gpu_available()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return True, "fastembed", gpu_available, None
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, GPU_AVAILABLE, _import_error = _detect_backend()
|
||||
|
||||
|
||||
def check_semantic_available() -> tuple[bool, str | None]:
|
||||
"""Check if semantic search dependencies are available."""
|
||||
return SEMANTIC_AVAILABLE, _import_error
|
||||
|
||||
|
||||
def check_gpu_available() -> tuple[bool, str]:
|
||||
"""Check if GPU acceleration is available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, status_message)
|
||||
"""
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
return False, "Semantic search not available"
|
||||
|
||||
try:
|
||||
from .gpu_support import is_gpu_available, get_gpu_summary
|
||||
if is_gpu_available():
|
||||
return True, get_gpu_summary()
|
||||
return False, "No GPU detected (using CPU)"
|
||||
except ImportError:
|
||||
return False, "GPU support module not available"
|
||||
|
||||
|
||||
# Export embedder components
|
||||
# BaseEmbedder is always available (abstract base class)
|
||||
from .base import BaseEmbedder
|
||||
|
||||
# Factory function for creating embedders
|
||||
from .factory import get_embedder as get_embedder_factory
|
||||
|
||||
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
|
||||
try:
|
||||
import ccw_litellm # noqa: F401
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
LITELLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
LiteLLMEmbedderWrapper = None
|
||||
LITELLM_AVAILABLE = False
|
||||
|
||||
|
||||
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific embedding backend can be used.
|
||||
|
||||
Notes:
|
||||
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
|
||||
- "litellm" requires ccw-litellm to be installed in the same environment.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
if backend == "fastembed":
|
||||
if SEMANTIC_AVAILABLE:
|
||||
return True, None
|
||||
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||
if backend == "litellm":
|
||||
if LITELLM_AVAILABLE:
|
||||
return True, None
|
||||
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
|
||||
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SEMANTIC_AVAILABLE",
|
||||
"SEMANTIC_BACKEND",
|
||||
"GPU_AVAILABLE",
|
||||
"LITELLM_AVAILABLE",
|
||||
"check_semantic_available",
|
||||
"is_embedding_backend_available",
|
||||
"check_gpu_available",
|
||||
"BaseEmbedder",
|
||||
"get_embedder_factory",
|
||||
"LiteLLMEmbedderWrapper",
|
||||
]
|
||||
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
File diff suppressed because it is too large
Load Diff
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Base class for embedders.
|
||||
|
||||
Defines the interface that all embedders must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""Base class for all embedders.
|
||||
|
||||
All embedder implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions.
|
||||
|
||||
Returns:
|
||||
int: Dimension of the embedding vectors.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self) -> str:
|
||||
"""Return model name.
|
||||
|
||||
Returns:
|
||||
str: Name or identifier of the underlying model.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for embeddings.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Embed texts to numpy array.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
...
|
||||
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""Code chunking strategies for semantic search.
|
||||
|
||||
This module provides various chunking strategies for breaking down source code
|
||||
into semantic chunks suitable for embedding and search.
|
||||
|
||||
Lightweight Mode:
|
||||
The ChunkConfig supports a `skip_token_count` option for performance optimization.
|
||||
When enabled, token counting uses a fast character-based estimation (char/4)
|
||||
instead of expensive tiktoken encoding.
|
||||
|
||||
Use cases for lightweight mode:
|
||||
- Large-scale indexing where speed is critical
|
||||
- Scenarios where approximate token counts are acceptable
|
||||
- Memory-constrained environments
|
||||
- Initial prototyping and development
|
||||
|
||||
Example:
|
||||
# Default mode (accurate tiktoken encoding)
|
||||
config = ChunkConfig()
|
||||
chunker = Chunker(config)
|
||||
|
||||
# Lightweight mode (fast char/4 estimation)
|
||||
config = ChunkConfig(skip_token_count=True)
|
||||
chunker = Chunker(config)
|
||||
chunks = chunker.chunk_file(content, symbols, path, language)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SemanticChunk, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkConfig:
|
||||
"""Configuration for chunking strategies."""
|
||||
max_chunk_size: int = 1000 # Max characters per chunk
|
||||
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
|
||||
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||
min_chunk_size: int = 50 # Minimum chunk size
|
||||
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
|
||||
strip_comments: bool = True # Remove comments from chunk content for embedding
|
||||
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
|
||||
preserve_original: bool = True # Store original content in metadata when stripping
|
||||
|
||||
|
||||
class CommentStripper:
|
||||
"""Remove comments from source code while preserving structure."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_comments(content: str) -> str:
|
||||
"""Strip Python comments (# style) but preserve docstrings.
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
for line in lines:
|
||||
new_line = []
|
||||
i = 0
|
||||
while i < len(line):
|
||||
char = line[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'") and not in_string:
|
||||
# Check for triple quotes
|
||||
if line[i:i+3] in ('"""', "'''"):
|
||||
in_string = True
|
||||
string_char = line[i:i+3]
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
continue
|
||||
else:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif in_string:
|
||||
if string_char and len(string_char) == 3:
|
||||
if line[i:i+3] == string_char:
|
||||
in_string = False
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
string_char = None
|
||||
continue
|
||||
elif char == string_char:
|
||||
# Check for escape
|
||||
if i > 0 and line[i-1] != '\\':
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
# Handle comments (only outside strings)
|
||||
if char == '#' and not in_string:
|
||||
# Rest of line is comment, skip it
|
||||
new_line.append('\n' if line.endswith('\n') else '')
|
||||
break
|
||||
|
||||
new_line.append(char)
|
||||
i += 1
|
||||
|
||||
result_lines.append(''.join(new_line))
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_c_style_comments(content: str) -> str:
|
||||
"""Strip C-style comments (// and /* */) from code.
|
||||
|
||||
Args:
|
||||
content: Source code with C-style comments
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
string_char = None
|
||||
in_multiline_comment = False
|
||||
|
||||
while i < len(content):
|
||||
# Handle multi-line comment end
|
||||
if in_multiline_comment:
|
||||
if content[i:i+2] == '*/':
|
||||
in_multiline_comment = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
char = content[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'", '`') and not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
elif in_string:
|
||||
result.append(char)
|
||||
if char == string_char and (i == 0 or content[i-1] != '\\'):
|
||||
in_string = False
|
||||
string_char = None
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle comments
|
||||
if content[i:i+2] == '//':
|
||||
# Single line comment - skip to end of line
|
||||
while i < len(content) and content[i] != '\n':
|
||||
i += 1
|
||||
if i < len(content):
|
||||
result.append('\n')
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if content[i:i+2] == '/*':
|
||||
in_multiline_comment = True
|
||||
i += 2
|
||||
continue
|
||||
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_comments(cls, content: str, language: str) -> str:
|
||||
"""Strip comments based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_comments(content)
|
||||
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
|
||||
return cls.strip_c_style_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class DocstringStripper:
|
||||
"""Remove docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_docstrings(content: str) -> str:
|
||||
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
# Check for docstring start
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
|
||||
# Single line docstring
|
||||
if stripped.count(quote_type) >= 2:
|
||||
# Skip this line (docstring)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Multi-line docstring - skip until closing
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
if quote_type in lines[i]:
|
||||
i += 1
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
i += 1
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_jsdoc_comments(content: str) -> str:
|
||||
"""Strip JSDoc comments (/** ... */) from code.
|
||||
|
||||
Args:
|
||||
content: JavaScript/TypeScript source code
|
||||
|
||||
Returns:
|
||||
Code with JSDoc comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_jsdoc = False
|
||||
|
||||
while i < len(content):
|
||||
if in_jsdoc:
|
||||
if content[i:i+2] == '*/':
|
||||
in_jsdoc = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for JSDoc start (/** but not /*)
|
||||
if content[i:i+3] == '/**':
|
||||
in_jsdoc = True
|
||||
i += 3
|
||||
continue
|
||||
|
||||
result.append(content[i])
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_docstrings(cls, content: str, language: str) -> str:
|
||||
"""Strip docstrings based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.strip_jsdoc_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class Chunker:
|
||||
"""Chunk code files for semantic embedding."""
|
||||
|
||||
def __init__(self, config: ChunkConfig | None = None) -> None:
|
||||
self.config = config or ChunkConfig()
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
self._comment_stripper = CommentStripper()
|
||||
self._docstring_stripper = DocstringStripper()
|
||||
|
||||
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
|
||||
"""Process chunk content by stripping comments/docstrings if configured.
|
||||
|
||||
Args:
|
||||
content: Original chunk content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_content, original_content_if_preserved)
|
||||
"""
|
||||
original = content if self.config.preserve_original else None
|
||||
processed = content
|
||||
|
||||
if self.config.strip_comments:
|
||||
processed = self._comment_stripper.strip_comments(processed, language)
|
||||
|
||||
if self.config.strip_docstrings:
|
||||
processed = self._docstring_stripper.strip_docstrings(processed, language)
|
||||
|
||||
# If nothing changed, don't store original
|
||||
if processed == content:
|
||||
original = None
|
||||
|
||||
return processed, original
|
||||
|
||||
def _estimate_token_count(self, text: str) -> int:
|
||||
"""Estimate token count based on config.
|
||||
|
||||
If skip_token_count is True, uses character-based estimation (char/4).
|
||||
Otherwise, uses accurate tiktoken encoding.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if self.config.skip_token_count:
|
||||
# Fast character-based estimation: ~4 chars per token
|
||||
return max(1, len(text) // 4)
|
||||
return self._tokenizer.count_tokens(text)
|
||||
|
||||
def chunk_by_symbol(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code by extracted symbols (functions, classes).
|
||||
|
||||
Each symbol becomes one chunk with its full content.
|
||||
Large symbols exceeding max_chunk_size are recursively split using sliding window.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start_line - 1)
|
||||
end_idx = min(len(lines), end_line)
|
||||
|
||||
chunk_content = "".join(lines[start_idx:end_idx])
|
||||
if len(chunk_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Check if symbol content exceeds max_chunk_size
|
||||
if len(chunk_content) > self.config.max_chunk_size:
|
||||
# Create line mapping for correct line number tracking
|
||||
line_mapping = list(range(start_line, end_line + 1))
|
||||
|
||||
# Use sliding window to split large symbol
|
||||
sub_chunks = self.chunk_sliding_window(
|
||||
chunk_content,
|
||||
file_path=file_path,
|
||||
language=language,
|
||||
line_mapping=line_mapping
|
||||
)
|
||||
|
||||
# Update sub_chunks with parent symbol metadata
|
||||
for sub_chunk in sub_chunks:
|
||||
sub_chunk.metadata["symbol_name"] = symbol.name
|
||||
sub_chunk.metadata["symbol_kind"] = symbol.kind
|
||||
sub_chunk.metadata["strategy"] = "symbol_split"
|
||||
sub_chunk.metadata["chunk_type"] = "code"
|
||||
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
|
||||
|
||||
chunks.extend(sub_chunks)
|
||||
else:
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"symbol_name": symbol.name,
|
||||
"symbol_kind": symbol.kind,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_sliding_window(
|
||||
self,
|
||||
content: str,
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
line_mapping: Optional[List[int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code using sliding window approach.
|
||||
|
||||
Used for files without clear symbol boundaries or very long functions.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
line_mapping: Optional list mapping content line indices to original line numbers
|
||||
(1-indexed). If provided, line_mapping[i] is the original line number
|
||||
for the i-th line in content.
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
if not lines:
|
||||
return chunks
|
||||
|
||||
# Calculate lines per chunk based on average line length
|
||||
avg_line_len = len(content) / max(len(lines), 1)
|
||||
lines_per_chunk = max(10, int(self.config.max_chunk_size / max(avg_line_len, 1)))
|
||||
overlap_lines = max(2, int(self.config.overlap / max(avg_line_len, 1)))
|
||||
# Ensure overlap is less than chunk size to prevent infinite loop
|
||||
overlap_lines = min(overlap_lines, lines_per_chunk - 1)
|
||||
|
||||
start = 0
|
||||
chunk_idx = 0
|
||||
|
||||
while start < len(lines):
|
||||
end = min(start + lines_per_chunk, len(lines))
|
||||
chunk_content = "".join(lines[start:end])
|
||||
|
||||
if len(chunk_content.strip()) >= self.config.min_chunk_size:
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
# Move window forward
|
||||
step = lines_per_chunk - overlap_lines
|
||||
if step <= 0:
|
||||
step = 1
|
||||
start += step
|
||||
continue
|
||||
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
# Calculate correct line numbers
|
||||
if line_mapping:
|
||||
# Use line mapping to get original line numbers
|
||||
start_line = line_mapping[start]
|
||||
end_line = line_mapping[end - 1]
|
||||
else:
|
||||
# Default behavior: treat content as starting at line 1
|
||||
start_line = start + 1
|
||||
end_line = end
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_index": chunk_idx,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "sliding_window",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
chunk_idx += 1
|
||||
|
||||
# Move window, accounting for overlap
|
||||
step = lines_per_chunk - overlap_lines
|
||||
if step <= 0:
|
||||
step = 1 # Failsafe to prevent infinite loop
|
||||
start += step
|
||||
|
||||
# Break if we've reached the end
|
||||
if end >= len(lines):
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk a file using the best strategy.
|
||||
|
||||
Uses symbol-based chunking if symbols available,
|
||||
falls back to sliding window for files without symbols.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
if symbols:
|
||||
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
|
||||
return self.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
class DocstringExtractor:
|
||||
"""Extract docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract Python docstrings with their line ranges.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
start_line = i + 1
|
||||
|
||||
if stripped.count(quote_type) >= 2:
|
||||
docstring_content = line
|
||||
end_line = i + 1
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
docstring_lines = [line]
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
docstring_lines.append(lines[i])
|
||||
if quote_type in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
docstring_content = "".join(docstring_lines)
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return docstrings
|
||||
|
||||
@staticmethod
|
||||
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract JSDoc comments with their line ranges.
|
||||
|
||||
Returns: List of (comment_content, start_line, end_line) tuples
|
||||
"""
|
||||
comments: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith('/**'):
|
||||
start_line = i + 1
|
||||
comment_lines = [line]
|
||||
i += 1
|
||||
|
||||
while i < len(lines):
|
||||
comment_lines.append(lines[i])
|
||||
if '*/' in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
comment_content = "".join(comment_lines)
|
||||
comments.append((comment_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return comments
|
||||
|
||||
@classmethod
|
||||
def extract_docstrings(
|
||||
cls,
|
||||
content: str,
|
||||
language: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""Extract docstrings based on language.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.extract_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.extract_jsdoc_comments(content)
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
|
||||
|
||||
Composition-based strategy that:
|
||||
1. Extracts docstrings as dedicated chunks
|
||||
2. For remaining code, uses base chunker (symbol or sliding window)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_chunker: Chunker | None = None,
|
||||
config: ChunkConfig | None = None
|
||||
) -> None:
|
||||
"""Initialize hybrid chunker.
|
||||
|
||||
Args:
|
||||
base_chunker: Chunker to use for non-docstring content
|
||||
config: Configuration for chunking
|
||||
"""
|
||||
self.config = config or ChunkConfig()
|
||||
self.base_chunker = base_chunker or Chunker(self.config)
|
||||
self.docstring_extractor = DocstringExtractor()
|
||||
|
||||
def _get_excluded_line_ranges(
|
||||
self,
|
||||
docstrings: List[Tuple[str, int, int]]
|
||||
) -> set[int]:
|
||||
"""Get set of line numbers that are part of docstrings."""
|
||||
excluded_lines: set[int] = set()
|
||||
for _, start_line, end_line in docstrings:
|
||||
for line_num in range(start_line, end_line + 1):
|
||||
excluded_lines.add(line_num)
|
||||
return excluded_lines
|
||||
|
||||
def _filter_symbols_outside_docstrings(
|
||||
self,
|
||||
symbols: List[Symbol],
|
||||
excluded_lines: set[int]
|
||||
) -> List[Symbol]:
|
||||
"""Filter symbols to exclude those completely within docstrings."""
|
||||
filtered: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
symbol_lines = set(range(start_line, end_line + 1))
|
||||
if not symbol_lines.issubset(excluded_lines):
|
||||
filtered.append(symbol)
|
||||
return filtered
|
||||
|
||||
def _find_parent_symbol(
|
||||
self,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
symbols: List[Symbol],
|
||||
) -> Optional[Symbol]:
|
||||
"""Find the smallest symbol range that fully contains a docstring span."""
|
||||
candidates: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
sym_start, sym_end = symbol.range
|
||||
if sym_start <= start_line and end_line <= sym_end:
|
||||
candidates.append(symbol)
|
||||
if not candidates:
|
||||
return None
|
||||
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk file using hybrid strategy.
|
||||
|
||||
Extracts docstrings first, then chunks remaining code.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
|
||||
# Step 1: Extract docstrings as dedicated chunks
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
if language == "python":
|
||||
# Fast path: avoid expensive docstring extraction if delimiters are absent.
|
||||
if '"""' in content or "'''" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
if "/**" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
else:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
|
||||
# Fast path: no docstrings -> delegate to base chunker directly.
|
||||
if not docstrings:
|
||||
if symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
else:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
return base_chunks
|
||||
|
||||
for docstring_content, start_line, end_line in docstrings:
|
||||
if len(docstring_content.strip()) >= self.config.min_chunk_size:
|
||||
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
|
||||
# Use base chunker's token estimation method
|
||||
token_count = self.base_chunker._estimate_token_count(docstring_content)
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_type": "docstring",
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "hybrid",
|
||||
"token_count": token_count,
|
||||
}
|
||||
if parent_symbol is not None:
|
||||
metadata["parent_symbol"] = parent_symbol.name
|
||||
metadata["parent_symbol_kind"] = parent_symbol.kind
|
||||
metadata["parent_symbol_range"] = parent_symbol.range
|
||||
chunks.append(SemanticChunk(
|
||||
content=docstring_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Step 2: Get line ranges occupied by docstrings
|
||||
excluded_lines = self._get_excluded_line_ranges(docstrings)
|
||||
|
||||
# Step 3: Filter symbols to exclude docstring-only ranges
|
||||
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
|
||||
|
||||
# Step 4: Chunk remaining content using base chunker
|
||||
if filtered_symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, filtered_symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
lines = content.splitlines(keepends=True)
|
||||
remaining_lines: List[str] = []
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
if i not in excluded_lines:
|
||||
remaining_lines.append(line)
|
||||
|
||||
if remaining_lines:
|
||||
remaining_content = "".join(remaining_lines)
|
||||
if len(remaining_content.strip()) >= self.config.min_chunk_size:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(
|
||||
remaining_content, file_path, language
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Smart code extraction for complete code blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
|
||||
|
||||
def extract_complete_code_block(
|
||||
result: SearchResult,
|
||||
source_file_path: Optional[str] = None,
|
||||
context_lines: int = 0,
|
||||
) -> str:
|
||||
"""Extract complete code block from a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult from semantic search.
|
||||
source_file_path: Optional path to source file for re-reading.
|
||||
context_lines: Additional lines of context to include above/below.
|
||||
|
||||
Returns:
|
||||
Complete code block as string.
|
||||
"""
|
||||
# If we have full content stored, use it
|
||||
if result.content:
|
||||
if context_lines == 0:
|
||||
return result.content
|
||||
# Need to add context, read from file
|
||||
|
||||
# Try to read from source file
|
||||
file_path = source_file_path or result.path
|
||||
if not file_path or not Path(file_path).exists():
|
||||
# Fall back to excerpt
|
||||
return result.excerpt or ""
|
||||
|
||||
try:
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.splitlines()
|
||||
|
||||
# Get line range
|
||||
start_line = result.start_line or 1
|
||||
end_line = result.end_line or len(lines)
|
||||
|
||||
# Add context
|
||||
start_idx = max(0, start_line - 1 - context_lines)
|
||||
end_idx = min(len(lines), end_line + context_lines)
|
||||
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
except Exception:
|
||||
return result.excerpt or result.content or ""
|
||||
|
||||
|
||||
def extract_symbol_with_context(
|
||||
file_path: str,
|
||||
symbol: Symbol,
|
||||
include_docstring: bool = True,
|
||||
include_decorators: bool = True,
|
||||
) -> str:
|
||||
"""Extract a symbol (function/class) with its docstring and decorators.
|
||||
|
||||
Args:
|
||||
file_path: Path to source file.
|
||||
symbol: Symbol to extract.
|
||||
include_docstring: Include docstring if present.
|
||||
include_decorators: Include decorators/annotations above symbol.
|
||||
|
||||
Returns:
|
||||
Complete symbol code with context.
|
||||
"""
|
||||
try:
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.splitlines()
|
||||
|
||||
start_line, end_line = symbol.range
|
||||
start_idx = start_line - 1
|
||||
end_idx = end_line
|
||||
|
||||
# Look for decorators above the symbol
|
||||
if include_decorators and start_idx > 0:
|
||||
decorator_start = start_idx
|
||||
# Search backwards for decorators
|
||||
i = start_idx - 1
|
||||
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
|
||||
line = lines[i].strip()
|
||||
if line.startswith("@"):
|
||||
decorator_start = i
|
||||
i -= 1
|
||||
elif line == "" or line.startswith("#"):
|
||||
# Skip empty lines and comments, continue looking
|
||||
i -= 1
|
||||
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
|
||||
# JavaScript/Java style comments
|
||||
decorator_start = i
|
||||
i -= 1
|
||||
else:
|
||||
# Found non-decorator, non-comment line, stop
|
||||
break
|
||||
start_idx = decorator_start
|
||||
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def format_search_result_code(
|
||||
result: SearchResult,
|
||||
max_lines: Optional[int] = None,
|
||||
show_line_numbers: bool = True,
|
||||
highlight_match: bool = False,
|
||||
) -> str:
|
||||
"""Format search result code for display.
|
||||
|
||||
Args:
|
||||
result: SearchResult to format.
|
||||
max_lines: Maximum lines to show (None for all).
|
||||
show_line_numbers: Include line numbers in output.
|
||||
highlight_match: Add markers for matched region.
|
||||
|
||||
Returns:
|
||||
Formatted code string.
|
||||
"""
|
||||
content = result.content or result.excerpt or ""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
lines = content.splitlines()
|
||||
|
||||
# Truncate if needed
|
||||
truncated = False
|
||||
if max_lines and len(lines) > max_lines:
|
||||
lines = lines[:max_lines]
|
||||
truncated = True
|
||||
|
||||
# Format with line numbers
|
||||
if show_line_numbers:
|
||||
start = result.start_line or 1
|
||||
formatted_lines = []
|
||||
for i, line in enumerate(lines):
|
||||
line_num = start + i
|
||||
formatted_lines.append(f"{line_num:4d} | {line}")
|
||||
output = "\n".join(formatted_lines)
|
||||
else:
|
||||
output = "\n".join(lines)
|
||||
|
||||
if truncated:
|
||||
output += "\n... (truncated)"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_code_block_summary(result: SearchResult) -> str:
|
||||
"""Get a concise summary of a code block.
|
||||
|
||||
Args:
|
||||
result: SearchResult to summarize.
|
||||
|
||||
Returns:
|
||||
Summary string like "function hello_world (lines 10-25)"
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if result.symbol_kind:
|
||||
parts.append(result.symbol_kind)
|
||||
|
||||
if result.symbol_name:
|
||||
parts.append(f"`{result.symbol_name}`")
|
||||
elif result.excerpt:
|
||||
# Extract first meaningful identifier
|
||||
first_line = result.excerpt.split("\n")[0][:50]
|
||||
parts.append(f'"{first_line}..."')
|
||||
|
||||
if result.start_line and result.end_line:
|
||||
if result.start_line == result.end_line:
|
||||
parts.append(f"(line {result.start_line})")
|
||||
else:
|
||||
parts.append(f"(lines {result.start_line}-{result.end_line})")
|
||||
|
||||
if result.path:
|
||||
file_name = Path(result.path).name
|
||||
parts.append(f"in {file_name}")
|
||||
|
||||
return " ".join(parts) if parts else "unknown code block"
|
||||
|
||||
|
||||
class CodeBlockResult:
|
||||
"""Enhanced search result with complete code block."""
|
||||
|
||||
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
|
||||
self.result = result
|
||||
self.source_path = source_path or result.path
|
||||
self._full_code: Optional[str] = None
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
return self.result.score
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.result.path
|
||||
|
||||
@property
|
||||
def file_name(self) -> str:
|
||||
return Path(self.result.path).name
|
||||
|
||||
@property
|
||||
def symbol_name(self) -> Optional[str]:
|
||||
return self.result.symbol_name
|
||||
|
||||
@property
|
||||
def symbol_kind(self) -> Optional[str]:
|
||||
return self.result.symbol_kind
|
||||
|
||||
@property
|
||||
def line_range(self) -> Tuple[int, int]:
|
||||
return (
|
||||
self.result.start_line or 1,
|
||||
self.result.end_line or 1
|
||||
)
|
||||
|
||||
@property
|
||||
def full_code(self) -> str:
|
||||
"""Get full code block content."""
|
||||
if self._full_code is None:
|
||||
self._full_code = extract_complete_code_block(self.result, self.source_path)
|
||||
return self._full_code
|
||||
|
||||
@property
|
||||
def excerpt(self) -> str:
|
||||
"""Get short excerpt."""
|
||||
return self.result.excerpt or ""
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
"""Get code block summary."""
|
||||
return get_code_block_summary(self.result)
|
||||
|
||||
def format(
|
||||
self,
|
||||
max_lines: Optional[int] = None,
|
||||
show_line_numbers: bool = True,
|
||||
) -> str:
|
||||
"""Format code for display."""
|
||||
# Use full code if available
|
||||
display_result = SearchResult(
|
||||
path=self.result.path,
|
||||
score=self.result.score,
|
||||
content=self.full_code,
|
||||
start_line=self.result.start_line,
|
||||
end_line=self.result.end_line,
|
||||
)
|
||||
return format_search_result_code(
|
||||
display_result,
|
||||
max_lines=max_lines,
|
||||
show_line_numbers=show_line_numbers
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
|
||||
|
||||
|
||||
def enhance_search_results(
|
||||
results: List[SearchResult],
|
||||
) -> List[CodeBlockResult]:
|
||||
"""Enhance search results with complete code block access.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult from semantic search.
|
||||
|
||||
Returns:
|
||||
List of CodeBlockResult with full code access.
|
||||
"""
|
||||
return [CodeBlockResult(r) for r in results]
|
||||
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Embedder for semantic code search using fastembed.
|
||||
|
||||
Supports GPU acceleration via ONNX execution providers (CUDA, TensorRT, DirectML, ROCm, CoreML).
|
||||
GPU acceleration is automatic when available, with transparent CPU fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import SEMANTIC_AVAILABLE
|
||||
from .base import BaseEmbedder
|
||||
from .gpu_support import get_optimal_providers, is_gpu_available, get_gpu_summary, get_selected_device_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global embedder cache for singleton pattern
|
||||
_embedder_cache: Dict[str, "Embedder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
|
||||
"""Get or create a cached Embedder instance (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
Embedder instances across multiple searches, avoiding repeated model
|
||||
loading overhead (~0.8s per load).
|
||||
|
||||
Args:
|
||||
profile: Model profile ("fast", "code", "multilingual", "balanced")
|
||||
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||
|
||||
Returns:
|
||||
Cached Embedder instance for the given profile
|
||||
"""
|
||||
global _embedder_cache
|
||||
|
||||
# Cache key includes GPU preference to support mixed configurations
|
||||
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
|
||||
|
||||
# All cache access is protected by _cache_lock to avoid races with
|
||||
# clear_embedder_cache() during concurrent access.
|
||||
with _cache_lock:
|
||||
embedder = _embedder_cache.get(cache_key)
|
||||
if embedder is not None:
|
||||
return embedder
|
||||
|
||||
# Create new embedder and cache it
|
||||
embedder = Embedder(profile=profile, use_gpu=use_gpu)
|
||||
# Pre-load model to ensure it's ready
|
||||
embedder._load_model()
|
||||
_embedder_cache[cache_key] = embedder
|
||||
|
||||
# Log GPU status on first embedder creation
|
||||
if use_gpu and is_gpu_available():
|
||||
logger.info(f"Embedder initialized with GPU: {get_gpu_summary()}")
|
||||
elif use_gpu:
|
||||
logger.debug("GPU not available, using CPU for embeddings")
|
||||
|
||||
return embedder
|
||||
|
||||
|
||||
def clear_embedder_cache() -> None:
|
||||
"""Clear the embedder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when embedders are no longer needed.
|
||||
"""
|
||||
global _embedder_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for embedder in _embedder_cache.values():
|
||||
if embedder._model is not None:
|
||||
del embedder._model
|
||||
embedder._model = None
|
||||
_embedder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class Embedder(BaseEmbedder):
|
||||
"""Generate embeddings for code chunks using fastembed (ONNX-based).
|
||||
|
||||
Supported Model Profiles:
|
||||
- fast: BAAI/bge-small-en-v1.5 (384 dim) - Fast, lightweight, English-optimized
|
||||
- code: jinaai/jina-embeddings-v2-base-code (768 dim) - Code-optimized, best for programming languages
|
||||
- multilingual: intfloat/multilingual-e5-large (1024 dim) - Multilingual + code support
|
||||
- balanced: mixedbread-ai/mxbai-embed-large-v1 (1024 dim) - High accuracy, general purpose
|
||||
"""
|
||||
|
||||
# Model profiles for different use cases
|
||||
MODELS = {
|
||||
"fast": "BAAI/bge-small-en-v1.5", # 384 dim - Fast, lightweight
|
||||
"code": "jinaai/jina-embeddings-v2-base-code", # 768 dim - Code-optimized
|
||||
"multilingual": "intfloat/multilingual-e5-large", # 1024 dim - Multilingual
|
||||
"balanced": "mixedbread-ai/mxbai-embed-large-v1", # 1024 dim - High accuracy
|
||||
}
|
||||
|
||||
# Dimension mapping for each model
|
||||
MODEL_DIMS = {
|
||||
"BAAI/bge-small-en-v1.5": 384,
|
||||
"jinaai/jina-embeddings-v2-base-code": 768,
|
||||
"intfloat/multilingual-e5-large": 1024,
|
||||
"mixedbread-ai/mxbai-embed-large-v1": 1024,
|
||||
}
|
||||
|
||||
# Default model (fast profile)
|
||||
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
|
||||
DEFAULT_PROFILE = "fast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
profile: str | None = None,
|
||||
use_gpu: bool = True,
|
||||
providers: List[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize embedder with model or profile.
|
||||
|
||||
Args:
|
||||
model_name: Explicit model name (e.g., "jinaai/jina-embeddings-v2-base-code")
|
||||
profile: Model profile shortcut ("fast", "code", "multilingual", "balanced")
|
||||
If both provided, model_name takes precedence.
|
||||
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||
providers: Explicit ONNX providers list (overrides use_gpu if provided)
|
||||
"""
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Semantic search dependencies not available. "
|
||||
"Install with: pip install codexlens[semantic]"
|
||||
)
|
||||
|
||||
# Resolve model name from profile or use explicit name
|
||||
if model_name:
|
||||
self._model_name = model_name
|
||||
elif profile and profile in self.MODELS:
|
||||
self._model_name = self.MODELS[profile]
|
||||
else:
|
||||
self._model_name = self.DEFAULT_MODEL
|
||||
|
||||
# Configure ONNX execution providers with device_id options for GPU selection
|
||||
# Using with_device_options=True ensures DirectML/CUDA device_id is passed correctly
|
||||
if providers is not None:
|
||||
self._providers = providers
|
||||
else:
|
||||
self._providers = get_optimal_providers(use_gpu=use_gpu, with_device_options=True)
|
||||
|
||||
self._use_gpu = use_gpu
|
||||
self._model = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Get model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Get embedding dimension for current model."""
|
||||
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Get maximum token limit for current model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens based on model profile.
|
||||
- fast: 512 (lightweight, optimized for speed)
|
||||
- code: 8192 (code-optimized, larger context)
|
||||
- multilingual: 512 (standard multilingual model)
|
||||
- balanced: 512 (general purpose)
|
||||
"""
|
||||
# Determine profile from model name
|
||||
profile = None
|
||||
for prof, model in self.MODELS.items():
|
||||
if model == self._model_name:
|
||||
profile = prof
|
||||
break
|
||||
|
||||
# Return token limit based on profile
|
||||
if profile == "code":
|
||||
return 8192
|
||||
elif profile in ("fast", "multilingual", "balanced"):
|
||||
return 512
|
||||
else:
|
||||
# Default for unknown models
|
||||
return 512
|
||||
|
||||
@property
|
||||
def providers(self) -> List[str]:
|
||||
"""Get configured ONNX execution providers."""
|
||||
return self._providers
|
||||
|
||||
@property
|
||||
def is_gpu_enabled(self) -> bool:
|
||||
"""Check if GPU acceleration is enabled for this embedder."""
|
||||
gpu_providers = {"CUDAExecutionProvider", "TensorrtExecutionProvider",
|
||||
"DmlExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"}
|
||||
# Handle both string providers and tuple providers (name, options)
|
||||
for p in self._providers:
|
||||
provider_name = p[0] if isinstance(p, tuple) else p
|
||||
if provider_name in gpu_providers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load the embedding model with configured providers."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
# providers already include device_id options via get_optimal_providers(with_device_options=True)
|
||||
# DO NOT pass device_ids separately - fastembed ignores it when providers is specified
|
||||
# See: fastembed/text/onnx_embedding.py - device_ids is only used with cuda=True
|
||||
try:
|
||||
self._model = TextEmbedding(
|
||||
model_name=self.model_name,
|
||||
providers=self._providers,
|
||||
)
|
||||
logger.debug(f"Model loaded with providers: {self._providers}")
|
||||
except TypeError:
|
||||
# Fallback for older fastembed versions without providers parameter
|
||||
logger.warning(
|
||||
"fastembed version doesn't support 'providers' parameter. "
|
||||
"Upgrade fastembed for GPU acceleration: pip install --upgrade fastembed"
|
||||
)
|
||||
self._model = TextEmbedding(model_name=self.model_name)
|
||||
|
||||
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for one or more texts.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each is a list of floats).
|
||||
|
||||
Note:
|
||||
This method converts numpy arrays to Python lists for backward compatibility.
|
||||
For memory-efficient processing, use embed_to_numpy() instead.
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
embeddings = list(self._model.embed(texts))
|
||||
return [emb.tolist() for emb in embeddings]
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], batch_size: Optional[int] = None) -> np.ndarray:
|
||||
"""Generate embeddings for one or more texts (returns numpy arrays).
|
||||
|
||||
This method is more memory-efficient than embed() as it avoids converting
|
||||
numpy arrays to Python lists, which can significantly reduce memory usage
|
||||
during batch processing.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
batch_size: Optional batch size for fastembed processing.
|
||||
Larger values improve GPU utilization but use more memory.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Pass batch_size to fastembed for optimal GPU utilization
|
||||
# Default batch_size in fastembed is 256, but larger values can improve throughput
|
||||
if batch_size is not None:
|
||||
embeddings = list(self._model.embed(texts, batch_size=batch_size))
|
||||
else:
|
||||
embeddings = list(self._model.embed(texts))
|
||||
return np.array(embeddings)
|
||||
|
||||
def embed_single(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
return self.embed(text)[0]
|
||||
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Factory for creating embedders.
|
||||
|
||||
Provides a unified interface for instantiating different embedder backends.
|
||||
Includes caching to avoid repeated model loading overhead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
# Module-level cache for embedder instances
|
||||
# Key: (backend, profile, model, use_gpu) -> embedder instance
|
||||
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_embedder(
|
||||
backend: str = "fastembed",
|
||||
profile: str = "code",
|
||||
model: str = "default",
|
||||
use_gpu: bool = True,
|
||||
endpoints: Optional[List[Dict[str, Any]]] = None,
|
||||
strategy: str = "latency_aware",
|
||||
cooldown: float = 60.0,
|
||||
**kwargs: Any,
|
||||
) -> BaseEmbedder:
|
||||
"""Factory function to create embedder based on backend.
|
||||
|
||||
Args:
|
||||
backend: Embedder backend to use. Options:
|
||||
- "fastembed": Use fastembed (ONNX-based) embedder (default)
|
||||
- "litellm": Use ccw-litellm embedder
|
||||
profile: Model profile for fastembed backend ("fast", "code", "multilingual", "balanced")
|
||||
Used only when backend="fastembed". Default: "code"
|
||||
model: Model identifier for litellm backend.
|
||||
Used only when backend="litellm". Default: "default"
|
||||
use_gpu: Whether to use GPU acceleration when available (default: True).
|
||||
Used only when backend="fastembed".
|
||||
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
|
||||
Each endpoint is a dict with keys: model, api_key, api_base, weight.
|
||||
Used only when backend="litellm" and multiple endpoints provided.
|
||||
strategy: Selection strategy for multi-endpoint mode:
|
||||
"round_robin", "latency_aware", "weighted_random".
|
||||
Default: "latency_aware"
|
||||
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
|
||||
**kwargs: Additional backend-specific arguments
|
||||
|
||||
Returns:
|
||||
BaseEmbedder: Configured embedder instance
|
||||
|
||||
Raises:
|
||||
ValueError: If backend is not recognized
|
||||
ImportError: If required backend dependencies are not installed
|
||||
|
||||
Examples:
|
||||
Create fastembed embedder with code profile:
|
||||
>>> embedder = get_embedder(backend="fastembed", profile="code")
|
||||
|
||||
Create fastembed embedder with fast profile and CPU only:
|
||||
>>> embedder = get_embedder(backend="fastembed", profile="fast", use_gpu=False)
|
||||
|
||||
Create litellm embedder:
|
||||
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
|
||||
|
||||
Create rotational embedder with multiple endpoints:
|
||||
>>> endpoints = [
|
||||
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||
... ]
|
||||
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
||||
"""
|
||||
# Build cache key from immutable configuration
|
||||
if backend == "fastembed":
|
||||
cache_key = ("fastembed", profile, None, use_gpu)
|
||||
elif backend == "litellm":
|
||||
# For litellm, use model as part of cache key
|
||||
# Multi-endpoint mode is not cached as it's more complex
|
||||
if endpoints and len(endpoints) > 1:
|
||||
cache_key = None # Skip cache for multi-endpoint
|
||||
else:
|
||||
effective_model = endpoints[0]["model"] if endpoints else model
|
||||
cache_key = ("litellm", None, effective_model, None)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
# Check cache first (thread-safe)
|
||||
if cache_key is not None:
|
||||
with _cache_lock:
|
||||
if cache_key in _embedder_cache:
|
||||
_logger.debug("Returning cached embedder for %s", cache_key)
|
||||
return _embedder_cache[cache_key]
|
||||
|
||||
# Create new embedder instance
|
||||
embedder: Optional[BaseEmbedder] = None
|
||||
|
||||
if backend == "fastembed":
|
||||
from .embedder import Embedder
|
||||
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||
elif backend == "litellm":
|
||||
# Check if multi-endpoint mode is requested
|
||||
if endpoints and len(endpoints) > 1:
|
||||
from .rotational_embedder import create_rotational_embedder
|
||||
# Multi-endpoint is not cached
|
||||
return create_rotational_embedder(
|
||||
endpoints_config=endpoints,
|
||||
strategy=strategy,
|
||||
default_cooldown=cooldown,
|
||||
)
|
||||
elif endpoints and len(endpoints) == 1:
|
||||
# Single endpoint in list - use it directly
|
||||
ep = endpoints[0]
|
||||
ep_kwargs = {**kwargs}
|
||||
if "api_key" in ep:
|
||||
ep_kwargs["api_key"] = ep["api_key"]
|
||||
if "api_base" in ep:
|
||||
ep_kwargs["api_base"] = ep["api_base"]
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||
else:
|
||||
# No endpoints list - use model parameter
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Supported backends: 'fastembed', 'litellm'"
|
||||
)
|
||||
|
||||
# Cache the embedder for future use (thread-safe)
|
||||
if cache_key is not None and embedder is not None:
|
||||
with _cache_lock:
|
||||
# Double-check to avoid race condition
|
||||
if cache_key not in _embedder_cache:
|
||||
_embedder_cache[cache_key] = embedder
|
||||
_logger.debug("Cached new embedder for %s", cache_key)
|
||||
else:
|
||||
# Another thread created it already, use that one
|
||||
embedder = _embedder_cache[cache_key]
|
||||
|
||||
return embedder # type: ignore
|
||||
|
||||
|
||||
def clear_embedder_cache() -> int:
|
||||
"""Clear the embedder cache.
|
||||
|
||||
Returns:
|
||||
Number of embedders cleared from cache
|
||||
"""
|
||||
with _cache_lock:
|
||||
count = len(_embedder_cache)
|
||||
_embedder_cache.clear()
|
||||
_logger.debug("Cleared %d embedders from cache", count)
|
||||
return count
|
||||
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""GPU acceleration support for semantic embeddings.
|
||||
|
||||
This module provides GPU detection, initialization, and fallback handling
|
||||
for ONNX-based embedding generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUDevice:
|
||||
"""Individual GPU device info."""
|
||||
device_id: int
|
||||
name: str
|
||||
is_discrete: bool # True for discrete GPU (NVIDIA, AMD), False for integrated (Intel UHD)
|
||||
vendor: str # "nvidia", "amd", "intel", "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUInfo:
|
||||
"""GPU availability and configuration info."""
|
||||
|
||||
gpu_available: bool = False
|
||||
cuda_available: bool = False
|
||||
gpu_count: int = 0
|
||||
gpu_name: Optional[str] = None
|
||||
onnx_providers: List[str] = None
|
||||
devices: List[GPUDevice] = None # List of detected GPU devices
|
||||
preferred_device_id: Optional[int] = None # Preferred GPU for embedding
|
||||
|
||||
def __post_init__(self):
|
||||
if self.onnx_providers is None:
|
||||
self.onnx_providers = ["CPUExecutionProvider"]
|
||||
if self.devices is None:
|
||||
self.devices = []
|
||||
|
||||
|
||||
_gpu_info_cache: Optional[GPUInfo] = None
|
||||
|
||||
|
||||
def _enumerate_gpus() -> List[GPUDevice]:
|
||||
"""Enumerate available GPU devices using WMI on Windows.
|
||||
|
||||
Returns:
|
||||
List of GPUDevice with device info, ordered by device_id.
|
||||
"""
|
||||
devices = []
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
if sys.platform == "win32":
|
||||
# Use PowerShell to query GPU information via WMI
|
||||
cmd = [
|
||||
"powershell", "-NoProfile", "-Command",
|
||||
"Get-WmiObject Win32_VideoController | Select-Object DeviceID, Name, AdapterCompatibility | ConvertTo-Json"
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
import json
|
||||
gpu_data = json.loads(result.stdout)
|
||||
|
||||
# Handle single GPU case (returns dict instead of list)
|
||||
if isinstance(gpu_data, dict):
|
||||
gpu_data = [gpu_data]
|
||||
|
||||
for idx, gpu in enumerate(gpu_data):
|
||||
name = gpu.get("Name", "Unknown GPU")
|
||||
compat = gpu.get("AdapterCompatibility", "").lower()
|
||||
|
||||
# Determine vendor
|
||||
name_lower = name.lower()
|
||||
if "nvidia" in name_lower or "nvidia" in compat:
|
||||
vendor = "nvidia"
|
||||
is_discrete = True
|
||||
elif "amd" in name_lower or "radeon" in name_lower or "amd" in compat:
|
||||
vendor = "amd"
|
||||
is_discrete = True
|
||||
elif "intel" in name_lower or "intel" in compat:
|
||||
vendor = "intel"
|
||||
# Intel UHD/Iris are integrated, Intel Arc is discrete
|
||||
is_discrete = "arc" in name_lower
|
||||
else:
|
||||
vendor = "unknown"
|
||||
is_discrete = False
|
||||
|
||||
devices.append(GPUDevice(
|
||||
device_id=idx,
|
||||
name=name,
|
||||
is_discrete=is_discrete,
|
||||
vendor=vendor
|
||||
))
|
||||
logger.debug(f"Detected GPU {idx}: {name} (vendor={vendor}, discrete={is_discrete})")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"GPU enumeration failed: {e}")
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def _get_preferred_device_id(devices: List[GPUDevice]) -> Optional[int]:
|
||||
"""Determine the preferred GPU device_id for embedding.
|
||||
|
||||
Preference order:
|
||||
1. NVIDIA discrete GPU (best DirectML/CUDA support)
|
||||
2. AMD discrete GPU
|
||||
3. Intel Arc (discrete)
|
||||
4. Intel integrated (fallback)
|
||||
|
||||
Returns:
|
||||
device_id of preferred GPU, or None to use default.
|
||||
"""
|
||||
if not devices:
|
||||
return None
|
||||
|
||||
# Priority: NVIDIA > AMD > Intel Arc > Intel integrated
|
||||
priority_order = [
|
||||
("nvidia", True), # NVIDIA discrete
|
||||
("amd", True), # AMD discrete
|
||||
("intel", True), # Intel Arc (discrete)
|
||||
("intel", False), # Intel integrated (fallback)
|
||||
]
|
||||
|
||||
for target_vendor, target_discrete in priority_order:
|
||||
for device in devices:
|
||||
if device.vendor == target_vendor and device.is_discrete == target_discrete:
|
||||
logger.info(f"Preferred GPU: {device.name} (device_id={device.device_id})")
|
||||
return device.device_id
|
||||
|
||||
# If no match, use first device
|
||||
if devices:
|
||||
return devices[0].device_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def detect_gpu(force_refresh: bool = False) -> GPUInfo:
|
||||
"""Detect available GPU resources for embedding acceleration.
|
||||
|
||||
Args:
|
||||
force_refresh: If True, re-detect GPU even if cached.
|
||||
|
||||
Returns:
|
||||
GPUInfo with detection results.
|
||||
"""
|
||||
global _gpu_info_cache
|
||||
|
||||
if _gpu_info_cache is not None and not force_refresh:
|
||||
return _gpu_info_cache
|
||||
|
||||
info = GPUInfo()
|
||||
|
||||
# Enumerate GPU devices first
|
||||
info.devices = _enumerate_gpus()
|
||||
info.gpu_count = len(info.devices)
|
||||
if info.devices:
|
||||
# Set preferred device (discrete GPU preferred over integrated)
|
||||
info.preferred_device_id = _get_preferred_device_id(info.devices)
|
||||
# Set gpu_name to preferred device name
|
||||
for dev in info.devices:
|
||||
if dev.device_id == info.preferred_device_id:
|
||||
info.gpu_name = dev.name
|
||||
break
|
||||
|
||||
# Check PyTorch CUDA availability (most reliable detection)
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
info.cuda_available = True
|
||||
info.gpu_available = True
|
||||
info.gpu_count = torch.cuda.device_count()
|
||||
if info.gpu_count > 0:
|
||||
info.gpu_name = torch.cuda.get_device_name(0)
|
||||
logger.debug(f"PyTorch CUDA detected: {info.gpu_count} GPU(s)")
|
||||
except ImportError:
|
||||
logger.debug("PyTorch not available for GPU detection")
|
||||
|
||||
# Check ONNX Runtime providers with validation
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
|
||||
# Build provider list with priority order
|
||||
providers = []
|
||||
|
||||
# Test each provider to ensure it actually works
|
||||
def test_provider(provider_name: str) -> bool:
|
||||
"""Test if a provider actually works by creating a dummy session."""
|
||||
try:
|
||||
# Create a minimal ONNX model to test provider
|
||||
import numpy as np
|
||||
# Simple test: just check if provider can be instantiated
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.log_severity_level = 4 # Suppress warnings
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# CUDA provider (NVIDIA GPU) - check if CUDA runtime is available
|
||||
if "CUDAExecutionProvider" in available_providers:
|
||||
# Verify CUDA is actually usable by checking for cuBLAS
|
||||
cuda_works = False
|
||||
try:
|
||||
import ctypes
|
||||
# Try to load cuBLAS to verify CUDA installation
|
||||
try:
|
||||
ctypes.CDLL("cublas64_12.dll")
|
||||
cuda_works = True
|
||||
except OSError:
|
||||
try:
|
||||
ctypes.CDLL("cublas64_11.dll")
|
||||
cuda_works = True
|
||||
except OSError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if cuda_works:
|
||||
providers.append("CUDAExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX CUDAExecutionProvider available and working")
|
||||
else:
|
||||
logger.debug("ONNX CUDAExecutionProvider listed but CUDA runtime not found")
|
||||
|
||||
# TensorRT provider (optimized NVIDIA inference)
|
||||
if "TensorrtExecutionProvider" in available_providers:
|
||||
# TensorRT requires additional libraries, skip for now
|
||||
logger.debug("ONNX TensorrtExecutionProvider available (requires TensorRT SDK)")
|
||||
|
||||
# DirectML provider (Windows GPU - AMD/Intel/NVIDIA)
|
||||
if "DmlExecutionProvider" in available_providers:
|
||||
providers.append("DmlExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX DmlExecutionProvider available (DirectML)")
|
||||
|
||||
# ROCm provider (AMD GPU on Linux)
|
||||
if "ROCMExecutionProvider" in available_providers:
|
||||
providers.append("ROCMExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX ROCMExecutionProvider available (AMD)")
|
||||
|
||||
# CoreML provider (Apple Silicon)
|
||||
if "CoreMLExecutionProvider" in available_providers:
|
||||
providers.append("CoreMLExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX CoreMLExecutionProvider available (Apple)")
|
||||
|
||||
# Always include CPU as fallback
|
||||
providers.append("CPUExecutionProvider")
|
||||
|
||||
info.onnx_providers = providers
|
||||
|
||||
except ImportError:
|
||||
logger.debug("ONNX Runtime not available")
|
||||
info.onnx_providers = ["CPUExecutionProvider"]
|
||||
|
||||
_gpu_info_cache = info
|
||||
return info
|
||||
|
||||
|
||||
def get_optimal_providers(use_gpu: bool = True, with_device_options: bool = False) -> list:
|
||||
"""Get optimal ONNX execution providers based on availability.
|
||||
|
||||
Args:
|
||||
use_gpu: If True, include GPU providers when available.
|
||||
If False, force CPU-only execution.
|
||||
with_device_options: If True, return providers as tuples with device_id options
|
||||
for proper GPU device selection (required for DirectML).
|
||||
|
||||
Returns:
|
||||
List of provider names or tuples (provider_name, options_dict) in priority order.
|
||||
"""
|
||||
if not use_gpu:
|
||||
return ["CPUExecutionProvider"]
|
||||
|
||||
gpu_info = detect_gpu()
|
||||
|
||||
# Check if GPU was requested but not available - log warning
|
||||
if not gpu_info.gpu_available:
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
except ImportError:
|
||||
available_providers = []
|
||||
logger.warning(
|
||||
"GPU acceleration was requested, but no supported GPU provider (CUDA, DirectML) "
|
||||
f"was found. Available providers: {available_providers}. Falling back to CPU."
|
||||
)
|
||||
else:
|
||||
# Log which GPU provider is being used
|
||||
gpu_providers = [p for p in gpu_info.onnx_providers if p != "CPUExecutionProvider"]
|
||||
if gpu_providers:
|
||||
logger.info(f"Using {gpu_providers[0]} for ONNX GPU acceleration")
|
||||
|
||||
if not with_device_options:
|
||||
return gpu_info.onnx_providers
|
||||
|
||||
# Build providers with device_id options for GPU providers
|
||||
device_id = get_selected_device_id()
|
||||
providers = []
|
||||
|
||||
for provider in gpu_info.onnx_providers:
|
||||
if provider == "DmlExecutionProvider" and device_id is not None:
|
||||
# DirectML requires device_id in provider_options tuple
|
||||
providers.append(("DmlExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"DmlExecutionProvider configured with device_id={device_id}")
|
||||
elif provider == "CUDAExecutionProvider" and device_id is not None:
|
||||
# CUDA also supports device_id in provider_options
|
||||
providers.append(("CUDAExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"CUDAExecutionProvider configured with device_id={device_id}")
|
||||
elif provider == "ROCMExecutionProvider" and device_id is not None:
|
||||
# ROCm supports device_id
|
||||
providers.append(("ROCMExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"ROCMExecutionProvider configured with device_id={device_id}")
|
||||
else:
|
||||
# CPU and other providers don't need device_id
|
||||
providers.append(provider)
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
def is_gpu_available() -> bool:
|
||||
"""Check if any GPU acceleration is available."""
|
||||
return detect_gpu().gpu_available
|
||||
|
||||
|
||||
def get_gpu_summary() -> str:
|
||||
"""Get human-readable GPU status summary."""
|
||||
info = detect_gpu()
|
||||
|
||||
if not info.gpu_available:
|
||||
return "GPU: Not available (using CPU)"
|
||||
|
||||
parts = []
|
||||
if info.gpu_name:
|
||||
parts.append(f"GPU: {info.gpu_name}")
|
||||
if info.gpu_count > 1:
|
||||
parts.append(f"({info.gpu_count} devices)")
|
||||
|
||||
# Show active providers (excluding CPU fallback)
|
||||
gpu_providers = [p for p in info.onnx_providers if p != "CPUExecutionProvider"]
|
||||
if gpu_providers:
|
||||
parts.append(f"Providers: {', '.join(gpu_providers)}")
|
||||
|
||||
return " | ".join(parts) if parts else "GPU: Available"
|
||||
|
||||
|
||||
def clear_gpu_cache() -> None:
|
||||
"""Clear cached GPU detection info."""
|
||||
global _gpu_info_cache
|
||||
_gpu_info_cache = None
|
||||
|
||||
|
||||
# User-selected device ID (overrides auto-detection)
|
||||
_selected_device_id: Optional[int] = None
|
||||
|
||||
|
||||
def get_gpu_devices() -> List[dict]:
|
||||
"""Get list of available GPU devices for frontend selection.
|
||||
|
||||
Returns:
|
||||
List of dicts with device info for each GPU.
|
||||
"""
|
||||
info = detect_gpu()
|
||||
devices = []
|
||||
|
||||
for dev in info.devices:
|
||||
devices.append({
|
||||
"device_id": dev.device_id,
|
||||
"name": dev.name,
|
||||
"vendor": dev.vendor,
|
||||
"is_discrete": dev.is_discrete,
|
||||
"is_preferred": dev.device_id == info.preferred_device_id,
|
||||
"is_selected": dev.device_id == get_selected_device_id(),
|
||||
})
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def get_selected_device_id() -> Optional[int]:
|
||||
"""Get the user-selected GPU device_id.
|
||||
|
||||
Returns:
|
||||
User-selected device_id, or auto-detected preferred device_id if not set.
|
||||
"""
|
||||
global _selected_device_id
|
||||
|
||||
if _selected_device_id is not None:
|
||||
return _selected_device_id
|
||||
|
||||
# Fall back to auto-detected preferred device
|
||||
info = detect_gpu()
|
||||
return info.preferred_device_id
|
||||
|
||||
|
||||
def set_selected_device_id(device_id: Optional[int]) -> bool:
|
||||
"""Set the GPU device_id to use for embeddings.
|
||||
|
||||
Args:
|
||||
device_id: GPU device_id to use, or None to use auto-detection.
|
||||
|
||||
Returns:
|
||||
True if device_id is valid, False otherwise.
|
||||
"""
|
||||
global _selected_device_id
|
||||
|
||||
if device_id is None:
|
||||
_selected_device_id = None
|
||||
logger.info("GPU selection reset to auto-detection")
|
||||
return True
|
||||
|
||||
# Validate device_id exists
|
||||
info = detect_gpu()
|
||||
valid_ids = [dev.device_id for dev in info.devices]
|
||||
|
||||
if device_id in valid_ids:
|
||||
_selected_device_id = device_id
|
||||
device_name = next((dev.name for dev in info.devices if dev.device_id == device_id), "Unknown")
|
||||
logger.info(f"GPU selection set to device {device_id}: {device_name}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Invalid device_id {device_id}. Valid IDs: {valid_ids}")
|
||||
return False
|
||||
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""LiteLLM embedder wrapper for CodexLens.
|
||||
|
||||
Provides integration with ccw-litellm's LiteLLMEmbedder for embedding generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
|
||||
class LiteLLMEmbedderWrapper(BaseEmbedder):
|
||||
"""Wrapper for ccw-litellm LiteLLMEmbedder.
|
||||
|
||||
This wrapper adapts the ccw-litellm LiteLLMEmbedder to the CodexLens
|
||||
BaseEmbedder interface, enabling seamless integration with CodexLens
|
||||
semantic search functionality.
|
||||
|
||||
Args:
|
||||
model: Model identifier for LiteLLM (default: "default")
|
||||
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm package is not installed
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "default", **kwargs) -> None:
|
||||
"""Initialize LiteLLM embedder wrapper.
|
||||
|
||||
Args:
|
||||
model: Model identifier for LiteLLM (default: "default")
|
||||
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm package is not installed
|
||||
"""
|
||||
try:
|
||||
from ccw_litellm import LiteLLMEmbedder
|
||||
self._embedder = LiteLLMEmbedder(model=model, **kwargs)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions from LiteLLMEmbedder.
|
||||
|
||||
Returns:
|
||||
int: Dimension of the embedding vectors.
|
||||
"""
|
||||
return self._embedder.dimensions
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name from LiteLLMEmbedder.
|
||||
|
||||
Returns:
|
||||
str: Name or identifier of the underlying model.
|
||||
"""
|
||||
return self._embedder.model_name
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for the embedding model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Reads from LiteLLM config's max_input_tokens property.
|
||||
"""
|
||||
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
|
||||
if hasattr(self._embedder, 'max_input_tokens'):
|
||||
return self._embedder.max_input_tokens
|
||||
|
||||
# Fallback: infer from model name
|
||||
model_name_lower = self.model_name.lower()
|
||||
|
||||
# Large models (8B or "large" in name)
|
||||
if '8b' in model_name_lower or 'large' in model_name_lower:
|
||||
return 32768
|
||||
|
||||
# OpenAI text-embedding-3-* models
|
||||
if 'text-embedding-3' in model_name_lower:
|
||||
return 8191
|
||||
|
||||
# Default fallback
|
||||
return 8192
|
||||
|
||||
def _sanitize_text(self, text: str) -> str:
|
||||
"""Sanitize text to work around ModelScope API routing bug.
|
||||
|
||||
ModelScope incorrectly routes text starting with lowercase 'import'
|
||||
to an Ollama endpoint, causing failures. This adds a leading space
|
||||
to work around the issue without affecting embedding quality.
|
||||
|
||||
Args:
|
||||
text: Text to sanitize.
|
||||
|
||||
Returns:
|
||||
Sanitized text safe for embedding API.
|
||||
"""
|
||||
if text.startswith('import'):
|
||||
return ' ' + text
|
||||
return text
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||
"""Embed texts to numpy array using LiteLLMEmbedder.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
**kwargs: Additional arguments (ignored for LiteLLM backend).
|
||||
Accepts batch_size for API compatibility with fastembed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Sanitize texts to avoid ModelScope routing bug
|
||||
texts = [self._sanitize_text(t) for t in texts]
|
||||
|
||||
# LiteLLM handles batching internally, ignore batch_size parameter
|
||||
return self._embedder.embed(texts)
|
||||
|
||||
def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
list[float]: Embedding vector as a list of floats.
|
||||
"""
|
||||
# Sanitize text before embedding
|
||||
sanitized = self._sanitize_text(text)
|
||||
embedding = self._embedder.embed([sanitized])
|
||||
return embedding[0].tolist()
|
||||
|
||||
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Reranker backends for second-stage search ranking.
|
||||
|
||||
This subpackage provides a unified interface and factory for different reranking
|
||||
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import BaseReranker
|
||||
from .factory import check_reranker_available, get_reranker
|
||||
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
|
||||
from .legacy import CrossEncoderReranker, check_cross_encoder_available
|
||||
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
|
||||
|
||||
__all__ = [
|
||||
"BaseReranker",
|
||||
"check_reranker_available",
|
||||
"get_reranker",
|
||||
"CrossEncoderReranker",
|
||||
"check_cross_encoder_available",
|
||||
"FastEmbedReranker",
|
||||
"check_fastembed_reranker_available",
|
||||
"ONNXReranker",
|
||||
"check_onnx_reranker_available",
|
||||
]
|
||||
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""API-based reranker using a remote HTTP provider.
|
||||
|
||||
Supported providers:
|
||||
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
|
||||
- Cohere: https://api.cohere.ai/v1/rerank
|
||||
- Jina: https://api.jina.ai/v1/rerank
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
|
||||
|
||||
|
||||
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback."""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Try loading from .env files
|
||||
try:
|
||||
from codexlens.env_config import get_env
|
||||
return get_env(key, workspace_root=workspace_root)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def check_httpx_available() -> tuple[bool, str | None]:
|
||||
try:
|
||||
import httpx # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"httpx not available: {exc}. Install with: pip install httpx"
|
||||
return True, None
|
||||
|
||||
|
||||
class APIReranker(BaseReranker):
|
||||
"""Reranker backed by a remote reranking HTTP API."""
|
||||
|
||||
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
|
||||
"siliconflow": {
|
||||
"api_base": "https://api.siliconflow.cn",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "BAAI/bge-reranker-v2-m3",
|
||||
},
|
||||
"cohere": {
|
||||
"api_base": "https://api.cohere.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "rerank-english-v3.0",
|
||||
},
|
||||
"jina": {
|
||||
"api_base": "https://api.jina.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "jina-reranker-v2-base-multilingual",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: str = "siliconflow",
|
||||
model_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
backoff_base_s: float = 0.5,
|
||||
backoff_max_s: float = 8.0,
|
||||
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||
workspace_root: Path | str | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> None:
|
||||
ok, err = check_httpx_available()
|
||||
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||
raise ImportError(err)
|
||||
|
||||
import httpx
|
||||
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
|
||||
self.provider = (provider or "").strip().lower()
|
||||
if self.provider not in self._PROVIDER_DEFAULTS:
|
||||
raise ValueError(
|
||||
f"Unknown reranker provider: {provider}. "
|
||||
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
|
||||
)
|
||||
|
||||
defaults = self._PROVIDER_DEFAULTS[self.provider]
|
||||
|
||||
# Load api_base from env with .env fallback
|
||||
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
|
||||
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
|
||||
self.endpoint = defaults["endpoint"]
|
||||
|
||||
# Load model from env with .env fallback
|
||||
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
|
||||
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
# Load API key from env with .env fallback
|
||||
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
|
||||
resolved_key = resolved_key.strip()
|
||||
if not resolved_key:
|
||||
raise ValueError(
|
||||
f"Missing API key for reranker provider '{self.provider}'. "
|
||||
f"Pass api_key=... or set ${env_api_key}."
|
||||
)
|
||||
self._api_key = resolved_key
|
||||
|
||||
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
|
||||
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
|
||||
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
|
||||
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.provider == "cohere":
|
||||
headers.setdefault("Cohere-Version", "2022-12-06")
|
||||
|
||||
self._client = httpx.Client(
|
||||
base_url=self.api_base,
|
||||
headers=headers,
|
||||
timeout=self.timeout_s,
|
||||
)
|
||||
|
||||
# Store max_input_tokens with model-aware defaults
|
||||
if max_input_tokens is not None:
|
||||
self._max_input_tokens = max_input_tokens
|
||||
else:
|
||||
# Infer from model name
|
||||
model_lower = self.model_name.lower()
|
||||
if '8b' in model_lower or 'large' in model_lower:
|
||||
self._max_input_tokens = 32768
|
||||
else:
|
||||
self._max_input_tokens = 8192
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking."""
|
||||
return self._max_input_tokens
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return
|
||||
|
||||
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
|
||||
if retry_after_s is not None and retry_after_s > 0:
|
||||
time.sleep(min(float(retry_after_s), self.backoff_max_s))
|
||||
return
|
||||
|
||||
exp = self.backoff_base_s * (2**attempt)
|
||||
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
|
||||
time.sleep(min(self.backoff_max_s, exp + jitter))
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
|
||||
value = (headers.get("Retry-After") or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_retry_status(status_code: int) -> bool:
|
||||
return status_code == 429 or 500 <= status_code <= 599
|
||||
|
||||
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
response = self._client.post(self.endpoint, json=dict(payload))
|
||||
except Exception as exc: # httpx is optional at import-time
|
||||
last_exc = exc
|
||||
if attempt < self.max_retries:
|
||||
self._sleep_backoff(attempt)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' after "
|
||||
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
status = int(getattr(response, "status_code", 0) or 0)
|
||||
if status >= 400:
|
||||
body_preview = ""
|
||||
try:
|
||||
body_preview = (response.text or "").strip()
|
||||
except Exception:
|
||||
body_preview = ""
|
||||
if len(body_preview) > 300:
|
||||
body_preview = body_preview[:300] + "…"
|
||||
|
||||
if self._should_retry_status(status) and attempt < self.max_retries:
|
||||
retry_after = self._parse_retry_after_seconds(response.headers)
|
||||
logger.warning(
|
||||
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
|
||||
self.api_base,
|
||||
self.endpoint,
|
||||
status,
|
||||
attempt + 1,
|
||||
self.max_retries + 1,
|
||||
)
|
||||
self._sleep_backoff(attempt, retry_after_s=retry_after)
|
||||
continue
|
||||
|
||||
if status in {401, 403}:
|
||||
raise RuntimeError(
|
||||
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
|
||||
"Check your API key."
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
|
||||
f"Response: {body_preview or '<empty>'}"
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' is not valid JSON: "
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' must be a JSON object; "
|
||||
f"got {type(data).__name__}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
|
||||
if not isinstance(results, list):
|
||||
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(expected)]
|
||||
filled = 0
|
||||
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score"))
|
||||
if idx is None or score is None:
|
||||
continue
|
||||
try:
|
||||
idx_int = int(idx)
|
||||
score_f = float(score)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if 0 <= idx_int < expected:
|
||||
scores[idx_int] = score_f
|
||||
filled += 1
|
||||
|
||||
if filled != expected:
|
||||
raise RuntimeError(
|
||||
f"Rerank response contained {filled}/{expected} scored documents; "
|
||||
"ensure top_n matches the number of documents."
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": list(documents),
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
return payload
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count using fast heuristic.
|
||||
|
||||
Uses len(text) // 4 as approximation (~4 chars per token for English).
|
||||
Not perfectly accurate for all models/languages but sufficient for
|
||||
batch sizing decisions where exact counts aren't critical.
|
||||
"""
|
||||
return len(text) // 4
|
||||
|
||||
def _create_token_aware_batches(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> list[list[tuple[int, str]]]:
|
||||
"""Split documents into batches that fit within token limits.
|
||||
|
||||
Uses 90% of max_input_tokens as safety margin.
|
||||
Each batch includes the query tokens overhead.
|
||||
"""
|
||||
max_tokens = int(self._max_input_tokens * 0.9)
|
||||
query_tokens = self._estimate_tokens(query)
|
||||
|
||||
batches: list[list[tuple[int, str]]] = []
|
||||
current_batch: list[tuple[int, str]] = []
|
||||
current_tokens = query_tokens # Start with query overhead
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_tokens = self._estimate_tokens(doc)
|
||||
|
||||
# Warn if single document exceeds token limit (will be truncated by API)
|
||||
if doc_tokens > max_tokens - query_tokens:
|
||||
logger.warning(
|
||||
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
|
||||
f"(limit: {max_tokens - query_tokens} after query overhead). "
|
||||
"Document will likely be truncated by the API."
|
||||
)
|
||||
|
||||
# If batch would exceed limit, start new batch
|
||||
if current_tokens + doc_tokens > max_tokens and current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
current_tokens = query_tokens
|
||||
|
||||
current_batch.append((idx, doc))
|
||||
current_tokens += doc_tokens
|
||||
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# Create token-aware batches
|
||||
batches = self._create_token_aware_batches(query, documents)
|
||||
|
||||
if len(batches) == 1:
|
||||
# Single batch - original behavior
|
||||
payload = self._build_payload(query=query, documents=documents)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
return self._extract_scores_from_results(results, expected=len(documents))
|
||||
|
||||
# Multiple batches - process each and merge results
|
||||
logger.info(
|
||||
f"Splitting {len(documents)} documents into {len(batches)} batches "
|
||||
f"(max_input_tokens: {self._max_input_tokens})"
|
||||
)
|
||||
|
||||
all_scores: list[float] = [0.0] * len(documents)
|
||||
|
||||
for batch in batches:
|
||||
batch_docs = [doc for _, doc in batch]
|
||||
payload = self._build_payload(query=query, documents=batch_docs)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
|
||||
|
||||
# Map scores back to original indices
|
||||
for (orig_idx, _), score in zip(batch, batch_scores):
|
||||
all_scores[orig_idx] = score
|
||||
|
||||
return all_scores
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
|
||||
) -> list[float]:
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
grouped: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
grouped.setdefault(str(query), []).append((idx, str(doc)))
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(len(pairs))]
|
||||
|
||||
for query, items in grouped.items():
|
||||
documents = [doc for _, doc in items]
|
||||
query_scores = self._rerank_one_query(query=query, documents=documents)
|
||||
for (orig_idx, _), score in zip(items, query_scores):
|
||||
scores[orig_idx] = float(score)
|
||||
|
||||
return scores
|
||||
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Base class for rerankers.
|
||||
|
||||
Defines the interface that all rerankers must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
"""Base class for all rerankers.
|
||||
|
||||
All reranker implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be processed at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair).
|
||||
"""
|
||||
...
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user