mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-12 02:37:45 +08:00
Refactor code structure and remove redundant changes
This commit is contained in:
@@ -401,7 +401,10 @@ async function executeCommandChain(chain, analysis) {
|
|||||||
state.updated_at = new Date().toISOString();
|
state.updated_at = new Date().toISOString();
|
||||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||||
|
|
||||||
// Assemble prompt with previous results
|
// Assemble prompt: Task context + Command instruction
|
||||||
|
let promptContent = formatCommand(cmd, state.execution_results, analysis);
|
||||||
|
|
||||||
|
// Build full prompt with context
|
||||||
let prompt = `Task: ${analysis.goal}\n`;
|
let prompt = `Task: ${analysis.goal}\n`;
|
||||||
if (state.execution_results.length > 0) {
|
if (state.execution_results.length > 0) {
|
||||||
prompt += '\nPrevious results:\n';
|
prompt += '\nPrevious results:\n';
|
||||||
@@ -411,7 +414,7 @@ async function executeCommandChain(chain, analysis) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
prompt += `\n${formatCommand(cmd, state.execution_results, analysis)}\n`;
|
prompt += `\n${promptContent}`;
|
||||||
|
|
||||||
// Record prompt used
|
// Record prompt used
|
||||||
state.prompts_used.push({
|
state.prompts_used.push({
|
||||||
@@ -421,9 +424,12 @@ async function executeCommandChain(chain, analysis) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Execute CLI command in background and stop
|
// Execute CLI command in background and stop
|
||||||
|
// Format: ccw cli -p "PROMPT" --tool <tool> --mode <mode>
|
||||||
|
// Note: -y is a command parameter INSIDE the prompt, not a ccw cli parameter
|
||||||
|
// Example prompt: "/workflow:plan -y \"task description here\""
|
||||||
try {
|
try {
|
||||||
const taskId = Bash(
|
const taskId = Bash(
|
||||||
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write -y`,
|
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write`,
|
||||||
{ run_in_background: true }
|
{ run_in_background: true }
|
||||||
).task_id;
|
).task_id;
|
||||||
|
|
||||||
@@ -486,69 +492,71 @@ async function executeCommandChain(chain, analysis) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Smart parameter assembly
|
// Smart parameter assembly
|
||||||
|
// Returns prompt content to be used with: ccw cli -p "RETURNED_VALUE" --tool claude --mode write
|
||||||
function formatCommand(cmd, previousResults, analysis) {
|
function formatCommand(cmd, previousResults, analysis) {
|
||||||
let line = cmd.command + ' --yes';
|
// Format: /workflow:<command> -y <parameters>
|
||||||
|
let prompt = `/workflow:${cmd.name} -y`;
|
||||||
const name = cmd.name;
|
const name = cmd.name;
|
||||||
|
|
||||||
// Planning commands - take task description
|
// Planning commands - take task description
|
||||||
if (['lite-plan', 'plan', 'tdd-plan', 'multi-cli-plan'].includes(name)) {
|
if (['lite-plan', 'plan', 'tdd-plan', 'multi-cli-plan'].includes(name)) {
|
||||||
line += ` "${analysis.goal}"`;
|
prompt += ` "${analysis.goal}"`;
|
||||||
|
|
||||||
// Lite execution - use --in-memory if plan exists
|
// Lite execution - use --in-memory if plan exists
|
||||||
} else if (name === 'lite-execute') {
|
} else if (name === 'lite-execute') {
|
||||||
const hasPlan = previousResults.some(r => r.command.includes('plan'));
|
const hasPlan = previousResults.some(r => r.command.includes('plan'));
|
||||||
line += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
|
prompt += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
|
||||||
|
|
||||||
// Standard execution - resume from planning session
|
// Standard execution - resume from planning session
|
||||||
} else if (name === 'execute') {
|
} else if (name === 'execute') {
|
||||||
const plan = previousResults.find(r => r.command.includes('plan'));
|
const plan = previousResults.find(r => r.command.includes('plan'));
|
||||||
if (plan?.session_id) line += ` --resume-session="${plan.session_id}"`;
|
if (plan?.session_id) prompt += ` --resume-session="${plan.session_id}"`;
|
||||||
|
|
||||||
// Bug fix commands - take bug description
|
// Bug fix commands - take bug description
|
||||||
} else if (['lite-fix', 'debug'].includes(name)) {
|
} else if (['lite-fix', 'debug'].includes(name)) {
|
||||||
line += ` "${analysis.goal}"`;
|
prompt += ` "${analysis.goal}"`;
|
||||||
|
|
||||||
// Brainstorm - take topic description
|
// Brainstorm - take topic description
|
||||||
} else if (name === 'brainstorm:auto-parallel' || name === 'auto-parallel') {
|
} else if (name === 'brainstorm:auto-parallel' || name === 'auto-parallel') {
|
||||||
line += ` "${analysis.goal}"`;
|
prompt += ` "${analysis.goal}"`;
|
||||||
|
|
||||||
// Test generation from session - needs source session
|
// Test generation from session - needs source session
|
||||||
} else if (name === 'test-gen') {
|
} else if (name === 'test-gen') {
|
||||||
const impl = previousResults.find(r =>
|
const impl = previousResults.find(r =>
|
||||||
r.command.includes('execute') || r.command.includes('lite-execute')
|
r.command.includes('execute') || r.command.includes('lite-execute')
|
||||||
);
|
);
|
||||||
if (impl?.session_id) line += ` "${impl.session_id}"`;
|
if (impl?.session_id) prompt += ` "${impl.session_id}"`;
|
||||||
else line += ` "${analysis.goal}"`;
|
else prompt += ` "${analysis.goal}"`;
|
||||||
|
|
||||||
// Test fix generation - session or description
|
// Test fix generation - session or description
|
||||||
} else if (name === 'test-fix-gen') {
|
} else if (name === 'test-fix-gen') {
|
||||||
const latest = previousResults.filter(r => r.session_id).pop();
|
const latest = previousResults.filter(r => r.session_id).pop();
|
||||||
if (latest?.session_id) line += ` "${latest.session_id}"`;
|
if (latest?.session_id) prompt += ` "${latest.session_id}"`;
|
||||||
else line += ` "${analysis.goal}"`;
|
else prompt += ` "${analysis.goal}"`;
|
||||||
|
|
||||||
// Review commands - take session or use latest
|
// Review commands - take session or use latest
|
||||||
} else if (name === 'review') {
|
} else if (name === 'review') {
|
||||||
const latest = previousResults.filter(r => r.session_id).pop();
|
const latest = previousResults.filter(r => r.session_id).pop();
|
||||||
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
|
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||||
|
|
||||||
// Review fix - takes session from review
|
// Review fix - takes session from review
|
||||||
} else if (name === 'review-fix') {
|
} else if (name === 'review-fix') {
|
||||||
const review = previousResults.find(r => r.command.includes('review'));
|
const review = previousResults.find(r => r.command.includes('review'));
|
||||||
const latest = review || previousResults.filter(r => r.session_id).pop();
|
const latest = review || previousResults.filter(r => r.session_id).pop();
|
||||||
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
|
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||||
|
|
||||||
// TDD verify - takes execution session
|
// TDD verify - takes execution session
|
||||||
} else if (name === 'tdd-verify') {
|
} else if (name === 'tdd-verify') {
|
||||||
const exec = previousResults.find(r => r.command.includes('execute'));
|
const exec = previousResults.find(r => r.command.includes('execute'));
|
||||||
if (exec?.session_id) line += ` --session="${exec.session_id}"`;
|
if (exec?.session_id) prompt += ` --session="${exec.session_id}"`;
|
||||||
|
|
||||||
// Session-based commands (test-cycle, review-session, plan-verify)
|
// Session-based commands (test-cycle, review-session, plan-verify)
|
||||||
} else if (name.includes('test') || name.includes('review') || name.includes('verify')) {
|
} else if (name.includes('test') || name.includes('review') || name.includes('verify')) {
|
||||||
const latest = previousResults.filter(r => r.session_id).pop();
|
const latest = previousResults.filter(r => r.session_id).pop();
|
||||||
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
|
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||||
}
|
}
|
||||||
|
|
||||||
return line;
|
return prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hook callback: Called when background CLI completes
|
// Hook callback: Called when background CLI completes
|
||||||
@@ -728,226 +736,68 @@ const cmd = registry.getCommand('lite-plan');
|
|||||||
// {name, command, description, argumentHint, allowedTools, filePath}
|
// {name, command, description, argumentHint, allowedTools, filePath}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Execution Examples
|
## Universal Prompt Template
|
||||||
|
|
||||||
### Simple Feature
|
### Standard Format
|
||||||
```
|
|
||||||
Goal: Add API endpoint for user profile
|
|
||||||
Scope: [api]
|
|
||||||
Complexity: simple
|
|
||||||
Constraints: []
|
|
||||||
Task Type: feature
|
|
||||||
|
|
||||||
Pipeline (with Minimum Execution Units):
|
```bash
|
||||||
需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过
|
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
|
||||||
|
|
||||||
Chain:
|
|
||||||
# Unit 1: Quick Implementation
|
|
||||||
1. /workflow:lite-plan --yes "Add API endpoint..."
|
|
||||||
2. /workflow:lite-execute --yes --in-memory
|
|
||||||
|
|
||||||
# Unit 2: Test Validation
|
|
||||||
3. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
|
||||||
4. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Complex Feature with Verification
|
### Prompt Content Template
|
||||||
|
|
||||||
```
|
```
|
||||||
Goal: Implement OAuth2 authentication system
|
Task: <task_description>
|
||||||
Scope: [auth, database, api, frontend]
|
|
||||||
Complexity: complex
|
|
||||||
Constraints: [no breaking changes]
|
|
||||||
Task Type: feature
|
|
||||||
|
|
||||||
Pipeline (with Minimum Execution Units):
|
<optional_previous_results>
|
||||||
需求 →【plan → plan-verify】→ 验证计划 → execute → 代码
|
|
||||||
→【review-session-cycle → review-fix】→ 修复代码
|
|
||||||
→【test-fix-gen → test-cycle-execute】→ 测试通过
|
|
||||||
|
|
||||||
Chain:
|
/workflow:<command> -y <command_parameters>
|
||||||
# Unit 1: Full Planning (plan + plan-verify)
|
|
||||||
1. /workflow:plan --yes "Implement OAuth2..."
|
|
||||||
2. /workflow:plan-verify --yes --session="WFS-xxx"
|
|
||||||
|
|
||||||
# Execution phase
|
|
||||||
3. /workflow:execute --yes --resume-session="WFS-xxx"
|
|
||||||
|
|
||||||
# Unit 2: Code Review (review-session-cycle + review-fix)
|
|
||||||
4. /workflow:review-session-cycle --yes --session="WFS-xxx"
|
|
||||||
5. /workflow:review-fix --yes --session="WFS-xxx"
|
|
||||||
|
|
||||||
# Unit 3: Test Validation (test-fix-gen + test-cycle-execute)
|
|
||||||
6. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
|
||||||
7. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Quick Bug Fix
|
### Template Variables
|
||||||
```
|
|
||||||
Goal: Fix login timeout issue
|
|
||||||
Scope: [auth]
|
|
||||||
Complexity: simple
|
|
||||||
Constraints: [urgent]
|
|
||||||
Task Type: bugfix
|
|
||||||
|
|
||||||
Pipeline:
|
| Variable | Description | Examples |
|
||||||
Bug报告 → lite-fix → 修复代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
|----------|-------------|----------|
|
||||||
|
| `<task_description>` | Brief task description | "Implement user authentication", "Fix memory leak" |
|
||||||
|
| `<optional_previous_results>` | Context from previous commands | "Previous results:\n- /workflow:plan: WFS-xxx" |
|
||||||
|
| `<command>` | Workflow command name | `plan`, `lite-execute`, `test-cycle-execute` |
|
||||||
|
| `-y` | Auto-confirm flag (inside prompt) | Always include for automation |
|
||||||
|
| `<command_parameters>` | Command-specific parameters | Task description, session ID, flags |
|
||||||
|
|
||||||
Chain:
|
### Command Parameter Patterns
|
||||||
1. /workflow:lite-fix --yes "Fix login timeout..."
|
|
||||||
2. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
| Command Type | Parameter Pattern | Example |
|
||||||
3. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
|--------------|------------------|---------|
|
||||||
|
| **Planning** | `"task description"` | `/workflow:plan -y "Implement OAuth2"` |
|
||||||
|
| **Execution (with plan)** | `--resume-session="WFS-xxx"` | `/workflow:execute -y --resume-session="WFS-plan-001"` |
|
||||||
|
| **Execution (standalone)** | `--in-memory` or `"task"` | `/workflow:lite-execute -y --in-memory` |
|
||||||
|
| **Session-based** | `--session="WFS-xxx"` | `/workflow:test-fix-gen -y --session="WFS-impl-001"` |
|
||||||
|
| **Fix/Debug** | `"problem description"` | `/workflow:lite-fix -y "Fix timeout bug"` |
|
||||||
|
|
||||||
|
### Complete Examples
|
||||||
|
|
||||||
|
**Planning Command**:
|
||||||
|
```bash
|
||||||
|
ccw cli -p 'Task: Implement user registration
|
||||||
|
|
||||||
|
/workflow:plan -y "Implement user registration with email validation"' --tool claude --mode write
|
||||||
```
|
```
|
||||||
|
|
||||||
### Skip Tests
|
**Execution with Context**:
|
||||||
```
|
```bash
|
||||||
Goal: Update documentation
|
ccw cli -p 'Task: Implement user registration
|
||||||
Scope: [docs]
|
|
||||||
Complexity: simple
|
|
||||||
Constraints: [skip-tests]
|
|
||||||
Task Type: feature
|
|
||||||
|
|
||||||
Pipeline:
|
Previous results:
|
||||||
需求 → lite-plan → 计划 → lite-execute → 代码
|
- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)
|
||||||
|
|
||||||
Chain:
|
/workflow:execute -y --resume-session="WFS-plan-20250124"' --tool claude --mode write
|
||||||
1. /workflow:lite-plan --yes "Update documentation..."
|
|
||||||
2. /workflow:lite-execute --yes --in-memory
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### TDD Workflow
|
**Standalone Lite Execution**:
|
||||||
```
|
```bash
|
||||||
Goal: Implement user authentication with test-first approach
|
ccw cli -p 'Task: Fix login timeout
|
||||||
Scope: [auth]
|
|
||||||
Complexity: medium
|
|
||||||
Constraints: [test-driven]
|
|
||||||
Task Type: tdd
|
|
||||||
|
|
||||||
Pipeline:
|
/workflow:lite-fix -y "Fix login timeout in auth module"' --tool claude --mode write
|
||||||
需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify → TDD验证通过
|
|
||||||
|
|
||||||
Chain:
|
|
||||||
1. /workflow:tdd-plan --yes "Implement user authentication..."
|
|
||||||
2. /workflow:execute --yes --resume-session="WFS-xxx"
|
|
||||||
3. /workflow:tdd-verify --yes --session="WFS-xxx"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Debug Workflow
|
|
||||||
```
|
|
||||||
Goal: Fix memory leak in WebSocket handler
|
|
||||||
Scope: [websocket]
|
|
||||||
Complexity: medium
|
|
||||||
Constraints: [production-issue]
|
|
||||||
Task Type: bugfix
|
|
||||||
|
|
||||||
Pipeline (快速修复):
|
|
||||||
Bug报告 → lite-fix → 修复代码 → test-cycle-execute → 测试通过
|
|
||||||
|
|
||||||
Pipeline (系统调试):
|
|
||||||
Bug报告 → debug → 调试日志 → 分析定位 → 修复
|
|
||||||
|
|
||||||
Chain:
|
|
||||||
1. /workflow:lite-fix --yes "Fix memory leak in WebSocket..."
|
|
||||||
2. /workflow:test-cycle-execute --yes --session="WFS-xxx"
|
|
||||||
|
|
||||||
OR (for hypothesis-driven debugging):
|
|
||||||
1. /workflow:debug --yes "Memory leak in WebSocket handler..."
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test Fix Workflow
|
|
||||||
```
|
|
||||||
Goal: Fix failing authentication tests
|
|
||||||
Scope: [auth, tests]
|
|
||||||
Complexity: simple
|
|
||||||
Constraints: []
|
|
||||||
Task Type: test-fix
|
|
||||||
|
|
||||||
Pipeline:
|
|
||||||
失败测试 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
|
||||||
|
|
||||||
Chain:
|
|
||||||
1. /workflow:test-fix-gen --yes "WFS-auth-impl-001"
|
|
||||||
2. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test Generation from Implementation
|
|
||||||
```
|
|
||||||
Goal: Generate comprehensive tests for completed user registration feature
|
|
||||||
Scope: [auth, tests]
|
|
||||||
Complexity: medium
|
|
||||||
Constraints: []
|
|
||||||
Task Type: test-gen
|
|
||||||
|
|
||||||
Pipeline (with Minimum Execution Units):
|
|
||||||
代码/会话 →【test-gen → execute】→ 测试通过
|
|
||||||
|
|
||||||
Chain:
|
|
||||||
# Unit: Test Generation (test-gen + execute)
|
|
||||||
1. /workflow:test-gen --yes "WFS-registration-20250124"
|
|
||||||
2. /workflow:execute --yes --session="WFS-test-registration"
|
|
||||||
|
|
||||||
Note: test-gen creates IMPL-001 (test generation) and IMPL-002 (test execution & fix)
|
|
||||||
execute runs both tasks - this is a Minimum Execution Unit
|
|
||||||
```
|
|
||||||
|
|
||||||
### Review + Fix Workflow
|
|
||||||
```
|
|
||||||
Goal: Code review of payment module
|
|
||||||
Scope: [payment]
|
|
||||||
Complexity: medium
|
|
||||||
Constraints: []
|
|
||||||
Task Type: review
|
|
||||||
|
|
||||||
Pipeline (with Minimum Execution Units):
|
|
||||||
代码 →【review-session-cycle → review-fix】→ 修复代码
|
|
||||||
→【test-fix-gen → test-cycle-execute】→ 测试通过
|
|
||||||
|
|
||||||
Chain:
|
|
||||||
# Unit 1: Code Review (review-session-cycle + review-fix)
|
|
||||||
1. /workflow:review-session-cycle --yes --session="WFS-payment-impl"
|
|
||||||
2. /workflow:review-fix --yes --session="WFS-payment-impl"
|
|
||||||
|
|
||||||
# Unit 2: Test Validation (test-fix-gen + test-cycle-execute)
|
|
||||||
3. /workflow:test-fix-gen --yes --session="WFS-payment-impl"
|
|
||||||
4. /workflow:test-cycle-execute --yes --session="WFS-test-payment-impl"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Brainstorm Workflow (Uncertain Requirements)
|
|
||||||
```
|
|
||||||
Goal: Explore solutions for real-time notification system
|
|
||||||
Scope: [notifications, architecture]
|
|
||||||
Complexity: complex
|
|
||||||
Constraints: []
|
|
||||||
Task Type: brainstorm
|
|
||||||
|
|
||||||
Pipeline:
|
|
||||||
探索主题 → brainstorm:auto-parallel → 分析结果 → plan → 详细计划
|
|
||||||
→ plan-verify → 验证计划 → execute → 代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
|
||||||
|
|
||||||
Chain:
|
|
||||||
1. /workflow:brainstorm:auto-parallel --yes "Explore solutions for real-time..."
|
|
||||||
2. /workflow:plan --yes "Implement chosen notification approach..."
|
|
||||||
3. /workflow:plan-verify --yes --session="WFS-xxx"
|
|
||||||
4. /workflow:execute --yes --resume-session="WFS-xxx"
|
|
||||||
5. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
|
||||||
6. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multi-CLI Plan (Multi-Perspective Analysis)
|
|
||||||
```
|
|
||||||
Goal: Compare microservices vs monolith architecture
|
|
||||||
Scope: [architecture]
|
|
||||||
Complexity: complex
|
|
||||||
Constraints: []
|
|
||||||
Task Type: multi-cli
|
|
||||||
|
|
||||||
Pipeline:
|
|
||||||
需求 → multi-cli-plan → 对比计划 → lite-execute → 代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
|
||||||
|
|
||||||
Chain:
|
|
||||||
1. /workflow:multi-cli-plan --yes "Compare microservices vs monolith..."
|
|
||||||
2. /workflow:lite-execute --yes --in-memory
|
|
||||||
3. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
|
||||||
4. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Execution Flow
|
## Execution Flow
|
||||||
@@ -983,11 +833,49 @@ async function ccwCoordinator(taskDescription) {
|
|||||||
|
|
||||||
## CLI Execution Model
|
## CLI Execution Model
|
||||||
|
|
||||||
**Serial Blocking**: Commands execute one-by-one. After launching CLI in background, orchestrator stops immediately and waits for hook callback.
|
### CLI Invocation Format
|
||||||
|
|
||||||
|
**IMPORTANT**: The `ccw cli` command executes prompts through external tools. The format is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Parameters**:
|
||||||
|
- `-p "PROMPT_CONTENT"`: The prompt content to execute (required)
|
||||||
|
- `--tool <tool>`: CLI tool to use (e.g., `claude`, `gemini`, `qwen`)
|
||||||
|
- `--mode <mode>`: Execution mode (`analysis` or `write`)
|
||||||
|
|
||||||
|
**Note**: `-y` is a **command parameter inside the prompt**, NOT a `ccw cli` parameter.
|
||||||
|
|
||||||
|
### Prompt Assembly
|
||||||
|
|
||||||
|
The prompt content should include the workflow command with `-y` for auto-confirm:
|
||||||
|
|
||||||
|
```
|
||||||
|
/workflow:<command> -y "<task description or parameters>"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Examples**:
|
||||||
|
```bash
|
||||||
|
# Planning command
|
||||||
|
ccw cli -p "/workflow:plan -y \"Implement user registration feature\"" --tool claude --mode write
|
||||||
|
|
||||||
|
# Execution command (with session reference)
|
||||||
|
ccw cli -p "/workflow:execute -y --resume-session=\"WFS-plan-20250124\"" --tool claude --mode write
|
||||||
|
|
||||||
|
# Lite execution (in-memory from previous plan)
|
||||||
|
ccw cli -p "/workflow:lite-execute -y --in-memory" --tool claude --mode write
|
||||||
|
```
|
||||||
|
|
||||||
|
### Serial Blocking
|
||||||
|
|
||||||
|
Commands execute one-by-one. After launching CLI in background, orchestrator stops immediately and waits for hook callback.
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
// Example: Execute command and stop
|
// Example: Execute command and stop
|
||||||
const taskId = Bash(`ccw cli -p "..." --tool claude --mode write -y`, { run_in_background: true }).task_id;
|
const prompt = '/workflow:plan -y "Implement user authentication"';
|
||||||
|
const taskId = Bash(`ccw cli -p "${prompt}" --tool claude --mode write`, { run_in_background: true }).task_id;
|
||||||
state.execution_results.push({ status: 'in-progress', task_id: taskId, ... });
|
state.execution_results.push({ status: 'in-progress', task_id: taskId, ... });
|
||||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||||
break; // Stop, wait for hook callback
|
break; // Stop, wait for hook callback
|
||||||
@@ -1023,20 +911,20 @@ All from `~/.claude/commands/workflow/`:
|
|||||||
- **test-gen → execute**: 生成全面的测试套件,execute 执行生成和测试
|
- **test-gen → execute**: 生成全面的测试套件,execute 执行生成和测试
|
||||||
- **test-fix-gen → test-cycle-execute**: 针对特定问题生成修复任务,test-cycle-execute 迭代测试和修复直到通过
|
- **test-fix-gen → test-cycle-execute**: 针对特定问题生成修复任务,test-cycle-execute 迭代测试和修复直到通过
|
||||||
|
|
||||||
### Task Type Routing (Pipeline View)
|
### Task Type Routing (Pipeline Summary)
|
||||||
|
|
||||||
**Note**: `【 】` marks Minimum Execution Units (最小执行单元) - these commands must execute together.
|
**Note**: `【 】` marks Minimum Execution Units (最小执行单元) - these commands must execute together.
|
||||||
|
|
||||||
| Task Type | Pipeline |
|
| Task Type | Pipeline | Minimum Units |
|
||||||
|-----------|----------|
|
|-----------|----------|---|
|
||||||
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Quick Implementation + Test Validation |
|
||||||
| **feature** (complex) | 需求 →【plan → plan-verify】→ 验证计划 → execute → 代码 →【review-session-cycle → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
| **feature** (complex) | 需求 →【plan → plan-verify】→ validate → execute → 代码 → review → fix | Full Planning + Code Review + Testing |
|
||||||
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Bug Fix + Test Validation |
|
||||||
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify → TDD验证通过 |
|
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify | TDD Planning + Execution |
|
||||||
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Test Validation |
|
||||||
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 |
|
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 | Test Generation + Execution |
|
||||||
| **review** | 代码 →【review-session-cycle/review-module-cycle → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
| **review** | 代码 →【review-* → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Code Review + Testing |
|
||||||
| **brainstorm** | 探索主题 → brainstorm:auto-parallel → 分析结果 →【plan → plan-verify】→ 验证计划 → execute → 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
| **brainstorm** | 探索主题 → brainstorm → 分析 →【plan → plan-verify】→ execute → test | Exploration + Planning + Execution |
|
||||||
| **multi-cli** | 需求 → multi-cli-plan → 对比计划 → lite-execute → 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
| **multi-cli** | 需求 → multi-cli-plan → 对比分析 → lite-execute → test | Multi-Perspective + Testing |
|
||||||
|
|
||||||
Use `CommandRegistry.getAllCommandsSummary()` to discover all commands dynamically.
|
Use `CommandRegistry.getAllCommandsSummary()` to discover all commands dynamically.
|
||||||
|
|||||||
28
codex-lens/build/lib/codexlens/__init__.py
Normal file
28
codex-lens/build/lib/codexlens/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""CodexLens package."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from . import config, entities, errors
|
||||||
|
from .config import Config
|
||||||
|
from .entities import IndexedFile, SearchResult, SemanticChunk, Symbol
|
||||||
|
from .errors import CodexLensError, ConfigError, ParseError, SearchError, StorageError
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"config",
|
||||||
|
"entities",
|
||||||
|
"errors",
|
||||||
|
"Config",
|
||||||
|
"IndexedFile",
|
||||||
|
"SearchResult",
|
||||||
|
"SemanticChunk",
|
||||||
|
"Symbol",
|
||||||
|
"CodexLensError",
|
||||||
|
"ConfigError",
|
||||||
|
"ParseError",
|
||||||
|
"StorageError",
|
||||||
|
"SearchError",
|
||||||
|
]
|
||||||
|
|
||||||
14
codex-lens/build/lib/codexlens/__main__.py
Normal file
14
codex-lens/build/lib/codexlens/__main__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Module entrypoint for `python -m codexlens`."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from codexlens.cli import app
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
88
codex-lens/build/lib/codexlens/api/__init__.py
Normal file
88
codex-lens/build/lib/codexlens/api/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Codexlens Public API Layer.
|
||||||
|
|
||||||
|
This module exports all public API functions and dataclasses for the
|
||||||
|
codexlens LSP-like functionality.
|
||||||
|
|
||||||
|
Dataclasses (from models.py):
|
||||||
|
- CallInfo: Call relationship information
|
||||||
|
- MethodContext: Method context with call relationships
|
||||||
|
- FileContextResult: File context result with method summaries
|
||||||
|
- DefinitionResult: Definition lookup result
|
||||||
|
- ReferenceResult: Reference lookup result
|
||||||
|
- GroupedReferences: References grouped by definition
|
||||||
|
- SymbolInfo: Symbol information for workspace search
|
||||||
|
- HoverInfo: Hover information for a symbol
|
||||||
|
- SemanticResult: Semantic search result
|
||||||
|
|
||||||
|
Utility functions (from utils.py):
|
||||||
|
- resolve_project: Resolve and validate project root path
|
||||||
|
- normalize_relationship_type: Normalize relationship type to canonical form
|
||||||
|
- rank_by_proximity: Rank results by file path proximity
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from codexlens.api import (
|
||||||
|
... DefinitionResult,
|
||||||
|
... resolve_project,
|
||||||
|
... normalize_relationship_type
|
||||||
|
... )
|
||||||
|
>>> project = resolve_project("/path/to/project")
|
||||||
|
>>> rel_type = normalize_relationship_type("calls")
|
||||||
|
>>> print(rel_type)
|
||||||
|
'call'
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Dataclasses
|
||||||
|
from .models import (
|
||||||
|
CallInfo,
|
||||||
|
MethodContext,
|
||||||
|
FileContextResult,
|
||||||
|
DefinitionResult,
|
||||||
|
ReferenceResult,
|
||||||
|
GroupedReferences,
|
||||||
|
SymbolInfo,
|
||||||
|
HoverInfo,
|
||||||
|
SemanticResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Utility functions
|
||||||
|
from .utils import (
|
||||||
|
resolve_project,
|
||||||
|
normalize_relationship_type,
|
||||||
|
rank_by_proximity,
|
||||||
|
rank_by_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# API functions
|
||||||
|
from .definition import find_definition
|
||||||
|
from .symbols import workspace_symbols
|
||||||
|
from .hover import get_hover
|
||||||
|
from .file_context import file_context
|
||||||
|
from .references import find_references
|
||||||
|
from .semantic import semantic_search
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Dataclasses
|
||||||
|
"CallInfo",
|
||||||
|
"MethodContext",
|
||||||
|
"FileContextResult",
|
||||||
|
"DefinitionResult",
|
||||||
|
"ReferenceResult",
|
||||||
|
"GroupedReferences",
|
||||||
|
"SymbolInfo",
|
||||||
|
"HoverInfo",
|
||||||
|
"SemanticResult",
|
||||||
|
# Utility functions
|
||||||
|
"resolve_project",
|
||||||
|
"normalize_relationship_type",
|
||||||
|
"rank_by_proximity",
|
||||||
|
"rank_by_score",
|
||||||
|
# API functions
|
||||||
|
"find_definition",
|
||||||
|
"workspace_symbols",
|
||||||
|
"get_hover",
|
||||||
|
"file_context",
|
||||||
|
"find_references",
|
||||||
|
"semantic_search",
|
||||||
|
]
|
||||||
126
codex-lens/build/lib/codexlens/api/definition.py
Normal file
126
codex-lens/build/lib/codexlens/api/definition.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""find_definition API implementation.
|
||||||
|
|
||||||
|
This module provides the find_definition() function for looking up
|
||||||
|
symbol definitions with a 3-stage fallback strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ..entities import Symbol
|
||||||
|
from ..storage.global_index import GlobalSymbolIndex
|
||||||
|
from ..storage.registry import RegistryStore
|
||||||
|
from ..errors import IndexNotFoundError
|
||||||
|
from .models import DefinitionResult
|
||||||
|
from .utils import resolve_project, rank_by_proximity
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def find_definition(
|
||||||
|
project_root: str,
|
||||||
|
symbol_name: str,
|
||||||
|
symbol_kind: Optional[str] = None,
|
||||||
|
file_context: Optional[str] = None,
|
||||||
|
limit: int = 10
|
||||||
|
) -> List[DefinitionResult]:
|
||||||
|
"""Find definition locations for a symbol.
|
||||||
|
|
||||||
|
Uses a 3-stage fallback strategy:
|
||||||
|
1. Exact match with kind filter
|
||||||
|
2. Exact match without kind filter
|
||||||
|
3. Prefix match
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Project root directory (for index location)
|
||||||
|
symbol_name: Name of the symbol to find
|
||||||
|
symbol_kind: Optional symbol kind filter (class, function, etc.)
|
||||||
|
file_context: Optional file path for proximity ranking
|
||||||
|
limit: Maximum number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DefinitionResult sorted by proximity if file_context provided
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexNotFoundError: If project is not indexed
|
||||||
|
"""
|
||||||
|
project_path = resolve_project(project_root)
|
||||||
|
|
||||||
|
# Get project info from registry
|
||||||
|
registry = RegistryStore()
|
||||||
|
project_info = registry.get_project(project_path)
|
||||||
|
if project_info is None:
|
||||||
|
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||||
|
|
||||||
|
# Open global symbol index
|
||||||
|
index_db = project_info.index_root / "_global_symbols.db"
|
||||||
|
if not index_db.exists():
|
||||||
|
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||||
|
|
||||||
|
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||||
|
|
||||||
|
# Stage 1: Exact match with kind filter
|
||||||
|
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
|
||||||
|
if results:
|
||||||
|
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
|
||||||
|
return _rank_and_convert(results, file_context)
|
||||||
|
|
||||||
|
# Stage 2: Exact match without kind (if kind was specified)
|
||||||
|
if symbol_kind:
|
||||||
|
results = _search_with_kind(global_index, symbol_name, None, limit)
|
||||||
|
if results:
|
||||||
|
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
|
||||||
|
return _rank_and_convert(results, file_context)
|
||||||
|
|
||||||
|
# Stage 3: Prefix match
|
||||||
|
results = global_index.search(
|
||||||
|
name=symbol_name,
|
||||||
|
kind=None,
|
||||||
|
limit=limit,
|
||||||
|
prefix_mode=True
|
||||||
|
)
|
||||||
|
if results:
|
||||||
|
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
|
||||||
|
return _rank_and_convert(results, file_context)
|
||||||
|
|
||||||
|
logger.debug(f"No definitions found for {symbol_name}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _search_with_kind(
|
||||||
|
global_index: GlobalSymbolIndex,
|
||||||
|
symbol_name: str,
|
||||||
|
symbol_kind: Optional[str],
|
||||||
|
limit: int
|
||||||
|
) -> List[Symbol]:
|
||||||
|
"""Search for symbols with optional kind filter."""
|
||||||
|
return global_index.search(
|
||||||
|
name=symbol_name,
|
||||||
|
kind=symbol_kind,
|
||||||
|
limit=limit,
|
||||||
|
prefix_mode=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rank_and_convert(
|
||||||
|
symbols: List[Symbol],
|
||||||
|
file_context: Optional[str]
|
||||||
|
) -> List[DefinitionResult]:
|
||||||
|
"""Convert symbols to DefinitionResult and rank by proximity."""
|
||||||
|
results = [
|
||||||
|
DefinitionResult(
|
||||||
|
name=sym.name,
|
||||||
|
kind=sym.kind,
|
||||||
|
file_path=sym.file or "",
|
||||||
|
line=sym.range[0] if sym.range else 1,
|
||||||
|
end_line=sym.range[1] if sym.range else 1,
|
||||||
|
signature=None, # Could extract from file if needed
|
||||||
|
container=None, # Could extract from parent symbol
|
||||||
|
score=1.0
|
||||||
|
)
|
||||||
|
for sym in symbols
|
||||||
|
]
|
||||||
|
return rank_by_proximity(results, file_context)
|
||||||
271
codex-lens/build/lib/codexlens/api/file_context.py
Normal file
271
codex-lens/build/lib/codexlens/api/file_context.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
"""file_context API implementation.
|
||||||
|
|
||||||
|
This module provides the file_context() function for retrieving
|
||||||
|
method call graphs from a source file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from ..entities import Symbol
|
||||||
|
from ..storage.global_index import GlobalSymbolIndex
|
||||||
|
from ..storage.dir_index import DirIndexStore
|
||||||
|
from ..storage.registry import RegistryStore
|
||||||
|
from ..errors import IndexNotFoundError
|
||||||
|
from .models import (
|
||||||
|
FileContextResult,
|
||||||
|
MethodContext,
|
||||||
|
CallInfo,
|
||||||
|
)
|
||||||
|
from .utils import resolve_project, normalize_relationship_type
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def file_context(
|
||||||
|
project_root: str,
|
||||||
|
file_path: str,
|
||||||
|
include_calls: bool = True,
|
||||||
|
include_callers: bool = True,
|
||||||
|
max_depth: int = 1,
|
||||||
|
format: str = "brief"
|
||||||
|
) -> FileContextResult:
|
||||||
|
"""Get method call context for a code file.
|
||||||
|
|
||||||
|
Retrieves all methods/functions in the file along with their
|
||||||
|
outgoing calls and incoming callers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Project root directory (for index location)
|
||||||
|
file_path: Path to the code file to analyze
|
||||||
|
include_calls: Whether to include outgoing calls
|
||||||
|
include_callers: Whether to include incoming callers
|
||||||
|
max_depth: Call chain depth (V1 only supports 1)
|
||||||
|
format: Output format (brief | detailed | tree)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FileContextResult with method contexts and summary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexNotFoundError: If project is not indexed
|
||||||
|
FileNotFoundError: If file does not exist
|
||||||
|
ValueError: If max_depth > 1 (V1 limitation)
|
||||||
|
"""
|
||||||
|
# V1 limitation: only depth=1 supported
|
||||||
|
if max_depth > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_depth > 1 not supported in V1. "
|
||||||
|
f"Requested: {max_depth}, supported: 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
project_path = resolve_project(project_root)
|
||||||
|
file_path_resolved = Path(file_path).resolve()
|
||||||
|
|
||||||
|
# Validate file exists
|
||||||
|
if not file_path_resolved.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path_resolved}")
|
||||||
|
|
||||||
|
# Get project info from registry
|
||||||
|
registry = RegistryStore()
|
||||||
|
project_info = registry.get_project(project_path)
|
||||||
|
if project_info is None:
|
||||||
|
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||||
|
|
||||||
|
# Open global symbol index
|
||||||
|
index_db = project_info.index_root / "_global_symbols.db"
|
||||||
|
if not index_db.exists():
|
||||||
|
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||||
|
|
||||||
|
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||||
|
|
||||||
|
# Get all symbols in the file
|
||||||
|
symbols = global_index.get_file_symbols(str(file_path_resolved))
|
||||||
|
|
||||||
|
# Filter to functions, methods, and classes
|
||||||
|
method_symbols = [
|
||||||
|
s for s in symbols
|
||||||
|
if s.kind in ("function", "method", "class")
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
|
||||||
|
|
||||||
|
# Try to find dir_index for relationship queries
|
||||||
|
dir_index = _find_dir_index(project_info, file_path_resolved)
|
||||||
|
|
||||||
|
# Build method contexts
|
||||||
|
methods: List[MethodContext] = []
|
||||||
|
outgoing_resolved = True
|
||||||
|
incoming_resolved = True
|
||||||
|
targets_resolved = True
|
||||||
|
|
||||||
|
for symbol in method_symbols:
|
||||||
|
calls: List[CallInfo] = []
|
||||||
|
callers: List[CallInfo] = []
|
||||||
|
|
||||||
|
if include_calls and dir_index:
|
||||||
|
try:
|
||||||
|
outgoing = dir_index.get_outgoing_calls(
|
||||||
|
str(file_path_resolved),
|
||||||
|
symbol.name
|
||||||
|
)
|
||||||
|
for target_name, rel_type, line, target_file in outgoing:
|
||||||
|
calls.append(CallInfo(
|
||||||
|
symbol_name=target_name,
|
||||||
|
file_path=target_file,
|
||||||
|
line=line,
|
||||||
|
relationship=normalize_relationship_type(rel_type)
|
||||||
|
))
|
||||||
|
if target_file is None:
|
||||||
|
targets_resolved = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get outgoing calls: {e}")
|
||||||
|
outgoing_resolved = False
|
||||||
|
|
||||||
|
if include_callers and dir_index:
|
||||||
|
try:
|
||||||
|
incoming = dir_index.get_incoming_calls(symbol.name)
|
||||||
|
for source_name, rel_type, line, source_file in incoming:
|
||||||
|
callers.append(CallInfo(
|
||||||
|
symbol_name=source_name,
|
||||||
|
file_path=source_file,
|
||||||
|
line=line,
|
||||||
|
relationship=normalize_relationship_type(rel_type)
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get incoming calls: {e}")
|
||||||
|
incoming_resolved = False
|
||||||
|
|
||||||
|
methods.append(MethodContext(
|
||||||
|
name=symbol.name,
|
||||||
|
kind=symbol.kind,
|
||||||
|
line_range=symbol.range if symbol.range else (1, 1),
|
||||||
|
signature=None, # Could extract from source
|
||||||
|
calls=calls,
|
||||||
|
callers=callers
|
||||||
|
))
|
||||||
|
|
||||||
|
# Detect language from file extension
|
||||||
|
language = _detect_language(file_path_resolved)
|
||||||
|
|
||||||
|
# Generate summary
|
||||||
|
summary = _generate_summary(file_path_resolved, methods, format)
|
||||||
|
|
||||||
|
return FileContextResult(
|
||||||
|
file_path=str(file_path_resolved),
|
||||||
|
language=language,
|
||||||
|
methods=methods,
|
||||||
|
summary=summary,
|
||||||
|
discovery_status={
|
||||||
|
"outgoing_resolved": outgoing_resolved,
|
||||||
|
"incoming_resolved": incoming_resolved,
|
||||||
|
"targets_resolved": targets_resolved
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
|
||||||
|
"""Find the dir_index that contains the file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_info: Project information from registry
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DirIndexStore if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Look for _index.db in file's directory or parent directories
|
||||||
|
current = file_path.parent
|
||||||
|
while current != current.parent:
|
||||||
|
index_db = current / "_index.db"
|
||||||
|
if index_db.exists():
|
||||||
|
return DirIndexStore(str(index_db))
|
||||||
|
|
||||||
|
# Also check in project's index_root
|
||||||
|
relative = current.relative_to(project_info.source_root)
|
||||||
|
index_in_cache = project_info.index_root / relative / "_index.db"
|
||||||
|
if index_in_cache.exists():
|
||||||
|
return DirIndexStore(str(index_in_cache))
|
||||||
|
|
||||||
|
current = current.parent
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to find dir_index: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_language(file_path: Path) -> str:
|
||||||
|
"""Detect programming language from file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Language name
|
||||||
|
"""
|
||||||
|
ext_map = {
|
||||||
|
".py": "python",
|
||||||
|
".js": "javascript",
|
||||||
|
".ts": "typescript",
|
||||||
|
".jsx": "javascript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
".go": "go",
|
||||||
|
".rs": "rust",
|
||||||
|
".java": "java",
|
||||||
|
".c": "c",
|
||||||
|
".cpp": "cpp",
|
||||||
|
".h": "c",
|
||||||
|
".hpp": "cpp",
|
||||||
|
}
|
||||||
|
return ext_map.get(file_path.suffix.lower(), "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_summary(
|
||||||
|
file_path: Path,
|
||||||
|
methods: List[MethodContext],
|
||||||
|
format: str
|
||||||
|
) -> str:
|
||||||
|
"""Generate human-readable summary of file context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
methods: List of method contexts
|
||||||
|
format: Output format (brief | detailed | tree)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted summary
|
||||||
|
"""
|
||||||
|
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
start, end = method.line_range
|
||||||
|
lines.append(f"### {method.name} (line {start}-{end})")
|
||||||
|
|
||||||
|
if method.calls:
|
||||||
|
calls_str = ", ".join(
|
||||||
|
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
|
||||||
|
if format == "detailed"
|
||||||
|
else c.symbol_name
|
||||||
|
for c in method.calls
|
||||||
|
)
|
||||||
|
lines.append(f"- Calls: {calls_str}")
|
||||||
|
|
||||||
|
if method.callers:
|
||||||
|
callers_str = ", ".join(
|
||||||
|
f"{c.symbol_name} ({c.file_path}:{c.line})"
|
||||||
|
if format == "detailed"
|
||||||
|
else c.symbol_name
|
||||||
|
for c in method.callers
|
||||||
|
)
|
||||||
|
lines.append(f"- Called by: {callers_str}")
|
||||||
|
|
||||||
|
if not method.calls and not method.callers:
|
||||||
|
lines.append("- (no call relationships)")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
148
codex-lens/build/lib/codexlens/api/hover.py
Normal file
148
codex-lens/build/lib/codexlens/api/hover.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""get_hover API implementation.
|
||||||
|
|
||||||
|
This module provides the get_hover() function for retrieving
|
||||||
|
detailed hover information for symbols.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..entities import Symbol
|
||||||
|
from ..storage.global_index import GlobalSymbolIndex
|
||||||
|
from ..storage.registry import RegistryStore
|
||||||
|
from ..errors import IndexNotFoundError
|
||||||
|
from .models import HoverInfo
|
||||||
|
from .utils import resolve_project
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hover(
|
||||||
|
project_root: str,
|
||||||
|
symbol_name: str,
|
||||||
|
file_path: Optional[str] = None
|
||||||
|
) -> Optional[HoverInfo]:
|
||||||
|
"""Get detailed hover information for a symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Project root directory (for index location)
|
||||||
|
symbol_name: Name of the symbol to look up
|
||||||
|
file_path: Optional file path to disambiguate when symbol
|
||||||
|
appears in multiple files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HoverInfo if symbol found, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexNotFoundError: If project is not indexed
|
||||||
|
"""
|
||||||
|
project_path = resolve_project(project_root)
|
||||||
|
|
||||||
|
# Get project info from registry
|
||||||
|
registry = RegistryStore()
|
||||||
|
project_info = registry.get_project(project_path)
|
||||||
|
if project_info is None:
|
||||||
|
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||||
|
|
||||||
|
# Open global symbol index
|
||||||
|
index_db = project_info.index_root / "_global_symbols.db"
|
||||||
|
if not index_db.exists():
|
||||||
|
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||||
|
|
||||||
|
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||||
|
|
||||||
|
# Search for the symbol
|
||||||
|
results = global_index.search(
|
||||||
|
name=symbol_name,
|
||||||
|
kind=None,
|
||||||
|
limit=50,
|
||||||
|
prefix_mode=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
logger.debug(f"No hover info found for {symbol_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If file_path provided, filter to that file
|
||||||
|
if file_path:
|
||||||
|
file_path_resolved = str(Path(file_path).resolve())
|
||||||
|
matching = [s for s in results if s.file == file_path_resolved]
|
||||||
|
if matching:
|
||||||
|
results = matching
|
||||||
|
|
||||||
|
# Take the first result
|
||||||
|
symbol = results[0]
|
||||||
|
|
||||||
|
# Build hover info
|
||||||
|
return HoverInfo(
|
||||||
|
name=symbol.name,
|
||||||
|
kind=symbol.kind,
|
||||||
|
signature=_extract_signature(symbol),
|
||||||
|
documentation=_extract_documentation(symbol),
|
||||||
|
file_path=symbol.file or "",
|
||||||
|
line_range=symbol.range if symbol.range else (1, 1),
|
||||||
|
type_info=_extract_type_info(symbol)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_signature(symbol: Symbol) -> str:
|
||||||
|
"""Extract signature from symbol.
|
||||||
|
|
||||||
|
For now, generates a basic signature based on kind and name.
|
||||||
|
In a full implementation, this would parse the actual source code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The symbol to extract signature from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signature string
|
||||||
|
"""
|
||||||
|
if symbol.kind == "function":
|
||||||
|
return f"def {symbol.name}(...)"
|
||||||
|
elif symbol.kind == "method":
|
||||||
|
return f"def {symbol.name}(self, ...)"
|
||||||
|
elif symbol.kind == "class":
|
||||||
|
return f"class {symbol.name}"
|
||||||
|
elif symbol.kind == "variable":
|
||||||
|
return symbol.name
|
||||||
|
elif symbol.kind == "constant":
|
||||||
|
return f"{symbol.name} = ..."
|
||||||
|
else:
|
||||||
|
return f"{symbol.kind} {symbol.name}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_documentation(symbol: Symbol) -> Optional[str]:
|
||||||
|
"""Extract documentation from symbol.
|
||||||
|
|
||||||
|
In a full implementation, this would parse docstrings from source.
|
||||||
|
For now, returns None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The symbol to extract documentation from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Documentation string if available, None otherwise
|
||||||
|
"""
|
||||||
|
# Would need to read source file and parse docstring
|
||||||
|
# For V1, return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_type_info(symbol: Symbol) -> Optional[str]:
|
||||||
|
"""Extract type information from symbol.
|
||||||
|
|
||||||
|
In a full implementation, this would parse type annotations.
|
||||||
|
For now, returns None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The symbol to extract type info from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type info string if available, None otherwise
|
||||||
|
"""
|
||||||
|
# Would need to parse type annotations from source
|
||||||
|
# For V1, return None
|
||||||
|
return None
|
||||||
281
codex-lens/build/lib/codexlens/api/models.py
Normal file
281
codex-lens/build/lib/codexlens/api/models.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""API dataclass definitions for codexlens LSP API.
|
||||||
|
|
||||||
|
This module defines all result dataclasses used by the public API layer,
|
||||||
|
following the patterns established in mcp/schema.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from typing import List, Optional, Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Section 4.2: file_context dataclasses
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallInfo:
|
||||||
|
"""Call relationship information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
symbol_name: Name of the called/calling symbol
|
||||||
|
file_path: Target file path (may be None if unresolved)
|
||||||
|
line: Line number of the call
|
||||||
|
relationship: Type of relationship (call | import | inheritance)
|
||||||
|
"""
|
||||||
|
symbol_name: str
|
||||||
|
file_path: Optional[str]
|
||||||
|
line: int
|
||||||
|
relationship: str # call | import | inheritance
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary, filtering None values."""
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MethodContext:
|
||||||
|
"""Method context with call relationships.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Method/function name
|
||||||
|
kind: Symbol kind (function | method | class)
|
||||||
|
line_range: Start and end line numbers
|
||||||
|
signature: Function signature (if available)
|
||||||
|
calls: List of outgoing calls
|
||||||
|
callers: List of incoming calls
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
kind: str # function | method | class
|
||||||
|
line_range: Tuple[int, int]
|
||||||
|
signature: Optional[str]
|
||||||
|
calls: List[CallInfo] = field(default_factory=list)
|
||||||
|
callers: List[CallInfo] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary, filtering None values."""
|
||||||
|
result = {
|
||||||
|
"name": self.name,
|
||||||
|
"kind": self.kind,
|
||||||
|
"line_range": list(self.line_range),
|
||||||
|
"calls": [c.to_dict() for c in self.calls],
|
||||||
|
"callers": [c.to_dict() for c in self.callers],
|
||||||
|
}
|
||||||
|
if self.signature is not None:
|
||||||
|
result["signature"] = self.signature
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileContextResult:
|
||||||
|
"""File context result with method summaries.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
file_path: Path to the analyzed file
|
||||||
|
language: Programming language
|
||||||
|
methods: List of method contexts
|
||||||
|
summary: Human-readable summary
|
||||||
|
discovery_status: Status flags for call resolution
|
||||||
|
"""
|
||||||
|
file_path: str
|
||||||
|
language: str
|
||||||
|
methods: List[MethodContext]
|
||||||
|
summary: str
|
||||||
|
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
|
||||||
|
"outgoing_resolved": False,
|
||||||
|
"incoming_resolved": True,
|
||||||
|
"targets_resolved": False
|
||||||
|
})
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
return {
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"language": self.language,
|
||||||
|
"methods": [m.to_dict() for m in self.methods],
|
||||||
|
"summary": self.summary,
|
||||||
|
"discovery_status": self.discovery_status,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Section 4.3: find_definition dataclasses
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DefinitionResult:
|
||||||
|
"""Definition lookup result.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Symbol name
|
||||||
|
kind: Symbol kind (class, function, method, etc.)
|
||||||
|
file_path: File where symbol is defined
|
||||||
|
line: Start line number
|
||||||
|
end_line: End line number
|
||||||
|
signature: Symbol signature (if available)
|
||||||
|
container: Containing class/module (if any)
|
||||||
|
score: Match score for ranking
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
file_path: str
|
||||||
|
line: int
|
||||||
|
end_line: int
|
||||||
|
signature: Optional[str] = None
|
||||||
|
container: Optional[str] = None
|
||||||
|
score: float = 1.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary, filtering None values."""
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Section 4.4: find_references dataclasses
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReferenceResult:
|
||||||
|
"""Reference lookup result.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
file_path: File containing the reference
|
||||||
|
line: Line number
|
||||||
|
column: Column number
|
||||||
|
context_line: The line of code containing the reference
|
||||||
|
relationship: Type of reference (call | import | type_annotation | inheritance)
|
||||||
|
"""
|
||||||
|
file_path: str
|
||||||
|
line: int
|
||||||
|
column: int
|
||||||
|
context_line: str
|
||||||
|
relationship: str # call | import | type_annotation | inheritance
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GroupedReferences:
|
||||||
|
"""References grouped by definition.
|
||||||
|
|
||||||
|
Used when a symbol has multiple definitions (e.g., overloads).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
definition: The definition this group refers to
|
||||||
|
references: List of references to this definition
|
||||||
|
"""
|
||||||
|
definition: DefinitionResult
|
||||||
|
references: List[ReferenceResult] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"definition": self.definition.to_dict(),
|
||||||
|
"references": [r.to_dict() for r in self.references],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Section 4.5: workspace_symbols dataclasses
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SymbolInfo:
|
||||||
|
"""Symbol information for workspace search.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Symbol name
|
||||||
|
kind: Symbol kind
|
||||||
|
file_path: File where symbol is defined
|
||||||
|
line: Line number
|
||||||
|
container: Containing class/module (if any)
|
||||||
|
score: Match score for ranking
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
file_path: str
|
||||||
|
line: int
|
||||||
|
container: Optional[str] = None
|
||||||
|
score: float = 1.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary, filtering None values."""
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Section 4.6: get_hover dataclasses
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HoverInfo:
|
||||||
|
"""Hover information for a symbol.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Symbol name
|
||||||
|
kind: Symbol kind
|
||||||
|
signature: Symbol signature
|
||||||
|
documentation: Documentation string (if available)
|
||||||
|
file_path: File where symbol is defined
|
||||||
|
line_range: Start and end line numbers
|
||||||
|
type_info: Type information (if available)
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
signature: str
|
||||||
|
documentation: Optional[str]
|
||||||
|
file_path: str
|
||||||
|
line_range: Tuple[int, int]
|
||||||
|
type_info: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary, filtering None values."""
|
||||||
|
result = {
|
||||||
|
"name": self.name,
|
||||||
|
"kind": self.kind,
|
||||||
|
"signature": self.signature,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"line_range": list(self.line_range),
|
||||||
|
}
|
||||||
|
if self.documentation is not None:
|
||||||
|
result["documentation"] = self.documentation
|
||||||
|
if self.type_info is not None:
|
||||||
|
result["type_info"] = self.type_info
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Section 4.7: semantic_search dataclasses
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SemanticResult:
|
||||||
|
"""Semantic search result.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
symbol_name: Name of the matched symbol
|
||||||
|
kind: Symbol kind
|
||||||
|
file_path: File where symbol is defined
|
||||||
|
line: Line number
|
||||||
|
vector_score: Vector similarity score (None if not available)
|
||||||
|
structural_score: Structural match score (None if not available)
|
||||||
|
fusion_score: Combined fusion score
|
||||||
|
snippet: Code snippet
|
||||||
|
match_reason: Explanation of why this matched (optional)
|
||||||
|
"""
|
||||||
|
symbol_name: str
|
||||||
|
kind: str
|
||||||
|
file_path: str
|
||||||
|
line: int
|
||||||
|
vector_score: Optional[float]
|
||||||
|
structural_score: Optional[float]
|
||||||
|
fusion_score: float
|
||||||
|
snippet: str
|
||||||
|
match_reason: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary, filtering None values."""
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
345
codex-lens/build/lib/codexlens/api/references.py
Normal file
345
codex-lens/build/lib/codexlens/api/references.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""Find references API for codexlens.
|
||||||
|
|
||||||
|
This module implements the find_references() function that wraps
|
||||||
|
ChainSearchEngine.search_references() with grouped result structure
|
||||||
|
for multi-definition symbols.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
DefinitionResult,
|
||||||
|
ReferenceResult,
|
||||||
|
GroupedReferences,
|
||||||
|
)
|
||||||
|
from .utils import (
|
||||||
|
resolve_project,
|
||||||
|
normalize_relationship_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_line_from_file(file_path: str, line: int) -> str:
|
||||||
|
"""Read a specific line from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
line: Line number (1-based)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The line content, stripped of trailing whitespace.
|
||||||
|
Returns empty string if file cannot be read or line doesn't exist.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = Path(file_path)
|
||||||
|
if not path.exists():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
with path.open("r", encoding="utf-8", errors="replace") as f:
|
||||||
|
for i, content in enumerate(f, 1):
|
||||||
|
if i == line:
|
||||||
|
return content.rstrip()
|
||||||
|
return ""
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_to_reference_result(
|
||||||
|
raw_ref: "RawReferenceResult",
|
||||||
|
) -> ReferenceResult:
|
||||||
|
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_ref: Raw reference result from ChainSearchEngine
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API ReferenceResult with context_line and normalized relationship
|
||||||
|
"""
|
||||||
|
# Read the actual line from the file
|
||||||
|
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
|
||||||
|
|
||||||
|
# Normalize relationship type
|
||||||
|
relationship = normalize_relationship_type(raw_ref.relationship_type)
|
||||||
|
|
||||||
|
return ReferenceResult(
|
||||||
|
file_path=raw_ref.file_path,
|
||||||
|
line=raw_ref.line,
|
||||||
|
column=raw_ref.column,
|
||||||
|
context_line=context_line,
|
||||||
|
relationship=relationship,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_references(
|
||||||
|
project_root: str,
|
||||||
|
symbol_name: str,
|
||||||
|
symbol_kind: Optional[str] = None,
|
||||||
|
include_definition: bool = True,
|
||||||
|
group_by_definition: bool = True,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[GroupedReferences]:
|
||||||
|
"""Find all reference locations for a symbol.
|
||||||
|
|
||||||
|
Multi-definition case returns grouped results to resolve ambiguity.
|
||||||
|
|
||||||
|
This function wraps ChainSearchEngine.search_references() and groups
|
||||||
|
the results by definition location. Each GroupedReferences contains
|
||||||
|
a definition and all references that point to it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Project root directory path
|
||||||
|
symbol_name: Name of the symbol to find references for
|
||||||
|
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
|
||||||
|
include_definition: Whether to include the definition location
|
||||||
|
in the result (default True)
|
||||||
|
group_by_definition: Whether to group references by definition.
|
||||||
|
If False, returns a single group with all references.
|
||||||
|
(default True)
|
||||||
|
limit: Maximum number of references to return (default 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GroupedReferences. Each group contains:
|
||||||
|
- definition: The DefinitionResult for this symbol definition
|
||||||
|
- references: List of ReferenceResult pointing to this definition
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If project_root does not exist or is not a directory
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> refs = find_references("/path/to/project", "authenticate")
|
||||||
|
>>> for group in refs:
|
||||||
|
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
|
||||||
|
... for ref in group.references:
|
||||||
|
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Reference relationship types are normalized:
|
||||||
|
- 'calls' -> 'call'
|
||||||
|
- 'imports' -> 'import'
|
||||||
|
- 'inherits' -> 'inheritance'
|
||||||
|
"""
|
||||||
|
# Validate and resolve project root
|
||||||
|
project_path = resolve_project(project_root)
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from codexlens.config import Config
|
||||||
|
from codexlens.storage.registry import RegistryStore
|
||||||
|
from codexlens.storage.path_mapper import PathMapper
|
||||||
|
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||||
|
from codexlens.search.chain_search import ChainSearchEngine
|
||||||
|
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
|
||||||
|
from codexlens.entities import Symbol
|
||||||
|
|
||||||
|
# Initialize infrastructure
|
||||||
|
config = Config()
|
||||||
|
registry = RegistryStore()
|
||||||
|
mapper = PathMapper(config.index_dir)
|
||||||
|
|
||||||
|
# Create chain search engine
|
||||||
|
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Find definitions for the symbol
|
||||||
|
definitions: List[DefinitionResult] = []
|
||||||
|
|
||||||
|
if include_definition or group_by_definition:
|
||||||
|
# Search for symbol definitions
|
||||||
|
symbols = engine.search_symbols(
|
||||||
|
name=symbol_name,
|
||||||
|
source_path=project_path,
|
||||||
|
kind=symbol_kind,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert Symbol to DefinitionResult
|
||||||
|
for sym in symbols:
|
||||||
|
# Only include exact name matches for definitions
|
||||||
|
if sym.name != symbol_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Optionally filter by kind
|
||||||
|
if symbol_kind and sym.kind != symbol_kind:
|
||||||
|
continue
|
||||||
|
|
||||||
|
definitions.append(DefinitionResult(
|
||||||
|
name=sym.name,
|
||||||
|
kind=sym.kind,
|
||||||
|
file_path=sym.file or "",
|
||||||
|
line=sym.range[0] if sym.range else 1,
|
||||||
|
end_line=sym.range[1] if sym.range else 1,
|
||||||
|
signature=None, # Not available from Symbol
|
||||||
|
container=None, # Not available from Symbol
|
||||||
|
score=1.0,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Step 2: Get all references using ChainSearchEngine
|
||||||
|
raw_references = engine.search_references(
|
||||||
|
symbol_name=symbol_name,
|
||||||
|
source_path=project_path,
|
||||||
|
depth=-1,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Transform raw references to API ReferenceResult
|
||||||
|
api_references: List[ReferenceResult] = []
|
||||||
|
for raw_ref in raw_references:
|
||||||
|
api_ref = _transform_to_reference_result(raw_ref)
|
||||||
|
api_references.append(api_ref)
|
||||||
|
|
||||||
|
# Step 4: Group references by definition
|
||||||
|
if group_by_definition and definitions:
|
||||||
|
return _group_references_by_definition(
|
||||||
|
definitions=definitions,
|
||||||
|
references=api_references,
|
||||||
|
include_definition=include_definition,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Return single group with placeholder definition or first definition
|
||||||
|
if definitions:
|
||||||
|
definition = definitions[0]
|
||||||
|
else:
|
||||||
|
# Create placeholder definition when no definition found
|
||||||
|
definition = DefinitionResult(
|
||||||
|
name=symbol_name,
|
||||||
|
kind=symbol_kind or "unknown",
|
||||||
|
file_path="",
|
||||||
|
line=0,
|
||||||
|
end_line=0,
|
||||||
|
signature=None,
|
||||||
|
container=None,
|
||||||
|
score=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [GroupedReferences(
|
||||||
|
definition=definition,
|
||||||
|
references=api_references,
|
||||||
|
)]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
engine.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _group_references_by_definition(
|
||||||
|
definitions: List[DefinitionResult],
|
||||||
|
references: List[ReferenceResult],
|
||||||
|
include_definition: bool = True,
|
||||||
|
) -> List[GroupedReferences]:
|
||||||
|
"""Group references by their likely definition.
|
||||||
|
|
||||||
|
Uses file proximity heuristic to assign references to definitions.
|
||||||
|
References in the same file or directory as a definition are
|
||||||
|
assigned to that definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
definitions: List of definition locations
|
||||||
|
references: List of reference locations
|
||||||
|
include_definition: Whether to include definition in results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GroupedReferences with references assigned to definitions
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if not definitions:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(definitions) == 1:
|
||||||
|
# Single definition - all references belong to it
|
||||||
|
return [GroupedReferences(
|
||||||
|
definition=definitions[0],
|
||||||
|
references=references,
|
||||||
|
)]
|
||||||
|
|
||||||
|
# Multiple definitions - group by proximity
|
||||||
|
groups: Dict[int, List[ReferenceResult]] = {
|
||||||
|
i: [] for i in range(len(definitions))
|
||||||
|
}
|
||||||
|
|
||||||
|
for ref in references:
|
||||||
|
# Find the closest definition by file proximity
|
||||||
|
best_def_idx = 0
|
||||||
|
best_score = -1
|
||||||
|
|
||||||
|
for i, defn in enumerate(definitions):
|
||||||
|
score = _proximity_score(ref.file_path, defn.file_path)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_def_idx = i
|
||||||
|
|
||||||
|
groups[best_def_idx].append(ref)
|
||||||
|
|
||||||
|
# Build result groups
|
||||||
|
result: List[GroupedReferences] = []
|
||||||
|
for i, defn in enumerate(definitions):
|
||||||
|
# Skip definitions with no references if not including definition itself
|
||||||
|
if not include_definition and not groups[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(GroupedReferences(
|
||||||
|
definition=defn,
|
||||||
|
references=groups[i],
|
||||||
|
))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _proximity_score(ref_path: str, def_path: str) -> int:
|
||||||
|
"""Calculate proximity score between two file paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ref_path: Reference file path
|
||||||
|
def_path: Definition file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Proximity score (higher = closer):
|
||||||
|
- Same file: 1000
|
||||||
|
- Same directory: 100
|
||||||
|
- Otherwise: common path prefix length
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if not ref_path or not def_path:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Normalize paths
|
||||||
|
ref_path = os.path.normpath(ref_path)
|
||||||
|
def_path = os.path.normpath(def_path)
|
||||||
|
|
||||||
|
# Same file
|
||||||
|
if ref_path == def_path:
|
||||||
|
return 1000
|
||||||
|
|
||||||
|
ref_dir = os.path.dirname(ref_path)
|
||||||
|
def_dir = os.path.dirname(def_path)
|
||||||
|
|
||||||
|
# Same directory
|
||||||
|
if ref_dir == def_dir:
|
||||||
|
return 100
|
||||||
|
|
||||||
|
# Common path prefix
|
||||||
|
try:
|
||||||
|
common = os.path.commonpath([ref_path, def_path])
|
||||||
|
return len(common)
|
||||||
|
except ValueError:
|
||||||
|
# No common path (different drives on Windows)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for the raw reference from ChainSearchEngine
|
||||||
|
class RawReferenceResult:
|
||||||
|
"""Type stub for ChainSearchEngine.ReferenceResult.
|
||||||
|
|
||||||
|
This is only used for type hints and is replaced at runtime
|
||||||
|
by the actual import.
|
||||||
|
"""
|
||||||
|
file_path: str
|
||||||
|
line: int
|
||||||
|
column: int
|
||||||
|
context: str
|
||||||
|
relationship_type: str
|
||||||
471
codex-lens/build/lib/codexlens/api/semantic.py
Normal file
471
codex-lens/build/lib/codexlens/api/semantic.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
"""Semantic search API with RRF fusion.
|
||||||
|
|
||||||
|
This module provides the semantic_search() function for combining
|
||||||
|
vector, structural, and keyword search with configurable fusion strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from .models import SemanticResult
|
||||||
|
from .utils import resolve_project
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def semantic_search(
|
||||||
|
project_root: str,
|
||||||
|
query: str,
|
||||||
|
mode: str = "fusion",
|
||||||
|
vector_weight: float = 0.5,
|
||||||
|
structural_weight: float = 0.3,
|
||||||
|
keyword_weight: float = 0.2,
|
||||||
|
fusion_strategy: str = "rrf",
|
||||||
|
kind_filter: Optional[List[str]] = None,
|
||||||
|
limit: int = 20,
|
||||||
|
include_match_reason: bool = False,
|
||||||
|
) -> List[SemanticResult]:
|
||||||
|
"""Semantic search - combining vector and structural search.
|
||||||
|
|
||||||
|
This function provides a high-level API for semantic code search,
|
||||||
|
combining vector similarity, structural (symbol + relationships),
|
||||||
|
and keyword-based search methods with configurable fusion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Project root directory
|
||||||
|
query: Natural language query
|
||||||
|
mode: Search mode
|
||||||
|
- vector: Vector search only
|
||||||
|
- structural: Structural search only (symbol + relationships)
|
||||||
|
- fusion: Fusion search (default)
|
||||||
|
vector_weight: Vector search weight [0, 1] (default 0.5)
|
||||||
|
structural_weight: Structural search weight [0, 1] (default 0.3)
|
||||||
|
keyword_weight: Keyword search weight [0, 1] (default 0.2)
|
||||||
|
fusion_strategy: Fusion strategy (maps to chain_search.py)
|
||||||
|
- rrf: Reciprocal Rank Fusion (recommended, default)
|
||||||
|
- staged: Staged cascade -> staged_cascade_search
|
||||||
|
- binary: Binary rerank cascade -> binary_cascade_search
|
||||||
|
- hybrid: Hybrid cascade -> hybrid_cascade_search
|
||||||
|
kind_filter: Symbol type filter (e.g., ["function", "class"])
|
||||||
|
limit: Max return count (default 20)
|
||||||
|
include_match_reason: Generate match reason (heuristic, not LLM)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Results sorted by fusion_score
|
||||||
|
|
||||||
|
Degradation:
|
||||||
|
- No vector index: vector_score=None, uses FTS + structural search
|
||||||
|
- No relationship data: structural_score=None, vector search only
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> results = semantic_search(
|
||||||
|
... "/path/to/project",
|
||||||
|
... "authentication handler",
|
||||||
|
... mode="fusion",
|
||||||
|
... fusion_strategy="rrf"
|
||||||
|
... )
|
||||||
|
>>> for r in results:
|
||||||
|
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
|
||||||
|
"""
|
||||||
|
# Validate and resolve project path
|
||||||
|
project_path = resolve_project(project_root)
|
||||||
|
|
||||||
|
# Normalize weights to sum to 1.0
|
||||||
|
total_weight = vector_weight + structural_weight + keyword_weight
|
||||||
|
if total_weight > 0:
|
||||||
|
vector_weight = vector_weight / total_weight
|
||||||
|
structural_weight = structural_weight / total_weight
|
||||||
|
keyword_weight = keyword_weight / total_weight
|
||||||
|
else:
|
||||||
|
# Default to equal weights if all zero
|
||||||
|
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
|
||||||
|
|
||||||
|
# Initialize search infrastructure
|
||||||
|
try:
|
||||||
|
from codexlens.config import Config
|
||||||
|
from codexlens.storage.registry import RegistryStore
|
||||||
|
from codexlens.storage.path_mapper import PathMapper
|
||||||
|
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||||
|
except ImportError as exc:
|
||||||
|
logger.error("Failed to import search dependencies: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config = Config.load()
|
||||||
|
|
||||||
|
# Get or create registry and mapper
|
||||||
|
try:
|
||||||
|
registry = RegistryStore.default()
|
||||||
|
mapper = PathMapper(registry)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to initialize search infrastructure: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build search options based on mode
|
||||||
|
search_options = _build_search_options(
|
||||||
|
mode=mode,
|
||||||
|
vector_weight=vector_weight,
|
||||||
|
structural_weight=structural_weight,
|
||||||
|
keyword_weight=keyword_weight,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute search based on fusion_strategy
|
||||||
|
try:
|
||||||
|
with ChainSearchEngine(registry, mapper, config=config) as engine:
|
||||||
|
chain_result = _execute_search(
|
||||||
|
engine=engine,
|
||||||
|
query=query,
|
||||||
|
source_path=project_path,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
options=search_options,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Search execution failed: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Transform results to SemanticResult
|
||||||
|
semantic_results = _transform_results(
|
||||||
|
results=chain_result.results,
|
||||||
|
mode=mode,
|
||||||
|
vector_weight=vector_weight,
|
||||||
|
structural_weight=structural_weight,
|
||||||
|
keyword_weight=keyword_weight,
|
||||||
|
kind_filter=kind_filter,
|
||||||
|
include_match_reason=include_match_reason,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
return semantic_results[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_search_options(
|
||||||
|
mode: str,
|
||||||
|
vector_weight: float,
|
||||||
|
structural_weight: float,
|
||||||
|
keyword_weight: float,
|
||||||
|
limit: int,
|
||||||
|
) -> "SearchOptions":
|
||||||
|
"""Build SearchOptions based on mode and weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Search mode (vector, structural, fusion)
|
||||||
|
vector_weight: Vector search weight
|
||||||
|
structural_weight: Structural search weight
|
||||||
|
keyword_weight: Keyword search weight
|
||||||
|
limit: Result limit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured SearchOptions
|
||||||
|
"""
|
||||||
|
from codexlens.search.chain_search import SearchOptions
|
||||||
|
|
||||||
|
# Default options
|
||||||
|
options = SearchOptions(
|
||||||
|
total_limit=limit * 2, # Fetch extra for filtering
|
||||||
|
limit_per_dir=limit,
|
||||||
|
include_symbols=True, # Always include symbols for structural
|
||||||
|
)
|
||||||
|
|
||||||
|
if mode == "vector":
|
||||||
|
# Pure vector mode
|
||||||
|
options.hybrid_mode = True
|
||||||
|
options.enable_vector = True
|
||||||
|
options.pure_vector = True
|
||||||
|
options.enable_fuzzy = False
|
||||||
|
elif mode == "structural":
|
||||||
|
# Structural only - use FTS + symbols
|
||||||
|
options.hybrid_mode = True
|
||||||
|
options.enable_vector = False
|
||||||
|
options.enable_fuzzy = True
|
||||||
|
options.include_symbols = True
|
||||||
|
else:
|
||||||
|
# Fusion mode (default)
|
||||||
|
options.hybrid_mode = True
|
||||||
|
options.enable_vector = vector_weight > 0
|
||||||
|
options.enable_fuzzy = keyword_weight > 0
|
||||||
|
options.include_symbols = structural_weight > 0
|
||||||
|
|
||||||
|
# Set custom weights for RRF
|
||||||
|
if options.enable_vector and keyword_weight > 0:
|
||||||
|
options.hybrid_weights = {
|
||||||
|
"vector": vector_weight,
|
||||||
|
"exact": keyword_weight * 0.7,
|
||||||
|
"fuzzy": keyword_weight * 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_search(
|
||||||
|
engine: "ChainSearchEngine",
|
||||||
|
query: str,
|
||||||
|
source_path: Path,
|
||||||
|
fusion_strategy: str,
|
||||||
|
options: "SearchOptions",
|
||||||
|
limit: int,
|
||||||
|
) -> "ChainSearchResult":
|
||||||
|
"""Execute search using appropriate strategy.
|
||||||
|
|
||||||
|
Maps fusion_strategy to ChainSearchEngine methods:
|
||||||
|
- rrf: Standard hybrid search with RRF fusion
|
||||||
|
- staged: staged_cascade_search
|
||||||
|
- binary: binary_cascade_search
|
||||||
|
- hybrid: hybrid_cascade_search
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: ChainSearchEngine instance
|
||||||
|
query: Search query
|
||||||
|
source_path: Project root path
|
||||||
|
fusion_strategy: Strategy name
|
||||||
|
options: Search options
|
||||||
|
limit: Result limit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChainSearchResult from the search
|
||||||
|
"""
|
||||||
|
from codexlens.search.chain_search import ChainSearchResult
|
||||||
|
|
||||||
|
if fusion_strategy == "staged":
|
||||||
|
# Use staged cascade search (4-stage pipeline)
|
||||||
|
return engine.staged_cascade_search(
|
||||||
|
query=query,
|
||||||
|
source_path=source_path,
|
||||||
|
k=limit,
|
||||||
|
coarse_k=limit * 5,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
elif fusion_strategy == "binary":
|
||||||
|
# Use binary cascade search (binary coarse + dense fine)
|
||||||
|
return engine.binary_cascade_search(
|
||||||
|
query=query,
|
||||||
|
source_path=source_path,
|
||||||
|
k=limit,
|
||||||
|
coarse_k=limit * 5,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
elif fusion_strategy == "hybrid":
|
||||||
|
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
|
||||||
|
return engine.hybrid_cascade_search(
|
||||||
|
query=query,
|
||||||
|
source_path=source_path,
|
||||||
|
k=limit,
|
||||||
|
coarse_k=limit * 5,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default: rrf - Standard search with RRF fusion
|
||||||
|
return engine.search(
|
||||||
|
query=query,
|
||||||
|
source_path=source_path,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_results(
|
||||||
|
results: List,
|
||||||
|
mode: str,
|
||||||
|
vector_weight: float,
|
||||||
|
structural_weight: float,
|
||||||
|
keyword_weight: float,
|
||||||
|
kind_filter: Optional[List[str]],
|
||||||
|
include_match_reason: bool,
|
||||||
|
query: str,
|
||||||
|
) -> List[SemanticResult]:
|
||||||
|
"""Transform ChainSearchEngine results to SemanticResult.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of SearchResult objects
|
||||||
|
mode: Search mode
|
||||||
|
vector_weight: Vector weight used
|
||||||
|
structural_weight: Structural weight used
|
||||||
|
keyword_weight: Keyword weight used
|
||||||
|
kind_filter: Optional symbol kind filter
|
||||||
|
include_match_reason: Whether to generate match reasons
|
||||||
|
query: Original query (for match reason generation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SemanticResult objects
|
||||||
|
"""
|
||||||
|
semantic_results = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
# Extract symbol info
|
||||||
|
symbol_name = getattr(result, "symbol_name", None)
|
||||||
|
symbol_kind = getattr(result, "symbol_kind", None)
|
||||||
|
start_line = getattr(result, "start_line", None)
|
||||||
|
|
||||||
|
# Use symbol object if available
|
||||||
|
if hasattr(result, "symbol") and result.symbol:
|
||||||
|
symbol_name = symbol_name or result.symbol.name
|
||||||
|
symbol_kind = symbol_kind or result.symbol.kind
|
||||||
|
if hasattr(result.symbol, "range") and result.symbol.range:
|
||||||
|
start_line = start_line or result.symbol.range[0]
|
||||||
|
|
||||||
|
# Filter by kind if specified
|
||||||
|
if kind_filter and symbol_kind:
|
||||||
|
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine scores based on mode and metadata
|
||||||
|
metadata = getattr(result, "metadata", {}) or {}
|
||||||
|
fusion_score = result.score
|
||||||
|
|
||||||
|
# Try to extract source scores from metadata
|
||||||
|
source_scores = metadata.get("source_scores", {})
|
||||||
|
vector_score: Optional[float] = None
|
||||||
|
structural_score: Optional[float] = None
|
||||||
|
|
||||||
|
if mode == "vector":
|
||||||
|
# In pure vector mode, the main score is the vector score
|
||||||
|
vector_score = result.score
|
||||||
|
structural_score = None
|
||||||
|
elif mode == "structural":
|
||||||
|
# In structural mode, no vector score
|
||||||
|
vector_score = None
|
||||||
|
structural_score = result.score
|
||||||
|
else:
|
||||||
|
# Fusion mode - try to extract individual scores
|
||||||
|
if "vector" in source_scores:
|
||||||
|
vector_score = source_scores["vector"]
|
||||||
|
elif metadata.get("fusion_method") == "simple_weighted":
|
||||||
|
# From weighted fusion
|
||||||
|
vector_score = source_scores.get("vector")
|
||||||
|
|
||||||
|
# Structural score approximation (from exact/fuzzy FTS)
|
||||||
|
fts_scores = []
|
||||||
|
if "exact" in source_scores:
|
||||||
|
fts_scores.append(source_scores["exact"])
|
||||||
|
if "fuzzy" in source_scores:
|
||||||
|
fts_scores.append(source_scores["fuzzy"])
|
||||||
|
if "splade" in source_scores:
|
||||||
|
fts_scores.append(source_scores["splade"])
|
||||||
|
|
||||||
|
if fts_scores:
|
||||||
|
structural_score = max(fts_scores)
|
||||||
|
|
||||||
|
# Build snippet
|
||||||
|
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
|
||||||
|
if len(snippet) > 500:
|
||||||
|
snippet = snippet[:500] + "..."
|
||||||
|
|
||||||
|
# Generate match reason if requested
|
||||||
|
match_reason = None
|
||||||
|
if include_match_reason:
|
||||||
|
match_reason = _generate_match_reason(
|
||||||
|
query=query,
|
||||||
|
symbol_name=symbol_name,
|
||||||
|
symbol_kind=symbol_kind,
|
||||||
|
snippet=snippet,
|
||||||
|
vector_score=vector_score,
|
||||||
|
structural_score=structural_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
semantic_result = SemanticResult(
|
||||||
|
symbol_name=symbol_name or Path(result.path).stem,
|
||||||
|
kind=symbol_kind or "unknown",
|
||||||
|
file_path=result.path,
|
||||||
|
line=start_line or 1,
|
||||||
|
vector_score=vector_score,
|
||||||
|
structural_score=structural_score,
|
||||||
|
fusion_score=fusion_score,
|
||||||
|
snippet=snippet,
|
||||||
|
match_reason=match_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
semantic_results.append(semantic_result)
|
||||||
|
|
||||||
|
# Sort by fusion_score descending
|
||||||
|
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
|
||||||
|
|
||||||
|
return semantic_results
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_match_reason(
|
||||||
|
query: str,
|
||||||
|
symbol_name: Optional[str],
|
||||||
|
symbol_kind: Optional[str],
|
||||||
|
snippet: str,
|
||||||
|
vector_score: Optional[float],
|
||||||
|
structural_score: Optional[float],
|
||||||
|
) -> str:
|
||||||
|
"""Generate human-readable match reason heuristically.
|
||||||
|
|
||||||
|
This is a simple heuristic-based approach, not LLM-powered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Original search query
|
||||||
|
symbol_name: Symbol name if available
|
||||||
|
symbol_kind: Symbol kind if available
|
||||||
|
snippet: Code snippet
|
||||||
|
vector_score: Vector similarity score
|
||||||
|
structural_score: Structural match score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable explanation string
|
||||||
|
"""
|
||||||
|
reasons = []
|
||||||
|
|
||||||
|
# Check for direct name match
|
||||||
|
query_lower = query.lower()
|
||||||
|
query_words = set(query_lower.split())
|
||||||
|
|
||||||
|
if symbol_name:
|
||||||
|
name_lower = symbol_name.lower()
|
||||||
|
# Direct substring match
|
||||||
|
if query_lower in name_lower or name_lower in query_lower:
|
||||||
|
reasons.append(f"Symbol name '{symbol_name}' matches query")
|
||||||
|
# Word overlap
|
||||||
|
name_words = set(_split_camel_case(symbol_name).lower().split())
|
||||||
|
overlap = query_words & name_words
|
||||||
|
if overlap and not reasons:
|
||||||
|
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
|
||||||
|
|
||||||
|
# Check snippet for keyword matches
|
||||||
|
snippet_lower = snippet.lower()
|
||||||
|
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
|
||||||
|
if matching_words and len(reasons) < 2:
|
||||||
|
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
|
||||||
|
|
||||||
|
# Add score-based reasoning
|
||||||
|
if vector_score is not None and vector_score > 0.7:
|
||||||
|
reasons.append("High semantic similarity")
|
||||||
|
elif vector_score is not None and vector_score > 0.5:
|
||||||
|
reasons.append("Moderate semantic similarity")
|
||||||
|
|
||||||
|
if structural_score is not None and structural_score > 0.8:
|
||||||
|
reasons.append("Strong structural match")
|
||||||
|
|
||||||
|
# Symbol kind context
|
||||||
|
if symbol_kind and len(reasons) < 3:
|
||||||
|
reasons.append(f"Matched {symbol_kind}")
|
||||||
|
|
||||||
|
if not reasons:
|
||||||
|
reasons.append("Partial relevance based on content analysis")
|
||||||
|
|
||||||
|
return "; ".join(reasons[:3])
|
||||||
|
|
||||||
|
|
||||||
|
def _split_camel_case(name: str) -> str:
|
||||||
|
"""Split camelCase and PascalCase to words.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Symbol name in camelCase or PascalCase
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Space-separated words
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Insert space before uppercase letters
|
||||||
|
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||||
|
# Insert space before uppercase followed by lowercase
|
||||||
|
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
|
||||||
|
# Replace underscores with spaces
|
||||||
|
result = result.replace("_", " ")
|
||||||
|
|
||||||
|
return result
|
||||||
146
codex-lens/build/lib/codexlens/api/symbols.py
Normal file
146
codex-lens/build/lib/codexlens/api/symbols.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""workspace_symbols API implementation.
|
||||||
|
|
||||||
|
This module provides the workspace_symbols() function for searching
|
||||||
|
symbols across the entire workspace with prefix matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ..entities import Symbol
|
||||||
|
from ..storage.global_index import GlobalSymbolIndex
|
||||||
|
from ..storage.registry import RegistryStore
|
||||||
|
from ..errors import IndexNotFoundError
|
||||||
|
from .models import SymbolInfo
|
||||||
|
from .utils import resolve_project
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def workspace_symbols(
|
||||||
|
project_root: str,
|
||||||
|
query: str,
|
||||||
|
kind_filter: Optional[List[str]] = None,
|
||||||
|
file_pattern: Optional[str] = None,
|
||||||
|
limit: int = 50
|
||||||
|
) -> List[SymbolInfo]:
|
||||||
|
"""Search for symbols across the entire workspace.
|
||||||
|
|
||||||
|
Uses prefix matching for efficient searching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Project root directory (for index location)
|
||||||
|
query: Search query (prefix match)
|
||||||
|
kind_filter: Optional list of symbol kinds to include
|
||||||
|
(e.g., ["class", "function"])
|
||||||
|
file_pattern: Optional glob pattern to filter by file path
|
||||||
|
(e.g., "*.py", "src/**/*.ts")
|
||||||
|
limit: Maximum number of results to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SymbolInfo sorted by score
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexNotFoundError: If project is not indexed
|
||||||
|
"""
|
||||||
|
project_path = resolve_project(project_root)
|
||||||
|
|
||||||
|
# Get project info from registry
|
||||||
|
registry = RegistryStore()
|
||||||
|
project_info = registry.get_project(project_path)
|
||||||
|
if project_info is None:
|
||||||
|
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||||
|
|
||||||
|
# Open global symbol index
|
||||||
|
index_db = project_info.index_root / "_global_symbols.db"
|
||||||
|
if not index_db.exists():
|
||||||
|
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||||
|
|
||||||
|
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||||
|
|
||||||
|
# Search with prefix matching
|
||||||
|
# If kind_filter has multiple kinds, we need to search for each
|
||||||
|
all_results: List[Symbol] = []
|
||||||
|
|
||||||
|
if kind_filter and len(kind_filter) > 0:
|
||||||
|
# Search for each kind separately
|
||||||
|
for kind in kind_filter:
|
||||||
|
results = global_index.search(
|
||||||
|
name=query,
|
||||||
|
kind=kind,
|
||||||
|
limit=limit,
|
||||||
|
prefix_mode=True
|
||||||
|
)
|
||||||
|
all_results.extend(results)
|
||||||
|
else:
|
||||||
|
# Search without kind filter
|
||||||
|
all_results = global_index.search(
|
||||||
|
name=query,
|
||||||
|
kind=None,
|
||||||
|
limit=limit,
|
||||||
|
prefix_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
|
||||||
|
|
||||||
|
# Apply file pattern filter if specified
|
||||||
|
if file_pattern:
|
||||||
|
all_results = [
|
||||||
|
sym for sym in all_results
|
||||||
|
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
|
||||||
|
]
|
||||||
|
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
|
||||||
|
|
||||||
|
# Convert to SymbolInfo and sort by relevance
|
||||||
|
symbols = [
|
||||||
|
SymbolInfo(
|
||||||
|
name=sym.name,
|
||||||
|
kind=sym.kind,
|
||||||
|
file_path=sym.file or "",
|
||||||
|
line=sym.range[0] if sym.range else 1,
|
||||||
|
container=None, # Could extract from parent
|
||||||
|
score=_calculate_score(sym.name, query)
|
||||||
|
)
|
||||||
|
for sym in all_results
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sort by score (exact matches first)
|
||||||
|
symbols.sort(key=lambda s: s.score, reverse=True)
|
||||||
|
|
||||||
|
return symbols[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_score(symbol_name: str, query: str) -> float:
|
||||||
|
"""Calculate relevance score for a symbol match.
|
||||||
|
|
||||||
|
Scoring:
|
||||||
|
- Exact match: 1.0
|
||||||
|
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
|
||||||
|
- Case-insensitive match: 0.6
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_name: The matched symbol name
|
||||||
|
query: The search query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
if symbol_name == query:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
if symbol_name.lower() == query.lower():
|
||||||
|
return 0.9
|
||||||
|
|
||||||
|
if symbol_name.startswith(query):
|
||||||
|
ratio = len(query) / len(symbol_name)
|
||||||
|
return 0.8 + 0.2 * ratio
|
||||||
|
|
||||||
|
if symbol_name.lower().startswith(query.lower()):
|
||||||
|
ratio = len(query) / len(symbol_name)
|
||||||
|
return 0.6 + 0.2 * ratio
|
||||||
|
|
||||||
|
return 0.5
|
||||||
153
codex-lens/build/lib/codexlens/api/utils.py
Normal file
153
codex-lens/build/lib/codexlens/api/utils.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Utility functions for the codexlens API.
|
||||||
|
|
||||||
|
This module provides helper functions for:
|
||||||
|
- Project resolution
|
||||||
|
- Relationship type normalization
|
||||||
|
- Result ranking by proximity
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, TypeVar, Callable
|
||||||
|
|
||||||
|
from .models import DefinitionResult
|
||||||
|
|
||||||
|
|
||||||
|
# Type variable for generic ranking
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_project(project_root: str) -> Path:
|
||||||
|
"""Resolve and validate project root path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Path to project root (relative or absolute)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved absolute Path
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If path does not exist or is not a directory
|
||||||
|
"""
|
||||||
|
path = Path(project_root).resolve()
|
||||||
|
if not path.exists():
|
||||||
|
raise ValueError(f"Project root does not exist: {path}")
|
||||||
|
if not path.is_dir():
|
||||||
|
raise ValueError(f"Project root is not a directory: {path}")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
# Relationship type normalization mapping
|
||||||
|
_RELATIONSHIP_NORMALIZATION = {
|
||||||
|
# Plural to singular
|
||||||
|
"calls": "call",
|
||||||
|
"imports": "import",
|
||||||
|
"inherits": "inheritance",
|
||||||
|
"uses": "use",
|
||||||
|
# Already normalized (passthrough)
|
||||||
|
"call": "call",
|
||||||
|
"import": "import",
|
||||||
|
"inheritance": "inheritance",
|
||||||
|
"use": "use",
|
||||||
|
"type_annotation": "type_annotation",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_relationship_type(relationship: str) -> str:
|
||||||
|
"""Normalize relationship type to canonical form.
|
||||||
|
|
||||||
|
Converts plural forms and variations to standard singular forms:
|
||||||
|
- 'calls' -> 'call'
|
||||||
|
- 'imports' -> 'import'
|
||||||
|
- 'inherits' -> 'inheritance'
|
||||||
|
- 'uses' -> 'use'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relationship: Raw relationship type string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized relationship type
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> normalize_relationship_type('calls')
|
||||||
|
'call'
|
||||||
|
>>> normalize_relationship_type('inherits')
|
||||||
|
'inheritance'
|
||||||
|
>>> normalize_relationship_type('call')
|
||||||
|
'call'
|
||||||
|
"""
|
||||||
|
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
|
||||||
|
|
||||||
|
|
||||||
|
def rank_by_proximity(
|
||||||
|
results: List[DefinitionResult],
|
||||||
|
file_context: Optional[str] = None
|
||||||
|
) -> List[DefinitionResult]:
|
||||||
|
"""Rank results by file path proximity to context.
|
||||||
|
|
||||||
|
V1 Implementation: Uses path-based proximity scoring.
|
||||||
|
|
||||||
|
Scoring algorithm:
|
||||||
|
1. Same directory: highest score (100)
|
||||||
|
2. Otherwise: length of common path prefix
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of definition results to rank
|
||||||
|
file_context: Reference file path for proximity calculation.
|
||||||
|
If None, returns results unchanged.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Results sorted by proximity score (highest first)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> results = [
|
||||||
|
... DefinitionResult(name="foo", kind="function",
|
||||||
|
... file_path="/a/b/c.py", line=1, end_line=10),
|
||||||
|
... DefinitionResult(name="foo", kind="function",
|
||||||
|
... file_path="/a/x/y.py", line=1, end_line=10),
|
||||||
|
... ]
|
||||||
|
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
|
||||||
|
>>> ranked[0].file_path
|
||||||
|
'/a/b/c.py'
|
||||||
|
"""
|
||||||
|
if not file_context or not results:
|
||||||
|
return results
|
||||||
|
|
||||||
|
def proximity_score(result: DefinitionResult) -> int:
|
||||||
|
"""Calculate proximity score for a result."""
|
||||||
|
result_dir = os.path.dirname(result.file_path)
|
||||||
|
context_dir = os.path.dirname(file_context)
|
||||||
|
|
||||||
|
# Same directory gets highest score
|
||||||
|
if result_dir == context_dir:
|
||||||
|
return 100
|
||||||
|
|
||||||
|
# Otherwise, score by common path prefix length
|
||||||
|
try:
|
||||||
|
common = os.path.commonpath([result.file_path, file_context])
|
||||||
|
return len(common)
|
||||||
|
except ValueError:
|
||||||
|
# No common path (different drives on Windows)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return sorted(results, key=proximity_score, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
def rank_by_score(
|
||||||
|
results: List[T],
|
||||||
|
score_fn: Callable[[T], float],
|
||||||
|
reverse: bool = True
|
||||||
|
) -> List[T]:
|
||||||
|
"""Generic ranking function by custom score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of items to rank
|
||||||
|
score_fn: Function to extract score from item
|
||||||
|
reverse: If True, highest scores first (default)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list
|
||||||
|
"""
|
||||||
|
return sorted(results, key=score_fn, reverse=reverse)
|
||||||
27
codex-lens/build/lib/codexlens/cli/__init__.py
Normal file
27
codex-lens/build/lib/codexlens/cli/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""CLI package for CodexLens."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Force UTF-8 encoding for Windows console
|
||||||
|
# This ensures Chinese characters display correctly instead of GBK garbled text
|
||||||
|
if sys.platform == "win32":
|
||||||
|
# Set environment variable for Python I/O encoding
|
||||||
|
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
|
||||||
|
|
||||||
|
# Reconfigure stdout/stderr to use UTF-8 if possible
|
||||||
|
try:
|
||||||
|
if hasattr(sys.stdout, "reconfigure"):
|
||||||
|
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
if hasattr(sys.stderr, "reconfigure"):
|
||||||
|
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
except Exception:
|
||||||
|
# Fallback: some environments don't support reconfigure
|
||||||
|
pass
|
||||||
|
|
||||||
|
from .commands import app
|
||||||
|
|
||||||
|
__all__ = ["app"]
|
||||||
|
|
||||||
4494
codex-lens/build/lib/codexlens/cli/commands.py
Normal file
4494
codex-lens/build/lib/codexlens/cli/commands.py
Normal file
File diff suppressed because it is too large
Load Diff
2001
codex-lens/build/lib/codexlens/cli/embedding_manager.py
Normal file
2001
codex-lens/build/lib/codexlens/cli/embedding_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
1026
codex-lens/build/lib/codexlens/cli/model_manager.py
Normal file
1026
codex-lens/build/lib/codexlens/cli/model_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
135
codex-lens/build/lib/codexlens/cli/output.py
Normal file
135
codex-lens/build/lib/codexlens/cli/output.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Rich and JSON output helpers for CodexLens CLI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from dataclasses import asdict, is_dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Iterable, Mapping, Sequence
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from codexlens.entities import SearchResult, Symbol
|
||||||
|
|
||||||
|
# Force UTF-8 encoding for Windows console to properly display Chinese text
|
||||||
|
# Use force_terminal=True and legacy_windows=False to avoid GBK encoding issues
|
||||||
|
console = Console(force_terminal=True, legacy_windows=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_jsonable(value: Any) -> Any:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if hasattr(value, "model_dump"):
|
||||||
|
return value.model_dump()
|
||||||
|
if is_dataclass(value):
|
||||||
|
return asdict(value)
|
||||||
|
if isinstance(value, Path):
|
||||||
|
return str(value)
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
return {k: _to_jsonable(v) for k, v in value.items()}
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
return [_to_jsonable(v) for v in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
|
||||||
|
"""Print JSON output with optional additional fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
success: Whether the operation succeeded
|
||||||
|
result: Result data (used when success=True)
|
||||||
|
error: Error message (used when success=False)
|
||||||
|
**kwargs: Additional fields to include in the payload (e.g., code, details)
|
||||||
|
"""
|
||||||
|
payload: dict[str, Any] = {"success": success}
|
||||||
|
if success:
|
||||||
|
payload["result"] = _to_jsonable(result)
|
||||||
|
else:
|
||||||
|
payload["error"] = error or "Unknown error"
|
||||||
|
# Include additional error details if provided
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
payload[key] = _to_jsonable(value)
|
||||||
|
console.print_json(json.dumps(payload, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
def render_search_results(
|
||||||
|
results: Sequence[SearchResult], *, title: str = "Search Results", verbose: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""Render search results with optional source tags in verbose mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Search results to display
|
||||||
|
title: Table title
|
||||||
|
verbose: If True, show search source tags ([E], [F], [V]) and fusion scores
|
||||||
|
"""
|
||||||
|
table = Table(title=title, show_lines=False)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
# Verbose mode: show source tags
|
||||||
|
table.add_column("Source", style="dim", width=6, justify="center")
|
||||||
|
|
||||||
|
table.add_column("Path", style="cyan", no_wrap=True)
|
||||||
|
table.add_column("Score", style="magenta", justify="right")
|
||||||
|
table.add_column("Excerpt", style="white")
|
||||||
|
|
||||||
|
for res in results:
|
||||||
|
excerpt = res.excerpt or ""
|
||||||
|
score_str = f"{res.score:.3f}"
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
# Extract search source tag if available
|
||||||
|
source = getattr(res, "search_source", None)
|
||||||
|
source_tag = ""
|
||||||
|
if source == "exact":
|
||||||
|
source_tag = "[E]"
|
||||||
|
elif source == "fuzzy":
|
||||||
|
source_tag = "[F]"
|
||||||
|
elif source == "vector":
|
||||||
|
source_tag = "[V]"
|
||||||
|
elif source == "fusion":
|
||||||
|
source_tag = "[RRF]"
|
||||||
|
table.add_row(source_tag, res.path, score_str, excerpt)
|
||||||
|
else:
|
||||||
|
table.add_row(res.path, score_str, excerpt)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def render_symbols(symbols: Sequence[Symbol], *, title: str = "Symbols") -> None:
|
||||||
|
table = Table(title=title)
|
||||||
|
table.add_column("Name", style="green")
|
||||||
|
table.add_column("Kind", style="yellow")
|
||||||
|
table.add_column("Range", style="white", justify="right")
|
||||||
|
|
||||||
|
for sym in symbols:
|
||||||
|
start, end = sym.range
|
||||||
|
table.add_row(sym.name, sym.kind, f"{start}-{end}")
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def render_status(stats: Mapping[str, Any]) -> None:
|
||||||
|
table = Table(title="Index Status")
|
||||||
|
table.add_column("Metric", style="cyan")
|
||||||
|
table.add_column("Value", style="white")
|
||||||
|
|
||||||
|
for key, value in stats.items():
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
value_text = ", ".join(f"{k}:{v}" for k, v in value.items())
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
value_text = ", ".join(str(v) for v in value)
|
||||||
|
else:
|
||||||
|
value_text = str(value)
|
||||||
|
table.add_row(str(key), value_text)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) -> None:
|
||||||
|
header = Text.assemble(("File: ", "bold"), (path, "cyan"), (" Language: ", "bold"), (language, "green"))
|
||||||
|
console.print(header)
|
||||||
|
render_symbols(list(symbols), title="Discovered Symbols")
|
||||||
|
|
||||||
692
codex-lens/build/lib/codexlens/config.py
Normal file
692
codex-lens/build/lib/codexlens/config.py
Normal file
@@ -0,0 +1,692 @@
|
|||||||
|
"""Configuration system for CodexLens."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from .errors import ConfigError
|
||||||
|
|
||||||
|
|
||||||
|
# Workspace-local directory name
|
||||||
|
WORKSPACE_DIR_NAME = ".codexlens"
|
||||||
|
|
||||||
|
# Settings file name
|
||||||
|
SETTINGS_FILE_NAME = "settings.json"
|
||||||
|
|
||||||
|
# SPLADE index database name (centralized storage)
|
||||||
|
SPLADE_DB_NAME = "_splade.db"
|
||||||
|
|
||||||
|
# Dense vector storage names (centralized storage)
|
||||||
|
VECTORS_HNSW_NAME = "_vectors.hnsw"
|
||||||
|
VECTORS_META_DB_NAME = "_vectors_meta.db"
|
||||||
|
BINARY_VECTORS_MMAP_NAME = "_binary_vectors.mmap"
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_global_dir() -> Path:
|
||||||
|
"""Get global CodexLens data directory."""
|
||||||
|
env_override = os.getenv("CODEXLENS_DATA_DIR")
|
||||||
|
if env_override:
|
||||||
|
return Path(env_override).expanduser().resolve()
|
||||||
|
return (Path.home() / ".codexlens").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def find_workspace_root(start_path: Path) -> Optional[Path]:
|
||||||
|
"""Find the workspace root by looking for .codexlens directory.
|
||||||
|
|
||||||
|
Searches from start_path upward to find an existing .codexlens directory.
|
||||||
|
Returns None if not found.
|
||||||
|
"""
|
||||||
|
current = start_path.resolve()
|
||||||
|
|
||||||
|
# Search up to filesystem root
|
||||||
|
while current != current.parent:
|
||||||
|
workspace_dir = current / WORKSPACE_DIR_NAME
|
||||||
|
if workspace_dir.is_dir():
|
||||||
|
return current
|
||||||
|
current = current.parent
|
||||||
|
|
||||||
|
# Check root as well
|
||||||
|
workspace_dir = current / WORKSPACE_DIR_NAME
|
||||||
|
if workspace_dir.is_dir():
|
||||||
|
return current
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""Runtime configuration for CodexLens.
|
||||||
|
|
||||||
|
- data_dir: Base directory for all persistent CodexLens data.
|
||||||
|
- venv_path: Optional virtualenv used for language tooling.
|
||||||
|
- supported_languages: Language IDs and their associated file extensions.
|
||||||
|
- parsing_rules: Per-language parsing and chunking hints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
data_dir: Path = field(default_factory=_default_global_dir)
|
||||||
|
venv_path: Path = field(default_factory=lambda: _default_global_dir() / "venv")
|
||||||
|
supported_languages: Dict[str, Dict[str, Any]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# Source code languages (category: "code")
|
||||||
|
"python": {"extensions": [".py"], "tree_sitter_language": "python", "category": "code"},
|
||||||
|
"javascript": {"extensions": [".js", ".jsx"], "tree_sitter_language": "javascript", "category": "code"},
|
||||||
|
"typescript": {"extensions": [".ts", ".tsx"], "tree_sitter_language": "typescript", "category": "code"},
|
||||||
|
"java": {"extensions": [".java"], "tree_sitter_language": "java", "category": "code"},
|
||||||
|
"go": {"extensions": [".go"], "tree_sitter_language": "go", "category": "code"},
|
||||||
|
"zig": {"extensions": [".zig"], "tree_sitter_language": "zig", "category": "code"},
|
||||||
|
"objective-c": {"extensions": [".m", ".mm"], "tree_sitter_language": "objc", "category": "code"},
|
||||||
|
"c": {"extensions": [".c", ".h"], "tree_sitter_language": "c", "category": "code"},
|
||||||
|
"cpp": {"extensions": [".cc", ".cpp", ".hpp", ".cxx"], "tree_sitter_language": "cpp", "category": "code"},
|
||||||
|
"rust": {"extensions": [".rs"], "tree_sitter_language": "rust", "category": "code"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
parsing_rules: Dict[str, Dict[str, Any]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"default": {
|
||||||
|
"max_chunk_chars": 4000,
|
||||||
|
"max_chunk_lines": 200,
|
||||||
|
"overlap_lines": 20,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_enabled: bool = False
|
||||||
|
llm_tool: str = "gemini"
|
||||||
|
llm_timeout_ms: int = 300000
|
||||||
|
llm_batch_size: int = 5
|
||||||
|
|
||||||
|
# Hybrid chunker configuration
|
||||||
|
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
|
||||||
|
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
|
||||||
|
|
||||||
|
# Embedding configuration
|
||||||
|
embedding_backend: str = "fastembed" # "fastembed" (local) or "litellm" (API)
|
||||||
|
embedding_model: str = "code" # For fastembed: profile (fast/code/multilingual/balanced)
|
||||||
|
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||||
|
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||||
|
|
||||||
|
# SPLADE sparse retrieval configuration
|
||||||
|
enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead)
|
||||||
|
splade_model: str = "naver/splade-cocondenser-ensembledistil"
|
||||||
|
splade_threshold: float = 0.01 # Min weight to store in index
|
||||||
|
splade_onnx_path: Optional[str] = None # Custom ONNX model path
|
||||||
|
|
||||||
|
# FTS fallback (disabled by default, available via --use-fts)
|
||||||
|
use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled)
|
||||||
|
|
||||||
|
# Indexing/search optimizations
|
||||||
|
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||||
|
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
|
||||||
|
|
||||||
|
# Graph expansion (search-time, uses precomputed neighbors)
|
||||||
|
enable_graph_expansion: bool = False
|
||||||
|
graph_expansion_depth: int = 2
|
||||||
|
|
||||||
|
# Optional search reranking (disabled by default)
|
||||||
|
enable_reranking: bool = False
|
||||||
|
reranking_top_k: int = 50
|
||||||
|
symbol_boost_factor: float = 1.5
|
||||||
|
|
||||||
|
# Optional cross-encoder reranking (second stage; requires optional reranker deps)
|
||||||
|
enable_cross_encoder_rerank: bool = False
|
||||||
|
reranker_backend: str = "onnx"
|
||||||
|
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
reranker_top_k: int = 50
|
||||||
|
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
|
||||||
|
reranker_chunk_type_weights: Optional[Dict[str, float]] = None # Weights for chunk types: {"code": 1.0, "docstring": 0.7}
|
||||||
|
reranker_test_file_penalty: float = 0.0 # Penalty for test files (0.0-1.0, e.g., 0.2 = 20% reduction)
|
||||||
|
|
||||||
|
# Chunk stripping configuration (for semantic embedding)
|
||||||
|
chunk_strip_comments: bool = True # Strip comments from code chunks
|
||||||
|
chunk_strip_docstrings: bool = True # Strip docstrings from code chunks
|
||||||
|
|
||||||
|
# Cascade search configuration (two-stage retrieval)
|
||||||
|
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
||||||
|
cascade_coarse_k: int = 100 # Number of coarse candidates from first stage
|
||||||
|
cascade_fine_k: int = 10 # Number of final results after reranking
|
||||||
|
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
|
||||||
|
|
||||||
|
# Staged cascade search configuration (4-stage pipeline)
|
||||||
|
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
|
||||||
|
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
|
||||||
|
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
|
||||||
|
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
|
||||||
|
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
|
||||||
|
|
||||||
|
# RRF fusion configuration
|
||||||
|
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
|
||||||
|
rrf_k: int = 60 # RRF constant (default 60)
|
||||||
|
|
||||||
|
# Category-based filtering to separate code/doc results
|
||||||
|
enable_category_filter: bool = True # Enable code/doc result separation
|
||||||
|
|
||||||
|
# Multi-endpoint configuration for litellm backend
|
||||||
|
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
|
||||||
|
embedding_pool_enabled: bool = False # Enable high availability pool for embeddings
|
||||||
|
embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
|
||||||
|
embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
|
||||||
|
|
||||||
|
# Reranker multi-endpoint configuration
|
||||||
|
reranker_pool_enabled: bool = False # Enable high availability pool for reranker
|
||||||
|
reranker_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
|
||||||
|
reranker_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
|
||||||
|
|
||||||
|
# API concurrency settings
|
||||||
|
api_max_workers: int = 4 # Max concurrent API calls for embedding/reranking
|
||||||
|
api_batch_size: int = 8 # Batch size for API requests
|
||||||
|
api_batch_size_dynamic: bool = False # Enable dynamic batch size calculation
|
||||||
|
api_batch_size_utilization_factor: float = 0.8 # Use 80% of model token capacity
|
||||||
|
api_batch_size_max: int = 2048 # Absolute upper limit for batch size
|
||||||
|
chars_per_token_estimate: int = 4 # Characters per token estimation ratio
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
try:
|
||||||
|
self.data_dir = self.data_dir.expanduser().resolve()
|
||||||
|
self.venv_path = self.venv_path.expanduser().resolve()
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
except PermissionError as exc:
|
||||||
|
raise ConfigError(
|
||||||
|
f"Permission denied initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||||
|
f"[{type(exc).__name__}]: {exc}"
|
||||||
|
) from exc
|
||||||
|
except OSError as exc:
|
||||||
|
raise ConfigError(
|
||||||
|
f"Filesystem error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||||
|
f"[{type(exc).__name__}]: {exc}"
|
||||||
|
) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
raise ConfigError(
|
||||||
|
f"Unexpected error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||||
|
f"[{type(exc).__name__}]: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def cache_dir(self) -> Path:
|
||||||
|
"""Directory for transient caches."""
|
||||||
|
return self.data_dir / "cache"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def index_dir(self) -> Path:
|
||||||
|
"""Directory where index artifacts are stored."""
|
||||||
|
return self.data_dir / "index"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def db_path(self) -> Path:
|
||||||
|
"""Default SQLite index path."""
|
||||||
|
return self.index_dir / "codexlens.db"
|
||||||
|
|
||||||
|
def ensure_runtime_dirs(self) -> None:
|
||||||
|
"""Create standard runtime directories if missing."""
|
||||||
|
for directory in (self.cache_dir, self.index_dir):
|
||||||
|
try:
|
||||||
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
except PermissionError as exc:
|
||||||
|
raise ConfigError(
|
||||||
|
f"Permission denied creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||||
|
) from exc
|
||||||
|
except OSError as exc:
|
||||||
|
raise ConfigError(
|
||||||
|
f"Filesystem error creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||||
|
) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
raise ConfigError(
|
||||||
|
f"Unexpected error creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def language_for_path(self, path: str | Path) -> str | None:
|
||||||
|
"""Infer a supported language ID from a file path."""
|
||||||
|
extension = Path(path).suffix.lower()
|
||||||
|
for language_id, spec in self.supported_languages.items():
|
||||||
|
extensions: List[str] = spec.get("extensions", [])
|
||||||
|
if extension in extensions:
|
||||||
|
return language_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
def category_for_path(self, path: str | Path) -> str | None:
|
||||||
|
"""Get file category ('code' or 'doc') from a file path."""
|
||||||
|
language = self.language_for_path(path)
|
||||||
|
if language is None:
|
||||||
|
return None
|
||||||
|
spec = self.supported_languages.get(language, {})
|
||||||
|
return spec.get("category")
|
||||||
|
|
||||||
|
def rules_for_language(self, language_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get parsing rules for a specific language, falling back to defaults."""
|
||||||
|
return {**self.parsing_rules.get("default", {}), **self.parsing_rules.get(language_id, {})}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def settings_path(self) -> Path:
|
||||||
|
"""Path to the settings file."""
|
||||||
|
return self.data_dir / SETTINGS_FILE_NAME
|
||||||
|
|
||||||
|
def save_settings(self) -> None:
|
||||||
|
"""Save embedding and other settings to file."""
|
||||||
|
embedding_config = {
|
||||||
|
"backend": self.embedding_backend,
|
||||||
|
"model": self.embedding_model,
|
||||||
|
"use_gpu": self.embedding_use_gpu,
|
||||||
|
"pool_enabled": self.embedding_pool_enabled,
|
||||||
|
"strategy": self.embedding_strategy,
|
||||||
|
"cooldown": self.embedding_cooldown,
|
||||||
|
}
|
||||||
|
# Include multi-endpoint config if present
|
||||||
|
if self.embedding_endpoints:
|
||||||
|
embedding_config["endpoints"] = self.embedding_endpoints
|
||||||
|
|
||||||
|
settings = {
|
||||||
|
"embedding": embedding_config,
|
||||||
|
"llm": {
|
||||||
|
"enabled": self.llm_enabled,
|
||||||
|
"tool": self.llm_tool,
|
||||||
|
"timeout_ms": self.llm_timeout_ms,
|
||||||
|
"batch_size": self.llm_batch_size,
|
||||||
|
},
|
||||||
|
"reranker": {
|
||||||
|
"enabled": self.enable_cross_encoder_rerank,
|
||||||
|
"backend": self.reranker_backend,
|
||||||
|
"model": self.reranker_model,
|
||||||
|
"top_k": self.reranker_top_k,
|
||||||
|
"max_input_tokens": self.reranker_max_input_tokens,
|
||||||
|
"pool_enabled": self.reranker_pool_enabled,
|
||||||
|
"strategy": self.reranker_strategy,
|
||||||
|
"cooldown": self.reranker_cooldown,
|
||||||
|
},
|
||||||
|
"cascade": {
|
||||||
|
"strategy": self.cascade_strategy,
|
||||||
|
"coarse_k": self.cascade_coarse_k,
|
||||||
|
"fine_k": self.cascade_fine_k,
|
||||||
|
},
|
||||||
|
"api": {
|
||||||
|
"max_workers": self.api_max_workers,
|
||||||
|
"batch_size": self.api_batch_size,
|
||||||
|
"batch_size_dynamic": self.api_batch_size_dynamic,
|
||||||
|
"batch_size_utilization_factor": self.api_batch_size_utilization_factor,
|
||||||
|
"batch_size_max": self.api_batch_size_max,
|
||||||
|
"chars_per_token_estimate": self.chars_per_token_estimate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with open(self.settings_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(settings, f, indent=2)
|
||||||
|
|
||||||
|
def load_settings(self) -> None:
|
||||||
|
"""Load settings from file if exists."""
|
||||||
|
if not self.settings_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.settings_path, "r", encoding="utf-8") as f:
|
||||||
|
settings = json.load(f)
|
||||||
|
|
||||||
|
# Load embedding settings
|
||||||
|
embedding = settings.get("embedding", {})
|
||||||
|
if "backend" in embedding:
|
||||||
|
backend = embedding["backend"]
|
||||||
|
# Support 'api' as alias for 'litellm'
|
||||||
|
if backend == "api":
|
||||||
|
backend = "litellm"
|
||||||
|
if backend in {"fastembed", "litellm"}:
|
||||||
|
self.embedding_backend = backend
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
"Invalid embedding backend in %s: %r (expected 'fastembed' or 'litellm')",
|
||||||
|
self.settings_path,
|
||||||
|
embedding["backend"],
|
||||||
|
)
|
||||||
|
if "model" in embedding:
|
||||||
|
self.embedding_model = embedding["model"]
|
||||||
|
if "use_gpu" in embedding:
|
||||||
|
self.embedding_use_gpu = embedding["use_gpu"]
|
||||||
|
|
||||||
|
# Load multi-endpoint configuration
|
||||||
|
if "endpoints" in embedding:
|
||||||
|
self.embedding_endpoints = embedding["endpoints"]
|
||||||
|
if "pool_enabled" in embedding:
|
||||||
|
self.embedding_pool_enabled = embedding["pool_enabled"]
|
||||||
|
if "strategy" in embedding:
|
||||||
|
self.embedding_strategy = embedding["strategy"]
|
||||||
|
if "cooldown" in embedding:
|
||||||
|
self.embedding_cooldown = embedding["cooldown"]
|
||||||
|
|
||||||
|
# Load LLM settings
|
||||||
|
llm = settings.get("llm", {})
|
||||||
|
if "enabled" in llm:
|
||||||
|
self.llm_enabled = llm["enabled"]
|
||||||
|
if "tool" in llm:
|
||||||
|
self.llm_tool = llm["tool"]
|
||||||
|
if "timeout_ms" in llm:
|
||||||
|
self.llm_timeout_ms = llm["timeout_ms"]
|
||||||
|
if "batch_size" in llm:
|
||||||
|
self.llm_batch_size = llm["batch_size"]
|
||||||
|
|
||||||
|
# Load reranker settings
|
||||||
|
reranker = settings.get("reranker", {})
|
||||||
|
if "enabled" in reranker:
|
||||||
|
self.enable_cross_encoder_rerank = reranker["enabled"]
|
||||||
|
if "backend" in reranker:
|
||||||
|
backend = reranker["backend"]
|
||||||
|
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
|
||||||
|
self.reranker_backend = backend
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
"Invalid reranker backend in %s: %r (expected 'fastembed', 'onnx', 'api', 'litellm', or 'legacy')",
|
||||||
|
self.settings_path,
|
||||||
|
backend,
|
||||||
|
)
|
||||||
|
if "model" in reranker:
|
||||||
|
self.reranker_model = reranker["model"]
|
||||||
|
if "top_k" in reranker:
|
||||||
|
self.reranker_top_k = reranker["top_k"]
|
||||||
|
if "max_input_tokens" in reranker:
|
||||||
|
self.reranker_max_input_tokens = reranker["max_input_tokens"]
|
||||||
|
if "pool_enabled" in reranker:
|
||||||
|
self.reranker_pool_enabled = reranker["pool_enabled"]
|
||||||
|
if "strategy" in reranker:
|
||||||
|
self.reranker_strategy = reranker["strategy"]
|
||||||
|
if "cooldown" in reranker:
|
||||||
|
self.reranker_cooldown = reranker["cooldown"]
|
||||||
|
|
||||||
|
# Load cascade settings
|
||||||
|
cascade = settings.get("cascade", {})
|
||||||
|
if "strategy" in cascade:
|
||||||
|
strategy = cascade["strategy"]
|
||||||
|
if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}:
|
||||||
|
self.cascade_strategy = strategy
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
"Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')",
|
||||||
|
self.settings_path,
|
||||||
|
strategy,
|
||||||
|
)
|
||||||
|
if "coarse_k" in cascade:
|
||||||
|
self.cascade_coarse_k = cascade["coarse_k"]
|
||||||
|
if "fine_k" in cascade:
|
||||||
|
self.cascade_fine_k = cascade["fine_k"]
|
||||||
|
|
||||||
|
# Load API settings
|
||||||
|
api = settings.get("api", {})
|
||||||
|
if "max_workers" in api:
|
||||||
|
self.api_max_workers = api["max_workers"]
|
||||||
|
if "batch_size" in api:
|
||||||
|
self.api_batch_size = api["batch_size"]
|
||||||
|
if "batch_size_dynamic" in api:
|
||||||
|
self.api_batch_size_dynamic = api["batch_size_dynamic"]
|
||||||
|
if "batch_size_utilization_factor" in api:
|
||||||
|
self.api_batch_size_utilization_factor = api["batch_size_utilization_factor"]
|
||||||
|
if "batch_size_max" in api:
|
||||||
|
self.api_batch_size_max = api["batch_size_max"]
|
||||||
|
if "chars_per_token_estimate" in api:
|
||||||
|
self.chars_per_token_estimate = api["chars_per_token_estimate"]
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning(
|
||||||
|
"Failed to load settings from %s (%s): %s",
|
||||||
|
self.settings_path,
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply .env overrides (highest priority)
|
||||||
|
self._apply_env_overrides()
|
||||||
|
|
||||||
|
def _apply_env_overrides(self) -> None:
|
||||||
|
"""Apply environment variable overrides from .env file.
|
||||||
|
|
||||||
|
Priority: default → settings.json → .env (highest)
|
||||||
|
|
||||||
|
Supported variables (with or without CODEXLENS_ prefix):
|
||||||
|
EMBEDDING_MODEL: Override embedding model/profile
|
||||||
|
EMBEDDING_BACKEND: Override embedding backend (fastembed/litellm)
|
||||||
|
EMBEDDING_POOL_ENABLED: Enable embedding high availability pool
|
||||||
|
EMBEDDING_STRATEGY: Load balance strategy for embedding
|
||||||
|
EMBEDDING_COOLDOWN: Rate limit cooldown for embedding
|
||||||
|
RERANKER_MODEL: Override reranker model
|
||||||
|
RERANKER_BACKEND: Override reranker backend
|
||||||
|
RERANKER_ENABLED: Override reranker enabled state (true/false)
|
||||||
|
RERANKER_POOL_ENABLED: Enable reranker high availability pool
|
||||||
|
RERANKER_STRATEGY: Load balance strategy for reranker
|
||||||
|
RERANKER_COOLDOWN: Rate limit cooldown for reranker
|
||||||
|
"""
|
||||||
|
from .env_config import load_global_env
|
||||||
|
|
||||||
|
env_vars = load_global_env()
|
||||||
|
if not env_vars:
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_env(key: str) -> str | None:
|
||||||
|
"""Get env var with or without CODEXLENS_ prefix."""
|
||||||
|
# Check prefixed version first (Dashboard format), then unprefixed
|
||||||
|
return env_vars.get(f"CODEXLENS_{key}") or env_vars.get(key)
|
||||||
|
|
||||||
|
# Embedding overrides
|
||||||
|
embedding_model = get_env("EMBEDDING_MODEL")
|
||||||
|
if embedding_model:
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
log.debug("Overriding embedding_model from .env: %s", self.embedding_model)
|
||||||
|
|
||||||
|
embedding_backend = get_env("EMBEDDING_BACKEND")
|
||||||
|
if embedding_backend:
|
||||||
|
backend = embedding_backend.lower()
|
||||||
|
# Support 'api' as alias for 'litellm'
|
||||||
|
if backend == "api":
|
||||||
|
backend = "litellm"
|
||||||
|
if backend in {"fastembed", "litellm"}:
|
||||||
|
self.embedding_backend = backend
|
||||||
|
log.debug("Overriding embedding_backend from .env: %s", backend)
|
||||||
|
else:
|
||||||
|
log.warning("Invalid EMBEDDING_BACKEND in .env: %r", embedding_backend)
|
||||||
|
|
||||||
|
embedding_pool = get_env("EMBEDDING_POOL_ENABLED")
|
||||||
|
if embedding_pool:
|
||||||
|
value = embedding_pool.lower()
|
||||||
|
self.embedding_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||||
|
log.debug("Overriding embedding_pool_enabled from .env: %s", self.embedding_pool_enabled)
|
||||||
|
|
||||||
|
embedding_strategy = get_env("EMBEDDING_STRATEGY")
|
||||||
|
if embedding_strategy:
|
||||||
|
strategy = embedding_strategy.lower()
|
||||||
|
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||||
|
self.embedding_strategy = strategy
|
||||||
|
log.debug("Overriding embedding_strategy from .env: %s", strategy)
|
||||||
|
else:
|
||||||
|
log.warning("Invalid EMBEDDING_STRATEGY in .env: %r", embedding_strategy)
|
||||||
|
|
||||||
|
embedding_cooldown = get_env("EMBEDDING_COOLDOWN")
|
||||||
|
if embedding_cooldown:
|
||||||
|
try:
|
||||||
|
self.embedding_cooldown = float(embedding_cooldown)
|
||||||
|
log.debug("Overriding embedding_cooldown from .env: %s", self.embedding_cooldown)
|
||||||
|
except ValueError:
|
||||||
|
log.warning("Invalid EMBEDDING_COOLDOWN in .env: %r", embedding_cooldown)
|
||||||
|
|
||||||
|
# Reranker overrides
|
||||||
|
reranker_model = get_env("RERANKER_MODEL")
|
||||||
|
if reranker_model:
|
||||||
|
self.reranker_model = reranker_model
|
||||||
|
log.debug("Overriding reranker_model from .env: %s", self.reranker_model)
|
||||||
|
|
||||||
|
reranker_backend = get_env("RERANKER_BACKEND")
|
||||||
|
if reranker_backend:
|
||||||
|
backend = reranker_backend.lower()
|
||||||
|
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
|
||||||
|
self.reranker_backend = backend
|
||||||
|
log.debug("Overriding reranker_backend from .env: %s", backend)
|
||||||
|
else:
|
||||||
|
log.warning("Invalid RERANKER_BACKEND in .env: %r", reranker_backend)
|
||||||
|
|
||||||
|
reranker_enabled = get_env("RERANKER_ENABLED")
|
||||||
|
if reranker_enabled:
|
||||||
|
value = reranker_enabled.lower()
|
||||||
|
self.enable_cross_encoder_rerank = value in {"true", "1", "yes", "on"}
|
||||||
|
log.debug("Overriding reranker_enabled from .env: %s", self.enable_cross_encoder_rerank)
|
||||||
|
|
||||||
|
reranker_pool = get_env("RERANKER_POOL_ENABLED")
|
||||||
|
if reranker_pool:
|
||||||
|
value = reranker_pool.lower()
|
||||||
|
self.reranker_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||||
|
log.debug("Overriding reranker_pool_enabled from .env: %s", self.reranker_pool_enabled)
|
||||||
|
|
||||||
|
reranker_strategy = get_env("RERANKER_STRATEGY")
|
||||||
|
if reranker_strategy:
|
||||||
|
strategy = reranker_strategy.lower()
|
||||||
|
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||||
|
self.reranker_strategy = strategy
|
||||||
|
log.debug("Overriding reranker_strategy from .env: %s", strategy)
|
||||||
|
else:
|
||||||
|
log.warning("Invalid RERANKER_STRATEGY in .env: %r", reranker_strategy)
|
||||||
|
|
||||||
|
reranker_cooldown = get_env("RERANKER_COOLDOWN")
|
||||||
|
if reranker_cooldown:
|
||||||
|
try:
|
||||||
|
self.reranker_cooldown = float(reranker_cooldown)
|
||||||
|
log.debug("Overriding reranker_cooldown from .env: %s", self.reranker_cooldown)
|
||||||
|
except ValueError:
|
||||||
|
log.warning("Invalid RERANKER_COOLDOWN in .env: %r", reranker_cooldown)
|
||||||
|
|
||||||
|
reranker_max_tokens = get_env("RERANKER_MAX_INPUT_TOKENS")
|
||||||
|
if reranker_max_tokens:
|
||||||
|
try:
|
||||||
|
self.reranker_max_input_tokens = int(reranker_max_tokens)
|
||||||
|
log.debug("Overriding reranker_max_input_tokens from .env: %s", self.reranker_max_input_tokens)
|
||||||
|
except ValueError:
|
||||||
|
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
|
||||||
|
|
||||||
|
# Reranker tuning from environment
|
||||||
|
test_penalty = get_env("RERANKER_TEST_FILE_PENALTY")
|
||||||
|
if test_penalty:
|
||||||
|
try:
|
||||||
|
self.reranker_test_file_penalty = float(test_penalty)
|
||||||
|
log.debug("Overriding reranker_test_file_penalty from .env: %s", self.reranker_test_file_penalty)
|
||||||
|
except ValueError:
|
||||||
|
log.warning("Invalid RERANKER_TEST_FILE_PENALTY in .env: %r", test_penalty)
|
||||||
|
|
||||||
|
docstring_weight = get_env("RERANKER_DOCSTRING_WEIGHT")
|
||||||
|
if docstring_weight:
|
||||||
|
try:
|
||||||
|
weight = float(docstring_weight)
|
||||||
|
self.reranker_chunk_type_weights = {"code": 1.0, "docstring": weight}
|
||||||
|
log.debug("Overriding reranker docstring weight from .env: %s", weight)
|
||||||
|
except ValueError:
|
||||||
|
log.warning("Invalid RERANKER_DOCSTRING_WEIGHT in .env: %r", docstring_weight)
|
||||||
|
|
||||||
|
# Chunk stripping from environment
|
||||||
|
strip_comments = get_env("CHUNK_STRIP_COMMENTS")
|
||||||
|
if strip_comments:
|
||||||
|
self.chunk_strip_comments = strip_comments.lower() in ("true", "1", "yes")
|
||||||
|
log.debug("Overriding chunk_strip_comments from .env: %s", self.chunk_strip_comments)
|
||||||
|
|
||||||
|
strip_docstrings = get_env("CHUNK_STRIP_DOCSTRINGS")
|
||||||
|
if strip_docstrings:
|
||||||
|
self.chunk_strip_docstrings = strip_docstrings.lower() in ("true", "1", "yes")
|
||||||
|
log.debug("Overriding chunk_strip_docstrings from .env: %s", self.chunk_strip_docstrings)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls) -> "Config":
|
||||||
|
"""Load config with settings from file."""
|
||||||
|
config = cls()
|
||||||
|
config.load_settings()
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkspaceConfig:
|
||||||
|
"""Workspace-local configuration for CodexLens.
|
||||||
|
|
||||||
|
Stores index data in project/.codexlens/ directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_root: Path
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.workspace_root = Path(self.workspace_root).resolve()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def codexlens_dir(self) -> Path:
|
||||||
|
"""The .codexlens directory in workspace root."""
|
||||||
|
return self.workspace_root / WORKSPACE_DIR_NAME
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_path(self) -> Path:
|
||||||
|
"""SQLite index path for this workspace."""
|
||||||
|
return self.codexlens_dir / "index.db"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_dir(self) -> Path:
|
||||||
|
"""Cache directory for this workspace."""
|
||||||
|
return self.codexlens_dir / "cache"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def env_path(self) -> Path:
|
||||||
|
"""Path to workspace .env file."""
|
||||||
|
return self.codexlens_dir / ".env"
|
||||||
|
|
||||||
|
def load_env(self, *, override: bool = False) -> int:
|
||||||
|
"""Load .env file and apply to os.environ.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
override: If True, override existing environment variables
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of variables applied
|
||||||
|
"""
|
||||||
|
from .env_config import apply_workspace_env
|
||||||
|
return apply_workspace_env(self.workspace_root, override=override)
|
||||||
|
|
||||||
|
def get_api_config(self, prefix: str) -> dict:
|
||||||
|
"""Get API configuration from environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with api_key, api_base, model, etc.
|
||||||
|
"""
|
||||||
|
from .env_config import get_api_config
|
||||||
|
return get_api_config(prefix, workspace_root=self.workspace_root)
|
||||||
|
|
||||||
|
def initialize(self) -> None:
|
||||||
|
"""Create the .codexlens directory structure."""
|
||||||
|
try:
|
||||||
|
self.codexlens_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create .gitignore to exclude cache but keep index
|
||||||
|
gitignore_path = self.codexlens_dir / ".gitignore"
|
||||||
|
if not gitignore_path.exists():
|
||||||
|
gitignore_path.write_text(
|
||||||
|
"# CodexLens workspace data\n"
|
||||||
|
"cache/\n"
|
||||||
|
"*.log\n"
|
||||||
|
".env\n" # Exclude .env from git
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc
|
||||||
|
|
||||||
|
def exists(self) -> bool:
|
||||||
|
"""Check if workspace is already initialized."""
|
||||||
|
return self.codexlens_dir.is_dir() and self.db_path.exists()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_path(cls, path: Path) -> Optional["WorkspaceConfig"]:
|
||||||
|
"""Create WorkspaceConfig from a path by finding workspace root.
|
||||||
|
|
||||||
|
Returns None if no workspace found.
|
||||||
|
"""
|
||||||
|
root = find_workspace_root(path)
|
||||||
|
if root is None:
|
||||||
|
return None
|
||||||
|
return cls(workspace_root=root)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_at(cls, path: Path) -> "WorkspaceConfig":
|
||||||
|
"""Create a new workspace at the given path."""
|
||||||
|
config = cls(workspace_root=path)
|
||||||
|
config.initialize()
|
||||||
|
return config
|
||||||
128
codex-lens/build/lib/codexlens/entities.py
Normal file
128
codex-lens/build/lib/codexlens/entities.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Pydantic entity models for CodexLens."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class Symbol(BaseModel):
|
||||||
|
"""A code symbol discovered in a file."""
|
||||||
|
|
||||||
|
name: str = Field(..., min_length=1)
|
||||||
|
kind: str = Field(..., min_length=1)
|
||||||
|
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
||||||
|
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
|
||||||
|
|
||||||
|
@field_validator("range")
|
||||||
|
@classmethod
|
||||||
|
def validate_range(cls, value: Tuple[int, int]) -> Tuple[int, int]:
|
||||||
|
if len(value) != 2:
|
||||||
|
raise ValueError("range must be a (start_line, end_line) tuple")
|
||||||
|
start_line, end_line = value
|
||||||
|
if start_line < 1 or end_line < 1:
|
||||||
|
raise ValueError("range lines must be >= 1")
|
||||||
|
if end_line < start_line:
|
||||||
|
raise ValueError("end_line must be >= start_line")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticChunk(BaseModel):
|
||||||
|
"""A semantically meaningful chunk of content, optionally embedded."""
|
||||||
|
|
||||||
|
content: str = Field(..., min_length=1)
|
||||||
|
embedding: Optional[List[float]] = Field(default=None, description="Vector embedding for semantic search")
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
id: Optional[int] = Field(default=None, description="Database row ID")
|
||||||
|
file_path: Optional[str] = Field(default=None, description="Source file path")
|
||||||
|
|
||||||
|
@field_validator("embedding")
|
||||||
|
@classmethod
|
||||||
|
def validate_embedding(cls, value: Optional[List[float]]) -> Optional[List[float]]:
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
if not value:
|
||||||
|
raise ValueError("embedding cannot be empty when provided")
|
||||||
|
norm = math.sqrt(sum(x * x for x in value))
|
||||||
|
epsilon = 1e-10
|
||||||
|
if norm < epsilon:
|
||||||
|
raise ValueError("embedding cannot be a zero vector")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class IndexedFile(BaseModel):
|
||||||
|
"""An indexed source file with symbols and optional semantic chunks."""
|
||||||
|
|
||||||
|
path: str = Field(..., min_length=1)
|
||||||
|
language: str = Field(..., min_length=1)
|
||||||
|
symbols: List[Symbol] = Field(default_factory=list)
|
||||||
|
chunks: List[SemanticChunk] = Field(default_factory=list)
|
||||||
|
relationships: List["CodeRelationship"] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("path", "language")
|
||||||
|
@classmethod
|
||||||
|
def strip_and_validate_nonempty(cls, value: str) -> str:
|
||||||
|
cleaned = value.strip()
|
||||||
|
if not cleaned:
|
||||||
|
raise ValueError("value cannot be blank")
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
class RelationshipType(str, Enum):
|
||||||
|
"""Types of code relationships."""
|
||||||
|
CALL = "calls"
|
||||||
|
INHERITS = "inherits"
|
||||||
|
IMPORTS = "imports"
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRelationship(BaseModel):
|
||||||
|
"""A relationship between code symbols (e.g., function calls, inheritance)."""
|
||||||
|
|
||||||
|
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
|
||||||
|
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
|
||||||
|
relationship_type: RelationshipType = Field(..., description="Type of relationship (call, inherits, etc.)")
|
||||||
|
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
|
||||||
|
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
|
||||||
|
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
|
||||||
|
|
||||||
|
|
||||||
|
class AdditionalLocation(BaseModel):
|
||||||
|
"""A pointer to another location where a similar result was found.
|
||||||
|
|
||||||
|
Used for grouping search results with similar scores and content,
|
||||||
|
where the primary result is stored in SearchResult and secondary
|
||||||
|
locations are stored in this model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: str = Field(..., min_length=1)
|
||||||
|
score: float = Field(..., ge=0.0)
|
||||||
|
start_line: Optional[int] = Field(default=None, description="Start line of the result (1-based)")
|
||||||
|
end_line: Optional[int] = Field(default=None, description="End line of the result (1-based)")
|
||||||
|
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
"""A unified search result for lexical or semantic search."""
|
||||||
|
|
||||||
|
path: str = Field(..., min_length=1)
|
||||||
|
score: float = Field(..., ge=0.0)
|
||||||
|
excerpt: Optional[str] = None
|
||||||
|
content: Optional[str] = Field(default=None, description="Full content of matched code block")
|
||||||
|
symbol: Optional[Symbol] = None
|
||||||
|
chunk: Optional[SemanticChunk] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
# Additional context for complete code blocks
|
||||||
|
start_line: Optional[int] = Field(default=None, description="Start line of code block (1-based)")
|
||||||
|
end_line: Optional[int] = Field(default=None, description="End line of code block (1-based)")
|
||||||
|
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol/function/class")
|
||||||
|
symbol_kind: Optional[str] = Field(default=None, description="Kind of symbol (function/class/method)")
|
||||||
|
|
||||||
|
# Field for grouping similar results
|
||||||
|
additional_locations: List["AdditionalLocation"] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Other locations for grouped results with similar scores and content."
|
||||||
|
)
|
||||||
304
codex-lens/build/lib/codexlens/env_config.py
Normal file
304
codex-lens/build/lib/codexlens/env_config.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""Environment configuration loader for CodexLens.
|
||||||
|
|
||||||
|
Loads .env files from workspace .codexlens directory with fallback to project root.
|
||||||
|
Provides unified access to API configurations.
|
||||||
|
|
||||||
|
Priority order:
|
||||||
|
1. Environment variables (already set)
|
||||||
|
2. .codexlens/.env (workspace-local)
|
||||||
|
3. .env (project root)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Supported environment variables with descriptions
|
||||||
|
ENV_VARS = {
|
||||||
|
# Reranker configuration (overrides settings.json)
|
||||||
|
"RERANKER_MODEL": "Reranker model name (overrides settings.json)",
|
||||||
|
"RERANKER_BACKEND": "Reranker backend: fastembed, onnx, api, litellm, legacy",
|
||||||
|
"RERANKER_ENABLED": "Enable reranker: true/false",
|
||||||
|
"RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)",
|
||||||
|
"RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)",
|
||||||
|
"RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina",
|
||||||
|
"RERANKER_POOL_ENABLED": "Enable reranker high availability pool: true/false",
|
||||||
|
"RERANKER_STRATEGY": "Reranker load balance strategy: round_robin, latency_aware, weighted_random",
|
||||||
|
"RERANKER_COOLDOWN": "Reranker rate limit cooldown in seconds",
|
||||||
|
# Embedding configuration (overrides settings.json)
|
||||||
|
"EMBEDDING_MODEL": "Embedding model/profile name (overrides settings.json)",
|
||||||
|
"EMBEDDING_BACKEND": "Embedding backend: fastembed, litellm",
|
||||||
|
"EMBEDDING_API_KEY": "API key for embedding service",
|
||||||
|
"EMBEDDING_API_BASE": "Base URL for embedding API",
|
||||||
|
"EMBEDDING_POOL_ENABLED": "Enable embedding high availability pool: true/false",
|
||||||
|
"EMBEDDING_STRATEGY": "Embedding load balance strategy: round_robin, latency_aware, weighted_random",
|
||||||
|
"EMBEDDING_COOLDOWN": "Embedding rate limit cooldown in seconds",
|
||||||
|
# LiteLLM configuration
|
||||||
|
"LITELLM_API_KEY": "API key for LiteLLM",
|
||||||
|
"LITELLM_API_BASE": "Base URL for LiteLLM",
|
||||||
|
"LITELLM_MODEL": "LiteLLM model name",
|
||||||
|
# General configuration
|
||||||
|
"CODEXLENS_DATA_DIR": "Custom data directory path",
|
||||||
|
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
|
||||||
|
# Chunking configuration
|
||||||
|
"CHUNK_STRIP_COMMENTS": "Strip comments from code chunks for embedding: true/false (default: true)",
|
||||||
|
"CHUNK_STRIP_DOCSTRINGS": "Strip docstrings from code chunks for embedding: true/false (default: true)",
|
||||||
|
# Reranker tuning
|
||||||
|
"RERANKER_TEST_FILE_PENALTY": "Penalty for test files in reranking: 0.0-1.0 (default: 0.0)",
|
||||||
|
"RERANKER_DOCSTRING_WEIGHT": "Weight for docstring chunks in reranking: 0.0-1.0 (default: 1.0)",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_env_line(line: str) -> tuple[str, str] | None:
|
||||||
|
"""Parse a single .env line, returning (key, value) or None."""
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# Skip empty lines and comments
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle export prefix
|
||||||
|
if line.startswith("export "):
|
||||||
|
line = line[7:].strip()
|
||||||
|
|
||||||
|
# Split on first =
|
||||||
|
if "=" not in line:
|
||||||
|
return None
|
||||||
|
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Remove surrounding quotes
|
||||||
|
if len(value) >= 2:
|
||||||
|
if (value.startswith('"') and value.endswith('"')) or \
|
||||||
|
(value.startswith("'") and value.endswith("'")):
|
||||||
|
value = value[1:-1]
|
||||||
|
|
||||||
|
return key, value
|
||||||
|
|
||||||
|
|
||||||
|
def load_env_file(env_path: Path) -> Dict[str, str]:
|
||||||
|
"""Load environment variables from a .env file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_path: Path to .env file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of environment variables
|
||||||
|
"""
|
||||||
|
if not env_path.is_file():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
env_vars: Dict[str, str] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = env_path.read_text(encoding="utf-8")
|
||||||
|
for line in content.splitlines():
|
||||||
|
result = _parse_env_line(line)
|
||||||
|
if result:
|
||||||
|
key, value = result
|
||||||
|
env_vars[key] = value
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("Failed to load .env file %s: %s", env_path, exc)
|
||||||
|
|
||||||
|
return env_vars
|
||||||
|
|
||||||
|
|
||||||
|
def _get_global_data_dir() -> Path:
|
||||||
|
"""Get global CodexLens data directory."""
|
||||||
|
env_override = os.environ.get("CODEXLENS_DATA_DIR")
|
||||||
|
if env_override:
|
||||||
|
return Path(env_override).expanduser().resolve()
|
||||||
|
return (Path.home() / ".codexlens").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def load_global_env() -> Dict[str, str]:
|
||||||
|
"""Load environment variables from global ~/.codexlens/.env file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of environment variables from global config
|
||||||
|
"""
|
||||||
|
global_env_path = _get_global_data_dir() / ".env"
|
||||||
|
if global_env_path.is_file():
|
||||||
|
env_vars = load_env_file(global_env_path)
|
||||||
|
log.debug("Loaded %d vars from global %s", len(env_vars), global_env_path)
|
||||||
|
return env_vars
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]:
|
||||||
|
"""Load environment variables from workspace .env files.
|
||||||
|
|
||||||
|
Priority (later overrides earlier):
|
||||||
|
1. Global ~/.codexlens/.env (lowest priority)
|
||||||
|
2. Project root .env
|
||||||
|
3. .codexlens/.env (highest priority)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_root: Workspace root directory. If None, uses current directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged dictionary of environment variables
|
||||||
|
"""
|
||||||
|
if workspace_root is None:
|
||||||
|
workspace_root = Path.cwd()
|
||||||
|
|
||||||
|
workspace_root = Path(workspace_root).resolve()
|
||||||
|
|
||||||
|
env_vars: Dict[str, str] = {}
|
||||||
|
|
||||||
|
# Load from global ~/.codexlens/.env (lowest priority)
|
||||||
|
global_vars = load_global_env()
|
||||||
|
if global_vars:
|
||||||
|
env_vars.update(global_vars)
|
||||||
|
|
||||||
|
# Load from project root .env (medium priority)
|
||||||
|
root_env = workspace_root / ".env"
|
||||||
|
if root_env.is_file():
|
||||||
|
loaded = load_env_file(root_env)
|
||||||
|
env_vars.update(loaded)
|
||||||
|
log.debug("Loaded %d vars from %s", len(loaded), root_env)
|
||||||
|
|
||||||
|
# Load from .codexlens/.env (highest priority)
|
||||||
|
codexlens_env = workspace_root / ".codexlens" / ".env"
|
||||||
|
if codexlens_env.is_file():
|
||||||
|
loaded = load_env_file(codexlens_env)
|
||||||
|
env_vars.update(loaded)
|
||||||
|
log.debug("Loaded %d vars from %s", len(loaded), codexlens_env)
|
||||||
|
|
||||||
|
return env_vars
|
||||||
|
|
||||||
|
|
||||||
|
def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int:
|
||||||
|
"""Load .env files and apply to os.environ.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_root: Workspace root directory
|
||||||
|
override: If True, override existing environment variables
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of variables applied
|
||||||
|
"""
|
||||||
|
env_vars = load_workspace_env(workspace_root)
|
||||||
|
applied = 0
|
||||||
|
|
||||||
|
for key, value in env_vars.items():
|
||||||
|
if override or key not in os.environ:
|
||||||
|
os.environ[key] = value
|
||||||
|
applied += 1
|
||||||
|
log.debug("Applied env var: %s", key)
|
||||||
|
|
||||||
|
return applied
|
||||||
|
|
||||||
|
|
||||||
|
def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None:
|
||||||
|
"""Get environment variable with .env file fallback.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. os.environ (already set)
|
||||||
|
2. .codexlens/.env
|
||||||
|
3. .env
|
||||||
|
4. default value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Environment variable name
|
||||||
|
default: Default value if not found
|
||||||
|
workspace_root: Workspace root for .env file lookup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Value or default
|
||||||
|
"""
|
||||||
|
# Check os.environ first
|
||||||
|
if key in os.environ:
|
||||||
|
return os.environ[key]
|
||||||
|
|
||||||
|
# Load from .env files
|
||||||
|
env_vars = load_workspace_env(workspace_root)
|
||||||
|
if key in env_vars:
|
||||||
|
return env_vars[key]
|
||||||
|
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_config(
|
||||||
|
prefix: str,
|
||||||
|
*,
|
||||||
|
workspace_root: Path | None = None,
|
||||||
|
defaults: Dict[str, Any] | None = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get API configuration from environment.
|
||||||
|
|
||||||
|
Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||||
|
workspace_root: Workspace root for .env file lookup
|
||||||
|
defaults: Default values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with api_key, api_base, model, etc.
|
||||||
|
"""
|
||||||
|
defaults = defaults or {}
|
||||||
|
|
||||||
|
config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Standard API config fields
|
||||||
|
field_mapping = {
|
||||||
|
"api_key": f"{prefix}_API_KEY",
|
||||||
|
"api_base": f"{prefix}_API_BASE",
|
||||||
|
"model": f"{prefix}_MODEL",
|
||||||
|
"provider": f"{prefix}_PROVIDER",
|
||||||
|
"timeout": f"{prefix}_TIMEOUT",
|
||||||
|
}
|
||||||
|
|
||||||
|
for field, env_key in field_mapping.items():
|
||||||
|
value = get_env(env_key, workspace_root=workspace_root)
|
||||||
|
if value is not None:
|
||||||
|
# Type conversion for specific fields
|
||||||
|
if field == "timeout":
|
||||||
|
try:
|
||||||
|
config[field] = float(value)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
config[field] = value
|
||||||
|
elif field in defaults:
|
||||||
|
config[field] = defaults[field]
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def generate_env_example() -> str:
|
||||||
|
"""Generate .env.example content with all supported variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String content for .env.example file
|
||||||
|
"""
|
||||||
|
lines = [
|
||||||
|
"# CodexLens Environment Configuration",
|
||||||
|
"# Copy this file to .codexlens/.env and fill in your values",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Group by prefix
|
||||||
|
groups: Dict[str, list] = {}
|
||||||
|
for key, desc in ENV_VARS.items():
|
||||||
|
prefix = key.split("_")[0]
|
||||||
|
if prefix not in groups:
|
||||||
|
groups[prefix] = []
|
||||||
|
groups[prefix].append((key, desc))
|
||||||
|
|
||||||
|
for prefix, items in groups.items():
|
||||||
|
lines.append(f"# {prefix} Configuration")
|
||||||
|
for key, desc in items:
|
||||||
|
lines.append(f"# {desc}")
|
||||||
|
lines.append(f"# {key}=")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
59
codex-lens/build/lib/codexlens/errors.py
Normal file
59
codex-lens/build/lib/codexlens/errors.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""CodexLens exception hierarchy."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class CodexLensError(Exception):
|
||||||
|
"""Base class for all CodexLens errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigError(CodexLensError):
|
||||||
|
"""Raised when configuration is invalid or cannot be loaded."""
|
||||||
|
|
||||||
|
|
||||||
|
class ParseError(CodexLensError):
|
||||||
|
"""Raised when parsing or indexing a file fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class StorageError(CodexLensError):
|
||||||
|
"""Raised when reading/writing index storage fails.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: Human-readable error description
|
||||||
|
db_path: Path to the database file (if applicable)
|
||||||
|
operation: The operation that failed (e.g., 'query', 'initialize', 'migrate')
|
||||||
|
details: Additional context for debugging
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
db_path: str | None = None,
|
||||||
|
operation: str | None = None,
|
||||||
|
details: dict | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.db_path = db_path
|
||||||
|
self.operation = operation
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
parts = [self.message]
|
||||||
|
if self.db_path:
|
||||||
|
parts.append(f"[db: {self.db_path}]")
|
||||||
|
if self.operation:
|
||||||
|
parts.append(f"[op: {self.operation}]")
|
||||||
|
if self.details:
|
||||||
|
detail_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
|
||||||
|
parts.append(f"[{detail_str}]")
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchError(CodexLensError):
|
||||||
|
"""Raised when a search operation fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class IndexNotFoundError(CodexLensError):
|
||||||
|
"""Raised when a project's index cannot be found."""
|
||||||
|
|
||||||
28
codex-lens/build/lib/codexlens/hybrid_search/__init__.py
Normal file
28
codex-lens/build/lib/codexlens/hybrid_search/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""Hybrid Search data structures for CodexLens.
|
||||||
|
|
||||||
|
This module provides core data structures for hybrid search:
|
||||||
|
- CodeSymbolNode: Graph node representing a code symbol
|
||||||
|
- CodeAssociationGraph: Graph of code relationships
|
||||||
|
- SearchResultCluster: Clustered search results
|
||||||
|
- Range: Position range in source files
|
||||||
|
- CallHierarchyItem: LSP call hierarchy item
|
||||||
|
|
||||||
|
Note: The search engine is in codexlens.search.hybrid_search
|
||||||
|
LSP-based expansion is in codexlens.lsp module
|
||||||
|
"""
|
||||||
|
|
||||||
|
from codexlens.hybrid_search.data_structures import (
|
||||||
|
CallHierarchyItem,
|
||||||
|
CodeAssociationGraph,
|
||||||
|
CodeSymbolNode,
|
||||||
|
Range,
|
||||||
|
SearchResultCluster,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CallHierarchyItem",
|
||||||
|
"CodeAssociationGraph",
|
||||||
|
"CodeSymbolNode",
|
||||||
|
"Range",
|
||||||
|
"SearchResultCluster",
|
||||||
|
]
|
||||||
602
codex-lens/build/lib/codexlens/hybrid_search/data_structures.py
Normal file
602
codex-lens/build/lib/codexlens/hybrid_search/data_structures.py
Normal file
@@ -0,0 +1,602 @@
|
|||||||
|
"""Core data structures for the hybrid search system.
|
||||||
|
|
||||||
|
This module defines the fundamental data structures used throughout the
|
||||||
|
hybrid search pipeline, including code symbol representations, association
|
||||||
|
graphs, and clustered search results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Range:
|
||||||
|
"""Position range within a source file.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
start_line: Starting line number (0-based).
|
||||||
|
start_character: Starting character offset within the line.
|
||||||
|
end_line: Ending line number (0-based).
|
||||||
|
end_character: Ending character offset within the line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_line: int
|
||||||
|
start_character: int
|
||||||
|
end_line: int
|
||||||
|
end_character: int
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate range values."""
|
||||||
|
if self.start_line < 0:
|
||||||
|
raise ValueError("start_line must be >= 0")
|
||||||
|
if self.start_character < 0:
|
||||||
|
raise ValueError("start_character must be >= 0")
|
||||||
|
if self.end_line < 0:
|
||||||
|
raise ValueError("end_line must be >= 0")
|
||||||
|
if self.end_character < 0:
|
||||||
|
raise ValueError("end_character must be >= 0")
|
||||||
|
if self.end_line < self.start_line:
|
||||||
|
raise ValueError("end_line must be >= start_line")
|
||||||
|
if self.end_line == self.start_line and self.end_character < self.start_character:
|
||||||
|
raise ValueError("end_character must be >= start_character on the same line")
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
return {
|
||||||
|
"start": {"line": self.start_line, "character": self.start_character},
|
||||||
|
"end": {"line": self.end_line, "character": self.end_character},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> Range:
|
||||||
|
"""Create Range from dictionary representation."""
|
||||||
|
return cls(
|
||||||
|
start_line=data["start"]["line"],
|
||||||
|
start_character=data["start"]["character"],
|
||||||
|
end_line=data["end"]["line"],
|
||||||
|
end_character=data["end"]["character"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
|
||||||
|
"""Create Range from LSP Range object.
|
||||||
|
|
||||||
|
LSP Range format:
|
||||||
|
{"start": {"line": int, "character": int},
|
||||||
|
"end": {"line": int, "character": int}}
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
start_line=lsp_range["start"]["line"],
|
||||||
|
start_character=lsp_range["start"]["character"],
|
||||||
|
end_line=lsp_range["end"]["line"],
|
||||||
|
end_character=lsp_range["end"]["character"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallHierarchyItem:
|
||||||
|
"""LSP CallHierarchyItem for representing callers/callees.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Symbol name (function, method, class name).
|
||||||
|
kind: Symbol kind (function, method, class, etc.).
|
||||||
|
file_path: Absolute file path where the symbol is defined.
|
||||||
|
range: Position range in the source file.
|
||||||
|
detail: Optional additional detail about the symbol.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
file_path: str
|
||||||
|
range: Range
|
||||||
|
detail: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"name": self.name,
|
||||||
|
"kind": self.kind,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"range": self.range.to_dict(),
|
||||||
|
}
|
||||||
|
if self.detail:
|
||||||
|
result["detail"] = self.detail
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||||
|
"""Create CallHierarchyItem from dictionary representation."""
|
||||||
|
return cls(
|
||||||
|
name=data["name"],
|
||||||
|
kind=data["kind"],
|
||||||
|
file_path=data["file_path"],
|
||||||
|
range=Range.from_dict(data["range"]),
|
||||||
|
detail=data.get("detail"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeSymbolNode:
|
||||||
|
"""Graph node representing a code symbol.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier in format 'file_path:name:line'.
|
||||||
|
name: Symbol name (function, class, variable name).
|
||||||
|
kind: Symbol kind (function, class, method, variable, etc.).
|
||||||
|
file_path: Absolute file path where symbol is defined.
|
||||||
|
range: Start/end position in the source file.
|
||||||
|
embedding: Optional vector embedding for semantic search.
|
||||||
|
raw_code: Raw source code of the symbol.
|
||||||
|
docstring: Documentation string (if available).
|
||||||
|
score: Ranking score (used during reranking).
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
file_path: str
|
||||||
|
range: Range
|
||||||
|
embedding: Optional[List[float]] = None
|
||||||
|
raw_code: str = ""
|
||||||
|
docstring: str = ""
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate required fields."""
|
||||||
|
if not self.id:
|
||||||
|
raise ValueError("id cannot be empty")
|
||||||
|
if not self.name:
|
||||||
|
raise ValueError("name cannot be empty")
|
||||||
|
if not self.kind:
|
||||||
|
raise ValueError("kind cannot be empty")
|
||||||
|
if not self.file_path:
|
||||||
|
raise ValueError("file_path cannot be empty")
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Hash based on unique ID."""
|
||||||
|
return hash(self.id)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Equality based on unique ID."""
|
||||||
|
if not isinstance(other, CodeSymbolNode):
|
||||||
|
return False
|
||||||
|
return self.id == other.id
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"id": self.id,
|
||||||
|
"name": self.name,
|
||||||
|
"kind": self.kind,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"range": self.range.to_dict(),
|
||||||
|
"score": self.score,
|
||||||
|
}
|
||||||
|
if self.raw_code:
|
||||||
|
result["raw_code"] = self.raw_code
|
||||||
|
if self.docstring:
|
||||||
|
result["docstring"] = self.docstring
|
||||||
|
# Exclude embedding from serialization (too large for JSON responses)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
|
||||||
|
"""Create CodeSymbolNode from dictionary representation."""
|
||||||
|
return cls(
|
||||||
|
id=data["id"],
|
||||||
|
name=data["name"],
|
||||||
|
kind=data["kind"],
|
||||||
|
file_path=data["file_path"],
|
||||||
|
range=Range.from_dict(data["range"]),
|
||||||
|
embedding=data.get("embedding"),
|
||||||
|
raw_code=data.get("raw_code", ""),
|
||||||
|
docstring=data.get("docstring", ""),
|
||||||
|
score=data.get("score", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_lsp_location(
|
||||||
|
cls,
|
||||||
|
uri: str,
|
||||||
|
name: str,
|
||||||
|
kind: str,
|
||||||
|
lsp_range: Dict[str, Any],
|
||||||
|
raw_code: str = "",
|
||||||
|
docstring: str = "",
|
||||||
|
) -> CodeSymbolNode:
|
||||||
|
"""Create CodeSymbolNode from LSP location data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: File URI (file:// prefix will be stripped).
|
||||||
|
name: Symbol name.
|
||||||
|
kind: Symbol kind.
|
||||||
|
lsp_range: LSP Range object.
|
||||||
|
raw_code: Optional raw source code.
|
||||||
|
docstring: Optional documentation string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New CodeSymbolNode instance.
|
||||||
|
"""
|
||||||
|
# Strip file:// prefix if present
|
||||||
|
file_path = uri
|
||||||
|
if file_path.startswith("file://"):
|
||||||
|
file_path = file_path[7:]
|
||||||
|
# Handle Windows paths (file:///C:/...)
|
||||||
|
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||||
|
file_path = file_path[1:]
|
||||||
|
|
||||||
|
range_obj = Range.from_lsp_range(lsp_range)
|
||||||
|
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=symbol_id,
|
||||||
|
name=name,
|
||||||
|
kind=kind,
|
||||||
|
file_path=file_path,
|
||||||
|
range=range_obj,
|
||||||
|
raw_code=raw_code,
|
||||||
|
docstring=docstring,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_id(cls, file_path: str, name: str, line: int) -> str:
|
||||||
|
"""Generate a unique symbol ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Absolute file path.
|
||||||
|
name: Symbol name.
|
||||||
|
line: Start line number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unique ID string in format 'file_path:name:line'.
|
||||||
|
"""
|
||||||
|
return f"{file_path}:{name}:{line}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeAssociationGraph:
|
||||||
|
"""Graph of code relationships between symbols.
|
||||||
|
|
||||||
|
This graph represents the association between code symbols discovered
|
||||||
|
through LSP queries (references, call hierarchy, etc.).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
|
||||||
|
edges: List of (from_id, to_id, relationship_type) tuples.
|
||||||
|
relationship_type: 'calls', 'references', 'inherits', 'imports'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
|
||||||
|
edges: List[Tuple[str, str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add_node(self, node: CodeSymbolNode) -> None:
|
||||||
|
"""Add a node to the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: CodeSymbolNode to add. If a node with the same ID exists,
|
||||||
|
it will be replaced.
|
||||||
|
"""
|
||||||
|
self.nodes[node.id] = node
|
||||||
|
|
||||||
|
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||||
|
"""Add an edge to the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_id: Source node ID.
|
||||||
|
to_id: Target node ID.
|
||||||
|
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If from_id or to_id not in graph nodes.
|
||||||
|
"""
|
||||||
|
if from_id not in self.nodes:
|
||||||
|
raise ValueError(f"Source node '{from_id}' not found in graph")
|
||||||
|
if to_id not in self.nodes:
|
||||||
|
raise ValueError(f"Target node '{to_id}' not found in graph")
|
||||||
|
|
||||||
|
edge = (from_id, to_id, rel_type)
|
||||||
|
if edge not in self.edges:
|
||||||
|
self.edges.append(edge)
|
||||||
|
|
||||||
|
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||||
|
"""Add an edge without validating node existence.
|
||||||
|
|
||||||
|
Use this method during bulk graph construction where nodes may be
|
||||||
|
added after edges, or when performance is critical.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_id: Source node ID.
|
||||||
|
to_id: Target node ID.
|
||||||
|
rel_type: Relationship type.
|
||||||
|
"""
|
||||||
|
edge = (from_id, to_id, rel_type)
|
||||||
|
if edge not in self.edges:
|
||||||
|
self.edges.append(edge)
|
||||||
|
|
||||||
|
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
|
||||||
|
"""Get a node by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Node ID to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeSymbolNode if found, None otherwise.
|
||||||
|
"""
|
||||||
|
return self.nodes.get(node_id)
|
||||||
|
|
||||||
|
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||||
|
"""Get neighboring nodes connected by outgoing edges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Node ID to find neighbors for.
|
||||||
|
rel_type: Optional filter by relationship type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of neighboring CodeSymbolNode objects.
|
||||||
|
"""
|
||||||
|
neighbors = []
|
||||||
|
for from_id, to_id, edge_rel in self.edges:
|
||||||
|
if from_id == node_id:
|
||||||
|
if rel_type is None or edge_rel == rel_type:
|
||||||
|
node = self.nodes.get(to_id)
|
||||||
|
if node:
|
||||||
|
neighbors.append(node)
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||||
|
"""Get nodes connected by incoming edges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Node ID to find incoming connections for.
|
||||||
|
rel_type: Optional filter by relationship type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of CodeSymbolNode objects with edges pointing to node_id.
|
||||||
|
"""
|
||||||
|
incoming = []
|
||||||
|
for from_id, to_id, edge_rel in self.edges:
|
||||||
|
if to_id == node_id:
|
||||||
|
if rel_type is None or edge_rel == rel_type:
|
||||||
|
node = self.nodes.get(from_id)
|
||||||
|
if node:
|
||||||
|
incoming.append(node)
|
||||||
|
return incoming
|
||||||
|
|
||||||
|
def to_networkx(self) -> "nx.DiGraph":
|
||||||
|
"""Convert to NetworkX DiGraph for graph algorithms.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NetworkX directed graph with nodes and edges.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If networkx is not installed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import networkx as nx
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"networkx is required for graph algorithms. "
|
||||||
|
"Install with: pip install networkx"
|
||||||
|
)
|
||||||
|
|
||||||
|
graph = nx.DiGraph()
|
||||||
|
|
||||||
|
# Add nodes with attributes
|
||||||
|
for node_id, node in self.nodes.items():
|
||||||
|
graph.add_node(
|
||||||
|
node_id,
|
||||||
|
name=node.name,
|
||||||
|
kind=node.kind,
|
||||||
|
file_path=node.file_path,
|
||||||
|
score=node.score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add edges with relationship type
|
||||||
|
for from_id, to_id, rel_type in self.edges:
|
||||||
|
graph.add_edge(from_id, to_id, relationship=rel_type)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'nodes' and 'edges' keys.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||||
|
"edges": [
|
||||||
|
{"from": from_id, "to": to_id, "relationship": rel_type}
|
||||||
|
for from_id, to_id, rel_type in self.edges
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
|
||||||
|
"""Create CodeAssociationGraph from dictionary representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary with 'nodes' and 'edges' keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New CodeAssociationGraph instance.
|
||||||
|
"""
|
||||||
|
graph = cls()
|
||||||
|
|
||||||
|
# Load nodes
|
||||||
|
for node_id, node_data in data.get("nodes", {}).items():
|
||||||
|
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
|
||||||
|
|
||||||
|
# Load edges
|
||||||
|
for edge_data in data.get("edges", []):
|
||||||
|
graph.edges.append((
|
||||||
|
edge_data["from"],
|
||||||
|
edge_data["to"],
|
||||||
|
edge_data["relationship"],
|
||||||
|
))
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of nodes in the graph."""
|
||||||
|
return len(self.nodes)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResultCluster:
|
||||||
|
"""Clustered search result containing related code symbols.
|
||||||
|
|
||||||
|
Search results are grouped into clusters based on graph community
|
||||||
|
detection or embedding similarity. Each cluster represents a
|
||||||
|
conceptually related group of code symbols.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cluster_id: Unique cluster identifier.
|
||||||
|
score: Cluster relevance score (max of symbol scores).
|
||||||
|
title: Human-readable cluster title/summary.
|
||||||
|
symbols: List of CodeSymbolNode in this cluster.
|
||||||
|
metadata: Additional cluster metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cluster_id: str
|
||||||
|
score: float
|
||||||
|
title: str
|
||||||
|
symbols: List[CodeSymbolNode] = field(default_factory=list)
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate cluster fields."""
|
||||||
|
if not self.cluster_id:
|
||||||
|
raise ValueError("cluster_id cannot be empty")
|
||||||
|
if self.score < 0:
|
||||||
|
raise ValueError("score must be >= 0")
|
||||||
|
|
||||||
|
def add_symbol(self, symbol: CodeSymbolNode) -> None:
|
||||||
|
"""Add a symbol to the cluster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: CodeSymbolNode to add.
|
||||||
|
"""
|
||||||
|
self.symbols.append(symbol)
|
||||||
|
|
||||||
|
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
|
||||||
|
"""Get top N symbols by score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Number of symbols to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of top N CodeSymbolNode objects sorted by score descending.
|
||||||
|
"""
|
||||||
|
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
|
||||||
|
return sorted_symbols[:n]
|
||||||
|
|
||||||
|
def update_score(self) -> None:
|
||||||
|
"""Update cluster score to max of symbol scores."""
|
||||||
|
if self.symbols:
|
||||||
|
self.score = max(s.score for s in self.symbols)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation of the cluster.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"cluster_id": self.cluster_id,
|
||||||
|
"score": self.score,
|
||||||
|
"title": self.title,
|
||||||
|
"symbols": [s.to_dict() for s in self.symbols],
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
|
||||||
|
"""Create SearchResultCluster from dictionary representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary with cluster data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New SearchResultCluster instance.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
cluster_id=data["cluster_id"],
|
||||||
|
score=data["score"],
|
||||||
|
title=data["title"],
|
||||||
|
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of symbols in the cluster."""
|
||||||
|
return len(self.symbols)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallHierarchyItem:
|
||||||
|
"""LSP CallHierarchyItem for representing callers/callees.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Symbol name (function, method, etc.).
|
||||||
|
kind: Symbol kind (function, method, etc.).
|
||||||
|
file_path: Absolute file path.
|
||||||
|
range: Position range in the file.
|
||||||
|
detail: Optional additional detail (e.g., signature).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
file_path: str
|
||||||
|
range: Range
|
||||||
|
detail: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"name": self.name,
|
||||||
|
"kind": self.kind,
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"range": self.range.to_dict(),
|
||||||
|
}
|
||||||
|
if self.detail:
|
||||||
|
result["detail"] = self.detail
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||||
|
"""Create CallHierarchyItem from dictionary representation."""
|
||||||
|
return cls(
|
||||||
|
name=data.get("name", "unknown"),
|
||||||
|
kind=data.get("kind", "unknown"),
|
||||||
|
file_path=data.get("file_path", data.get("uri", "")),
|
||||||
|
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||||
|
detail=data.get("detail"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||||
|
"""Create CallHierarchyItem from LSP response format.
|
||||||
|
|
||||||
|
LSP uses 0-based line numbers and 'character' instead of 'char'.
|
||||||
|
"""
|
||||||
|
uri = data.get("uri", data.get("file_path", ""))
|
||||||
|
# Strip file:// prefix
|
||||||
|
file_path = uri
|
||||||
|
if file_path.startswith("file://"):
|
||||||
|
file_path = file_path[7:]
|
||||||
|
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||||
|
file_path = file_path[1:]
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=data.get("name", "unknown"),
|
||||||
|
kind=str(data.get("kind", "unknown")),
|
||||||
|
file_path=file_path,
|
||||||
|
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||||
|
detail=data.get("detail"),
|
||||||
|
)
|
||||||
26
codex-lens/build/lib/codexlens/indexing/__init__.py
Normal file
26
codex-lens/build/lib/codexlens/indexing/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Code indexing and symbol extraction."""
|
||||||
|
from codexlens.indexing.symbol_extractor import SymbolExtractor
|
||||||
|
from codexlens.indexing.embedding import (
|
||||||
|
BinaryEmbeddingBackend,
|
||||||
|
DenseEmbeddingBackend,
|
||||||
|
CascadeEmbeddingBackend,
|
||||||
|
get_cascade_embedder,
|
||||||
|
binarize_embedding,
|
||||||
|
pack_binary_embedding,
|
||||||
|
unpack_binary_embedding,
|
||||||
|
hamming_distance,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SymbolExtractor",
|
||||||
|
# Cascade embedding backends
|
||||||
|
"BinaryEmbeddingBackend",
|
||||||
|
"DenseEmbeddingBackend",
|
||||||
|
"CascadeEmbeddingBackend",
|
||||||
|
"get_cascade_embedder",
|
||||||
|
# Utility functions
|
||||||
|
"binarize_embedding",
|
||||||
|
"pack_binary_embedding",
|
||||||
|
"unpack_binary_embedding",
|
||||||
|
"hamming_distance",
|
||||||
|
]
|
||||||
582
codex-lens/build/lib/codexlens/indexing/embedding.py
Normal file
582
codex-lens/build/lib/codexlens/indexing/embedding.py
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
"""Multi-type embedding backends for cascade retrieval.
|
||||||
|
|
||||||
|
This module provides embedding backends optimized for cascade retrieval:
|
||||||
|
1. BinaryEmbeddingBackend - Fast coarse filtering with binary vectors
|
||||||
|
2. DenseEmbeddingBackend - High-precision dense vectors for reranking
|
||||||
|
3. CascadeEmbeddingBackend - Combined binary + dense for two-stage retrieval
|
||||||
|
|
||||||
|
Cascade retrieval workflow:
|
||||||
|
1. Binary search (fast, ~32 bytes/vector) -> top-K candidates
|
||||||
|
2. Dense rerank (precise, ~8KB/vector) -> final results
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from codexlens.semantic.base import BaseEmbedder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Utility Functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def binarize_embedding(embedding: np.ndarray) -> np.ndarray:
|
||||||
|
"""Convert float embedding to binary vector.
|
||||||
|
|
||||||
|
Applies sign-based quantization: values > 0 become 1, values <= 0 become 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Float32 embedding of any dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary vector (uint8 with values 0 or 1) of same dimension
|
||||||
|
"""
|
||||||
|
return (embedding > 0).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_binary_embedding(binary_vector: np.ndarray) -> bytes:
|
||||||
|
"""Pack binary vector into compact bytes format.
|
||||||
|
|
||||||
|
Packs 8 binary values into each byte for storage efficiency.
|
||||||
|
For a 256-dim binary vector, output is 32 bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
binary_vector: Binary vector (uint8 with values 0 or 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Packed bytes (length = ceil(dim / 8))
|
||||||
|
"""
|
||||||
|
# Ensure vector length is multiple of 8 by padding if needed
|
||||||
|
dim = len(binary_vector)
|
||||||
|
padded_dim = ((dim + 7) // 8) * 8
|
||||||
|
if padded_dim > dim:
|
||||||
|
padded = np.zeros(padded_dim, dtype=np.uint8)
|
||||||
|
padded[:dim] = binary_vector
|
||||||
|
binary_vector = padded
|
||||||
|
|
||||||
|
# Pack 8 bits per byte
|
||||||
|
packed = np.packbits(binary_vector)
|
||||||
|
return packed.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_binary_embedding(packed_bytes: bytes, dim: int = 256) -> np.ndarray:
|
||||||
|
"""Unpack bytes back to binary vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
packed_bytes: Packed binary data
|
||||||
|
dim: Original vector dimension (default: 256)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary vector (uint8 with values 0 or 1)
|
||||||
|
"""
|
||||||
|
unpacked = np.unpackbits(np.frombuffer(packed_bytes, dtype=np.uint8))
|
||||||
|
return unpacked[:dim]
|
||||||
|
|
||||||
|
|
||||||
|
def hamming_distance(a: bytes, b: bytes) -> int:
|
||||||
|
"""Compute Hamming distance between two packed binary vectors.
|
||||||
|
|
||||||
|
Uses XOR and popcount for efficient distance computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: First packed binary vector
|
||||||
|
b: Second packed binary vector
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hamming distance (number of differing bits)
|
||||||
|
"""
|
||||||
|
a_arr = np.frombuffer(a, dtype=np.uint8)
|
||||||
|
b_arr = np.frombuffer(b, dtype=np.uint8)
|
||||||
|
xor = np.bitwise_xor(a_arr, b_arr)
|
||||||
|
return int(np.unpackbits(xor).sum())
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Binary Embedding Backend
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryEmbeddingBackend(BaseEmbedder):
|
||||||
|
"""Generate 256-dimensional binary embeddings for fast coarse retrieval.
|
||||||
|
|
||||||
|
Uses a lightweight embedding model and applies sign-based quantization
|
||||||
|
to produce compact binary vectors (32 bytes per embedding).
|
||||||
|
|
||||||
|
Suitable for:
|
||||||
|
- First-stage candidate retrieval
|
||||||
|
- Hamming distance-based similarity search
|
||||||
|
- Memory-constrained environments
|
||||||
|
|
||||||
|
Model: sentence-transformers/all-MiniLM-L6-v2 (384 dim) -> quantized to 256 bits
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, fast
|
||||||
|
BINARY_DIM = 256
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize binary embedding backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Base embedding model name. Defaults to BAAI/bge-small-en-v1.5
|
||||||
|
use_gpu: Whether to use GPU acceleration
|
||||||
|
"""
|
||||||
|
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||||
|
|
||||||
|
if not SEMANTIC_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"Semantic search dependencies not available. "
|
||||||
|
"Install with: pip install codexlens[semantic]"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._model_name = model_name or self.DEFAULT_MODEL
|
||||||
|
self._use_gpu = use_gpu
|
||||||
|
self._model = None
|
||||||
|
|
||||||
|
# Projection matrix for dimension reduction (lazily initialized)
|
||||||
|
self._projection_matrix: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Return model name."""
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
"""Return binary embedding dimension (256)."""
|
||||||
|
return self.BINARY_DIM
|
||||||
|
|
||||||
|
@property
|
||||||
|
def packed_bytes(self) -> int:
|
||||||
|
"""Return packed bytes size (32 bytes for 256 bits)."""
|
||||||
|
return self.BINARY_DIM // 8
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Lazy load the embedding model."""
|
||||||
|
if self._model is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from fastembed import TextEmbedding
|
||||||
|
from codexlens.semantic.gpu_support import get_optimal_providers
|
||||||
|
|
||||||
|
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
|
||||||
|
try:
|
||||||
|
self._model = TextEmbedding(
|
||||||
|
model_name=self._model_name,
|
||||||
|
providers=providers,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
# Fallback for older fastembed versions
|
||||||
|
self._model = TextEmbedding(model_name=self._model_name)
|
||||||
|
|
||||||
|
logger.debug(f"BinaryEmbeddingBackend loaded model: {self._model_name}")
|
||||||
|
|
||||||
|
def _get_projection_matrix(self, input_dim: int) -> np.ndarray:
|
||||||
|
"""Get or create projection matrix for dimension reduction.
|
||||||
|
|
||||||
|
Uses random projection with fixed seed for reproducibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: Input embedding dimension from base model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Projection matrix of shape (input_dim, BINARY_DIM)
|
||||||
|
"""
|
||||||
|
if self._projection_matrix is not None:
|
||||||
|
return self._projection_matrix
|
||||||
|
|
||||||
|
# Fixed seed for reproducibility across sessions
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
# Gaussian random projection
|
||||||
|
self._projection_matrix = rng.randn(input_dim, self.BINARY_DIM).astype(np.float32)
|
||||||
|
# Normalize columns for consistent scale
|
||||||
|
norms = np.linalg.norm(self._projection_matrix, axis=0, keepdims=True)
|
||||||
|
self._projection_matrix /= (norms + 1e-8)
|
||||||
|
|
||||||
|
return self._projection_matrix
|
||||||
|
|
||||||
|
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||||
|
"""Generate binary embeddings as numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary embeddings of shape (n_texts, 256) with values 0 or 1
|
||||||
|
"""
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
else:
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
# Get base float embeddings
|
||||||
|
float_embeddings = np.array(list(self._model.embed(texts)))
|
||||||
|
input_dim = float_embeddings.shape[1]
|
||||||
|
|
||||||
|
# Project to target dimension if needed
|
||||||
|
if input_dim != self.BINARY_DIM:
|
||||||
|
projection = self._get_projection_matrix(input_dim)
|
||||||
|
float_embeddings = float_embeddings @ projection
|
||||||
|
|
||||||
|
# Binarize
|
||||||
|
return binarize_embedding(float_embeddings)
|
||||||
|
|
||||||
|
def embed_packed(self, texts: str | Iterable[str]) -> List[bytes]:
|
||||||
|
"""Generate packed binary embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of packed bytes (32 bytes each for 256-dim)
|
||||||
|
"""
|
||||||
|
binary = self.embed_to_numpy(texts)
|
||||||
|
return [pack_binary_embedding(vec) for vec in binary]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Dense Embedding Backend
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class DenseEmbeddingBackend(BaseEmbedder):
|
||||||
|
"""Generate high-dimensional dense embeddings for precise reranking.
|
||||||
|
|
||||||
|
Uses large embedding models to produce 2048-dimensional float32 vectors
|
||||||
|
for maximum retrieval quality.
|
||||||
|
|
||||||
|
Suitable for:
|
||||||
|
- Second-stage reranking
|
||||||
|
- High-precision similarity search
|
||||||
|
- Quality-critical applications
|
||||||
|
|
||||||
|
Model: BAAI/bge-large-en-v1.5 (1024 dim) with optional expansion
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, use small for testing
|
||||||
|
TARGET_DIM = 768 # Reduced target for faster testing
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
expand_dim: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize dense embedding backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Dense embedding model name. Defaults to BAAI/bge-large-en-v1.5
|
||||||
|
use_gpu: Whether to use GPU acceleration
|
||||||
|
expand_dim: If True, expand embeddings to TARGET_DIM using learned expansion
|
||||||
|
"""
|
||||||
|
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||||
|
|
||||||
|
if not SEMANTIC_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"Semantic search dependencies not available. "
|
||||||
|
"Install with: pip install codexlens[semantic]"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._model_name = model_name or self.DEFAULT_MODEL
|
||||||
|
self._use_gpu = use_gpu
|
||||||
|
self._expand_dim = expand_dim
|
||||||
|
self._model = None
|
||||||
|
self._native_dim: Optional[int] = None
|
||||||
|
|
||||||
|
# Expansion matrix for dimension expansion (lazily initialized)
|
||||||
|
self._expansion_matrix: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Return model name."""
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
"""Return embedding dimension.
|
||||||
|
|
||||||
|
Returns TARGET_DIM if expand_dim is True, otherwise native model dimension.
|
||||||
|
"""
|
||||||
|
if self._expand_dim:
|
||||||
|
return self.TARGET_DIM
|
||||||
|
# Return cached native dim or estimate based on model
|
||||||
|
if self._native_dim is not None:
|
||||||
|
return self._native_dim
|
||||||
|
# Model dimension estimates
|
||||||
|
model_dims = {
|
||||||
|
"BAAI/bge-large-en-v1.5": 1024,
|
||||||
|
"BAAI/bge-base-en-v1.5": 768,
|
||||||
|
"BAAI/bge-small-en-v1.5": 384,
|
||||||
|
"intfloat/multilingual-e5-large": 1024,
|
||||||
|
}
|
||||||
|
return model_dims.get(self._model_name, 1024)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit."""
|
||||||
|
return 512 # Conservative default for large models
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Lazy load the embedding model."""
|
||||||
|
if self._model is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from fastembed import TextEmbedding
|
||||||
|
from codexlens.semantic.gpu_support import get_optimal_providers
|
||||||
|
|
||||||
|
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
|
||||||
|
try:
|
||||||
|
self._model = TextEmbedding(
|
||||||
|
model_name=self._model_name,
|
||||||
|
providers=providers,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
self._model = TextEmbedding(model_name=self._model_name)
|
||||||
|
|
||||||
|
logger.debug(f"DenseEmbeddingBackend loaded model: {self._model_name}")
|
||||||
|
|
||||||
|
def _get_expansion_matrix(self, input_dim: int) -> np.ndarray:
|
||||||
|
"""Get or create expansion matrix for dimension expansion.
|
||||||
|
|
||||||
|
Uses random orthogonal projection for information-preserving expansion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: Input embedding dimension from base model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expansion matrix of shape (input_dim, TARGET_DIM)
|
||||||
|
"""
|
||||||
|
if self._expansion_matrix is not None:
|
||||||
|
return self._expansion_matrix
|
||||||
|
|
||||||
|
# Fixed seed for reproducibility
|
||||||
|
rng = np.random.RandomState(123)
|
||||||
|
|
||||||
|
# Create semi-orthogonal expansion matrix
|
||||||
|
# First input_dim columns form identity-like structure
|
||||||
|
self._expansion_matrix = np.zeros((input_dim, self.TARGET_DIM), dtype=np.float32)
|
||||||
|
|
||||||
|
# Copy original dimensions
|
||||||
|
copy_dim = min(input_dim, self.TARGET_DIM)
|
||||||
|
self._expansion_matrix[:copy_dim, :copy_dim] = np.eye(copy_dim, dtype=np.float32)
|
||||||
|
|
||||||
|
# Fill remaining with random projections
|
||||||
|
if self.TARGET_DIM > input_dim:
|
||||||
|
random_part = rng.randn(input_dim, self.TARGET_DIM - input_dim).astype(np.float32)
|
||||||
|
# Normalize
|
||||||
|
norms = np.linalg.norm(random_part, axis=0, keepdims=True)
|
||||||
|
random_part /= (norms + 1e-8)
|
||||||
|
self._expansion_matrix[:, input_dim:] = random_part
|
||||||
|
|
||||||
|
return self._expansion_matrix
|
||||||
|
|
||||||
|
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||||
|
"""Generate dense embeddings as numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dense embeddings of shape (n_texts, TARGET_DIM) as float32
|
||||||
|
"""
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
else:
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
# Get base float embeddings
|
||||||
|
float_embeddings = np.array(list(self._model.embed(texts)), dtype=np.float32)
|
||||||
|
self._native_dim = float_embeddings.shape[1]
|
||||||
|
|
||||||
|
# Expand to target dimension if needed
|
||||||
|
if self._expand_dim and self._native_dim < self.TARGET_DIM:
|
||||||
|
expansion = self._get_expansion_matrix(self._native_dim)
|
||||||
|
float_embeddings = float_embeddings @ expansion
|
||||||
|
|
||||||
|
return float_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Cascade Embedding Backend
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class CascadeEmbeddingBackend(BaseEmbedder):
|
||||||
|
"""Combined binary + dense embedding backend for cascade retrieval.
|
||||||
|
|
||||||
|
Generates both binary (for fast coarse filtering) and dense (for precise
|
||||||
|
reranking) embeddings in a single pass, optimized for two-stage retrieval.
|
||||||
|
|
||||||
|
Cascade workflow:
|
||||||
|
1. encode_cascade() returns (binary_embeddings, dense_embeddings)
|
||||||
|
2. Binary search: Use Hamming distance on binary vectors -> top-K candidates
|
||||||
|
3. Dense rerank: Use cosine similarity on dense vectors -> final results
|
||||||
|
|
||||||
|
Memory efficiency:
|
||||||
|
- Binary: 32 bytes per vector (256 bits)
|
||||||
|
- Dense: 8192 bytes per vector (2048 x float32)
|
||||||
|
- Total: ~8KB per document for full cascade support
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
binary_model: Optional[str] = None,
|
||||||
|
dense_model: Optional[str] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize cascade embedding backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
binary_model: Model for binary embeddings. Defaults to BAAI/bge-small-en-v1.5
|
||||||
|
dense_model: Model for dense embeddings. Defaults to BAAI/bge-large-en-v1.5
|
||||||
|
use_gpu: Whether to use GPU acceleration
|
||||||
|
"""
|
||||||
|
self._binary_backend = BinaryEmbeddingBackend(
|
||||||
|
model_name=binary_model,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
)
|
||||||
|
self._dense_backend = DenseEmbeddingBackend(
|
||||||
|
model_name=dense_model,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
expand_dim=True,
|
||||||
|
)
|
||||||
|
self._use_gpu = use_gpu
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Return model names for both backends."""
|
||||||
|
return f"cascade({self._binary_backend.model_name}, {self._dense_backend.model_name})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
"""Return dense embedding dimension (for compatibility)."""
|
||||||
|
return self._dense_backend.embedding_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def binary_dim(self) -> int:
|
||||||
|
"""Return binary embedding dimension."""
|
||||||
|
return self._binary_backend.embedding_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dense_dim(self) -> int:
|
||||||
|
"""Return dense embedding dimension."""
|
||||||
|
return self._dense_backend.embedding_dim
|
||||||
|
|
||||||
|
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||||
|
"""Generate dense embeddings (for BaseEmbedder compatibility).
|
||||||
|
|
||||||
|
For cascade embeddings, use encode_cascade() instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dense embeddings of shape (n_texts, dense_dim)
|
||||||
|
"""
|
||||||
|
return self._dense_backend.embed_to_numpy(texts)
|
||||||
|
|
||||||
|
def encode_cascade(
|
||||||
|
self,
|
||||||
|
texts: str | Iterable[str],
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Generate both binary and dense embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of:
|
||||||
|
- binary_embeddings: Shape (n_texts, 256), uint8 values 0/1
|
||||||
|
- dense_embeddings: Shape (n_texts, 2048), float32
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
else:
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
binary_embeddings = self._binary_backend.embed_to_numpy(texts)
|
||||||
|
dense_embeddings = self._dense_backend.embed_to_numpy(texts)
|
||||||
|
|
||||||
|
return binary_embeddings, dense_embeddings
|
||||||
|
|
||||||
|
def encode_binary(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||||
|
"""Generate only binary embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary embeddings of shape (n_texts, 256)
|
||||||
|
"""
|
||||||
|
return self._binary_backend.embed_to_numpy(texts)
|
||||||
|
|
||||||
|
def encode_dense(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||||
|
"""Generate only dense embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dense embeddings of shape (n_texts, 2048)
|
||||||
|
"""
|
||||||
|
return self._dense_backend.embed_to_numpy(texts)
|
||||||
|
|
||||||
|
def encode_binary_packed(self, texts: str | Iterable[str]) -> List[bytes]:
|
||||||
|
"""Generate packed binary embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of packed bytes (32 bytes each)
|
||||||
|
"""
|
||||||
|
return self._binary_backend.embed_packed(texts)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Factory Function
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def get_cascade_embedder(
|
||||||
|
binary_model: Optional[str] = None,
|
||||||
|
dense_model: Optional[str] = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
) -> CascadeEmbeddingBackend:
|
||||||
|
"""Factory function to create a cascade embedder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
binary_model: Model for binary embeddings (default: BAAI/bge-small-en-v1.5)
|
||||||
|
dense_model: Model for dense embeddings (default: BAAI/bge-large-en-v1.5)
|
||||||
|
use_gpu: Whether to use GPU acceleration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured CascadeEmbeddingBackend instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> embedder = get_cascade_embedder()
|
||||||
|
>>> binary, dense = embedder.encode_cascade(["hello world"])
|
||||||
|
>>> binary.shape # (1, 256)
|
||||||
|
>>> dense.shape # (1, 2048)
|
||||||
|
"""
|
||||||
|
return CascadeEmbeddingBackend(
|
||||||
|
binary_model=binary_model,
|
||||||
|
dense_model=dense_model,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
)
|
||||||
277
codex-lens/build/lib/codexlens/indexing/symbol_extractor.py
Normal file
277
codex-lens/build/lib/codexlens/indexing/symbol_extractor.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Symbol and relationship extraction from source code."""
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
try:
|
||||||
|
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||||
|
except Exception: # pragma: no cover - optional dependency / platform variance
|
||||||
|
TreeSitterSymbolParser = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
class SymbolExtractor:
|
||||||
|
"""Extract symbols and relationships from source code using regex patterns."""
|
||||||
|
|
||||||
|
# Pattern definitions for different languages
|
||||||
|
PATTERNS = {
|
||||||
|
'python': {
|
||||||
|
'function': r'^(?:async\s+)?def\s+(\w+)\s*\(',
|
||||||
|
'class': r'^class\s+(\w+)\s*[:\(]',
|
||||||
|
'import': r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)',
|
||||||
|
'call': r'(?<![.\w])(\w+)\s*\(',
|
||||||
|
},
|
||||||
|
'typescript': {
|
||||||
|
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*[<\(]',
|
||||||
|
'class': r'(?:export\s+)?class\s+(\w+)',
|
||||||
|
'import': r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
|
||||||
|
'call': r'(?<![.\w])(\w+)\s*[<\(]',
|
||||||
|
},
|
||||||
|
'javascript': {
|
||||||
|
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(',
|
||||||
|
'class': r'(?:export\s+)?class\s+(\w+)',
|
||||||
|
'import': r"(?:import|require)\s*\(?['\"]([^'\"]+)['\"]",
|
||||||
|
'call': r'(?<![.\w])(\w+)\s*\(',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LANGUAGE_MAP = {
|
||||||
|
'.py': 'python',
|
||||||
|
'.ts': 'typescript',
|
||||||
|
'.tsx': 'typescript',
|
||||||
|
'.js': 'javascript',
|
||||||
|
'.jsx': 'javascript',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.db_conn: Optional[sqlite3.Connection] = None
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""Connect to database and ensure schema exists."""
|
||||||
|
self.db_conn = sqlite3.connect(str(self.db_path))
|
||||||
|
self._ensure_tables()
|
||||||
|
|
||||||
|
def __enter__(self) -> "SymbolExtractor":
|
||||||
|
"""Context manager entry: connect to database."""
|
||||||
|
self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
"""Context manager exit: close database connection."""
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def _ensure_tables(self) -> None:
|
||||||
|
"""Create symbols and relationships tables if they don't exist."""
|
||||||
|
if not self.db_conn:
|
||||||
|
return
|
||||||
|
cursor = self.db_conn.cursor()
|
||||||
|
|
||||||
|
# Create symbols table with qualified_name
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS symbols (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
qualified_name TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
kind TEXT NOT NULL,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
start_line INTEGER NOT NULL,
|
||||||
|
end_line INTEGER NOT NULL,
|
||||||
|
UNIQUE(file_path, name, start_line)
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# Create relationships table with target_symbol_fqn
|
||||||
|
cursor.execute('''
|
||||||
|
CREATE TABLE IF NOT EXISTS symbol_relationships (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
source_symbol_id INTEGER NOT NULL,
|
||||||
|
target_symbol_fqn TEXT NOT NULL,
|
||||||
|
relationship_type TEXT NOT NULL,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
line INTEGER,
|
||||||
|
FOREIGN KEY (source_symbol_id) REFERENCES symbols(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# Create performance indexes
|
||||||
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)')
|
||||||
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_path)')
|
||||||
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_source ON symbol_relationships(source_symbol_id)')
|
||||||
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_target ON symbol_relationships(target_symbol_fqn)')
|
||||||
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_type ON symbol_relationships(relationship_type)')
|
||||||
|
|
||||||
|
self.db_conn.commit()
|
||||||
|
|
||||||
|
def extract_from_file(self, file_path: Path, content: str) -> Tuple[List[Dict], List[Dict]]:
|
||||||
|
"""Extract symbols and relationships from file content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the source file
|
||||||
|
content: File content as string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (symbols, relationships) where:
|
||||||
|
- symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
|
||||||
|
- relationships: List of relationship dicts with source_scope, target, type, file_path, line
|
||||||
|
"""
|
||||||
|
ext = file_path.suffix.lower()
|
||||||
|
lang = self.LANGUAGE_MAP.get(ext)
|
||||||
|
|
||||||
|
if not lang or lang not in self.PATTERNS:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
patterns = self.PATTERNS[lang]
|
||||||
|
symbols = []
|
||||||
|
relationships: List[Dict] = []
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
current_scope = None
|
||||||
|
|
||||||
|
for line_num, line in enumerate(lines, 1):
|
||||||
|
# Extract function/class definitions
|
||||||
|
for kind in ['function', 'class']:
|
||||||
|
if kind in patterns:
|
||||||
|
match = re.search(patterns[kind], line)
|
||||||
|
if match:
|
||||||
|
name = match.group(1)
|
||||||
|
qualified_name = f"{file_path.stem}.{name}"
|
||||||
|
symbols.append({
|
||||||
|
'qualified_name': qualified_name,
|
||||||
|
'name': name,
|
||||||
|
'kind': kind,
|
||||||
|
'file_path': str(file_path),
|
||||||
|
'start_line': line_num,
|
||||||
|
'end_line': line_num, # Simplified - would need proper parsing for actual end
|
||||||
|
})
|
||||||
|
current_scope = name
|
||||||
|
|
||||||
|
if TreeSitterSymbolParser is not None:
|
||||||
|
try:
|
||||||
|
ts_parser = TreeSitterSymbolParser(lang, file_path)
|
||||||
|
if ts_parser.is_available():
|
||||||
|
indexed = ts_parser.parse(content, file_path)
|
||||||
|
if indexed is not None and indexed.relationships:
|
||||||
|
relationships = [
|
||||||
|
{
|
||||||
|
"source_scope": r.source_symbol,
|
||||||
|
"target": r.target_symbol,
|
||||||
|
"type": r.relationship_type.value,
|
||||||
|
"file_path": str(file_path),
|
||||||
|
"line": r.source_line,
|
||||||
|
}
|
||||||
|
for r in indexed.relationships
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
relationships = []
|
||||||
|
|
||||||
|
# Regex fallback for relationships (when tree-sitter is unavailable)
|
||||||
|
if not relationships:
|
||||||
|
current_scope = None
|
||||||
|
for line_num, line in enumerate(lines, 1):
|
||||||
|
for kind in ['function', 'class']:
|
||||||
|
if kind in patterns:
|
||||||
|
match = re.search(patterns[kind], line)
|
||||||
|
if match:
|
||||||
|
current_scope = match.group(1)
|
||||||
|
|
||||||
|
# Extract imports
|
||||||
|
if 'import' in patterns:
|
||||||
|
match = re.search(patterns['import'], line)
|
||||||
|
if match:
|
||||||
|
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||||
|
if import_target and current_scope:
|
||||||
|
relationships.append({
|
||||||
|
'source_scope': current_scope,
|
||||||
|
'target': import_target.strip(),
|
||||||
|
'type': 'imports',
|
||||||
|
'file_path': str(file_path),
|
||||||
|
'line': line_num,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Extract function calls (simplified)
|
||||||
|
if 'call' in patterns and current_scope:
|
||||||
|
for match in re.finditer(patterns['call'], line):
|
||||||
|
call_name = match.group(1)
|
||||||
|
# Skip common keywords and the current function
|
||||||
|
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
|
||||||
|
relationships.append({
|
||||||
|
'source_scope': current_scope,
|
||||||
|
'target': call_name,
|
||||||
|
'type': 'calls',
|
||||||
|
'file_path': str(file_path),
|
||||||
|
'line': line_num,
|
||||||
|
})
|
||||||
|
|
||||||
|
return symbols, relationships
|
||||||
|
|
||||||
|
def save_symbols(self, symbols: List[Dict]) -> Dict[str, int]:
|
||||||
|
"""Save symbols to database and return name->id mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping symbol name to database id
|
||||||
|
"""
|
||||||
|
if not self.db_conn or not symbols:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
cursor = self.db_conn.cursor()
|
||||||
|
name_to_id = {}
|
||||||
|
|
||||||
|
for sym in symbols:
|
||||||
|
try:
|
||||||
|
cursor.execute('''
|
||||||
|
INSERT OR IGNORE INTO symbols
|
||||||
|
(qualified_name, name, kind, file_path, start_line, end_line)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
''', (sym['qualified_name'], sym['name'], sym['kind'],
|
||||||
|
sym['file_path'], sym['start_line'], sym['end_line']))
|
||||||
|
|
||||||
|
# Get the id
|
||||||
|
cursor.execute('''
|
||||||
|
SELECT id FROM symbols
|
||||||
|
WHERE file_path = ? AND name = ? AND start_line = ?
|
||||||
|
''', (sym['file_path'], sym['name'], sym['start_line']))
|
||||||
|
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
name_to_id[sym['name']] = row[0]
|
||||||
|
except sqlite3.Error:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.db_conn.commit()
|
||||||
|
return name_to_id
|
||||||
|
|
||||||
|
def save_relationships(self, relationships: List[Dict], name_to_id: Dict[str, int]) -> None:
|
||||||
|
"""Save relationships to database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relationships: List of relationship dicts with source_scope, target, type, file_path, line
|
||||||
|
name_to_id: Dictionary mapping symbol names to database ids
|
||||||
|
"""
|
||||||
|
if not self.db_conn or not relationships:
|
||||||
|
return
|
||||||
|
|
||||||
|
cursor = self.db_conn.cursor()
|
||||||
|
|
||||||
|
for rel in relationships:
|
||||||
|
source_id = name_to_id.get(rel['source_scope'])
|
||||||
|
if source_id:
|
||||||
|
try:
|
||||||
|
cursor.execute('''
|
||||||
|
INSERT INTO symbol_relationships
|
||||||
|
(source_symbol_id, target_symbol_fqn, relationship_type, file_path, line)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
''', (source_id, rel['target'], rel['type'], rel['file_path'], rel['line']))
|
||||||
|
except sqlite3.Error:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.db_conn.commit()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close database connection."""
|
||||||
|
if self.db_conn:
|
||||||
|
self.db_conn.close()
|
||||||
|
self.db_conn = None
|
||||||
34
codex-lens/build/lib/codexlens/lsp/__init__.py
Normal file
34
codex-lens/build/lib/codexlens/lsp/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""LSP module for real-time language server integration.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- LspBridge: HTTP bridge to VSCode language servers
|
||||||
|
- LspGraphBuilder: Build code association graphs via LSP
|
||||||
|
- Location: Position in a source file
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from codexlens.lsp import LspBridge, LspGraphBuilder
|
||||||
|
>>>
|
||||||
|
>>> async with LspBridge() as bridge:
|
||||||
|
... refs = await bridge.get_references(symbol)
|
||||||
|
... graph = await LspGraphBuilder().build_from_seeds(seeds, bridge)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from codexlens.lsp.lsp_bridge import (
|
||||||
|
CacheEntry,
|
||||||
|
Location,
|
||||||
|
LspBridge,
|
||||||
|
)
|
||||||
|
from codexlens.lsp.lsp_graph_builder import (
|
||||||
|
LspGraphBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Alias for backward compatibility
|
||||||
|
GraphBuilder = LspGraphBuilder
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CacheEntry",
|
||||||
|
"GraphBuilder",
|
||||||
|
"Location",
|
||||||
|
"LspBridge",
|
||||||
|
"LspGraphBuilder",
|
||||||
|
]
|
||||||
551
codex-lens/build/lib/codexlens/lsp/handlers.py
Normal file
551
codex-lens/build/lib/codexlens/lsp/handlers.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
"""LSP request handlers for codex-lens.
|
||||||
|
|
||||||
|
This module contains handlers for LSP requests:
|
||||||
|
- textDocument/definition
|
||||||
|
- textDocument/completion
|
||||||
|
- workspace/symbol
|
||||||
|
- textDocument/didSave
|
||||||
|
- textDocument/hover
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from urllib.parse import quote, unquote
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lsprotocol import types as lsp
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
from codexlens.entities import Symbol
|
||||||
|
from codexlens.lsp.server import server
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Symbol kind mapping from codex-lens to LSP
|
||||||
|
SYMBOL_KIND_MAP = {
|
||||||
|
"class": lsp.SymbolKind.Class,
|
||||||
|
"function": lsp.SymbolKind.Function,
|
||||||
|
"method": lsp.SymbolKind.Method,
|
||||||
|
"variable": lsp.SymbolKind.Variable,
|
||||||
|
"constant": lsp.SymbolKind.Constant,
|
||||||
|
"property": lsp.SymbolKind.Property,
|
||||||
|
"field": lsp.SymbolKind.Field,
|
||||||
|
"interface": lsp.SymbolKind.Interface,
|
||||||
|
"module": lsp.SymbolKind.Module,
|
||||||
|
"namespace": lsp.SymbolKind.Namespace,
|
||||||
|
"package": lsp.SymbolKind.Package,
|
||||||
|
"enum": lsp.SymbolKind.Enum,
|
||||||
|
"enum_member": lsp.SymbolKind.EnumMember,
|
||||||
|
"struct": lsp.SymbolKind.Struct,
|
||||||
|
"type": lsp.SymbolKind.TypeParameter,
|
||||||
|
"type_alias": lsp.SymbolKind.TypeParameter,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Completion kind mapping from codex-lens to LSP
|
||||||
|
COMPLETION_KIND_MAP = {
|
||||||
|
"class": lsp.CompletionItemKind.Class,
|
||||||
|
"function": lsp.CompletionItemKind.Function,
|
||||||
|
"method": lsp.CompletionItemKind.Method,
|
||||||
|
"variable": lsp.CompletionItemKind.Variable,
|
||||||
|
"constant": lsp.CompletionItemKind.Constant,
|
||||||
|
"property": lsp.CompletionItemKind.Property,
|
||||||
|
"field": lsp.CompletionItemKind.Field,
|
||||||
|
"interface": lsp.CompletionItemKind.Interface,
|
||||||
|
"module": lsp.CompletionItemKind.Module,
|
||||||
|
"enum": lsp.CompletionItemKind.Enum,
|
||||||
|
"enum_member": lsp.CompletionItemKind.EnumMember,
|
||||||
|
"struct": lsp.CompletionItemKind.Struct,
|
||||||
|
"type": lsp.CompletionItemKind.TypeParameter,
|
||||||
|
"type_alias": lsp.CompletionItemKind.TypeParameter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _path_to_uri(path: Union[str, Path]) -> str:
|
||||||
|
"""Convert a file path to a URI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path (string or Path object)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File URI string
|
||||||
|
"""
|
||||||
|
path_str = str(Path(path).resolve())
|
||||||
|
# Handle Windows paths
|
||||||
|
if path_str.startswith("/"):
|
||||||
|
return f"file://{quote(path_str)}"
|
||||||
|
else:
|
||||||
|
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
|
||||||
|
|
||||||
|
|
||||||
|
def _uri_to_path(uri: str) -> Path:
|
||||||
|
"""Convert a URI to a file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: File URI string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path object
|
||||||
|
"""
|
||||||
|
path = uri.replace("file:///", "").replace("file://", "")
|
||||||
|
return Path(unquote(path))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
|
||||||
|
"""Extract the word at the given position in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_text: Full document text
|
||||||
|
line: 0-based line number
|
||||||
|
character: 0-based character position
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Word at position, or None if no word found
|
||||||
|
"""
|
||||||
|
lines = document_text.splitlines()
|
||||||
|
if line >= len(lines):
|
||||||
|
return None
|
||||||
|
|
||||||
|
line_text = lines[line]
|
||||||
|
if character > len(line_text):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find word boundaries
|
||||||
|
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
|
||||||
|
for match in word_pattern.finditer(line_text):
|
||||||
|
if match.start() <= character <= match.end():
|
||||||
|
return match.group()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
|
||||||
|
"""Extract the incomplete word prefix at the given position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_text: Full document text
|
||||||
|
line: 0-based line number
|
||||||
|
character: 0-based character position
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prefix string (may be empty)
|
||||||
|
"""
|
||||||
|
lines = document_text.splitlines()
|
||||||
|
if line >= len(lines):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
line_text = lines[line]
|
||||||
|
if character > len(line_text):
|
||||||
|
character = len(line_text)
|
||||||
|
|
||||||
|
# Extract text before cursor
|
||||||
|
before_cursor = line_text[:character]
|
||||||
|
|
||||||
|
# Find the start of the current word
|
||||||
|
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
|
||||||
|
if match:
|
||||||
|
return match.group()
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
|
||||||
|
"""Convert a codex-lens Symbol to an LSP Location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: codex-lens Symbol object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LSP Location, or None if symbol has no file
|
||||||
|
"""
|
||||||
|
if not symbol.file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# LSP uses 0-based lines, codex-lens uses 1-based
|
||||||
|
start_line = max(0, symbol.range[0] - 1)
|
||||||
|
end_line = max(0, symbol.range[1] - 1)
|
||||||
|
|
||||||
|
return lsp.Location(
|
||||||
|
uri=_path_to_uri(symbol.file),
|
||||||
|
range=lsp.Range(
|
||||||
|
start=lsp.Position(line=start_line, character=0),
|
||||||
|
end=lsp.Position(line=end_line, character=0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
|
||||||
|
"""Map codex-lens symbol kind to LSP SymbolKind.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kind: codex-lens symbol kind string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LSP SymbolKind
|
||||||
|
"""
|
||||||
|
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
|
||||||
|
|
||||||
|
|
||||||
|
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
|
||||||
|
"""Map codex-lens symbol kind to LSP CompletionItemKind.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kind: codex-lens symbol kind string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LSP CompletionItemKind
|
||||||
|
"""
|
||||||
|
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# LSP Request Handlers
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
|
||||||
|
def lsp_definition(
|
||||||
|
params: lsp.DefinitionParams,
|
||||||
|
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
|
||||||
|
"""Handle textDocument/definition request.
|
||||||
|
|
||||||
|
Finds the definition of the symbol at the cursor position.
|
||||||
|
"""
|
||||||
|
if not server.global_index:
|
||||||
|
logger.debug("No global index available for definition lookup")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get document
|
||||||
|
document = server.workspace.get_text_document(params.text_document.uri)
|
||||||
|
if not document:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get word at position
|
||||||
|
word = _get_word_at_position(
|
||||||
|
document.source,
|
||||||
|
params.position.line,
|
||||||
|
params.position.character,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not word:
|
||||||
|
logger.debug("No word found at position")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("Looking up definition for: %s", word)
|
||||||
|
|
||||||
|
# Search for exact symbol match
|
||||||
|
try:
|
||||||
|
symbols = server.global_index.search(
|
||||||
|
name=word,
|
||||||
|
limit=10,
|
||||||
|
prefix_mode=False, # Exact match preferred
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter for exact name match
|
||||||
|
exact_matches = [s for s in symbols if s.name == word]
|
||||||
|
if not exact_matches:
|
||||||
|
# Fall back to prefix search
|
||||||
|
symbols = server.global_index.search(
|
||||||
|
name=word,
|
||||||
|
limit=10,
|
||||||
|
prefix_mode=True,
|
||||||
|
)
|
||||||
|
exact_matches = [s for s in symbols if s.name == word]
|
||||||
|
|
||||||
|
if not exact_matches:
|
||||||
|
logger.debug("No definition found for: %s", word)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to LSP locations
|
||||||
|
locations = []
|
||||||
|
for sym in exact_matches:
|
||||||
|
loc = symbol_to_location(sym)
|
||||||
|
if loc:
|
||||||
|
locations.append(loc)
|
||||||
|
|
||||||
|
if len(locations) == 1:
|
||||||
|
return locations[0]
|
||||||
|
elif locations:
|
||||||
|
return locations
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error looking up definition: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
|
||||||
|
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
|
||||||
|
"""Handle textDocument/references request.
|
||||||
|
|
||||||
|
Finds all references to the symbol at the cursor position using
|
||||||
|
the code_relationships table for accurate call-site tracking.
|
||||||
|
Falls back to same-name symbol search if search_engine is unavailable.
|
||||||
|
"""
|
||||||
|
document = server.workspace.get_text_document(params.text_document.uri)
|
||||||
|
if not document:
|
||||||
|
return None
|
||||||
|
|
||||||
|
word = _get_word_at_position(
|
||||||
|
document.source,
|
||||||
|
params.position.line,
|
||||||
|
params.position.character,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not word:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("Finding references for: %s", word)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try using search_engine.search_references() for accurate reference tracking
|
||||||
|
if server.search_engine and server.workspace_root:
|
||||||
|
references = server.search_engine.search_references(
|
||||||
|
symbol_name=word,
|
||||||
|
source_path=server.workspace_root,
|
||||||
|
limit=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
if references:
|
||||||
|
locations = []
|
||||||
|
for ref in references:
|
||||||
|
locations.append(
|
||||||
|
lsp.Location(
|
||||||
|
uri=_path_to_uri(ref.file_path),
|
||||||
|
range=lsp.Range(
|
||||||
|
start=lsp.Position(
|
||||||
|
line=max(0, ref.line - 1),
|
||||||
|
character=ref.column,
|
||||||
|
),
|
||||||
|
end=lsp.Position(
|
||||||
|
line=max(0, ref.line - 1),
|
||||||
|
character=ref.column + len(word),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return locations if locations else None
|
||||||
|
|
||||||
|
# Fallback: search for symbols with same name using global_index
|
||||||
|
if server.global_index:
|
||||||
|
symbols = server.global_index.search(
|
||||||
|
name=word,
|
||||||
|
limit=100,
|
||||||
|
prefix_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter for exact matches
|
||||||
|
exact_matches = [s for s in symbols if s.name == word]
|
||||||
|
|
||||||
|
locations = []
|
||||||
|
for sym in exact_matches:
|
||||||
|
loc = symbol_to_location(sym)
|
||||||
|
if loc:
|
||||||
|
locations.append(loc)
|
||||||
|
|
||||||
|
return locations if locations else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error finding references: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
|
||||||
|
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
|
||||||
|
"""Handle textDocument/completion request.
|
||||||
|
|
||||||
|
Provides code completion suggestions based on indexed symbols.
|
||||||
|
"""
|
||||||
|
if not server.global_index:
|
||||||
|
return None
|
||||||
|
|
||||||
|
document = server.workspace.get_text_document(params.text_document.uri)
|
||||||
|
if not document:
|
||||||
|
return None
|
||||||
|
|
||||||
|
prefix = _get_prefix_at_position(
|
||||||
|
document.source,
|
||||||
|
params.position.line,
|
||||||
|
params.position.character,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not prefix or len(prefix) < 2:
|
||||||
|
# Require at least 2 characters for completion
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("Completing prefix: %s", prefix)
|
||||||
|
|
||||||
|
try:
|
||||||
|
symbols = server.global_index.search(
|
||||||
|
name=prefix,
|
||||||
|
limit=50,
|
||||||
|
prefix_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not symbols:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to completion items
|
||||||
|
items = []
|
||||||
|
seen_names = set()
|
||||||
|
|
||||||
|
for sym in symbols:
|
||||||
|
if sym.name in seen_names:
|
||||||
|
continue
|
||||||
|
seen_names.add(sym.name)
|
||||||
|
|
||||||
|
items.append(
|
||||||
|
lsp.CompletionItem(
|
||||||
|
label=sym.name,
|
||||||
|
kind=_symbol_kind_to_completion_kind(sym.kind),
|
||||||
|
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
|
||||||
|
sort_text=sym.name.lower(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lsp.CompletionList(
|
||||||
|
is_incomplete=len(symbols) >= 50,
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error getting completions: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
|
||||||
|
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
|
||||||
|
"""Handle textDocument/hover request.
|
||||||
|
|
||||||
|
Provides hover information for the symbol at the cursor position
|
||||||
|
using HoverProvider for rich symbol information including
|
||||||
|
signature, documentation, and location.
|
||||||
|
"""
|
||||||
|
if not server.global_index:
|
||||||
|
return None
|
||||||
|
|
||||||
|
document = server.workspace.get_text_document(params.text_document.uri)
|
||||||
|
if not document:
|
||||||
|
return None
|
||||||
|
|
||||||
|
word = _get_word_at_position(
|
||||||
|
document.source,
|
||||||
|
params.position.line,
|
||||||
|
params.position.character,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not word:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("Hover for: %s", word)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use HoverProvider for rich symbol information
|
||||||
|
from codexlens.lsp.providers import HoverProvider
|
||||||
|
|
||||||
|
provider = HoverProvider(server.global_index, server.registry)
|
||||||
|
info = provider.get_hover_info(word)
|
||||||
|
|
||||||
|
if not info:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Format as markdown with signature and location
|
||||||
|
content = provider.format_hover_markdown(info)
|
||||||
|
|
||||||
|
return lsp.Hover(
|
||||||
|
contents=lsp.MarkupContent(
|
||||||
|
kind=lsp.MarkupKind.Markdown,
|
||||||
|
value=content,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error getting hover info: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.WORKSPACE_SYMBOL)
|
||||||
|
def lsp_workspace_symbol(
|
||||||
|
params: lsp.WorkspaceSymbolParams,
|
||||||
|
) -> Optional[List[lsp.SymbolInformation]]:
|
||||||
|
"""Handle workspace/symbol request.
|
||||||
|
|
||||||
|
Searches for symbols across the workspace.
|
||||||
|
"""
|
||||||
|
if not server.global_index:
|
||||||
|
return None
|
||||||
|
|
||||||
|
query = params.query
|
||||||
|
if not query or len(query) < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("Workspace symbol search: %s", query)
|
||||||
|
|
||||||
|
try:
|
||||||
|
symbols = server.global_index.search(
|
||||||
|
name=query,
|
||||||
|
limit=100,
|
||||||
|
prefix_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not symbols:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for sym in symbols:
|
||||||
|
loc = symbol_to_location(sym)
|
||||||
|
if loc:
|
||||||
|
result.append(
|
||||||
|
lsp.SymbolInformation(
|
||||||
|
name=sym.name,
|
||||||
|
kind=_symbol_kind_to_lsp(sym.kind),
|
||||||
|
location=loc,
|
||||||
|
container_name=Path(sym.file).parent.name if sym.file else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result if result else None
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error searching workspace symbols: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
|
||||||
|
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
|
||||||
|
"""Handle textDocument/didSave notification.
|
||||||
|
|
||||||
|
Triggers incremental re-indexing of the saved file.
|
||||||
|
Note: Full incremental indexing requires WatcherManager integration,
|
||||||
|
which is planned for Phase 2.
|
||||||
|
"""
|
||||||
|
file_path = _uri_to_path(params.text_document.uri)
|
||||||
|
logger.info("File saved: %s", file_path)
|
||||||
|
|
||||||
|
# Phase 1: Just log the save event
|
||||||
|
# Phase 2 will integrate with WatcherManager for incremental indexing
|
||||||
|
# if server.watcher_manager:
|
||||||
|
# server.watcher_manager.trigger_reindex(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
|
||||||
|
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
|
||||||
|
"""Handle textDocument/didOpen notification."""
|
||||||
|
file_path = _uri_to_path(params.text_document.uri)
|
||||||
|
logger.debug("File opened: %s", file_path)
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
|
||||||
|
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
|
||||||
|
"""Handle textDocument/didClose notification."""
|
||||||
|
file_path = _uri_to_path(params.text_document.uri)
|
||||||
|
logger.debug("File closed: %s", file_path)
|
||||||
834
codex-lens/build/lib/codexlens/lsp/lsp_bridge.py
Normal file
834
codex-lens/build/lib/codexlens/lsp/lsp_bridge.py
Normal file
@@ -0,0 +1,834 @@
|
|||||||
|
"""LspBridge service for real-time LSP communication with caching.
|
||||||
|
|
||||||
|
This module provides a bridge to communicate with language servers either via:
|
||||||
|
1. Standalone LSP Manager (direct subprocess communication - default)
|
||||||
|
2. VSCode Bridge extension (HTTP-based, legacy mode)
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Direct communication with language servers (no VSCode dependency)
|
||||||
|
- Cache with TTL and file modification time invalidation
|
||||||
|
- Graceful error handling with empty results on failure
|
||||||
|
- Support for definition, references, hover, and call hierarchy
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||||
|
|
||||||
|
# Check for optional dependencies
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
HAS_AIOHTTP = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_AIOHTTP = False
|
||||||
|
|
||||||
|
from codexlens.hybrid_search.data_structures import (
|
||||||
|
CallHierarchyItem,
|
||||||
|
CodeSymbolNode,
|
||||||
|
Range,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Location:
|
||||||
|
"""A location in a source file (LSP response format)."""
|
||||||
|
|
||||||
|
file_path: str
|
||||||
|
line: int
|
||||||
|
character: int
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"line": self.line,
|
||||||
|
"character": self.character,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_lsp_response(cls, data: Dict[str, Any]) -> "Location":
|
||||||
|
"""Create Location from LSP response format.
|
||||||
|
|
||||||
|
Handles both direct format and VSCode URI format.
|
||||||
|
"""
|
||||||
|
# Handle VSCode URI format (file:///path/to/file)
|
||||||
|
uri = data.get("uri", data.get("file_path", ""))
|
||||||
|
if uri.startswith("file:///"):
|
||||||
|
# Windows: file:///C:/path -> C:/path
|
||||||
|
# Unix: file:///path -> /path
|
||||||
|
file_path = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||||
|
elif uri.startswith("file://"):
|
||||||
|
file_path = uri[7:]
|
||||||
|
else:
|
||||||
|
file_path = uri
|
||||||
|
|
||||||
|
# Get position from range or direct fields
|
||||||
|
if "range" in data:
|
||||||
|
range_data = data["range"]
|
||||||
|
start = range_data.get("start", {})
|
||||||
|
line = start.get("line", 0) + 1 # LSP is 0-based, convert to 1-based
|
||||||
|
character = start.get("character", 0) + 1
|
||||||
|
else:
|
||||||
|
line = data.get("line", 1)
|
||||||
|
character = data.get("character", 1)
|
||||||
|
|
||||||
|
return cls(file_path=file_path, line=line, character=character)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry:
|
||||||
|
"""A cached LSP response with expiration metadata.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
data: The cached response data
|
||||||
|
file_mtime: File modification time when cached (for invalidation)
|
||||||
|
cached_at: Unix timestamp when entry was cached
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: Any
|
||||||
|
file_mtime: float
|
||||||
|
cached_at: float
|
||||||
|
|
||||||
|
|
||||||
|
class LspBridge:
|
||||||
|
"""Bridge for real-time LSP communication with language servers.
|
||||||
|
|
||||||
|
By default, uses StandaloneLspManager to directly spawn and communicate
|
||||||
|
with language servers via JSON-RPC over stdio. No VSCode dependency required.
|
||||||
|
|
||||||
|
For legacy mode, can use VSCode Bridge HTTP server (set use_vscode_bridge=True).
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Direct language server communication (default)
|
||||||
|
- Response caching with TTL and file modification invalidation
|
||||||
|
- Timeout handling
|
||||||
|
- Graceful error handling returning empty results
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Default: standalone mode (no VSCode needed)
|
||||||
|
async with LspBridge() as bridge:
|
||||||
|
refs = await bridge.get_references(symbol)
|
||||||
|
definition = await bridge.get_definition(symbol)
|
||||||
|
|
||||||
|
# Legacy: VSCode Bridge mode
|
||||||
|
async with LspBridge(use_vscode_bridge=True) as bridge:
|
||||||
|
refs = await bridge.get_references(symbol)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_BRIDGE_URL = "http://127.0.0.1:3457"
|
||||||
|
DEFAULT_TIMEOUT = 30.0 # seconds (increased for standalone mode)
|
||||||
|
DEFAULT_CACHE_TTL = 300 # 5 minutes
|
||||||
|
DEFAULT_MAX_CACHE_SIZE = 1000 # Maximum cache entries
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bridge_url: str = DEFAULT_BRIDGE_URL,
|
||||||
|
timeout: float = DEFAULT_TIMEOUT,
|
||||||
|
cache_ttl: int = DEFAULT_CACHE_TTL,
|
||||||
|
max_cache_size: int = DEFAULT_MAX_CACHE_SIZE,
|
||||||
|
use_vscode_bridge: bool = False,
|
||||||
|
workspace_root: Optional[str] = None,
|
||||||
|
config_file: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Initialize LspBridge.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bridge_url: URL of the VSCode Bridge HTTP server (legacy mode only)
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
cache_ttl: Cache time-to-live in seconds
|
||||||
|
max_cache_size: Maximum number of cache entries (LRU eviction)
|
||||||
|
use_vscode_bridge: If True, use VSCode Bridge HTTP mode (requires aiohttp)
|
||||||
|
workspace_root: Root directory for standalone LSP manager
|
||||||
|
config_file: Path to lsp-servers.json configuration file
|
||||||
|
"""
|
||||||
|
self.bridge_url = bridge_url
|
||||||
|
self.timeout = timeout
|
||||||
|
self.cache_ttl = cache_ttl
|
||||||
|
self.max_cache_size = max_cache_size
|
||||||
|
self.use_vscode_bridge = use_vscode_bridge
|
||||||
|
self.workspace_root = workspace_root
|
||||||
|
self.config_file = config_file
|
||||||
|
|
||||||
|
self.cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||||
|
|
||||||
|
# VSCode Bridge mode (legacy)
|
||||||
|
self._session: Optional["aiohttp.ClientSession"] = None
|
||||||
|
|
||||||
|
# Standalone mode (default)
|
||||||
|
self._manager: Optional["StandaloneLspManager"] = None
|
||||||
|
self._manager_started = False
|
||||||
|
|
||||||
|
# Validate dependencies
|
||||||
|
if use_vscode_bridge and not HAS_AIOHTTP:
|
||||||
|
raise ImportError(
|
||||||
|
"aiohttp is required for VSCode Bridge mode: pip install aiohttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_manager(self) -> "StandaloneLspManager":
|
||||||
|
"""Ensure standalone LSP manager is started."""
|
||||||
|
if self._manager is None:
|
||||||
|
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||||
|
self._manager = StandaloneLspManager(
|
||||||
|
workspace_root=self.workspace_root,
|
||||||
|
config_file=self.config_file,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._manager_started:
|
||||||
|
await self._manager.start()
|
||||||
|
self._manager_started = True
|
||||||
|
|
||||||
|
return self._manager
|
||||||
|
|
||||||
|
async def _get_session(self) -> "aiohttp.ClientSession":
|
||||||
|
"""Get or create the aiohttp session (VSCode Bridge mode only)."""
|
||||||
|
if not HAS_AIOHTTP:
|
||||||
|
raise ImportError("aiohttp required for VSCode Bridge mode")
|
||||||
|
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close connections and cleanup resources."""
|
||||||
|
# Close VSCode Bridge session
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
# Stop standalone manager
|
||||||
|
if self._manager and self._manager_started:
|
||||||
|
await self._manager.stop()
|
||||||
|
self._manager_started = False
|
||||||
|
|
||||||
|
def _get_file_mtime(self, file_path: str) -> float:
|
||||||
|
"""Get file modification time, or 0 if file doesn't exist."""
|
||||||
|
try:
|
||||||
|
return os.path.getmtime(file_path)
|
||||||
|
except OSError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _is_cached(self, cache_key: str, file_path: str) -> bool:
|
||||||
|
"""Check if cache entry is valid.
|
||||||
|
|
||||||
|
Cache is invalid if:
|
||||||
|
- Entry doesn't exist
|
||||||
|
- TTL has expired
|
||||||
|
- File has been modified since caching
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_key: The cache key to check
|
||||||
|
file_path: Path to source file for mtime check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if cache is valid and can be used
|
||||||
|
"""
|
||||||
|
if cache_key not in self.cache:
|
||||||
|
return False
|
||||||
|
|
||||||
|
entry = self.cache[cache_key]
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Check TTL
|
||||||
|
if now - entry.cached_at > self.cache_ttl:
|
||||||
|
del self.cache[cache_key]
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check file modification time
|
||||||
|
current_mtime = self._get_file_mtime(file_path)
|
||||||
|
if current_mtime != entry.file_mtime:
|
||||||
|
del self.cache[cache_key]
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Move to end on access (LRU behavior)
|
||||||
|
self.cache.move_to_end(cache_key)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _cache(self, key: str, file_path: str, data: Any) -> None:
|
||||||
|
"""Store data in cache with LRU eviction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
file_path: Path to source file (for mtime tracking)
|
||||||
|
data: Data to cache
|
||||||
|
"""
|
||||||
|
# Remove oldest entries if at capacity
|
||||||
|
while len(self.cache) >= self.max_cache_size:
|
||||||
|
self.cache.popitem(last=False) # Remove oldest (FIFO order)
|
||||||
|
|
||||||
|
# Move to end if key exists (update access order)
|
||||||
|
if key in self.cache:
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
|
||||||
|
self.cache[key] = CacheEntry(
|
||||||
|
data=data,
|
||||||
|
file_mtime=self._get_file_mtime(file_path),
|
||||||
|
cached_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear all cached entries."""
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
async def _request_vscode_bridge(self, action: str, params: Dict[str, Any]) -> Any:
|
||||||
|
"""Make HTTP request to VSCode Bridge (legacy mode).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The endpoint/action name (e.g., "get_definition")
|
||||||
|
params: Request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response data on success, None on failure
|
||||||
|
"""
|
||||||
|
url = f"{self.bridge_url}/{action}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = await self._get_session()
|
||||||
|
async with session.post(url, json=params) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await response.json()
|
||||||
|
if data.get("success") is False:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data.get("result")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_references(self, symbol: CodeSymbolNode) -> List[Location]:
|
||||||
|
"""Get all references to a symbol via real-time LSP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The code symbol to find references for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Location objects where the symbol is referenced.
|
||||||
|
Returns empty list on error or timeout.
|
||||||
|
"""
|
||||||
|
cache_key = f"refs:{symbol.id}"
|
||||||
|
|
||||||
|
if self._is_cached(cache_key, symbol.file_path):
|
||||||
|
return self.cache[cache_key].data
|
||||||
|
|
||||||
|
locations: List[Location] = []
|
||||||
|
|
||||||
|
if self.use_vscode_bridge:
|
||||||
|
# Legacy: VSCode Bridge HTTP mode
|
||||||
|
result = await self._request_vscode_bridge("get_references", {
|
||||||
|
"file_path": symbol.file_path,
|
||||||
|
"line": symbol.range.start_line,
|
||||||
|
"character": symbol.range.start_character,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Don't cache on connection error (result is None)
|
||||||
|
if result is None:
|
||||||
|
return locations
|
||||||
|
|
||||||
|
if isinstance(result, list):
|
||||||
|
for item in result:
|
||||||
|
try:
|
||||||
|
locations.append(Location.from_lsp_response(item))
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Default: Standalone mode
|
||||||
|
manager = await self._ensure_manager()
|
||||||
|
result = await manager.get_references(
|
||||||
|
file_path=symbol.file_path,
|
||||||
|
line=symbol.range.start_line,
|
||||||
|
character=symbol.range.start_character,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in result:
|
||||||
|
try:
|
||||||
|
locations.append(Location.from_lsp_response(item))
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._cache(cache_key, symbol.file_path, locations)
|
||||||
|
return locations
|
||||||
|
|
||||||
|
async def get_definition(self, symbol: CodeSymbolNode) -> Optional[Location]:
|
||||||
|
"""Get symbol definition location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The code symbol to find definition for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Location of the definition, or None if not found
|
||||||
|
"""
|
||||||
|
cache_key = f"def:{symbol.id}"
|
||||||
|
|
||||||
|
if self._is_cached(cache_key, symbol.file_path):
|
||||||
|
return self.cache[cache_key].data
|
||||||
|
|
||||||
|
location: Optional[Location] = None
|
||||||
|
|
||||||
|
if self.use_vscode_bridge:
|
||||||
|
# Legacy: VSCode Bridge HTTP mode
|
||||||
|
result = await self._request_vscode_bridge("get_definition", {
|
||||||
|
"file_path": symbol.file_path,
|
||||||
|
"line": symbol.range.start_line,
|
||||||
|
"character": symbol.range.start_character,
|
||||||
|
})
|
||||||
|
|
||||||
|
if result:
|
||||||
|
if isinstance(result, list) and len(result) > 0:
|
||||||
|
try:
|
||||||
|
location = Location.from_lsp_response(result[0])
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
pass
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
try:
|
||||||
|
location = Location.from_lsp_response(result)
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Default: Standalone mode
|
||||||
|
manager = await self._ensure_manager()
|
||||||
|
result = await manager.get_definition(
|
||||||
|
file_path=symbol.file_path,
|
||||||
|
line=symbol.range.start_line,
|
||||||
|
character=symbol.range.start_character,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
try:
|
||||||
|
location = Location.from_lsp_response(result)
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._cache(cache_key, symbol.file_path, location)
|
||||||
|
return location
|
||||||
|
|
||||||
|
async def get_call_hierarchy(self, symbol: CodeSymbolNode) -> List[CallHierarchyItem]:
|
||||||
|
"""Get incoming/outgoing calls for a symbol.
|
||||||
|
|
||||||
|
If call hierarchy is not supported by the language server,
|
||||||
|
falls back to using references.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The code symbol to get call hierarchy for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of CallHierarchyItem representing callers/callees.
|
||||||
|
Returns empty list on error or if not supported.
|
||||||
|
"""
|
||||||
|
cache_key = f"calls:{symbol.id}"
|
||||||
|
|
||||||
|
if self._is_cached(cache_key, symbol.file_path):
|
||||||
|
return self.cache[cache_key].data
|
||||||
|
|
||||||
|
items: List[CallHierarchyItem] = []
|
||||||
|
|
||||||
|
if self.use_vscode_bridge:
|
||||||
|
# Legacy: VSCode Bridge HTTP mode
|
||||||
|
result = await self._request_vscode_bridge("get_call_hierarchy", {
|
||||||
|
"file_path": symbol.file_path,
|
||||||
|
"line": symbol.range.start_line,
|
||||||
|
"character": symbol.range.start_character,
|
||||||
|
})
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
# Fallback: use references
|
||||||
|
refs = await self.get_references(symbol)
|
||||||
|
for ref in refs:
|
||||||
|
items.append(CallHierarchyItem(
|
||||||
|
name=f"caller@{ref.line}",
|
||||||
|
kind="reference",
|
||||||
|
file_path=ref.file_path,
|
||||||
|
range=Range(
|
||||||
|
start_line=ref.line,
|
||||||
|
start_character=ref.character,
|
||||||
|
end_line=ref.line,
|
||||||
|
end_character=ref.character,
|
||||||
|
),
|
||||||
|
detail="Inferred from reference",
|
||||||
|
))
|
||||||
|
elif isinstance(result, list):
|
||||||
|
for item in result:
|
||||||
|
try:
|
||||||
|
range_data = item.get("range", {})
|
||||||
|
start = range_data.get("start", {})
|
||||||
|
end = range_data.get("end", {})
|
||||||
|
|
||||||
|
items.append(CallHierarchyItem(
|
||||||
|
name=item.get("name", "unknown"),
|
||||||
|
kind=item.get("kind", "unknown"),
|
||||||
|
file_path=item.get("file_path", item.get("uri", "")),
|
||||||
|
range=Range(
|
||||||
|
start_line=start.get("line", 0) + 1,
|
||||||
|
start_character=start.get("character", 0) + 1,
|
||||||
|
end_line=end.get("line", 0) + 1,
|
||||||
|
end_character=end.get("character", 0) + 1,
|
||||||
|
),
|
||||||
|
detail=item.get("detail"),
|
||||||
|
))
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Default: Standalone mode
|
||||||
|
manager = await self._ensure_manager()
|
||||||
|
|
||||||
|
# Try to get call hierarchy items
|
||||||
|
hierarchy_items = await manager.get_call_hierarchy_items(
|
||||||
|
file_path=symbol.file_path,
|
||||||
|
line=symbol.range.start_line,
|
||||||
|
character=symbol.range.start_character,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hierarchy_items:
|
||||||
|
# Get incoming calls for each item
|
||||||
|
for h_item in hierarchy_items:
|
||||||
|
incoming = await manager.get_incoming_calls(h_item)
|
||||||
|
for call in incoming:
|
||||||
|
from_item = call.get("from", {})
|
||||||
|
range_data = from_item.get("range", {})
|
||||||
|
start = range_data.get("start", {})
|
||||||
|
end = range_data.get("end", {})
|
||||||
|
|
||||||
|
# Parse URI
|
||||||
|
uri = from_item.get("uri", "")
|
||||||
|
if uri.startswith("file:///"):
|
||||||
|
fp = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||||
|
elif uri.startswith("file://"):
|
||||||
|
fp = uri[7:]
|
||||||
|
else:
|
||||||
|
fp = uri
|
||||||
|
|
||||||
|
items.append(CallHierarchyItem(
|
||||||
|
name=from_item.get("name", "unknown"),
|
||||||
|
kind=str(from_item.get("kind", "unknown")),
|
||||||
|
file_path=fp,
|
||||||
|
range=Range(
|
||||||
|
start_line=start.get("line", 0) + 1,
|
||||||
|
start_character=start.get("character", 0) + 1,
|
||||||
|
end_line=end.get("line", 0) + 1,
|
||||||
|
end_character=end.get("character", 0) + 1,
|
||||||
|
),
|
||||||
|
detail=from_item.get("detail"),
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# Fallback: use references
|
||||||
|
refs = await self.get_references(symbol)
|
||||||
|
for ref in refs:
|
||||||
|
items.append(CallHierarchyItem(
|
||||||
|
name=f"caller@{ref.line}",
|
||||||
|
kind="reference",
|
||||||
|
file_path=ref.file_path,
|
||||||
|
range=Range(
|
||||||
|
start_line=ref.line,
|
||||||
|
start_character=ref.character,
|
||||||
|
end_line=ref.line,
|
||||||
|
end_character=ref.character,
|
||||||
|
),
|
||||||
|
detail="Inferred from reference",
|
||||||
|
))
|
||||||
|
|
||||||
|
self._cache(cache_key, symbol.file_path, items)
|
||||||
|
return items
|
||||||
|
|
||||||
|
async def get_document_symbols(self, file_path: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all symbols in a document (batch operation).
|
||||||
|
|
||||||
|
This is more efficient than individual hover queries when processing
|
||||||
|
multiple locations in the same file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the source file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of symbol dictionaries with name, kind, range, etc.
|
||||||
|
Returns empty list on error or timeout.
|
||||||
|
"""
|
||||||
|
cache_key = f"symbols:{file_path}"
|
||||||
|
|
||||||
|
if self._is_cached(cache_key, file_path):
|
||||||
|
return self.cache[cache_key].data
|
||||||
|
|
||||||
|
symbols: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
if self.use_vscode_bridge:
|
||||||
|
# Legacy: VSCode Bridge HTTP mode
|
||||||
|
result = await self._request_vscode_bridge("get_document_symbols", {
|
||||||
|
"file_path": file_path,
|
||||||
|
})
|
||||||
|
|
||||||
|
if isinstance(result, list):
|
||||||
|
symbols = self._flatten_document_symbols(result)
|
||||||
|
else:
|
||||||
|
# Default: Standalone mode
|
||||||
|
manager = await self._ensure_manager()
|
||||||
|
result = await manager.get_document_symbols(file_path)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
symbols = self._flatten_document_symbols(result)
|
||||||
|
|
||||||
|
self._cache(cache_key, file_path, symbols)
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
def _flatten_document_symbols(
|
||||||
|
self, symbols: List[Dict[str, Any]], parent_name: str = ""
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Flatten nested document symbols into a flat list.
|
||||||
|
|
||||||
|
Document symbols can be nested (e.g., methods inside classes).
|
||||||
|
This flattens them for easier lookup by line number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbols: List of symbol dictionaries (may be nested)
|
||||||
|
parent_name: Name of parent symbol for qualification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Flat list of all symbols with their ranges
|
||||||
|
"""
|
||||||
|
flat: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for sym in symbols:
|
||||||
|
# Add the symbol itself
|
||||||
|
symbol_entry = {
|
||||||
|
"name": sym.get("name", "unknown"),
|
||||||
|
"kind": self._symbol_kind_to_string(sym.get("kind", 0)),
|
||||||
|
"range": sym.get("range", sym.get("location", {}).get("range", {})),
|
||||||
|
"selection_range": sym.get("selectionRange", {}),
|
||||||
|
"detail": sym.get("detail", ""),
|
||||||
|
"parent": parent_name,
|
||||||
|
}
|
||||||
|
flat.append(symbol_entry)
|
||||||
|
|
||||||
|
# Recursively process children
|
||||||
|
children = sym.get("children", [])
|
||||||
|
if children:
|
||||||
|
qualified_name = sym.get("name", "")
|
||||||
|
if parent_name:
|
||||||
|
qualified_name = f"{parent_name}.{qualified_name}"
|
||||||
|
flat.extend(self._flatten_document_symbols(children, qualified_name))
|
||||||
|
|
||||||
|
return flat
|
||||||
|
|
||||||
|
def _symbol_kind_to_string(self, kind: int) -> str:
|
||||||
|
"""Convert LSP SymbolKind integer to string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kind: LSP SymbolKind enum value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable string representation
|
||||||
|
"""
|
||||||
|
# LSP SymbolKind enum (1-indexed)
|
||||||
|
kinds = {
|
||||||
|
1: "file",
|
||||||
|
2: "module",
|
||||||
|
3: "namespace",
|
||||||
|
4: "package",
|
||||||
|
5: "class",
|
||||||
|
6: "method",
|
||||||
|
7: "property",
|
||||||
|
8: "field",
|
||||||
|
9: "constructor",
|
||||||
|
10: "enum",
|
||||||
|
11: "interface",
|
||||||
|
12: "function",
|
||||||
|
13: "variable",
|
||||||
|
14: "constant",
|
||||||
|
15: "string",
|
||||||
|
16: "number",
|
||||||
|
17: "boolean",
|
||||||
|
18: "array",
|
||||||
|
19: "object",
|
||||||
|
20: "key",
|
||||||
|
21: "null",
|
||||||
|
22: "enum_member",
|
||||||
|
23: "struct",
|
||||||
|
24: "event",
|
||||||
|
25: "operator",
|
||||||
|
26: "type_parameter",
|
||||||
|
}
|
||||||
|
return kinds.get(kind, "unknown")
|
||||||
|
|
||||||
|
async def get_hover(self, symbol: CodeSymbolNode) -> Optional[str]:
|
||||||
|
"""Get hover documentation for a symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: The code symbol to get hover info for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Hover documentation as string, or None if not available
|
||||||
|
"""
|
||||||
|
cache_key = f"hover:{symbol.id}"
|
||||||
|
|
||||||
|
if self._is_cached(cache_key, symbol.file_path):
|
||||||
|
return self.cache[cache_key].data
|
||||||
|
|
||||||
|
hover_text: Optional[str] = None
|
||||||
|
|
||||||
|
if self.use_vscode_bridge:
|
||||||
|
# Legacy: VSCode Bridge HTTP mode
|
||||||
|
result = await self._request_vscode_bridge("get_hover", {
|
||||||
|
"file_path": symbol.file_path,
|
||||||
|
"line": symbol.range.start_line,
|
||||||
|
"character": symbol.range.start_character,
|
||||||
|
})
|
||||||
|
|
||||||
|
if result:
|
||||||
|
hover_text = self._parse_hover_result(result)
|
||||||
|
else:
|
||||||
|
# Default: Standalone mode
|
||||||
|
manager = await self._ensure_manager()
|
||||||
|
hover_text = await manager.get_hover(
|
||||||
|
file_path=symbol.file_path,
|
||||||
|
line=symbol.range.start_line,
|
||||||
|
character=symbol.range.start_character,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache(cache_key, symbol.file_path, hover_text)
|
||||||
|
return hover_text
|
||||||
|
|
||||||
|
def _parse_hover_result(self, result: Any) -> Optional[str]:
|
||||||
|
"""Parse hover result into string."""
|
||||||
|
if isinstance(result, str):
|
||||||
|
return result
|
||||||
|
elif isinstance(result, list):
|
||||||
|
parts = []
|
||||||
|
for item in result:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
value = item.get("value", item.get("contents", ""))
|
||||||
|
if value:
|
||||||
|
parts.append(str(value))
|
||||||
|
return "\n\n".join(parts) if parts else None
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
contents = result.get("contents", result.get("value", ""))
|
||||||
|
if isinstance(contents, str):
|
||||||
|
return contents
|
||||||
|
elif isinstance(contents, list):
|
||||||
|
parts = []
|
||||||
|
for c in contents:
|
||||||
|
if isinstance(c, str):
|
||||||
|
parts.append(c)
|
||||||
|
elif isinstance(c, dict):
|
||||||
|
parts.append(str(c.get("value", "")))
|
||||||
|
return "\n\n".join(parts) if parts else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "LspBridge":
|
||||||
|
"""Async context manager entry."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
"""Async context manager exit - close connections."""
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Simple test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
async def test_lsp_bridge():
|
||||||
|
"""Simple test of LspBridge functionality."""
|
||||||
|
print("Testing LspBridge (Standalone Mode)...")
|
||||||
|
print(f"Timeout: {LspBridge.DEFAULT_TIMEOUT}s")
|
||||||
|
print(f"Cache TTL: {LspBridge.DEFAULT_CACHE_TTL}s")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Create a test symbol pointing to this file
|
||||||
|
test_file = os.path.abspath(__file__)
|
||||||
|
test_symbol = CodeSymbolNode(
|
||||||
|
id=f"{test_file}:LspBridge:96",
|
||||||
|
name="LspBridge",
|
||||||
|
kind="class",
|
||||||
|
file_path=test_file,
|
||||||
|
range=Range(
|
||||||
|
start_line=96,
|
||||||
|
start_character=1,
|
||||||
|
end_line=200,
|
||||||
|
end_character=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Test symbol: {test_symbol.name} in {os.path.basename(test_symbol.file_path)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Use standalone mode (default)
|
||||||
|
async with LspBridge(
|
||||||
|
workspace_root=str(Path(__file__).parent.parent.parent.parent),
|
||||||
|
) as bridge:
|
||||||
|
print("1. Testing get_document_symbols...")
|
||||||
|
try:
|
||||||
|
symbols = await bridge.get_document_symbols(test_file)
|
||||||
|
print(f" Found {len(symbols)} symbols")
|
||||||
|
for sym in symbols[:5]:
|
||||||
|
print(f" - {sym.get('name')} ({sym.get('kind')})")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("2. Testing get_definition...")
|
||||||
|
try:
|
||||||
|
definition = await bridge.get_definition(test_symbol)
|
||||||
|
if definition:
|
||||||
|
print(f" Definition: {os.path.basename(definition.file_path)}:{definition.line}")
|
||||||
|
else:
|
||||||
|
print(" No definition found")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("3. Testing get_references...")
|
||||||
|
try:
|
||||||
|
refs = await bridge.get_references(test_symbol)
|
||||||
|
print(f" Found {len(refs)} references")
|
||||||
|
for ref in refs[:3]:
|
||||||
|
print(f" - {os.path.basename(ref.file_path)}:{ref.line}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("4. Testing get_hover...")
|
||||||
|
try:
|
||||||
|
hover = await bridge.get_hover(test_symbol)
|
||||||
|
if hover:
|
||||||
|
print(f" Hover: {hover[:100]}...")
|
||||||
|
else:
|
||||||
|
print(" No hover info found")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("5. Testing get_call_hierarchy...")
|
||||||
|
try:
|
||||||
|
calls = await bridge.get_call_hierarchy(test_symbol)
|
||||||
|
print(f" Found {len(calls)} call hierarchy items")
|
||||||
|
for call in calls[:3]:
|
||||||
|
print(f" - {call.name} in {os.path.basename(call.file_path)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error: {e}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("6. Testing cache...")
|
||||||
|
print(f" Cache entries: {len(bridge.cache)}")
|
||||||
|
for key in list(bridge.cache.keys())[:5]:
|
||||||
|
print(f" - {key}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Test complete!")
|
||||||
|
|
||||||
|
# Run the test
|
||||||
|
# Note: On Windows, use default ProactorEventLoop (supports subprocess creation)
|
||||||
|
|
||||||
|
asyncio.run(test_lsp_bridge())
|
||||||
375
codex-lens/build/lib/codexlens/lsp/lsp_graph_builder.py
Normal file
375
codex-lens/build/lib/codexlens/lsp/lsp_graph_builder.py
Normal file
@@ -0,0 +1,375 @@
|
|||||||
|
"""Graph builder for code association graphs via LSP."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from codexlens.hybrid_search.data_structures import (
|
||||||
|
CallHierarchyItem,
|
||||||
|
CodeAssociationGraph,
|
||||||
|
CodeSymbolNode,
|
||||||
|
Range,
|
||||||
|
)
|
||||||
|
from codexlens.lsp.lsp_bridge import (
|
||||||
|
Location,
|
||||||
|
LspBridge,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LspGraphBuilder:
|
||||||
|
"""Builds code association graph by expanding from seed symbols using LSP."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_depth: int = 2,
|
||||||
|
max_nodes: int = 100,
|
||||||
|
max_concurrent: int = 10,
|
||||||
|
):
|
||||||
|
"""Initialize GraphBuilder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_depth: Maximum depth for BFS expansion from seeds.
|
||||||
|
max_nodes: Maximum number of nodes in the graph.
|
||||||
|
max_concurrent: Maximum concurrent LSP requests.
|
||||||
|
"""
|
||||||
|
self.max_depth = max_depth
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.max_concurrent = max_concurrent
|
||||||
|
# Cache for document symbols per file (avoids per-location hover queries)
|
||||||
|
self._document_symbols_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
async def build_from_seeds(
|
||||||
|
self,
|
||||||
|
seeds: List[CodeSymbolNode],
|
||||||
|
lsp_bridge: LspBridge,
|
||||||
|
) -> CodeAssociationGraph:
|
||||||
|
"""Build association graph by BFS expansion from seeds.
|
||||||
|
|
||||||
|
For each seed:
|
||||||
|
1. Get references via LSP
|
||||||
|
2. Get call hierarchy via LSP
|
||||||
|
3. Add nodes and edges to graph
|
||||||
|
4. Continue expanding until max_depth or max_nodes reached
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seeds: Initial seed symbols to expand from.
|
||||||
|
lsp_bridge: LSP bridge for querying language servers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeAssociationGraph with expanded nodes and relationships.
|
||||||
|
"""
|
||||||
|
graph = CodeAssociationGraph()
|
||||||
|
visited: Set[str] = set()
|
||||||
|
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||||
|
|
||||||
|
# Initialize queue with seeds at depth 0
|
||||||
|
queue: List[Tuple[CodeSymbolNode, int]] = [(s, 0) for s in seeds]
|
||||||
|
|
||||||
|
# Add seed nodes to graph
|
||||||
|
for seed in seeds:
|
||||||
|
graph.add_node(seed)
|
||||||
|
|
||||||
|
# BFS expansion
|
||||||
|
while queue and len(graph.nodes) < self.max_nodes:
|
||||||
|
# Take a batch of nodes from queue
|
||||||
|
batch_size = min(self.max_concurrent, len(queue))
|
||||||
|
batch = queue[:batch_size]
|
||||||
|
queue = queue[batch_size:]
|
||||||
|
|
||||||
|
# Expand nodes in parallel
|
||||||
|
tasks = [
|
||||||
|
self._expand_node(
|
||||||
|
node, depth, graph, lsp_bridge, visited, semaphore
|
||||||
|
)
|
||||||
|
for node, depth in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Process results and add new nodes to queue
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning("Error expanding node: %s", result)
|
||||||
|
continue
|
||||||
|
if result:
|
||||||
|
# Add new nodes to queue if not at max depth
|
||||||
|
for new_node, new_depth in result:
|
||||||
|
if (
|
||||||
|
new_depth <= self.max_depth
|
||||||
|
and len(graph.nodes) < self.max_nodes
|
||||||
|
):
|
||||||
|
queue.append((new_node, new_depth))
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
async def _expand_node(
|
||||||
|
self,
|
||||||
|
node: CodeSymbolNode,
|
||||||
|
depth: int,
|
||||||
|
graph: CodeAssociationGraph,
|
||||||
|
lsp_bridge: LspBridge,
|
||||||
|
visited: Set[str],
|
||||||
|
semaphore: asyncio.Semaphore,
|
||||||
|
) -> List[Tuple[CodeSymbolNode, int]]:
|
||||||
|
"""Expand a single node, return new nodes to process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Node to expand.
|
||||||
|
depth: Current depth in BFS.
|
||||||
|
graph: Graph to add nodes and edges to.
|
||||||
|
lsp_bridge: LSP bridge for queries.
|
||||||
|
visited: Set of visited node IDs.
|
||||||
|
semaphore: Semaphore for concurrency control.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (new_node, new_depth) tuples to add to queue.
|
||||||
|
"""
|
||||||
|
# Skip if already visited or at max depth
|
||||||
|
if node.id in visited:
|
||||||
|
return []
|
||||||
|
if depth > self.max_depth:
|
||||||
|
return []
|
||||||
|
if len(graph.nodes) >= self.max_nodes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
visited.add(node.id)
|
||||||
|
new_nodes: List[Tuple[CodeSymbolNode, int]] = []
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
# Get relationships in parallel
|
||||||
|
try:
|
||||||
|
refs_task = lsp_bridge.get_references(node)
|
||||||
|
calls_task = lsp_bridge.get_call_hierarchy(node)
|
||||||
|
|
||||||
|
refs, calls = await asyncio.gather(
|
||||||
|
refs_task, calls_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle reference results
|
||||||
|
if isinstance(refs, Exception):
|
||||||
|
logger.debug(
|
||||||
|
"Failed to get references for %s: %s", node.id, refs
|
||||||
|
)
|
||||||
|
refs = []
|
||||||
|
|
||||||
|
# Handle call hierarchy results
|
||||||
|
if isinstance(calls, Exception):
|
||||||
|
logger.debug(
|
||||||
|
"Failed to get call hierarchy for %s: %s",
|
||||||
|
node.id,
|
||||||
|
calls,
|
||||||
|
)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
# Process references
|
||||||
|
for ref in refs:
|
||||||
|
if len(graph.nodes) >= self.max_nodes:
|
||||||
|
break
|
||||||
|
|
||||||
|
ref_node = await self._location_to_node(ref, lsp_bridge)
|
||||||
|
if ref_node and ref_node.id != node.id:
|
||||||
|
if ref_node.id not in graph.nodes:
|
||||||
|
graph.add_node(ref_node)
|
||||||
|
new_nodes.append((ref_node, depth + 1))
|
||||||
|
# Use add_edge since both nodes should exist now
|
||||||
|
graph.add_edge(node.id, ref_node.id, "references")
|
||||||
|
|
||||||
|
# Process call hierarchy (incoming calls)
|
||||||
|
for call in calls:
|
||||||
|
if len(graph.nodes) >= self.max_nodes:
|
||||||
|
break
|
||||||
|
|
||||||
|
call_node = await self._call_hierarchy_to_node(
|
||||||
|
call, lsp_bridge
|
||||||
|
)
|
||||||
|
if call_node and call_node.id != node.id:
|
||||||
|
if call_node.id not in graph.nodes:
|
||||||
|
graph.add_node(call_node)
|
||||||
|
new_nodes.append((call_node, depth + 1))
|
||||||
|
# Incoming call: call_node calls node
|
||||||
|
graph.add_edge(call_node.id, node.id, "calls")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Error during node expansion for %s: %s", node.id, e
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_nodes
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the document symbols cache.
|
||||||
|
|
||||||
|
Call this between searches to free memory and ensure fresh data.
|
||||||
|
"""
|
||||||
|
self._document_symbols_cache.clear()
|
||||||
|
|
||||||
|
async def _get_symbol_at_location(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
line: int,
|
||||||
|
lsp_bridge: LspBridge,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Find symbol at location using cached document symbols.
|
||||||
|
|
||||||
|
This is much more efficient than individual hover queries because
|
||||||
|
document symbols are fetched once per file and cached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the source file.
|
||||||
|
line: Line number (1-based).
|
||||||
|
lsp_bridge: LSP bridge for fetching document symbols.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symbol dictionary with name, kind, range, etc., or None if not found.
|
||||||
|
"""
|
||||||
|
# Get or fetch document symbols for this file
|
||||||
|
if file_path not in self._document_symbols_cache:
|
||||||
|
symbols = await lsp_bridge.get_document_symbols(file_path)
|
||||||
|
self._document_symbols_cache[file_path] = symbols
|
||||||
|
|
||||||
|
symbols = self._document_symbols_cache[file_path]
|
||||||
|
|
||||||
|
# Find symbol containing this line (best match = smallest range)
|
||||||
|
best_match: Optional[Dict[str, Any]] = None
|
||||||
|
best_range_size = float("inf")
|
||||||
|
|
||||||
|
for symbol in symbols:
|
||||||
|
sym_range = symbol.get("range", {})
|
||||||
|
start = sym_range.get("start", {})
|
||||||
|
end = sym_range.get("end", {})
|
||||||
|
|
||||||
|
# LSP ranges are 0-based, our line is 1-based
|
||||||
|
start_line = start.get("line", 0) + 1
|
||||||
|
end_line = end.get("line", 0) + 1
|
||||||
|
|
||||||
|
if start_line <= line <= end_line:
|
||||||
|
range_size = end_line - start_line
|
||||||
|
if range_size < best_range_size:
|
||||||
|
best_match = symbol
|
||||||
|
best_range_size = range_size
|
||||||
|
|
||||||
|
return best_match
|
||||||
|
|
||||||
|
async def _location_to_node(
|
||||||
|
self,
|
||||||
|
location: Location,
|
||||||
|
lsp_bridge: LspBridge,
|
||||||
|
) -> Optional[CodeSymbolNode]:
|
||||||
|
"""Convert LSP location to CodeSymbolNode.
|
||||||
|
|
||||||
|
Uses cached document symbols instead of individual hover queries
|
||||||
|
for better performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: LSP location to convert.
|
||||||
|
lsp_bridge: LSP bridge for additional queries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeSymbolNode or None if conversion fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file_path = location.file_path
|
||||||
|
start_line = location.line
|
||||||
|
|
||||||
|
# Try to find symbol info from cached document symbols (fast)
|
||||||
|
symbol_info = await self._get_symbol_at_location(
|
||||||
|
file_path, start_line, lsp_bridge
|
||||||
|
)
|
||||||
|
|
||||||
|
if symbol_info:
|
||||||
|
name = symbol_info.get("name", f"symbol_L{start_line}")
|
||||||
|
kind = symbol_info.get("kind", "unknown")
|
||||||
|
|
||||||
|
# Extract range from symbol if available
|
||||||
|
sym_range = symbol_info.get("range", {})
|
||||||
|
start = sym_range.get("start", {})
|
||||||
|
end = sym_range.get("end", {})
|
||||||
|
|
||||||
|
location_range = Range(
|
||||||
|
start_line=start.get("line", start_line - 1) + 1,
|
||||||
|
start_character=start.get("character", location.character - 1) + 1,
|
||||||
|
end_line=end.get("line", start_line - 1) + 1,
|
||||||
|
end_character=end.get("character", location.character - 1) + 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to basic node without symbol info
|
||||||
|
name = f"symbol_L{start_line}"
|
||||||
|
kind = "unknown"
|
||||||
|
location_range = Range(
|
||||||
|
start_line=location.line,
|
||||||
|
start_character=location.character,
|
||||||
|
end_line=location.line,
|
||||||
|
end_character=location.character,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_id = self._create_node_id(file_path, name, start_line)
|
||||||
|
|
||||||
|
return CodeSymbolNode(
|
||||||
|
id=node_id,
|
||||||
|
name=name,
|
||||||
|
kind=kind,
|
||||||
|
file_path=file_path,
|
||||||
|
range=location_range,
|
||||||
|
docstring="", # Skip hover for performance
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed to convert location to node: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _call_hierarchy_to_node(
|
||||||
|
self,
|
||||||
|
call_item: CallHierarchyItem,
|
||||||
|
lsp_bridge: LspBridge,
|
||||||
|
) -> Optional[CodeSymbolNode]:
|
||||||
|
"""Convert CallHierarchyItem to CodeSymbolNode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
call_item: Call hierarchy item to convert.
|
||||||
|
lsp_bridge: LSP bridge (unused, kept for API consistency).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeSymbolNode or None if conversion fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file_path = call_item.file_path
|
||||||
|
name = call_item.name
|
||||||
|
start_line = call_item.range.start_line
|
||||||
|
# CallHierarchyItem.kind is already a string
|
||||||
|
kind = call_item.kind
|
||||||
|
|
||||||
|
node_id = self._create_node_id(file_path, name, start_line)
|
||||||
|
|
||||||
|
return CodeSymbolNode(
|
||||||
|
id=node_id,
|
||||||
|
name=name,
|
||||||
|
kind=kind,
|
||||||
|
file_path=file_path,
|
||||||
|
range=call_item.range,
|
||||||
|
docstring=call_item.detail or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to convert call hierarchy item to node: %s", e
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_node_id(
|
||||||
|
self, file_path: str, name: str, line: int
|
||||||
|
) -> str:
|
||||||
|
"""Create unique node ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file.
|
||||||
|
name: Symbol name.
|
||||||
|
line: Line number (0-based).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unique node ID string.
|
||||||
|
"""
|
||||||
|
return f"{file_path}:{name}:{line}"
|
||||||
177
codex-lens/build/lib/codexlens/lsp/providers.py
Normal file
177
codex-lens/build/lib/codexlens/lsp/providers.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""LSP feature providers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||||
|
from codexlens.storage.registry import RegistryStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HoverInfo:
|
||||||
|
"""Hover information for a symbol."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
signature: str
|
||||||
|
documentation: Optional[str]
|
||||||
|
file_path: str
|
||||||
|
line_range: tuple # (start_line, end_line)
|
||||||
|
|
||||||
|
|
||||||
|
class HoverProvider:
|
||||||
|
"""Provides hover information for symbols."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
global_index: "GlobalSymbolIndex",
|
||||||
|
registry: Optional["RegistryStore"] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize hover provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_index: Global symbol index for lookups
|
||||||
|
registry: Optional registry store for index path resolution
|
||||||
|
"""
|
||||||
|
self.global_index = global_index
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
|
||||||
|
"""Get hover information for a symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_name: Name of the symbol to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HoverInfo or None if symbol not found
|
||||||
|
"""
|
||||||
|
# Look up symbol in global index using exact match
|
||||||
|
symbols = self.global_index.search(
|
||||||
|
name=symbol_name,
|
||||||
|
limit=1,
|
||||||
|
prefix_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter for exact name match
|
||||||
|
exact_matches = [s for s in symbols if s.name == symbol_name]
|
||||||
|
|
||||||
|
if not exact_matches:
|
||||||
|
return None
|
||||||
|
|
||||||
|
symbol = exact_matches[0]
|
||||||
|
|
||||||
|
# Extract signature from source file
|
||||||
|
signature = self._extract_signature(symbol)
|
||||||
|
|
||||||
|
# Symbol uses 'file' attribute and 'range' tuple
|
||||||
|
file_path = symbol.file or ""
|
||||||
|
start_line, end_line = symbol.range
|
||||||
|
|
||||||
|
return HoverInfo(
|
||||||
|
name=symbol.name,
|
||||||
|
kind=symbol.kind,
|
||||||
|
signature=signature,
|
||||||
|
documentation=None, # Symbol doesn't have docstring field
|
||||||
|
file_path=file_path,
|
||||||
|
line_range=(start_line, end_line),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_signature(self, symbol) -> str:
|
||||||
|
"""Extract function/class signature from source file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Symbol object with file and range information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted signature string or fallback kind + name
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file_path = Path(symbol.file) if symbol.file else None
|
||||||
|
if not file_path or not file_path.exists():
|
||||||
|
return f"{symbol.kind} {symbol.name}"
|
||||||
|
|
||||||
|
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
||||||
|
lines = content.split("\n")
|
||||||
|
|
||||||
|
# Extract signature lines (first line of definition + continuation)
|
||||||
|
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
|
||||||
|
if start_line >= len(lines) or start_line < 0:
|
||||||
|
return f"{symbol.kind} {symbol.name}"
|
||||||
|
|
||||||
|
signature_lines = []
|
||||||
|
first_line = lines[start_line]
|
||||||
|
signature_lines.append(first_line)
|
||||||
|
|
||||||
|
# Continue if multiline signature (no closing paren + colon yet)
|
||||||
|
# Look for patterns like "def func(", "class Foo(", etc.
|
||||||
|
i = start_line + 1
|
||||||
|
max_lines = min(start_line + 5, len(lines))
|
||||||
|
while i < max_lines:
|
||||||
|
line = signature_lines[-1]
|
||||||
|
# Stop if we see closing pattern
|
||||||
|
if "):" in line or line.rstrip().endswith(":"):
|
||||||
|
break
|
||||||
|
signature_lines.append(lines[i])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return "\n".join(signature_lines)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
|
||||||
|
return f"{symbol.kind} {symbol.name}"
|
||||||
|
|
||||||
|
def format_hover_markdown(self, info: HoverInfo) -> str:
|
||||||
|
"""Format hover info as Markdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info: HoverInfo object to format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted hover content
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Detect language for code fence based on file extension
|
||||||
|
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
|
||||||
|
lang_map = {
|
||||||
|
".py": "python",
|
||||||
|
".js": "javascript",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
".jsx": "javascript",
|
||||||
|
".java": "java",
|
||||||
|
".go": "go",
|
||||||
|
".rs": "rust",
|
||||||
|
".c": "c",
|
||||||
|
".cpp": "cpp",
|
||||||
|
".h": "c",
|
||||||
|
".hpp": "cpp",
|
||||||
|
".cs": "csharp",
|
||||||
|
".rb": "ruby",
|
||||||
|
".php": "php",
|
||||||
|
}
|
||||||
|
lang = lang_map.get(ext, "")
|
||||||
|
|
||||||
|
# Code block with signature
|
||||||
|
parts.append(f"```{lang}\n{info.signature}\n```")
|
||||||
|
|
||||||
|
# Documentation if available
|
||||||
|
if info.documentation:
|
||||||
|
parts.append(f"\n---\n\n{info.documentation}")
|
||||||
|
|
||||||
|
# Location info
|
||||||
|
file_name = Path(info.file_path).name if info.file_path else "unknown"
|
||||||
|
parts.append(
|
||||||
|
f"\n---\n\n*{info.kind}* defined in "
|
||||||
|
f"`{file_name}` "
|
||||||
|
f"(line {info.line_range[0]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
263
codex-lens/build/lib/codexlens/lsp/server.py
Normal file
263
codex-lens/build/lib/codexlens/lsp/server.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
"""codex-lens LSP Server implementation using pygls.
|
||||||
|
|
||||||
|
This module provides the main Language Server class and entry point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lsprotocol import types as lsp
|
||||||
|
from pygls.lsp.server import LanguageServer
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
from codexlens.config import Config
|
||||||
|
from codexlens.search.chain_search import ChainSearchEngine
|
||||||
|
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||||
|
from codexlens.storage.path_mapper import PathMapper
|
||||||
|
from codexlens.storage.registry import RegistryStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CodexLensLanguageServer(LanguageServer):
|
||||||
|
"""Language Server for codex-lens code indexing.
|
||||||
|
|
||||||
|
Provides IDE features using codex-lens symbol index:
|
||||||
|
- Go to Definition
|
||||||
|
- Find References
|
||||||
|
- Code Completion
|
||||||
|
- Hover Information
|
||||||
|
- Workspace Symbol Search
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
registry: Global project registry for path lookups
|
||||||
|
mapper: Path mapper for source/index conversions
|
||||||
|
global_index: Project-wide symbol index
|
||||||
|
search_engine: Chain search engine for symbol search
|
||||||
|
workspace_root: Current workspace root path
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(name="codexlens-lsp", version="0.1.0")
|
||||||
|
|
||||||
|
self.registry: Optional[RegistryStore] = None
|
||||||
|
self.mapper: Optional[PathMapper] = None
|
||||||
|
self.global_index: Optional[GlobalSymbolIndex] = None
|
||||||
|
self.search_engine: Optional[ChainSearchEngine] = None
|
||||||
|
self.workspace_root: Optional[Path] = None
|
||||||
|
self._config: Optional[Config] = None
|
||||||
|
|
||||||
|
def initialize_components(self, workspace_root: Path) -> bool:
|
||||||
|
"""Initialize codex-lens components for the workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_root: Root path of the workspace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if initialization succeeded, False otherwise
|
||||||
|
"""
|
||||||
|
self.workspace_root = workspace_root.resolve()
|
||||||
|
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize registry
|
||||||
|
self.registry = RegistryStore()
|
||||||
|
self.registry.initialize()
|
||||||
|
|
||||||
|
# Initialize path mapper
|
||||||
|
self.mapper = PathMapper()
|
||||||
|
|
||||||
|
# Try to find project in registry
|
||||||
|
project_info = self.registry.find_by_source_path(str(self.workspace_root))
|
||||||
|
|
||||||
|
if project_info:
|
||||||
|
project_id = int(project_info["id"])
|
||||||
|
index_root = Path(project_info["index_root"])
|
||||||
|
|
||||||
|
# Initialize global symbol index
|
||||||
|
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||||
|
self.global_index = GlobalSymbolIndex(global_db, project_id)
|
||||||
|
self.global_index.initialize()
|
||||||
|
|
||||||
|
# Initialize search engine
|
||||||
|
self._config = Config()
|
||||||
|
self.search_engine = ChainSearchEngine(
|
||||||
|
registry=self.registry,
|
||||||
|
mapper=self.mapper,
|
||||||
|
config=self._config,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Workspace not indexed by codex-lens: %s. "
|
||||||
|
"Run 'codexlens index %s' to index first.",
|
||||||
|
self.workspace_root,
|
||||||
|
self.workspace_root,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to initialize codex-lens: %s", exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def shutdown_components(self) -> None:
|
||||||
|
"""Clean up codex-lens components."""
|
||||||
|
if self.global_index:
|
||||||
|
try:
|
||||||
|
self.global_index.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Error closing global index: %s", exc)
|
||||||
|
self.global_index = None
|
||||||
|
|
||||||
|
if self.search_engine:
|
||||||
|
try:
|
||||||
|
self.search_engine.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Error closing search engine: %s", exc)
|
||||||
|
self.search_engine = None
|
||||||
|
|
||||||
|
if self.registry:
|
||||||
|
try:
|
||||||
|
self.registry.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Error closing registry: %s", exc)
|
||||||
|
self.registry = None
|
||||||
|
|
||||||
|
|
||||||
|
# Create server instance
|
||||||
|
server = CodexLensLanguageServer()
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.INITIALIZE)
|
||||||
|
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
|
||||||
|
"""Handle LSP initialize request."""
|
||||||
|
logger.info("LSP initialize request received")
|
||||||
|
|
||||||
|
# Get workspace root
|
||||||
|
workspace_root: Optional[Path] = None
|
||||||
|
if params.root_uri:
|
||||||
|
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
|
||||||
|
elif params.root_path:
|
||||||
|
workspace_root = Path(params.root_path)
|
||||||
|
|
||||||
|
if workspace_root:
|
||||||
|
server.initialize_components(workspace_root)
|
||||||
|
|
||||||
|
# Declare server capabilities
|
||||||
|
return lsp.InitializeResult(
|
||||||
|
capabilities=lsp.ServerCapabilities(
|
||||||
|
text_document_sync=lsp.TextDocumentSyncOptions(
|
||||||
|
open_close=True,
|
||||||
|
change=lsp.TextDocumentSyncKind.Incremental,
|
||||||
|
save=lsp.SaveOptions(include_text=False),
|
||||||
|
),
|
||||||
|
definition_provider=True,
|
||||||
|
references_provider=True,
|
||||||
|
completion_provider=lsp.CompletionOptions(
|
||||||
|
trigger_characters=[".", ":"],
|
||||||
|
resolve_provider=False,
|
||||||
|
),
|
||||||
|
hover_provider=True,
|
||||||
|
workspace_symbol_provider=True,
|
||||||
|
),
|
||||||
|
server_info=lsp.ServerInfo(
|
||||||
|
name="codexlens-lsp",
|
||||||
|
version="0.1.0",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@server.feature(lsp.SHUTDOWN)
|
||||||
|
def lsp_shutdown(params: None) -> None:
|
||||||
|
"""Handle LSP shutdown request."""
|
||||||
|
logger.info("LSP shutdown request received")
|
||||||
|
server.shutdown_components()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""Entry point for codexlens-lsp command.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exit code (0 for success)
|
||||||
|
"""
|
||||||
|
# Import handlers to register them with the server
|
||||||
|
# This must be done before starting the server
|
||||||
|
import codexlens.lsp.handlers # noqa: F401
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="codex-lens Language Server",
|
||||||
|
prog="codexlens-lsp",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stdio",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Use stdio for communication (default)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tcp",
|
||||||
|
action="store_true",
|
||||||
|
help="Use TCP for communication",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
default="127.0.0.1",
|
||||||
|
help="TCP host (default: 127.0.0.1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=2087,
|
||||||
|
help="TCP port (default: 2087)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
|
default="INFO",
|
||||||
|
help="Log level (default: INFO)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-file",
|
||||||
|
help="Log file path (optional)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
log_handlers = []
|
||||||
|
if args.log_file:
|
||||||
|
log_handlers.append(logging.FileHandler(args.log_file))
|
||||||
|
else:
|
||||||
|
log_handlers.append(logging.StreamHandler(sys.stderr))
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, args.log_level),
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=log_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Starting codexlens-lsp server")
|
||||||
|
|
||||||
|
if args.tcp:
|
||||||
|
logger.info("Starting TCP server on %s:%d", args.host, args.port)
|
||||||
|
server.start_tcp(args.host, args.port)
|
||||||
|
else:
|
||||||
|
logger.info("Starting stdio server")
|
||||||
|
server.start_io()
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
1159
codex-lens/build/lib/codexlens/lsp/standalone_manager.py
Normal file
1159
codex-lens/build/lib/codexlens/lsp/standalone_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
20
codex-lens/build/lib/codexlens/mcp/__init__.py
Normal file
20
codex-lens/build/lib/codexlens/mcp/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""Model Context Protocol implementation for Claude Code integration."""
|
||||||
|
|
||||||
|
from codexlens.mcp.schema import (
|
||||||
|
MCPContext,
|
||||||
|
SymbolInfo,
|
||||||
|
ReferenceInfo,
|
||||||
|
RelatedSymbol,
|
||||||
|
)
|
||||||
|
from codexlens.mcp.provider import MCPProvider
|
||||||
|
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MCPContext",
|
||||||
|
"SymbolInfo",
|
||||||
|
"ReferenceInfo",
|
||||||
|
"RelatedSymbol",
|
||||||
|
"MCPProvider",
|
||||||
|
"HookManager",
|
||||||
|
"create_context_for_prompt",
|
||||||
|
]
|
||||||
170
codex-lens/build/lib/codexlens/mcp/hooks.py
Normal file
170
codex-lens/build/lib/codexlens/mcp/hooks.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""Hook interfaces for Claude Code integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
|
||||||
|
|
||||||
|
from codexlens.mcp.schema import MCPContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from codexlens.mcp.provider import MCPProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HookManager:
|
||||||
|
"""Manages hook registration and execution."""
|
||||||
|
|
||||||
|
def __init__(self, mcp_provider: "MCPProvider") -> None:
|
||||||
|
self.mcp_provider = mcp_provider
|
||||||
|
self._pre_hooks: Dict[str, Callable] = {}
|
||||||
|
self._post_hooks: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
# Register default hooks
|
||||||
|
self._register_default_hooks()
|
||||||
|
|
||||||
|
def _register_default_hooks(self) -> None:
|
||||||
|
"""Register built-in hooks."""
|
||||||
|
self._pre_hooks["explain"] = self._pre_explain_hook
|
||||||
|
self._pre_hooks["refactor"] = self._pre_refactor_hook
|
||||||
|
self._pre_hooks["document"] = self._pre_document_hook
|
||||||
|
|
||||||
|
def execute_pre_hook(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
) -> Optional[MCPContext]:
|
||||||
|
"""Execute pre-tool hook to gather context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action being performed (e.g., "explain", "refactor")
|
||||||
|
params: Parameters for the action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCPContext to inject into prompt, or None
|
||||||
|
"""
|
||||||
|
hook = self._pre_hooks.get(action)
|
||||||
|
|
||||||
|
if not hook:
|
||||||
|
logger.debug(f"No pre-hook for action: {action}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return hook(params)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Pre-hook failed for {action}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def execute_post_hook(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
result: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Execute post-tool hook for proactive caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action that was performed
|
||||||
|
result: Result of the action
|
||||||
|
"""
|
||||||
|
hook = self._post_hooks.get(action)
|
||||||
|
|
||||||
|
if not hook:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
hook(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Post-hook failed for {action}: {e}")
|
||||||
|
|
||||||
|
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||||
|
"""Pre-hook for 'explain' action."""
|
||||||
|
symbol_name = params.get("symbol")
|
||||||
|
|
||||||
|
if not symbol_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.mcp_provider.build_context(
|
||||||
|
symbol_name=symbol_name,
|
||||||
|
context_type="symbol_explanation",
|
||||||
|
include_references=True,
|
||||||
|
include_related=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||||
|
"""Pre-hook for 'refactor' action."""
|
||||||
|
symbol_name = params.get("symbol")
|
||||||
|
|
||||||
|
if not symbol_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.mcp_provider.build_context(
|
||||||
|
symbol_name=symbol_name,
|
||||||
|
context_type="refactor_context",
|
||||||
|
include_references=True,
|
||||||
|
include_related=True,
|
||||||
|
max_references=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||||
|
"""Pre-hook for 'document' action."""
|
||||||
|
symbol_name = params.get("symbol")
|
||||||
|
file_path = params.get("file_path")
|
||||||
|
|
||||||
|
if symbol_name:
|
||||||
|
return self.mcp_provider.build_context(
|
||||||
|
symbol_name=symbol_name,
|
||||||
|
context_type="documentation_context",
|
||||||
|
include_references=False,
|
||||||
|
include_related=True,
|
||||||
|
)
|
||||||
|
elif file_path:
|
||||||
|
return self.mcp_provider.build_context_for_file(
|
||||||
|
Path(file_path),
|
||||||
|
context_type="file_documentation",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def register_pre_hook(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
|
||||||
|
) -> None:
|
||||||
|
"""Register a custom pre-tool hook."""
|
||||||
|
self._pre_hooks[action] = hook
|
||||||
|
|
||||||
|
def register_post_hook(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
hook: Callable[[Any], None],
|
||||||
|
) -> None:
|
||||||
|
"""Register a custom post-tool hook."""
|
||||||
|
self._post_hooks[action] = hook
|
||||||
|
|
||||||
|
|
||||||
|
def create_context_for_prompt(
|
||||||
|
mcp_provider: "MCPProvider",
|
||||||
|
action: str,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""Create context string for prompt injection.
|
||||||
|
|
||||||
|
This is the main entry point for Claude Code hook integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_provider: The MCP provider instance
|
||||||
|
action: Action being performed
|
||||||
|
params: Action parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string for prompt injection
|
||||||
|
"""
|
||||||
|
manager = HookManager(mcp_provider)
|
||||||
|
context = manager.execute_pre_hook(action, params)
|
||||||
|
|
||||||
|
if context:
|
||||||
|
return context.to_prompt_injection()
|
||||||
|
|
||||||
|
return ""
|
||||||
202
codex-lens/build/lib/codexlens/mcp/provider.py
Normal file
202
codex-lens/build/lib/codexlens/mcp/provider.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""MCP context provider."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, TYPE_CHECKING
|
||||||
|
|
||||||
|
from codexlens.mcp.schema import (
|
||||||
|
MCPContext,
|
||||||
|
SymbolInfo,
|
||||||
|
ReferenceInfo,
|
||||||
|
RelatedSymbol,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||||
|
from codexlens.storage.registry import RegistryStore
|
||||||
|
from codexlens.search.chain_search import ChainSearchEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPProvider:
|
||||||
|
"""Builds MCP context objects from codex-lens data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
global_index: "GlobalSymbolIndex",
|
||||||
|
search_engine: "ChainSearchEngine",
|
||||||
|
registry: "RegistryStore",
|
||||||
|
) -> None:
|
||||||
|
self.global_index = global_index
|
||||||
|
self.search_engine = search_engine
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
def build_context(
|
||||||
|
self,
|
||||||
|
symbol_name: str,
|
||||||
|
context_type: str = "symbol_explanation",
|
||||||
|
include_references: bool = True,
|
||||||
|
include_related: bool = True,
|
||||||
|
max_references: int = 10,
|
||||||
|
) -> Optional[MCPContext]:
|
||||||
|
"""Build comprehensive context for a symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_name: Name of the symbol to contextualize
|
||||||
|
context_type: Type of context being requested
|
||||||
|
include_references: Whether to include reference locations
|
||||||
|
include_related: Whether to include related symbols
|
||||||
|
max_references: Maximum number of references to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCPContext object or None if symbol not found
|
||||||
|
"""
|
||||||
|
# Look up symbol
|
||||||
|
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
|
||||||
|
|
||||||
|
if not symbols:
|
||||||
|
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
symbol = symbols[0]
|
||||||
|
|
||||||
|
# Build SymbolInfo
|
||||||
|
symbol_info = SymbolInfo(
|
||||||
|
name=symbol.name,
|
||||||
|
kind=symbol.kind,
|
||||||
|
file_path=symbol.file or "",
|
||||||
|
line_start=symbol.range[0],
|
||||||
|
line_end=symbol.range[1],
|
||||||
|
signature=None, # Symbol entity doesn't have signature
|
||||||
|
documentation=None, # Symbol entity doesn't have docstring
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract definition source code
|
||||||
|
definition = self._extract_definition(symbol)
|
||||||
|
|
||||||
|
# Get references
|
||||||
|
references = []
|
||||||
|
if include_references:
|
||||||
|
refs = self.search_engine.search_references(
|
||||||
|
symbol_name,
|
||||||
|
limit=max_references,
|
||||||
|
)
|
||||||
|
references = [
|
||||||
|
ReferenceInfo(
|
||||||
|
file_path=r.file_path,
|
||||||
|
line=r.line,
|
||||||
|
column=r.column,
|
||||||
|
context=r.context,
|
||||||
|
relationship_type=r.relationship_type,
|
||||||
|
)
|
||||||
|
for r in refs
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get related symbols
|
||||||
|
related_symbols = []
|
||||||
|
if include_related:
|
||||||
|
related_symbols = self._get_related_symbols(symbol)
|
||||||
|
|
||||||
|
return MCPContext(
|
||||||
|
context_type=context_type,
|
||||||
|
symbol=symbol_info,
|
||||||
|
definition=definition,
|
||||||
|
references=references,
|
||||||
|
related_symbols=related_symbols,
|
||||||
|
metadata={
|
||||||
|
"source": "codex-lens",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_definition(self, symbol) -> Optional[str]:
|
||||||
|
"""Extract source code for symbol definition."""
|
||||||
|
try:
|
||||||
|
file_path = Path(symbol.file) if symbol.file else None
|
||||||
|
if not file_path or not file_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = file_path.read_text(encoding='utf-8', errors='ignore')
|
||||||
|
lines = content.split("\n")
|
||||||
|
|
||||||
|
start = symbol.range[0] - 1
|
||||||
|
end = symbol.range[1]
|
||||||
|
|
||||||
|
if start >= len(lines):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return "\n".join(lines[start:end])
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to extract definition: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
|
||||||
|
"""Get symbols related to the given symbol."""
|
||||||
|
related = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Search for symbols that might be related by name patterns
|
||||||
|
# This is a simplified implementation - could be enhanced with relationship data
|
||||||
|
|
||||||
|
# Look for imports/callers via reference search
|
||||||
|
refs = self.search_engine.search_references(symbol.name, limit=20)
|
||||||
|
|
||||||
|
seen_names = set()
|
||||||
|
for ref in refs:
|
||||||
|
# Extract potential symbol name from context
|
||||||
|
if ref.relationship_type and ref.relationship_type not in seen_names:
|
||||||
|
related.append(RelatedSymbol(
|
||||||
|
name=f"{Path(ref.file_path).stem}",
|
||||||
|
kind="module",
|
||||||
|
relationship=ref.relationship_type,
|
||||||
|
file_path=ref.file_path,
|
||||||
|
))
|
||||||
|
seen_names.add(ref.relationship_type)
|
||||||
|
if len(related) >= 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get related symbols: {e}")
|
||||||
|
|
||||||
|
return related
|
||||||
|
|
||||||
|
def build_context_for_file(
|
||||||
|
self,
|
||||||
|
file_path: Path,
|
||||||
|
context_type: str = "file_overview",
|
||||||
|
) -> MCPContext:
|
||||||
|
"""Build context for an entire file."""
|
||||||
|
# Try to get symbols by searching with file path
|
||||||
|
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
|
||||||
|
symbols = []
|
||||||
|
|
||||||
|
# Search for common symbols that might be in this file
|
||||||
|
# This is a simplified approach - a full implementation would query by file path
|
||||||
|
try:
|
||||||
|
# Use the global index to search for symbols from this file
|
||||||
|
file_str = str(file_path.resolve())
|
||||||
|
# Get all symbols and filter by file path (not efficient but works)
|
||||||
|
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
|
||||||
|
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get file symbols: {e}")
|
||||||
|
|
||||||
|
related = [
|
||||||
|
RelatedSymbol(
|
||||||
|
name=s.name,
|
||||||
|
kind=s.kind,
|
||||||
|
relationship="defines",
|
||||||
|
)
|
||||||
|
for s in symbols
|
||||||
|
]
|
||||||
|
|
||||||
|
return MCPContext(
|
||||||
|
context_type=context_type,
|
||||||
|
related_symbols=related,
|
||||||
|
metadata={
|
||||||
|
"file_path": str(file_path),
|
||||||
|
"symbol_count": len(symbols),
|
||||||
|
},
|
||||||
|
)
|
||||||
113
codex-lens/build/lib/codexlens/mcp/schema.py
Normal file
113
codex-lens/build/lib/codexlens/mcp/schema.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""MCP data models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SymbolInfo:
|
||||||
|
"""Information about a code symbol."""
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
file_path: str
|
||||||
|
line_start: int
|
||||||
|
line_end: int
|
||||||
|
signature: Optional[str] = None
|
||||||
|
documentation: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReferenceInfo:
|
||||||
|
"""Information about a symbol reference."""
|
||||||
|
file_path: str
|
||||||
|
line: int
|
||||||
|
column: int
|
||||||
|
context: str
|
||||||
|
relationship_type: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RelatedSymbol:
|
||||||
|
"""Related symbol (import, call target, etc.)."""
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
relationship: str # "imports", "calls", "inherits", "uses"
|
||||||
|
file_path: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPContext:
|
||||||
|
"""Model Context Protocol context object.
|
||||||
|
|
||||||
|
This is the structured context that gets injected into
|
||||||
|
LLM prompts to provide code understanding.
|
||||||
|
"""
|
||||||
|
version: str = "1.0"
|
||||||
|
context_type: str = "code_context"
|
||||||
|
symbol: Optional[SymbolInfo] = None
|
||||||
|
definition: Optional[str] = None
|
||||||
|
references: List[ReferenceInfo] = field(default_factory=list)
|
||||||
|
related_symbols: List[RelatedSymbol] = field(default_factory=list)
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary for JSON serialization."""
|
||||||
|
result = {
|
||||||
|
"version": self.version,
|
||||||
|
"context_type": self.context_type,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.symbol:
|
||||||
|
result["symbol"] = self.symbol.to_dict()
|
||||||
|
if self.definition:
|
||||||
|
result["definition"] = self.definition
|
||||||
|
if self.references:
|
||||||
|
result["references"] = [r.to_dict() for r in self.references]
|
||||||
|
if self.related_symbols:
|
||||||
|
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def to_json(self, indent: int = 2) -> str:
|
||||||
|
"""Serialize to JSON string."""
|
||||||
|
return json.dumps(self.to_dict(), indent=indent)
|
||||||
|
|
||||||
|
def to_prompt_injection(self) -> str:
|
||||||
|
"""Format for injection into LLM prompt."""
|
||||||
|
parts = ["<code_context>"]
|
||||||
|
|
||||||
|
if self.symbol:
|
||||||
|
parts.append(f"## Symbol: {self.symbol.name}")
|
||||||
|
parts.append(f"Type: {self.symbol.kind}")
|
||||||
|
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
|
||||||
|
|
||||||
|
if self.definition:
|
||||||
|
parts.append("\n## Definition")
|
||||||
|
parts.append(f"```\n{self.definition}\n```")
|
||||||
|
|
||||||
|
if self.references:
|
||||||
|
parts.append(f"\n## References ({len(self.references)} found)")
|
||||||
|
for ref in self.references[:5]: # Limit to 5
|
||||||
|
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
|
||||||
|
parts.append(f" ```\n {ref.context}\n ```")
|
||||||
|
|
||||||
|
if self.related_symbols:
|
||||||
|
parts.append("\n## Related Symbols")
|
||||||
|
for sym in self.related_symbols[:10]: # Limit to 10
|
||||||
|
parts.append(f"- {sym.name} ({sym.relationship})")
|
||||||
|
|
||||||
|
parts.append("</code_context>")
|
||||||
|
return "\n".join(parts)
|
||||||
8
codex-lens/build/lib/codexlens/parsers/__init__.py
Normal file
8
codex-lens/build/lib/codexlens/parsers/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Parsers for CodexLens."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .factory import ParserFactory
|
||||||
|
|
||||||
|
__all__ = ["ParserFactory"]
|
||||||
|
|
||||||
202
codex-lens/build/lib/codexlens/parsers/encoding.py
Normal file
202
codex-lens/build/lib/codexlens/parsers/encoding.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""Optional encoding detection module for CodexLens.
|
||||||
|
|
||||||
|
Provides automatic encoding detection with graceful fallback to UTF-8.
|
||||||
|
Install with: pip install codexlens[encoding]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Feature flag for encoding detection availability
|
||||||
|
ENCODING_DETECTION_AVAILABLE = False
|
||||||
|
_import_error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_chardet_backend() -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Detect if chardet or charset-normalizer is available."""
|
||||||
|
try:
|
||||||
|
import chardet
|
||||||
|
return True, None
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from charset_normalizer import from_bytes
|
||||||
|
return True, None
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False, "chardet not available. Install with: pip install codexlens[encoding]"
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize on module load
|
||||||
|
ENCODING_DETECTION_AVAILABLE, _import_error = _detect_chardet_backend()
|
||||||
|
|
||||||
|
|
||||||
|
def check_encoding_available() -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Check if encoding detection dependencies are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (available, error_message)
|
||||||
|
"""
|
||||||
|
return ENCODING_DETECTION_AVAILABLE, _import_error
|
||||||
|
|
||||||
|
|
||||||
|
def detect_encoding(content_bytes: bytes, confidence_threshold: float = 0.7) -> str:
|
||||||
|
"""Detect encoding from file content bytes.
|
||||||
|
|
||||||
|
Uses chardet or charset-normalizer with configurable confidence threshold.
|
||||||
|
Falls back to UTF-8 if confidence is too low or detection unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_bytes: Raw file content as bytes
|
||||||
|
confidence_threshold: Minimum confidence (0.0-1.0) to accept detection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Detected encoding name (e.g., 'utf-8', 'iso-8859-1', 'gbk')
|
||||||
|
Returns 'utf-8' as fallback if detection fails or confidence too low
|
||||||
|
"""
|
||||||
|
if not ENCODING_DETECTION_AVAILABLE:
|
||||||
|
log.debug("Encoding detection not available, using UTF-8 fallback")
|
||||||
|
return "utf-8"
|
||||||
|
|
||||||
|
if not content_bytes:
|
||||||
|
return "utf-8"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try chardet first
|
||||||
|
try:
|
||||||
|
import chardet
|
||||||
|
result = chardet.detect(content_bytes)
|
||||||
|
encoding = result.get("encoding")
|
||||||
|
confidence = result.get("confidence", 0.0)
|
||||||
|
|
||||||
|
if encoding and confidence >= confidence_threshold:
|
||||||
|
log.debug(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
|
||||||
|
# Normalize encoding name: replace underscores with hyphens
|
||||||
|
return encoding.lower().replace('_', '-')
|
||||||
|
else:
|
||||||
|
log.debug(
|
||||||
|
f"Low confidence encoding detection: {encoding} "
|
||||||
|
f"(confidence: {confidence:.2f}), using UTF-8 fallback"
|
||||||
|
)
|
||||||
|
return "utf-8"
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback to charset-normalizer
|
||||||
|
try:
|
||||||
|
from charset_normalizer import from_bytes
|
||||||
|
results = from_bytes(content_bytes)
|
||||||
|
if results:
|
||||||
|
best = results.best()
|
||||||
|
if best and best.encoding:
|
||||||
|
log.debug(f"Detected encoding via charset-normalizer: {best.encoding}")
|
||||||
|
# Normalize encoding name: replace underscores with hyphens
|
||||||
|
return best.encoding.lower().replace('_', '-')
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Encoding detection failed: {e}, using UTF-8 fallback")
|
||||||
|
|
||||||
|
return "utf-8"
|
||||||
|
|
||||||
|
|
||||||
|
def read_file_safe(
|
||||||
|
path: Path | str,
|
||||||
|
confidence_threshold: float = 0.7,
|
||||||
|
max_detection_bytes: int = 100_000
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""Read file with automatic encoding detection and safe decoding.
|
||||||
|
|
||||||
|
Reads file bytes, detects encoding, and decodes with error replacement
|
||||||
|
to preserve file structure even with encoding issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to file to read
|
||||||
|
confidence_threshold: Minimum confidence for encoding detection
|
||||||
|
max_detection_bytes: Maximum bytes to use for encoding detection (default 100KB)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (content, detected_encoding)
|
||||||
|
- content: Decoded file content (with <20> for unmappable bytes)
|
||||||
|
- detected_encoding: Detected encoding name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If file cannot be read
|
||||||
|
IsADirectoryError: If path is a directory
|
||||||
|
"""
|
||||||
|
file_path = Path(path) if isinstance(path, str) else path
|
||||||
|
|
||||||
|
# Read file bytes
|
||||||
|
try:
|
||||||
|
content_bytes = file_path.read_bytes()
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to read file {file_path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Detect encoding from first N bytes for performance
|
||||||
|
detection_sample = content_bytes[:max_detection_bytes] if len(content_bytes) > max_detection_bytes else content_bytes
|
||||||
|
encoding = detect_encoding(detection_sample, confidence_threshold)
|
||||||
|
|
||||||
|
# Decode with error replacement to preserve structure
|
||||||
|
try:
|
||||||
|
content = content_bytes.decode(encoding, errors='replace')
|
||||||
|
log.debug(f"Successfully decoded {file_path} using {encoding}")
|
||||||
|
return content, encoding
|
||||||
|
except Exception as e:
|
||||||
|
# Final fallback to UTF-8 with replacement
|
||||||
|
log.warning(f"Failed to decode {file_path} with {encoding}, using UTF-8: {e}")
|
||||||
|
content = content_bytes.decode('utf-8', errors='replace')
|
||||||
|
return content, 'utf-8'
|
||||||
|
|
||||||
|
|
||||||
|
def is_binary_file(path: Path | str, sample_size: int = 8192) -> bool:
|
||||||
|
"""Check if file is likely binary by sampling first bytes.
|
||||||
|
|
||||||
|
Uses heuristic: if >30% of sample bytes are null or non-text, consider binary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to file to check
|
||||||
|
sample_size: Number of bytes to sample (default 8KB)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file appears to be binary, False otherwise
|
||||||
|
"""
|
||||||
|
file_path = Path(path) if isinstance(path, str) else path
|
||||||
|
|
||||||
|
try:
|
||||||
|
with file_path.open('rb') as f:
|
||||||
|
sample = f.read(sample_size)
|
||||||
|
|
||||||
|
if not sample:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Count null bytes and non-printable characters
|
||||||
|
null_count = sample.count(b'\x00')
|
||||||
|
non_text_count = sum(1 for byte in sample if byte < 0x20 and byte not in (0x09, 0x0a, 0x0d))
|
||||||
|
|
||||||
|
# If >30% null bytes or >50% non-text, consider binary
|
||||||
|
null_ratio = null_count / len(sample)
|
||||||
|
non_text_ratio = non_text_count / len(sample)
|
||||||
|
|
||||||
|
return null_ratio > 0.3 or non_text_ratio > 0.5
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Binary check failed for {file_path}: {e}, assuming text")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ENCODING_DETECTION_AVAILABLE",
|
||||||
|
"check_encoding_available",
|
||||||
|
"detect_encoding",
|
||||||
|
"read_file_safe",
|
||||||
|
"is_binary_file",
|
||||||
|
]
|
||||||
385
codex-lens/build/lib/codexlens/parsers/factory.py
Normal file
385
codex-lens/build/lib/codexlens/parsers/factory.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Parser factory for CodexLens.
|
||||||
|
|
||||||
|
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
|
||||||
|
available. Regex fallbacks are retained to preserve the existing parser
|
||||||
|
interface and behavior in minimal environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
from codexlens.config import Config
|
||||||
|
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||||
|
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||||
|
|
||||||
|
|
||||||
|
class Parser(Protocol):
|
||||||
|
def parse(self, text: str, path: Path) -> IndexedFile: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimpleRegexParser:
|
||||||
|
language_id: str
|
||||||
|
|
||||||
|
def parse(self, text: str, path: Path) -> IndexedFile:
|
||||||
|
# Try tree-sitter first for supported languages
|
||||||
|
if self.language_id in {"python", "javascript", "typescript"}:
|
||||||
|
ts_parser = TreeSitterSymbolParser(self.language_id, path)
|
||||||
|
if ts_parser.is_available():
|
||||||
|
indexed = ts_parser.parse(text, path)
|
||||||
|
if indexed is not None:
|
||||||
|
return indexed
|
||||||
|
|
||||||
|
# Fallback to regex parsing
|
||||||
|
if self.language_id == "python":
|
||||||
|
symbols = _parse_python_symbols_regex(text)
|
||||||
|
relationships = _parse_python_relationships_regex(text, path)
|
||||||
|
elif self.language_id in {"javascript", "typescript"}:
|
||||||
|
symbols = _parse_js_ts_symbols_regex(text)
|
||||||
|
relationships = _parse_js_ts_relationships_regex(text, path)
|
||||||
|
elif self.language_id == "java":
|
||||||
|
symbols = _parse_java_symbols(text)
|
||||||
|
relationships = []
|
||||||
|
elif self.language_id == "go":
|
||||||
|
symbols = _parse_go_symbols(text)
|
||||||
|
relationships = []
|
||||||
|
elif self.language_id == "markdown":
|
||||||
|
symbols = _parse_markdown_symbols(text)
|
||||||
|
relationships = []
|
||||||
|
elif self.language_id == "text":
|
||||||
|
symbols = _parse_text_symbols(text)
|
||||||
|
relationships = []
|
||||||
|
else:
|
||||||
|
symbols = _parse_generic_symbols(text)
|
||||||
|
relationships = []
|
||||||
|
|
||||||
|
return IndexedFile(
|
||||||
|
path=str(path.resolve()),
|
||||||
|
language=self.language_id,
|
||||||
|
symbols=symbols,
|
||||||
|
chunks=[],
|
||||||
|
relationships=relationships,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ParserFactory:
|
||||||
|
def __init__(self, config: Config) -> None:
|
||||||
|
self.config = config
|
||||||
|
self._parsers: Dict[str, Parser] = {}
|
||||||
|
|
||||||
|
def get_parser(self, language_id: str) -> Parser:
|
||||||
|
if language_id not in self._parsers:
|
||||||
|
self._parsers[language_id] = SimpleRegexParser(language_id)
|
||||||
|
return self._parsers[language_id]
|
||||||
|
|
||||||
|
|
||||||
|
# Regex-based fallback parsers
|
||||||
|
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
|
||||||
|
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
|
||||||
|
|
||||||
|
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
|
||||||
|
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_python_symbols(text: str) -> List[Symbol]:
|
||||||
|
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
|
||||||
|
ts_parser = TreeSitterSymbolParser("python")
|
||||||
|
if ts_parser.is_available():
|
||||||
|
symbols = ts_parser.parse_symbols(text)
|
||||||
|
if symbols is not None:
|
||||||
|
return symbols
|
||||||
|
return _parse_python_symbols_regex(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_js_ts_symbols(
|
||||||
|
text: str,
|
||||||
|
language_id: str = "javascript",
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> List[Symbol]:
|
||||||
|
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
|
||||||
|
ts_parser = TreeSitterSymbolParser(language_id, path)
|
||||||
|
if ts_parser.is_available():
|
||||||
|
symbols = ts_parser.parse_symbols(text)
|
||||||
|
if symbols is not None:
|
||||||
|
return symbols
|
||||||
|
return _parse_js_ts_symbols_regex(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
current_class_indent: Optional[int] = None
|
||||||
|
for i, line in enumerate(text.splitlines(), start=1):
|
||||||
|
class_match = _PY_CLASS_RE.match(line)
|
||||||
|
if class_match:
|
||||||
|
current_class_indent = len(line) - len(line.lstrip(" "))
|
||||||
|
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||||
|
continue
|
||||||
|
def_match = _PY_DEF_RE.match(line)
|
||||||
|
if def_match:
|
||||||
|
indent = len(line) - len(line.lstrip(" "))
|
||||||
|
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
|
||||||
|
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
|
||||||
|
continue
|
||||||
|
if current_class_indent is not None:
|
||||||
|
indent = len(line) - len(line.lstrip(" "))
|
||||||
|
if line.strip() and indent <= current_class_indent:
|
||||||
|
current_class_indent = None
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||||
|
relationships: List[CodeRelationship] = []
|
||||||
|
current_scope: str | None = None
|
||||||
|
source_file = str(path.resolve())
|
||||||
|
|
||||||
|
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||||
|
class_match = _PY_CLASS_RE.match(line)
|
||||||
|
if class_match:
|
||||||
|
current_scope = class_match.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
def_match = _PY_DEF_RE.match(line)
|
||||||
|
if def_match:
|
||||||
|
current_scope = def_match.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_scope is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
import_match = _PY_IMPORT_RE.search(line)
|
||||||
|
if import_match:
|
||||||
|
import_target = import_match.group(1) or import_match.group(2)
|
||||||
|
if import_target:
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=current_scope,
|
||||||
|
target_symbol=import_target.strip(),
|
||||||
|
relationship_type=RelationshipType.IMPORTS,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=line_num,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for call_match in _PY_CALL_RE.finditer(line):
|
||||||
|
call_name = call_match.group(1)
|
||||||
|
if call_name in {
|
||||||
|
"if",
|
||||||
|
"for",
|
||||||
|
"while",
|
||||||
|
"return",
|
||||||
|
"print",
|
||||||
|
"len",
|
||||||
|
"str",
|
||||||
|
"int",
|
||||||
|
"float",
|
||||||
|
"list",
|
||||||
|
"dict",
|
||||||
|
"set",
|
||||||
|
"tuple",
|
||||||
|
current_scope,
|
||||||
|
}:
|
||||||
|
continue
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=current_scope,
|
||||||
|
target_symbol=call_name,
|
||||||
|
relationship_type=RelationshipType.CALL,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=line_num,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return relationships
|
||||||
|
|
||||||
|
|
||||||
|
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
|
||||||
|
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
|
||||||
|
_JS_ARROW_RE = re.compile(
|
||||||
|
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
|
||||||
|
)
|
||||||
|
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
|
||||||
|
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
|
||||||
|
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
in_class = False
|
||||||
|
class_brace_depth = 0
|
||||||
|
brace_depth = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(text.splitlines(), start=1):
|
||||||
|
brace_depth += line.count("{") - line.count("}")
|
||||||
|
|
||||||
|
class_match = _JS_CLASS_RE.match(line)
|
||||||
|
if class_match:
|
||||||
|
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||||
|
in_class = True
|
||||||
|
class_brace_depth = brace_depth
|
||||||
|
continue
|
||||||
|
|
||||||
|
if in_class and brace_depth < class_brace_depth:
|
||||||
|
in_class = False
|
||||||
|
|
||||||
|
func_match = _JS_FUNC_RE.match(line)
|
||||||
|
if func_match:
|
||||||
|
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
arrow_match = _JS_ARROW_RE.match(line)
|
||||||
|
if arrow_match:
|
||||||
|
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if in_class:
|
||||||
|
method_match = _JS_METHOD_RE.match(line)
|
||||||
|
if method_match:
|
||||||
|
name = method_match.group(1)
|
||||||
|
if name != "constructor":
|
||||||
|
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||||
|
relationships: List[CodeRelationship] = []
|
||||||
|
current_scope: str | None = None
|
||||||
|
source_file = str(path.resolve())
|
||||||
|
|
||||||
|
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||||
|
class_match = _JS_CLASS_RE.match(line)
|
||||||
|
if class_match:
|
||||||
|
current_scope = class_match.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
func_match = _JS_FUNC_RE.match(line)
|
||||||
|
if func_match:
|
||||||
|
current_scope = func_match.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
arrow_match = _JS_ARROW_RE.match(line)
|
||||||
|
if arrow_match:
|
||||||
|
current_scope = arrow_match.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_scope is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
import_match = _JS_IMPORT_RE.search(line)
|
||||||
|
if import_match:
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=current_scope,
|
||||||
|
target_symbol=import_match.group(1),
|
||||||
|
relationship_type=RelationshipType.IMPORTS,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=line_num,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for call_match in _JS_CALL_RE.finditer(line):
|
||||||
|
call_name = call_match.group(1)
|
||||||
|
if call_name in {current_scope}:
|
||||||
|
continue
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=current_scope,
|
||||||
|
target_symbol=call_name,
|
||||||
|
relationship_type=RelationshipType.CALL,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=line_num,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return relationships
|
||||||
|
|
||||||
|
|
||||||
|
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
|
||||||
|
_JAVA_METHOD_RE = re.compile(
|
||||||
|
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_java_symbols(text: str) -> List[Symbol]:
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
for i, line in enumerate(text.splitlines(), start=1):
|
||||||
|
class_match = _JAVA_CLASS_RE.match(line)
|
||||||
|
if class_match:
|
||||||
|
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||||
|
continue
|
||||||
|
method_match = _JAVA_METHOD_RE.match(line)
|
||||||
|
if method_match:
|
||||||
|
symbols.append(Symbol(name=method_match.group(1), kind="method", range=(i, i)))
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
_GO_FUNC_RE = re.compile(r"^\s*func\s+(?:\([^)]+\)\s+)?([A-Za-z_]\w*)\s*\(")
|
||||||
|
_GO_TYPE_RE = re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+(?:struct|interface)\b")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_go_symbols(text: str) -> List[Symbol]:
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
for i, line in enumerate(text.splitlines(), start=1):
|
||||||
|
type_match = _GO_TYPE_RE.match(line)
|
||||||
|
if type_match:
|
||||||
|
symbols.append(Symbol(name=type_match.group(1), kind="class", range=(i, i)))
|
||||||
|
continue
|
||||||
|
func_match = _GO_FUNC_RE.match(line)
|
||||||
|
if func_match:
|
||||||
|
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
_GENERIC_DEF_RE = re.compile(r"^\s*(?:def|function|func)\s+([A-Za-z_]\w*)\b")
|
||||||
|
_GENERIC_CLASS_RE = re.compile(r"^\s*(?:class|struct|interface)\s+([A-Za-z_]\w*)\b")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_generic_symbols(text: str) -> List[Symbol]:
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
for i, line in enumerate(text.splitlines(), start=1):
|
||||||
|
class_match = _GENERIC_CLASS_RE.match(line)
|
||||||
|
if class_match:
|
||||||
|
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||||
|
continue
|
||||||
|
def_match = _GENERIC_DEF_RE.match(line)
|
||||||
|
if def_match:
|
||||||
|
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
# Markdown heading regex: # Heading, ## Heading, etc.
|
||||||
|
_MD_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_markdown_symbols(text: str) -> List[Symbol]:
|
||||||
|
"""Parse Markdown headings as symbols.
|
||||||
|
|
||||||
|
Extracts # headings as 'section' symbols with heading level as kind suffix.
|
||||||
|
"""
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
for i, line in enumerate(text.splitlines(), start=1):
|
||||||
|
heading_match = _MD_HEADING_RE.match(line)
|
||||||
|
if heading_match:
|
||||||
|
level = len(heading_match.group(1))
|
||||||
|
title = heading_match.group(2).strip()
|
||||||
|
# Use 'section' kind with level indicator
|
||||||
|
kind = f"h{level}"
|
||||||
|
symbols.append(Symbol(name=title, kind=kind, range=(i, i)))
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_text_symbols(text: str) -> List[Symbol]:
|
||||||
|
"""Parse plain text files - no symbols, just index content."""
|
||||||
|
# Text files don't have structured symbols, return empty list
|
||||||
|
# The file content will still be indexed for FTS search
|
||||||
|
return []
|
||||||
98
codex-lens/build/lib/codexlens/parsers/tokenizer.py
Normal file
98
codex-lens/build/lib/codexlens/parsers/tokenizer.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Token counting utilities for CodexLens.
|
||||||
|
|
||||||
|
Provides accurate token counting using tiktoken with character count fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
TIKTOKEN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TIKTOKEN_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""Token counter with tiktoken primary and character count fallback."""
|
||||||
|
|
||||||
|
def __init__(self, encoding_name: str = "cl100k_base") -> None:
|
||||||
|
"""Initialize tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
|
||||||
|
"""
|
||||||
|
self._encoding: Optional[object] = None
|
||||||
|
self._encoding_name = encoding_name
|
||||||
|
|
||||||
|
if TIKTOKEN_AVAILABLE:
|
||||||
|
try:
|
||||||
|
self._encoding = tiktoken.get_encoding(encoding_name)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to character counting if encoding fails
|
||||||
|
self._encoding = None
|
||||||
|
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
"""Count tokens in text.
|
||||||
|
|
||||||
|
Uses tiktoken if available, otherwise falls back to character count / 4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if self._encoding is not None:
|
||||||
|
try:
|
||||||
|
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
# Fall through to character count fallback
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: rough estimate using character count
|
||||||
|
# Average of ~4 characters per token for English text
|
||||||
|
return max(1, len(text) // 4)
|
||||||
|
|
||||||
|
def is_using_tiktoken(self) -> bool:
|
||||||
|
"""Check if tiktoken is being used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tiktoken is available and initialized
|
||||||
|
"""
|
||||||
|
return self._encoding is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Global default tokenizer instance
|
||||||
|
_default_tokenizer: Optional[Tokenizer] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_tokenizer() -> Tokenizer:
|
||||||
|
"""Get the global default tokenizer instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Shared Tokenizer instance
|
||||||
|
"""
|
||||||
|
global _default_tokenizer
|
||||||
|
if _default_tokenizer is None:
|
||||||
|
_default_tokenizer = Tokenizer()
|
||||||
|
return _default_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
|
||||||
|
"""Count tokens in text using default or provided tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
tokenizer: Optional tokenizer instance (uses default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = get_default_tokenizer()
|
||||||
|
return tokenizer.count_tokens(text)
|
||||||
809
codex-lens/build/lib/codexlens/parsers/treesitter_parser.py
Normal file
809
codex-lens/build/lib/codexlens/parsers/treesitter_parser.py
Normal file
@@ -0,0 +1,809 @@
|
|||||||
|
"""Tree-sitter based parser for CodexLens.
|
||||||
|
|
||||||
|
Provides precise AST-level parsing via tree-sitter.
|
||||||
|
|
||||||
|
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
|
||||||
|
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
|
||||||
|
return `None`; callers should use a regex-based fallback such as
|
||||||
|
`codexlens.parsers.factory.SimpleRegexParser`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tree_sitter import Language as TreeSitterLanguage
|
||||||
|
from tree_sitter import Node as TreeSitterNode
|
||||||
|
from tree_sitter import Parser as TreeSitterParser
|
||||||
|
TREE_SITTER_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TreeSitterLanguage = None # type: ignore[assignment]
|
||||||
|
TreeSitterNode = None # type: ignore[assignment]
|
||||||
|
TreeSitterParser = None # type: ignore[assignment]
|
||||||
|
TREE_SITTER_AVAILABLE = False
|
||||||
|
|
||||||
|
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||||
|
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TreeSitterSymbolParser:
|
||||||
|
"""Parser using tree-sitter for AST-level symbol extraction."""
|
||||||
|
|
||||||
|
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
|
||||||
|
"""Initialize tree-sitter parser for a language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||||
|
path: Optional file path for language variant detection (e.g., .tsx)
|
||||||
|
"""
|
||||||
|
self.language_id = language_id
|
||||||
|
self.path = path
|
||||||
|
self._parser: Optional[object] = None
|
||||||
|
self._language: Optional[TreeSitterLanguage] = None
|
||||||
|
self._tokenizer = get_default_tokenizer()
|
||||||
|
|
||||||
|
if TREE_SITTER_AVAILABLE:
|
||||||
|
self._initialize_parser()
|
||||||
|
|
||||||
|
def _initialize_parser(self) -> None:
|
||||||
|
"""Initialize tree-sitter parser and language."""
|
||||||
|
if TreeSitterParser is None or TreeSitterLanguage is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load language grammar
|
||||||
|
if self.language_id == "python":
|
||||||
|
import tree_sitter_python
|
||||||
|
self._language = TreeSitterLanguage(tree_sitter_python.language())
|
||||||
|
elif self.language_id == "javascript":
|
||||||
|
import tree_sitter_javascript
|
||||||
|
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
|
||||||
|
elif self.language_id == "typescript":
|
||||||
|
import tree_sitter_typescript
|
||||||
|
# Detect TSX files by extension
|
||||||
|
if self.path is not None and self.path.suffix.lower() == ".tsx":
|
||||||
|
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
|
||||||
|
else:
|
||||||
|
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create parser
|
||||||
|
self._parser = TreeSitterParser()
|
||||||
|
if hasattr(self._parser, "set_language"):
|
||||||
|
self._parser.set_language(self._language) # type: ignore[attr-defined]
|
||||||
|
else:
|
||||||
|
self._parser.language = self._language # type: ignore[assignment]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Gracefully handle missing language bindings
|
||||||
|
self._parser = None
|
||||||
|
self._language = None
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""Check if tree-sitter parser is available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if parser is initialized and ready
|
||||||
|
"""
|
||||||
|
return self._parser is not None and self._language is not None
|
||||||
|
|
||||||
|
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
|
||||||
|
if not self.is_available() or self._parser is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
source_bytes = text.encode("utf8")
|
||||||
|
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||||
|
return source_bytes, tree.root_node
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
|
||||||
|
"""Parse source code and extract symbols without creating IndexedFile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Source code text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of symbols if parsing succeeds, None if tree-sitter unavailable
|
||||||
|
"""
|
||||||
|
parsed = self._parse_tree(text)
|
||||||
|
if parsed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
source_bytes, root = parsed
|
||||||
|
try:
|
||||||
|
return self._extract_symbols(source_bytes, root)
|
||||||
|
except Exception:
|
||||||
|
# Gracefully handle extraction errors
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
|
||||||
|
"""Parse source code and extract symbols.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Source code text
|
||||||
|
path: File path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndexedFile if parsing succeeds, None if tree-sitter unavailable
|
||||||
|
"""
|
||||||
|
parsed = self._parse_tree(text)
|
||||||
|
if parsed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
source_bytes, root = parsed
|
||||||
|
try:
|
||||||
|
symbols = self._extract_symbols(source_bytes, root)
|
||||||
|
relationships = self._extract_relationships(source_bytes, root, path)
|
||||||
|
|
||||||
|
return IndexedFile(
|
||||||
|
path=str(path.resolve()),
|
||||||
|
language=self.language_id,
|
||||||
|
symbols=symbols,
|
||||||
|
chunks=[],
|
||||||
|
relationships=relationships,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Gracefully handle parsing errors
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||||
|
"""Extract symbols from AST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_bytes: Source code as bytes
|
||||||
|
root: Root AST node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of extracted symbols
|
||||||
|
"""
|
||||||
|
if self.language_id == "python":
|
||||||
|
return self._extract_python_symbols(source_bytes, root)
|
||||||
|
elif self.language_id in {"javascript", "typescript"}:
|
||||||
|
return self._extract_js_ts_symbols(source_bytes, root)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _extract_relationships(
|
||||||
|
self,
|
||||||
|
source_bytes: bytes,
|
||||||
|
root: TreeSitterNode,
|
||||||
|
path: Path,
|
||||||
|
) -> List[CodeRelationship]:
|
||||||
|
if self.language_id == "python":
|
||||||
|
return self._extract_python_relationships(source_bytes, root, path)
|
||||||
|
if self.language_id in {"javascript", "typescript"}:
|
||||||
|
return self._extract_js_ts_relationships(source_bytes, root, path)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _extract_python_relationships(
|
||||||
|
self,
|
||||||
|
source_bytes: bytes,
|
||||||
|
root: TreeSitterNode,
|
||||||
|
path: Path,
|
||||||
|
) -> List[CodeRelationship]:
|
||||||
|
source_file = str(path.resolve())
|
||||||
|
relationships: List[CodeRelationship] = []
|
||||||
|
|
||||||
|
scope_stack: List[str] = []
|
||||||
|
alias_stack: List[Dict[str, str]] = [{}]
|
||||||
|
|
||||||
|
def record_import(target_symbol: str, source_line: int) -> None:
|
||||||
|
if not target_symbol.strip() or not scope_stack:
|
||||||
|
return
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=scope_stack[-1],
|
||||||
|
target_symbol=target_symbol,
|
||||||
|
relationship_type=RelationshipType.IMPORTS,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=source_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_call(target_symbol: str, source_line: int) -> None:
|
||||||
|
if not target_symbol.strip() or not scope_stack:
|
||||||
|
return
|
||||||
|
base = target_symbol.split(".", 1)[0]
|
||||||
|
if base in {"self", "cls"}:
|
||||||
|
return
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=scope_stack[-1],
|
||||||
|
target_symbol=target_symbol,
|
||||||
|
relationship_type=RelationshipType.CALL,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=source_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||||
|
if not target_symbol.strip() or not scope_stack:
|
||||||
|
return
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=scope_stack[-1],
|
||||||
|
target_symbol=target_symbol,
|
||||||
|
relationship_type=RelationshipType.INHERITS,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=source_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit(node: TreeSitterNode) -> None:
|
||||||
|
pushed_scope = False
|
||||||
|
pushed_aliases = False
|
||||||
|
|
||||||
|
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is not None:
|
||||||
|
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if scope_name:
|
||||||
|
scope_stack.append(scope_name)
|
||||||
|
pushed_scope = True
|
||||||
|
alias_stack.append(dict(alias_stack[-1]))
|
||||||
|
pushed_aliases = True
|
||||||
|
|
||||||
|
if node.type == "class_definition" and pushed_scope:
|
||||||
|
superclasses = node.child_by_field_name("superclasses")
|
||||||
|
if superclasses is not None:
|
||||||
|
for child in superclasses.children:
|
||||||
|
dotted = self._python_expression_to_dotted(source_bytes, child)
|
||||||
|
if not dotted:
|
||||||
|
continue
|
||||||
|
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||||
|
record_inherits(resolved, self._node_start_line(node))
|
||||||
|
|
||||||
|
if node.type in {"import_statement", "import_from_statement"}:
|
||||||
|
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
|
||||||
|
if updates:
|
||||||
|
alias_stack[-1].update(updates)
|
||||||
|
for target_symbol in imported_targets:
|
||||||
|
record_import(target_symbol, self._node_start_line(node))
|
||||||
|
|
||||||
|
if node.type == "call":
|
||||||
|
fn_node = node.child_by_field_name("function")
|
||||||
|
if fn_node is not None:
|
||||||
|
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
|
||||||
|
if dotted:
|
||||||
|
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||||
|
record_call(resolved, self._node_start_line(node))
|
||||||
|
|
||||||
|
for child in node.children:
|
||||||
|
visit(child)
|
||||||
|
|
||||||
|
if pushed_aliases:
|
||||||
|
alias_stack.pop()
|
||||||
|
if pushed_scope:
|
||||||
|
scope_stack.pop()
|
||||||
|
|
||||||
|
visit(root)
|
||||||
|
return relationships
|
||||||
|
|
||||||
|
def _extract_js_ts_relationships(
|
||||||
|
self,
|
||||||
|
source_bytes: bytes,
|
||||||
|
root: TreeSitterNode,
|
||||||
|
path: Path,
|
||||||
|
) -> List[CodeRelationship]:
|
||||||
|
source_file = str(path.resolve())
|
||||||
|
relationships: List[CodeRelationship] = []
|
||||||
|
|
||||||
|
scope_stack: List[str] = []
|
||||||
|
alias_stack: List[Dict[str, str]] = [{}]
|
||||||
|
|
||||||
|
def record_import(target_symbol: str, source_line: int) -> None:
|
||||||
|
if not target_symbol.strip() or not scope_stack:
|
||||||
|
return
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=scope_stack[-1],
|
||||||
|
target_symbol=target_symbol,
|
||||||
|
relationship_type=RelationshipType.IMPORTS,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=source_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_call(target_symbol: str, source_line: int) -> None:
|
||||||
|
if not target_symbol.strip() or not scope_stack:
|
||||||
|
return
|
||||||
|
base = target_symbol.split(".", 1)[0]
|
||||||
|
if base in {"this", "super"}:
|
||||||
|
return
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=scope_stack[-1],
|
||||||
|
target_symbol=target_symbol,
|
||||||
|
relationship_type=RelationshipType.CALL,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=source_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||||
|
if not target_symbol.strip() or not scope_stack:
|
||||||
|
return
|
||||||
|
relationships.append(
|
||||||
|
CodeRelationship(
|
||||||
|
source_symbol=scope_stack[-1],
|
||||||
|
target_symbol=target_symbol,
|
||||||
|
relationship_type=RelationshipType.INHERITS,
|
||||||
|
source_file=source_file,
|
||||||
|
target_file=None,
|
||||||
|
source_line=source_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def visit(node: TreeSitterNode) -> None:
|
||||||
|
pushed_scope = False
|
||||||
|
pushed_aliases = False
|
||||||
|
|
||||||
|
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is not None:
|
||||||
|
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if scope_name:
|
||||||
|
scope_stack.append(scope_name)
|
||||||
|
pushed_scope = True
|
||||||
|
alias_stack.append(dict(alias_stack[-1]))
|
||||||
|
pushed_aliases = True
|
||||||
|
|
||||||
|
if node.type in {"class_declaration", "class"}:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is not None:
|
||||||
|
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if scope_name:
|
||||||
|
scope_stack.append(scope_name)
|
||||||
|
pushed_scope = True
|
||||||
|
alias_stack.append(dict(alias_stack[-1]))
|
||||||
|
pushed_aliases = True
|
||||||
|
|
||||||
|
if pushed_scope:
|
||||||
|
superclass = node.child_by_field_name("superclass")
|
||||||
|
if superclass is not None:
|
||||||
|
dotted = self._js_expression_to_dotted(source_bytes, superclass)
|
||||||
|
if dotted:
|
||||||
|
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||||
|
record_inherits(resolved, self._node_start_line(node))
|
||||||
|
|
||||||
|
if node.type == "variable_declarator":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
value_node = node.child_by_field_name("value")
|
||||||
|
if (
|
||||||
|
name_node is not None
|
||||||
|
and value_node is not None
|
||||||
|
and name_node.type in {"identifier", "property_identifier"}
|
||||||
|
and value_node.type == "arrow_function"
|
||||||
|
):
|
||||||
|
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if scope_name:
|
||||||
|
scope_stack.append(scope_name)
|
||||||
|
pushed_scope = True
|
||||||
|
alias_stack.append(dict(alias_stack[-1]))
|
||||||
|
pushed_aliases = True
|
||||||
|
|
||||||
|
if node.type == "method_definition" and self._has_class_ancestor(node):
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is not None:
|
||||||
|
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if scope_name and scope_name != "constructor":
|
||||||
|
scope_stack.append(scope_name)
|
||||||
|
pushed_scope = True
|
||||||
|
alias_stack.append(dict(alias_stack[-1]))
|
||||||
|
pushed_aliases = True
|
||||||
|
|
||||||
|
if node.type in {"import_declaration", "import_statement"}:
|
||||||
|
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
|
||||||
|
if updates:
|
||||||
|
alias_stack[-1].update(updates)
|
||||||
|
for target_symbol in imported_targets:
|
||||||
|
record_import(target_symbol, self._node_start_line(node))
|
||||||
|
|
||||||
|
# Best-effort support for CommonJS require() imports:
|
||||||
|
# const fs = require("fs")
|
||||||
|
if node.type == "variable_declarator":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
value_node = node.child_by_field_name("value")
|
||||||
|
if (
|
||||||
|
name_node is not None
|
||||||
|
and value_node is not None
|
||||||
|
and name_node.type == "identifier"
|
||||||
|
and value_node.type == "call_expression"
|
||||||
|
):
|
||||||
|
callee = value_node.child_by_field_name("function")
|
||||||
|
args = value_node.child_by_field_name("arguments")
|
||||||
|
if (
|
||||||
|
callee is not None
|
||||||
|
and self._node_text(source_bytes, callee).strip() == "require"
|
||||||
|
and args is not None
|
||||||
|
):
|
||||||
|
module_name = self._js_first_string_argument(source_bytes, args)
|
||||||
|
if module_name:
|
||||||
|
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
|
||||||
|
record_import(module_name, self._node_start_line(node))
|
||||||
|
|
||||||
|
if node.type == "call_expression":
|
||||||
|
fn_node = node.child_by_field_name("function")
|
||||||
|
if fn_node is not None:
|
||||||
|
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
|
||||||
|
if dotted:
|
||||||
|
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||||
|
record_call(resolved, self._node_start_line(node))
|
||||||
|
|
||||||
|
for child in node.children:
|
||||||
|
visit(child)
|
||||||
|
|
||||||
|
if pushed_aliases:
|
||||||
|
alias_stack.pop()
|
||||||
|
if pushed_scope:
|
||||||
|
scope_stack.pop()
|
||||||
|
|
||||||
|
visit(root)
|
||||||
|
return relationships
|
||||||
|
|
||||||
|
def _node_start_line(self, node: TreeSitterNode) -> int:
|
||||||
|
return node.start_point[0] + 1
|
||||||
|
|
||||||
|
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
|
||||||
|
dotted = (dotted or "").strip()
|
||||||
|
if not dotted:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
base, sep, rest = dotted.partition(".")
|
||||||
|
resolved_base = aliases.get(base, base)
|
||||||
|
if not rest:
|
||||||
|
return resolved_base
|
||||||
|
if resolved_base and rest:
|
||||||
|
return f"{resolved_base}.{rest}"
|
||||||
|
return resolved_base
|
||||||
|
|
||||||
|
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||||
|
if node.type in {"identifier", "dotted_name"}:
|
||||||
|
return self._node_text(source_bytes, node).strip()
|
||||||
|
if node.type == "attribute":
|
||||||
|
obj = node.child_by_field_name("object")
|
||||||
|
attr = node.child_by_field_name("attribute")
|
||||||
|
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||||
|
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
|
||||||
|
if obj_text and attr_text:
|
||||||
|
return f"{obj_text}.{attr_text}"
|
||||||
|
return obj_text or attr_text
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _python_import_aliases_and_targets(
|
||||||
|
self,
|
||||||
|
source_bytes: bytes,
|
||||||
|
node: TreeSitterNode,
|
||||||
|
) -> tuple[Dict[str, str], List[str]]:
|
||||||
|
aliases: Dict[str, str] = {}
|
||||||
|
targets: List[str] = []
|
||||||
|
|
||||||
|
if node.type == "import_statement":
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == "aliased_import":
|
||||||
|
name_node = child.child_by_field_name("name")
|
||||||
|
alias_node = child.child_by_field_name("alias")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
module_name = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if not module_name:
|
||||||
|
continue
|
||||||
|
bound_name = (
|
||||||
|
self._node_text(source_bytes, alias_node).strip()
|
||||||
|
if alias_node is not None
|
||||||
|
else module_name.split(".", 1)[0]
|
||||||
|
)
|
||||||
|
if bound_name:
|
||||||
|
aliases[bound_name] = module_name
|
||||||
|
targets.append(module_name)
|
||||||
|
elif child.type == "dotted_name":
|
||||||
|
module_name = self._node_text(source_bytes, child).strip()
|
||||||
|
if not module_name:
|
||||||
|
continue
|
||||||
|
bound_name = module_name.split(".", 1)[0]
|
||||||
|
if bound_name:
|
||||||
|
aliases[bound_name] = bound_name
|
||||||
|
targets.append(module_name)
|
||||||
|
|
||||||
|
if node.type == "import_from_statement":
|
||||||
|
module_name = ""
|
||||||
|
module_node = node.child_by_field_name("module_name")
|
||||||
|
if module_node is None:
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == "dotted_name":
|
||||||
|
module_node = child
|
||||||
|
break
|
||||||
|
if module_node is not None:
|
||||||
|
module_name = self._node_text(source_bytes, module_node).strip()
|
||||||
|
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == "aliased_import":
|
||||||
|
name_node = child.child_by_field_name("name")
|
||||||
|
alias_node = child.child_by_field_name("alias")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
imported_name = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if not imported_name or imported_name == "*":
|
||||||
|
continue
|
||||||
|
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||||
|
bound_name = (
|
||||||
|
self._node_text(source_bytes, alias_node).strip()
|
||||||
|
if alias_node is not None
|
||||||
|
else imported_name
|
||||||
|
)
|
||||||
|
if bound_name:
|
||||||
|
aliases[bound_name] = target
|
||||||
|
targets.append(target)
|
||||||
|
elif child.type == "identifier":
|
||||||
|
imported_name = self._node_text(source_bytes, child).strip()
|
||||||
|
if not imported_name or imported_name in {"from", "import", "*"}:
|
||||||
|
continue
|
||||||
|
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||||
|
aliases[imported_name] = target
|
||||||
|
targets.append(target)
|
||||||
|
|
||||||
|
return aliases, targets
|
||||||
|
|
||||||
|
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||||
|
if node.type in {"this", "super"}:
|
||||||
|
return node.type
|
||||||
|
if node.type in {"identifier", "property_identifier"}:
|
||||||
|
return self._node_text(source_bytes, node).strip()
|
||||||
|
if node.type == "member_expression":
|
||||||
|
obj = node.child_by_field_name("object")
|
||||||
|
prop = node.child_by_field_name("property")
|
||||||
|
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||||
|
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
|
||||||
|
if obj_text and prop_text:
|
||||||
|
return f"{obj_text}.{prop_text}"
|
||||||
|
return obj_text or prop_text
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _js_import_aliases_and_targets(
|
||||||
|
self,
|
||||||
|
source_bytes: bytes,
|
||||||
|
node: TreeSitterNode,
|
||||||
|
) -> tuple[Dict[str, str], List[str]]:
|
||||||
|
aliases: Dict[str, str] = {}
|
||||||
|
targets: List[str] = []
|
||||||
|
|
||||||
|
module_name = ""
|
||||||
|
source_node = node.child_by_field_name("source")
|
||||||
|
if source_node is not None:
|
||||||
|
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
|
||||||
|
if module_name:
|
||||||
|
targets.append(module_name)
|
||||||
|
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == "import_clause":
|
||||||
|
for clause_child in child.children:
|
||||||
|
if clause_child.type == "identifier":
|
||||||
|
# Default import: import React from "react"
|
||||||
|
local = self._node_text(source_bytes, clause_child).strip()
|
||||||
|
if local and module_name:
|
||||||
|
aliases[local] = module_name
|
||||||
|
if clause_child.type == "namespace_import":
|
||||||
|
# Namespace import: import * as fs from "fs"
|
||||||
|
name_node = clause_child.child_by_field_name("name")
|
||||||
|
if name_node is not None and module_name:
|
||||||
|
local = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if local:
|
||||||
|
aliases[local] = module_name
|
||||||
|
if clause_child.type == "named_imports":
|
||||||
|
for spec in clause_child.children:
|
||||||
|
if spec.type != "import_specifier":
|
||||||
|
continue
|
||||||
|
name_node = spec.child_by_field_name("name")
|
||||||
|
alias_node = spec.child_by_field_name("alias")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
imported = self._node_text(source_bytes, name_node).strip()
|
||||||
|
if not imported:
|
||||||
|
continue
|
||||||
|
local = (
|
||||||
|
self._node_text(source_bytes, alias_node).strip()
|
||||||
|
if alias_node is not None
|
||||||
|
else imported
|
||||||
|
)
|
||||||
|
if local and module_name:
|
||||||
|
aliases[local] = f"{module_name}.{imported}"
|
||||||
|
targets.append(f"{module_name}.{imported}")
|
||||||
|
|
||||||
|
return aliases, targets
|
||||||
|
|
||||||
|
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
|
||||||
|
for child in args_node.children:
|
||||||
|
if child.type == "string":
|
||||||
|
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||||
|
"""Extract Python symbols from AST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_bytes: Source code as bytes
|
||||||
|
root: Root AST node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Python symbols (classes, functions, methods)
|
||||||
|
"""
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
|
||||||
|
for node in self._iter_nodes(root):
|
||||||
|
if node.type == "class_definition":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
symbols.append(Symbol(
|
||||||
|
name=self._node_text(source_bytes, name_node),
|
||||||
|
kind="class",
|
||||||
|
range=self._node_range(node),
|
||||||
|
))
|
||||||
|
elif node.type in {"function_definition", "async_function_definition"}:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
symbols.append(Symbol(
|
||||||
|
name=self._node_text(source_bytes, name_node),
|
||||||
|
kind=self._python_function_kind(node),
|
||||||
|
range=self._node_range(node),
|
||||||
|
))
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||||
|
"""Extract JavaScript/TypeScript symbols from AST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_bytes: Source code as bytes
|
||||||
|
root: Root AST node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of JS/TS symbols (classes, functions, methods)
|
||||||
|
"""
|
||||||
|
symbols: List[Symbol] = []
|
||||||
|
|
||||||
|
for node in self._iter_nodes(root):
|
||||||
|
if node.type in {"class_declaration", "class"}:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
symbols.append(Symbol(
|
||||||
|
name=self._node_text(source_bytes, name_node),
|
||||||
|
kind="class",
|
||||||
|
range=self._node_range(node),
|
||||||
|
))
|
||||||
|
elif node.type in {"function_declaration", "generator_function_declaration"}:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
symbols.append(Symbol(
|
||||||
|
name=self._node_text(source_bytes, name_node),
|
||||||
|
kind="function",
|
||||||
|
range=self._node_range(node),
|
||||||
|
))
|
||||||
|
elif node.type == "variable_declarator":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
value_node = node.child_by_field_name("value")
|
||||||
|
if (
|
||||||
|
name_node is None
|
||||||
|
or value_node is None
|
||||||
|
or name_node.type not in {"identifier", "property_identifier"}
|
||||||
|
or value_node.type != "arrow_function"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
symbols.append(Symbol(
|
||||||
|
name=self._node_text(source_bytes, name_node),
|
||||||
|
kind="function",
|
||||||
|
range=self._node_range(node),
|
||||||
|
))
|
||||||
|
elif node.type == "method_definition" and self._has_class_ancestor(node):
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node is None:
|
||||||
|
continue
|
||||||
|
name = self._node_text(source_bytes, name_node)
|
||||||
|
if name == "constructor":
|
||||||
|
continue
|
||||||
|
symbols.append(Symbol(
|
||||||
|
name=name,
|
||||||
|
kind="method",
|
||||||
|
range=self._node_range(node),
|
||||||
|
))
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
def _python_function_kind(self, node: TreeSitterNode) -> str:
|
||||||
|
"""Determine if Python function is a method or standalone function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Function definition node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
'method' if inside a class, 'function' otherwise
|
||||||
|
"""
|
||||||
|
parent = node.parent
|
||||||
|
while parent is not None:
|
||||||
|
if parent.type in {"function_definition", "async_function_definition"}:
|
||||||
|
return "function"
|
||||||
|
if parent.type == "class_definition":
|
||||||
|
return "method"
|
||||||
|
parent = parent.parent
|
||||||
|
return "function"
|
||||||
|
|
||||||
|
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
|
||||||
|
"""Check if node has a class ancestor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST node to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if node is inside a class
|
||||||
|
"""
|
||||||
|
parent = node.parent
|
||||||
|
while parent is not None:
|
||||||
|
if parent.type in {"class_declaration", "class"}:
|
||||||
|
return True
|
||||||
|
parent = parent.parent
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _iter_nodes(self, root: TreeSitterNode):
|
||||||
|
"""Iterate over all nodes in AST.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root: Root node to start iteration
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AST nodes in depth-first order
|
||||||
|
"""
|
||||||
|
stack = [root]
|
||||||
|
while stack:
|
||||||
|
node = stack.pop()
|
||||||
|
yield node
|
||||||
|
for child in reversed(node.children):
|
||||||
|
stack.append(child)
|
||||||
|
|
||||||
|
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||||
|
"""Extract text for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_bytes: Source code as bytes
|
||||||
|
node: AST node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text content of node
|
||||||
|
"""
|
||||||
|
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||||
|
|
||||||
|
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
|
||||||
|
"""Get line range for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(start_line, end_line) tuple, 1-based inclusive
|
||||||
|
"""
|
||||||
|
start_line = node.start_point[0] + 1
|
||||||
|
end_line = node.end_point[0] + 1
|
||||||
|
return (start_line, max(start_line, end_line))
|
||||||
|
|
||||||
|
def count_tokens(self, text: str) -> int:
|
||||||
|
"""Count tokens in text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count
|
||||||
|
"""
|
||||||
|
return self._tokenizer.count_tokens(text)
|
||||||
53
codex-lens/build/lib/codexlens/search/__init__.py
Normal file
53
codex-lens/build/lib/codexlens/search/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from .chain_search import (
|
||||||
|
ChainSearchEngine,
|
||||||
|
SearchOptions,
|
||||||
|
SearchStats,
|
||||||
|
ChainSearchResult,
|
||||||
|
quick_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clustering availability flag (lazy import pattern)
|
||||||
|
CLUSTERING_AVAILABLE = False
|
||||||
|
_clustering_import_error: str | None = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
|
||||||
|
from .clustering import check_clustering_available
|
||||||
|
CLUSTERING_AVAILABLE = _clustering_flag
|
||||||
|
except ImportError as e:
|
||||||
|
_clustering_import_error = str(e)
|
||||||
|
|
||||||
|
def check_clustering_available() -> tuple[bool, str | None]:
|
||||||
|
"""Fallback when clustering module not loadable."""
|
||||||
|
return False, _clustering_import_error
|
||||||
|
|
||||||
|
|
||||||
|
# Clustering module exports (conditional)
|
||||||
|
try:
|
||||||
|
from .clustering import (
|
||||||
|
BaseClusteringStrategy,
|
||||||
|
ClusteringConfig,
|
||||||
|
ClusteringStrategyFactory,
|
||||||
|
get_strategy,
|
||||||
|
)
|
||||||
|
_clustering_exports = [
|
||||||
|
"BaseClusteringStrategy",
|
||||||
|
"ClusteringConfig",
|
||||||
|
"ClusteringStrategyFactory",
|
||||||
|
"get_strategy",
|
||||||
|
]
|
||||||
|
except ImportError:
|
||||||
|
_clustering_exports = []
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChainSearchEngine",
|
||||||
|
"SearchOptions",
|
||||||
|
"SearchStats",
|
||||||
|
"ChainSearchResult",
|
||||||
|
"quick_search",
|
||||||
|
# Clustering
|
||||||
|
"CLUSTERING_AVAILABLE",
|
||||||
|
"check_clustering_available",
|
||||||
|
*_clustering_exports,
|
||||||
|
]
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
"""Association tree module for LSP-based code relationship discovery.
|
||||||
|
|
||||||
|
This module provides components for building and processing call association trees
|
||||||
|
using Language Server Protocol (LSP) call hierarchy capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .builder import AssociationTreeBuilder
|
||||||
|
from .data_structures import (
|
||||||
|
CallTree,
|
||||||
|
TreeNode,
|
||||||
|
UniqueNode,
|
||||||
|
)
|
||||||
|
from .deduplicator import ResultDeduplicator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AssociationTreeBuilder",
|
||||||
|
"CallTree",
|
||||||
|
"TreeNode",
|
||||||
|
"UniqueNode",
|
||||||
|
"ResultDeduplicator",
|
||||||
|
]
|
||||||
@@ -0,0 +1,450 @@
|
|||||||
|
"""Association tree builder using LSP call hierarchy.
|
||||||
|
|
||||||
|
Builds call relationship trees by recursively expanding from seed locations
|
||||||
|
using Language Server Protocol (LSP) call hierarchy capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
|
||||||
|
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||||
|
from .data_structures import CallTree, TreeNode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AssociationTreeBuilder:
|
||||||
|
"""Builds association trees from seed locations using LSP call hierarchy.
|
||||||
|
|
||||||
|
Uses depth-first recursive expansion to build a tree of code relationships
|
||||||
|
starting from seed locations (typically from vector search results).
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Start from seed locations (vector search results)
|
||||||
|
- For each seed, get call hierarchy items via LSP
|
||||||
|
- Recursively expand incoming calls (callers) if expand_callers=True
|
||||||
|
- Recursively expand outgoing calls (callees) if expand_callees=True
|
||||||
|
- Track visited nodes to prevent cycles
|
||||||
|
- Stop at max_depth or when no more relations found
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lsp_manager: StandaloneLspManager for LSP communication
|
||||||
|
visited: Set of visited node IDs to prevent cycles
|
||||||
|
timeout: Timeout for individual LSP requests (seconds)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lsp_manager: StandaloneLspManager,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
analysis_wait: float = 2.0,
|
||||||
|
):
|
||||||
|
"""Initialize AssociationTreeBuilder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lsp_manager: StandaloneLspManager instance for LSP communication
|
||||||
|
timeout: Timeout for individual LSP requests in seconds
|
||||||
|
analysis_wait: Time to wait for LSP analysis on first file (seconds)
|
||||||
|
"""
|
||||||
|
self.lsp_manager = lsp_manager
|
||||||
|
self.timeout = timeout
|
||||||
|
self.analysis_wait = analysis_wait
|
||||||
|
self.visited: Set[str] = set()
|
||||||
|
self._analyzed_files: Set[str] = set() # Track files already analyzed
|
||||||
|
|
||||||
|
async def build_tree(
|
||||||
|
self,
|
||||||
|
seed_file_path: str,
|
||||||
|
seed_line: int,
|
||||||
|
seed_character: int = 1,
|
||||||
|
max_depth: int = 5,
|
||||||
|
expand_callers: bool = True,
|
||||||
|
expand_callees: bool = True,
|
||||||
|
) -> CallTree:
|
||||||
|
"""Build call tree from a single seed location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed_file_path: Path to the seed file
|
||||||
|
seed_line: Line number of the seed symbol (1-based)
|
||||||
|
seed_character: Character position (1-based, default 1)
|
||||||
|
max_depth: Maximum recursion depth (default 5)
|
||||||
|
expand_callers: Whether to expand incoming calls (callers)
|
||||||
|
expand_callees: Whether to expand outgoing calls (callees)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallTree containing all discovered nodes and relationships
|
||||||
|
"""
|
||||||
|
tree = CallTree()
|
||||||
|
self.visited.clear()
|
||||||
|
|
||||||
|
# Determine wait time - only wait for analysis on first encounter of file
|
||||||
|
wait_time = 0.0
|
||||||
|
if seed_file_path not in self._analyzed_files:
|
||||||
|
wait_time = self.analysis_wait
|
||||||
|
self._analyzed_files.add(seed_file_path)
|
||||||
|
|
||||||
|
# Get call hierarchy items for the seed position
|
||||||
|
try:
|
||||||
|
hierarchy_items = await asyncio.wait_for(
|
||||||
|
self.lsp_manager.get_call_hierarchy_items(
|
||||||
|
file_path=seed_file_path,
|
||||||
|
line=seed_line,
|
||||||
|
character=seed_character,
|
||||||
|
wait_for_analysis=wait_time,
|
||||||
|
),
|
||||||
|
timeout=self.timeout + wait_time,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Timeout getting call hierarchy items for %s:%d",
|
||||||
|
seed_file_path,
|
||||||
|
seed_line,
|
||||||
|
)
|
||||||
|
return tree
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Error getting call hierarchy items for %s:%d: %s",
|
||||||
|
seed_file_path,
|
||||||
|
seed_line,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return tree
|
||||||
|
|
||||||
|
if not hierarchy_items:
|
||||||
|
logger.debug(
|
||||||
|
"No call hierarchy items found for %s:%d",
|
||||||
|
seed_file_path,
|
||||||
|
seed_line,
|
||||||
|
)
|
||||||
|
return tree
|
||||||
|
|
||||||
|
# Create root nodes from hierarchy items
|
||||||
|
for item_dict in hierarchy_items:
|
||||||
|
# Convert LSP dict to CallHierarchyItem
|
||||||
|
item = self._dict_to_call_hierarchy_item(item_dict)
|
||||||
|
if not item:
|
||||||
|
continue
|
||||||
|
|
||||||
|
root_node = TreeNode(
|
||||||
|
item=item,
|
||||||
|
depth=0,
|
||||||
|
path_from_root=[self._create_node_id(item)],
|
||||||
|
)
|
||||||
|
tree.roots.append(root_node)
|
||||||
|
tree.add_node(root_node)
|
||||||
|
|
||||||
|
# Mark as visited
|
||||||
|
self.visited.add(root_node.node_id)
|
||||||
|
|
||||||
|
# Recursively expand the tree
|
||||||
|
await self._expand_node(
|
||||||
|
node=root_node,
|
||||||
|
node_dict=item_dict,
|
||||||
|
tree=tree,
|
||||||
|
current_depth=0,
|
||||||
|
max_depth=max_depth,
|
||||||
|
expand_callers=expand_callers,
|
||||||
|
expand_callees=expand_callees,
|
||||||
|
)
|
||||||
|
|
||||||
|
tree.depth_reached = max_depth
|
||||||
|
return tree
|
||||||
|
|
||||||
|
async def _expand_node(
|
||||||
|
self,
|
||||||
|
node: TreeNode,
|
||||||
|
node_dict: Dict,
|
||||||
|
tree: CallTree,
|
||||||
|
current_depth: int,
|
||||||
|
max_depth: int,
|
||||||
|
expand_callers: bool,
|
||||||
|
expand_callees: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Recursively expand a node by fetching its callers and callees.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: TreeNode to expand
|
||||||
|
node_dict: LSP CallHierarchyItem dict (for LSP requests)
|
||||||
|
tree: CallTree to add discovered nodes to
|
||||||
|
current_depth: Current recursion depth
|
||||||
|
max_depth: Maximum allowed depth
|
||||||
|
expand_callers: Whether to expand incoming calls
|
||||||
|
expand_callees: Whether to expand outgoing calls
|
||||||
|
"""
|
||||||
|
# Stop if max depth reached
|
||||||
|
if current_depth >= max_depth:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare tasks for parallel expansion
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
if expand_callers:
|
||||||
|
tasks.append(
|
||||||
|
self._expand_incoming_calls(
|
||||||
|
node=node,
|
||||||
|
node_dict=node_dict,
|
||||||
|
tree=tree,
|
||||||
|
current_depth=current_depth,
|
||||||
|
max_depth=max_depth,
|
||||||
|
expand_callers=expand_callers,
|
||||||
|
expand_callees=expand_callees,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if expand_callees:
|
||||||
|
tasks.append(
|
||||||
|
self._expand_outgoing_calls(
|
||||||
|
node=node,
|
||||||
|
node_dict=node_dict,
|
||||||
|
tree=tree,
|
||||||
|
current_depth=current_depth,
|
||||||
|
max_depth=max_depth,
|
||||||
|
expand_callers=expand_callers,
|
||||||
|
expand_callees=expand_callees,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute expansions in parallel
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
async def _expand_incoming_calls(
|
||||||
|
self,
|
||||||
|
node: TreeNode,
|
||||||
|
node_dict: Dict,
|
||||||
|
tree: CallTree,
|
||||||
|
current_depth: int,
|
||||||
|
max_depth: int,
|
||||||
|
expand_callers: bool,
|
||||||
|
expand_callees: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Expand incoming calls (callers) for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: TreeNode being expanded
|
||||||
|
node_dict: LSP dict for the node
|
||||||
|
tree: CallTree to add nodes to
|
||||||
|
current_depth: Current depth
|
||||||
|
max_depth: Maximum depth
|
||||||
|
expand_callers: Whether to continue expanding callers
|
||||||
|
expand_callees: Whether to expand callees
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
incoming_calls = await asyncio.wait_for(
|
||||||
|
self.lsp_manager.get_incoming_calls(item=node_dict),
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.debug("Timeout getting incoming calls for %s", node.node_id)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Error getting incoming calls for %s: %s", node.node_id, e)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not incoming_calls:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process each incoming call
|
||||||
|
for call_dict in incoming_calls:
|
||||||
|
caller_dict = call_dict.get("from")
|
||||||
|
if not caller_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert to CallHierarchyItem
|
||||||
|
caller_item = self._dict_to_call_hierarchy_item(caller_dict)
|
||||||
|
if not caller_item:
|
||||||
|
continue
|
||||||
|
|
||||||
|
caller_id = self._create_node_id(caller_item)
|
||||||
|
|
||||||
|
# Check for cycles
|
||||||
|
if caller_id in self.visited:
|
||||||
|
# Create cycle marker node
|
||||||
|
cycle_node = TreeNode(
|
||||||
|
item=caller_item,
|
||||||
|
depth=current_depth + 1,
|
||||||
|
is_cycle=True,
|
||||||
|
path_from_root=node.path_from_root + [caller_id],
|
||||||
|
)
|
||||||
|
node.parents.append(cycle_node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create new caller node
|
||||||
|
caller_node = TreeNode(
|
||||||
|
item=caller_item,
|
||||||
|
depth=current_depth + 1,
|
||||||
|
path_from_root=node.path_from_root + [caller_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to tree
|
||||||
|
tree.add_node(caller_node)
|
||||||
|
tree.add_edge(caller_node, node)
|
||||||
|
|
||||||
|
# Update relationships
|
||||||
|
node.parents.append(caller_node)
|
||||||
|
caller_node.children.append(node)
|
||||||
|
|
||||||
|
# Mark as visited
|
||||||
|
self.visited.add(caller_id)
|
||||||
|
|
||||||
|
# Recursively expand the caller
|
||||||
|
await self._expand_node(
|
||||||
|
node=caller_node,
|
||||||
|
node_dict=caller_dict,
|
||||||
|
tree=tree,
|
||||||
|
current_depth=current_depth + 1,
|
||||||
|
max_depth=max_depth,
|
||||||
|
expand_callers=expand_callers,
|
||||||
|
expand_callees=expand_callees,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _expand_outgoing_calls(
|
||||||
|
self,
|
||||||
|
node: TreeNode,
|
||||||
|
node_dict: Dict,
|
||||||
|
tree: CallTree,
|
||||||
|
current_depth: int,
|
||||||
|
max_depth: int,
|
||||||
|
expand_callers: bool,
|
||||||
|
expand_callees: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Expand outgoing calls (callees) for a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: TreeNode being expanded
|
||||||
|
node_dict: LSP dict for the node
|
||||||
|
tree: CallTree to add nodes to
|
||||||
|
current_depth: Current depth
|
||||||
|
max_depth: Maximum depth
|
||||||
|
expand_callers: Whether to expand callers
|
||||||
|
expand_callees: Whether to continue expanding callees
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
outgoing_calls = await asyncio.wait_for(
|
||||||
|
self.lsp_manager.get_outgoing_calls(item=node_dict),
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.debug("Timeout getting outgoing calls for %s", node.node_id)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Error getting outgoing calls for %s: %s", node.node_id, e)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not outgoing_calls:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process each outgoing call
|
||||||
|
for call_dict in outgoing_calls:
|
||||||
|
callee_dict = call_dict.get("to")
|
||||||
|
if not callee_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert to CallHierarchyItem
|
||||||
|
callee_item = self._dict_to_call_hierarchy_item(callee_dict)
|
||||||
|
if not callee_item:
|
||||||
|
continue
|
||||||
|
|
||||||
|
callee_id = self._create_node_id(callee_item)
|
||||||
|
|
||||||
|
# Check for cycles
|
||||||
|
if callee_id in self.visited:
|
||||||
|
# Create cycle marker node
|
||||||
|
cycle_node = TreeNode(
|
||||||
|
item=callee_item,
|
||||||
|
depth=current_depth + 1,
|
||||||
|
is_cycle=True,
|
||||||
|
path_from_root=node.path_from_root + [callee_id],
|
||||||
|
)
|
||||||
|
node.children.append(cycle_node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create new callee node
|
||||||
|
callee_node = TreeNode(
|
||||||
|
item=callee_item,
|
||||||
|
depth=current_depth + 1,
|
||||||
|
path_from_root=node.path_from_root + [callee_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to tree
|
||||||
|
tree.add_node(callee_node)
|
||||||
|
tree.add_edge(node, callee_node)
|
||||||
|
|
||||||
|
# Update relationships
|
||||||
|
node.children.append(callee_node)
|
||||||
|
callee_node.parents.append(node)
|
||||||
|
|
||||||
|
# Mark as visited
|
||||||
|
self.visited.add(callee_id)
|
||||||
|
|
||||||
|
# Recursively expand the callee
|
||||||
|
await self._expand_node(
|
||||||
|
node=callee_node,
|
||||||
|
node_dict=callee_dict,
|
||||||
|
tree=tree,
|
||||||
|
current_depth=current_depth + 1,
|
||||||
|
max_depth=max_depth,
|
||||||
|
expand_callers=expand_callers,
|
||||||
|
expand_callees=expand_callees,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _dict_to_call_hierarchy_item(
|
||||||
|
self, item_dict: Dict
|
||||||
|
) -> Optional[CallHierarchyItem]:
|
||||||
|
"""Convert LSP dict to CallHierarchyItem.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item_dict: LSP CallHierarchyItem dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CallHierarchyItem or None if conversion fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract URI and convert to file path
|
||||||
|
uri = item_dict.get("uri", "")
|
||||||
|
file_path = uri.replace("file:///", "").replace("file://", "")
|
||||||
|
|
||||||
|
# Handle Windows paths (file:///C:/...)
|
||||||
|
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||||
|
file_path = file_path[1:]
|
||||||
|
|
||||||
|
# Extract range
|
||||||
|
range_dict = item_dict.get("range", {})
|
||||||
|
start = range_dict.get("start", {})
|
||||||
|
end = range_dict.get("end", {})
|
||||||
|
|
||||||
|
# Create Range (convert from 0-based to 1-based)
|
||||||
|
item_range = Range(
|
||||||
|
start_line=start.get("line", 0) + 1,
|
||||||
|
start_character=start.get("character", 0) + 1,
|
||||||
|
end_line=end.get("line", 0) + 1,
|
||||||
|
end_character=end.get("character", 0) + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CallHierarchyItem(
|
||||||
|
name=item_dict.get("name", "unknown"),
|
||||||
|
kind=str(item_dict.get("kind", "unknown")),
|
||||||
|
file_path=file_path,
|
||||||
|
range=item_range,
|
||||||
|
detail=item_dict.get("detail"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed to convert dict to CallHierarchyItem: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_node_id(self, item: CallHierarchyItem) -> str:
|
||||||
|
"""Create unique node ID from CallHierarchyItem.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: CallHierarchyItem
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unique node ID string
|
||||||
|
"""
|
||||||
|
return f"{item.file_path}:{item.name}:{item.range.start_line}"
|
||||||
@@ -0,0 +1,191 @@
|
|||||||
|
"""Data structures for association tree building.
|
||||||
|
|
||||||
|
Defines the core data classes for representing call hierarchy trees and
|
||||||
|
deduplicated results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TreeNode:
|
||||||
|
"""Node in the call association tree.
|
||||||
|
|
||||||
|
Represents a single function/method in the tree, including its position
|
||||||
|
in the hierarchy and relationships.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
item: LSP CallHierarchyItem containing symbol information
|
||||||
|
depth: Distance from the root node (seed) - 0 for roots
|
||||||
|
children: List of child nodes (functions called by this node)
|
||||||
|
parents: List of parent nodes (functions that call this node)
|
||||||
|
is_cycle: Whether this node creates a circular reference
|
||||||
|
path_from_root: Path (list of node IDs) from root to this node
|
||||||
|
"""
|
||||||
|
|
||||||
|
item: CallHierarchyItem
|
||||||
|
depth: int = 0
|
||||||
|
children: List[TreeNode] = field(default_factory=list)
|
||||||
|
parents: List[TreeNode] = field(default_factory=list)
|
||||||
|
is_cycle: bool = False
|
||||||
|
path_from_root: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_id(self) -> str:
|
||||||
|
"""Unique identifier for this node."""
|
||||||
|
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Hash based on node ID."""
|
||||||
|
return hash(self.node_id)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Equality based on node ID."""
|
||||||
|
if not isinstance(other, TreeNode):
|
||||||
|
return False
|
||||||
|
return self.node_id == other.node_id
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the node."""
|
||||||
|
cycle_marker = " [CYCLE]" if self.is_cycle else ""
|
||||||
|
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallTree:
|
||||||
|
"""Complete call tree structure built from seeds.
|
||||||
|
|
||||||
|
Contains all nodes discovered through recursive expansion and
|
||||||
|
the relationships between them.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
roots: List of root nodes (seed symbols)
|
||||||
|
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
|
||||||
|
node_list: Flat list of all nodes in tree order
|
||||||
|
edges: List of (from_node_id, to_node_id) tuples representing calls
|
||||||
|
depth_reached: Maximum depth achieved in expansion
|
||||||
|
"""
|
||||||
|
|
||||||
|
roots: List[TreeNode] = field(default_factory=list)
|
||||||
|
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
|
||||||
|
node_list: List[TreeNode] = field(default_factory=list)
|
||||||
|
edges: List[tuple[str, str]] = field(default_factory=list)
|
||||||
|
depth_reached: int = 0
|
||||||
|
|
||||||
|
def add_node(self, node: TreeNode) -> None:
|
||||||
|
"""Add a node to the tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: TreeNode to add
|
||||||
|
"""
|
||||||
|
if node.node_id not in self.all_nodes:
|
||||||
|
self.all_nodes[node.node_id] = node
|
||||||
|
self.node_list.append(node)
|
||||||
|
|
||||||
|
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
|
||||||
|
"""Add an edge between two nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_node: Source node
|
||||||
|
to_node: Target node
|
||||||
|
"""
|
||||||
|
edge = (from_node.node_id, to_node.node_id)
|
||||||
|
if edge not in self.edges:
|
||||||
|
self.edges.append(edge)
|
||||||
|
|
||||||
|
def get_node(self, node_id: str) -> Optional[TreeNode]:
|
||||||
|
"""Get a node by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: Node identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TreeNode if found, None otherwise
|
||||||
|
"""
|
||||||
|
return self.all_nodes.get(node_id)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return total number of nodes in tree."""
|
||||||
|
return len(self.all_nodes)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the tree."""
|
||||||
|
return (
|
||||||
|
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
|
||||||
|
f"depth={self.depth_reached})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UniqueNode:
|
||||||
|
"""Deduplicated unique code symbol from the tree.
|
||||||
|
|
||||||
|
Represents a single unique code location that may appear multiple times
|
||||||
|
in the tree under different contexts. Contains aggregated information
|
||||||
|
about all occurrences.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
file_path: Absolute path to the file
|
||||||
|
name: Symbol name (function, method, class, etc.)
|
||||||
|
kind: Symbol kind (function, method, class, etc.)
|
||||||
|
range: Code range in the file
|
||||||
|
min_depth: Minimum depth at which this node appears in the tree
|
||||||
|
occurrences: Number of times this node appears in the tree
|
||||||
|
paths: List of paths from roots to this node
|
||||||
|
context_nodes: Related nodes from the tree
|
||||||
|
score: Composite relevance score (higher is better)
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_path: str
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
range: Range
|
||||||
|
min_depth: int = 0
|
||||||
|
occurrences: int = 1
|
||||||
|
paths: List[List[str]] = field(default_factory=list)
|
||||||
|
context_nodes: List[str] = field(default_factory=list)
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_key(self) -> tuple[str, int, int]:
|
||||||
|
"""Unique key for deduplication.
|
||||||
|
|
||||||
|
Uses (file_path, start_line, end_line) as the unique identifier
|
||||||
|
for this symbol across all occurrences.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.file_path,
|
||||||
|
self.range.start_line,
|
||||||
|
self.range.end_line,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_path(self, path: List[str]) -> None:
|
||||||
|
"""Add a path from root to this node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: List of node IDs from root to this node
|
||||||
|
"""
|
||||||
|
if path not in self.paths:
|
||||||
|
self.paths.append(path)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Hash based on node key."""
|
||||||
|
return hash(self.node_key)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Equality based on node key."""
|
||||||
|
if not isinstance(other, UniqueNode):
|
||||||
|
return False
|
||||||
|
return self.node_key == other.node_key
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of the unique node."""
|
||||||
|
return (
|
||||||
|
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
|
||||||
|
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
|
||||||
|
)
|
||||||
@@ -0,0 +1,301 @@
|
|||||||
|
"""Result deduplication for association tree nodes.
|
||||||
|
|
||||||
|
Provides functionality to extract unique nodes from a call tree and assign
|
||||||
|
relevance scores based on various factors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .data_structures import (
|
||||||
|
CallTree,
|
||||||
|
TreeNode,
|
||||||
|
UniqueNode,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Symbol kind weights for scoring (higher = more relevant)
|
||||||
|
KIND_WEIGHTS: Dict[str, float] = {
|
||||||
|
# Functions and methods are primary targets
|
||||||
|
"function": 1.0,
|
||||||
|
"method": 1.0,
|
||||||
|
"12": 1.0, # LSP SymbolKind.Function
|
||||||
|
"6": 1.0, # LSP SymbolKind.Method
|
||||||
|
# Classes are important but secondary
|
||||||
|
"class": 0.8,
|
||||||
|
"5": 0.8, # LSP SymbolKind.Class
|
||||||
|
# Interfaces and types
|
||||||
|
"interface": 0.7,
|
||||||
|
"11": 0.7, # LSP SymbolKind.Interface
|
||||||
|
"type": 0.6,
|
||||||
|
# Constructors
|
||||||
|
"constructor": 0.9,
|
||||||
|
"9": 0.9, # LSP SymbolKind.Constructor
|
||||||
|
# Variables and constants
|
||||||
|
"variable": 0.4,
|
||||||
|
"13": 0.4, # LSP SymbolKind.Variable
|
||||||
|
"constant": 0.5,
|
||||||
|
"14": 0.5, # LSP SymbolKind.Constant
|
||||||
|
# Default for unknown kinds
|
||||||
|
"unknown": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ResultDeduplicator:
|
||||||
|
"""Extracts and scores unique nodes from call trees.
|
||||||
|
|
||||||
|
Processes a CallTree to extract unique code locations, merging duplicates
|
||||||
|
and assigning relevance scores based on:
|
||||||
|
- Depth: Shallower nodes (closer to seeds) score higher
|
||||||
|
- Frequency: Nodes appearing multiple times score higher
|
||||||
|
- Kind: Function/method > class > variable
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
depth_weight: Weight for depth factor in scoring (default 0.4)
|
||||||
|
frequency_weight: Weight for frequency factor (default 0.3)
|
||||||
|
kind_weight: Weight for symbol kind factor (default 0.3)
|
||||||
|
max_depth_penalty: Maximum depth before full penalty applied
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depth_weight: float = 0.4,
|
||||||
|
frequency_weight: float = 0.3,
|
||||||
|
kind_weight: float = 0.3,
|
||||||
|
max_depth_penalty: int = 10,
|
||||||
|
):
|
||||||
|
"""Initialize ResultDeduplicator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth_weight: Weight for depth factor (0.0-1.0)
|
||||||
|
frequency_weight: Weight for frequency factor (0.0-1.0)
|
||||||
|
kind_weight: Weight for symbol kind factor (0.0-1.0)
|
||||||
|
max_depth_penalty: Depth at which score becomes 0 for depth factor
|
||||||
|
"""
|
||||||
|
self.depth_weight = depth_weight
|
||||||
|
self.frequency_weight = frequency_weight
|
||||||
|
self.kind_weight = kind_weight
|
||||||
|
self.max_depth_penalty = max_depth_penalty
|
||||||
|
|
||||||
|
def deduplicate(
|
||||||
|
self,
|
||||||
|
tree: CallTree,
|
||||||
|
max_results: Optional[int] = None,
|
||||||
|
) -> List[UniqueNode]:
|
||||||
|
"""Extract unique nodes from the call tree.
|
||||||
|
|
||||||
|
Traverses the tree, groups nodes by their unique key (file_path,
|
||||||
|
start_line, end_line), and merges duplicate occurrences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tree: CallTree to process
|
||||||
|
max_results: Maximum number of results to return (None = all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UniqueNode objects, sorted by score descending
|
||||||
|
"""
|
||||||
|
if not tree.node_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Group nodes by unique key
|
||||||
|
unique_map: Dict[tuple, UniqueNode] = {}
|
||||||
|
|
||||||
|
for node in tree.node_list:
|
||||||
|
if node.is_cycle:
|
||||||
|
# Skip cycle markers - they point to already-counted nodes
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = self._get_node_key(node)
|
||||||
|
|
||||||
|
if key in unique_map:
|
||||||
|
# Update existing unique node
|
||||||
|
unique_node = unique_map[key]
|
||||||
|
unique_node.occurrences += 1
|
||||||
|
unique_node.min_depth = min(unique_node.min_depth, node.depth)
|
||||||
|
unique_node.add_path(node.path_from_root)
|
||||||
|
|
||||||
|
# Collect context from relationships
|
||||||
|
for parent in node.parents:
|
||||||
|
if not parent.is_cycle:
|
||||||
|
unique_node.context_nodes.append(parent.node_id)
|
||||||
|
for child in node.children:
|
||||||
|
if not child.is_cycle:
|
||||||
|
unique_node.context_nodes.append(child.node_id)
|
||||||
|
else:
|
||||||
|
# Create new unique node
|
||||||
|
unique_node = UniqueNode(
|
||||||
|
file_path=node.item.file_path,
|
||||||
|
name=node.item.name,
|
||||||
|
kind=node.item.kind,
|
||||||
|
range=node.item.range,
|
||||||
|
min_depth=node.depth,
|
||||||
|
occurrences=1,
|
||||||
|
paths=[node.path_from_root.copy()],
|
||||||
|
context_nodes=[],
|
||||||
|
score=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect initial context
|
||||||
|
for parent in node.parents:
|
||||||
|
if not parent.is_cycle:
|
||||||
|
unique_node.context_nodes.append(parent.node_id)
|
||||||
|
for child in node.children:
|
||||||
|
if not child.is_cycle:
|
||||||
|
unique_node.context_nodes.append(child.node_id)
|
||||||
|
|
||||||
|
unique_map[key] = unique_node
|
||||||
|
|
||||||
|
# Calculate scores for all unique nodes
|
||||||
|
unique_nodes = list(unique_map.values())
|
||||||
|
|
||||||
|
# Find max frequency for normalization
|
||||||
|
max_frequency = max((n.occurrences for n in unique_nodes), default=1)
|
||||||
|
|
||||||
|
for node in unique_nodes:
|
||||||
|
node.score = self._score_node(node, max_frequency)
|
||||||
|
|
||||||
|
# Sort by score descending
|
||||||
|
unique_nodes.sort(key=lambda n: n.score, reverse=True)
|
||||||
|
|
||||||
|
# Apply max_results limit
|
||||||
|
if max_results is not None and max_results > 0:
|
||||||
|
unique_nodes = unique_nodes[:max_results]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Deduplicated %d tree nodes to %d unique nodes",
|
||||||
|
len(tree.node_list),
|
||||||
|
len(unique_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
|
return unique_nodes
|
||||||
|
|
||||||
|
def _score_node(
|
||||||
|
self,
|
||||||
|
node: UniqueNode,
|
||||||
|
max_frequency: int,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate composite score for a unique node.
|
||||||
|
|
||||||
|
Score = depth_weight * depth_score +
|
||||||
|
frequency_weight * frequency_score +
|
||||||
|
kind_weight * kind_score
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: UniqueNode to score
|
||||||
|
max_frequency: Maximum occurrence count for normalization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composite score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
# Depth score: closer to root = higher score
|
||||||
|
# Score of 1.0 at depth 0, decreasing to 0.0 at max_depth_penalty
|
||||||
|
depth_score = max(
|
||||||
|
0.0,
|
||||||
|
1.0 - (node.min_depth / self.max_depth_penalty),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Frequency score: more occurrences = higher score
|
||||||
|
frequency_score = node.occurrences / max_frequency if max_frequency > 0 else 0.0
|
||||||
|
|
||||||
|
# Kind score: function/method > class > variable
|
||||||
|
kind_str = str(node.kind).lower()
|
||||||
|
kind_score = KIND_WEIGHTS.get(kind_str, KIND_WEIGHTS["unknown"])
|
||||||
|
|
||||||
|
# Composite score
|
||||||
|
score = (
|
||||||
|
self.depth_weight * depth_score
|
||||||
|
+ self.frequency_weight * frequency_score
|
||||||
|
+ self.kind_weight * kind_score
|
||||||
|
)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _get_node_key(self, node: TreeNode) -> tuple:
|
||||||
|
"""Get unique key for a tree node.
|
||||||
|
|
||||||
|
Uses (file_path, start_line, end_line) as the unique identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: TreeNode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple key for deduplication
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
node.item.file_path,
|
||||||
|
node.item.range.start_line,
|
||||||
|
node.item.range.end_line,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_by_kind(
|
||||||
|
self,
|
||||||
|
nodes: List[UniqueNode],
|
||||||
|
kinds: List[str],
|
||||||
|
) -> List[UniqueNode]:
|
||||||
|
"""Filter unique nodes by symbol kind.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes: List of UniqueNode to filter
|
||||||
|
kinds: List of allowed kinds (e.g., ["function", "method"])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of UniqueNode
|
||||||
|
"""
|
||||||
|
kinds_lower = [k.lower() for k in kinds]
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in nodes
|
||||||
|
if str(node.kind).lower() in kinds_lower
|
||||||
|
]
|
||||||
|
|
||||||
|
def filter_by_file(
|
||||||
|
self,
|
||||||
|
nodes: List[UniqueNode],
|
||||||
|
file_patterns: List[str],
|
||||||
|
) -> List[UniqueNode]:
|
||||||
|
"""Filter unique nodes by file path patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes: List of UniqueNode to filter
|
||||||
|
file_patterns: List of path substrings to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of UniqueNode
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
node
|
||||||
|
for node in nodes
|
||||||
|
if any(pattern in node.file_path for pattern in file_patterns)
|
||||||
|
]
|
||||||
|
|
||||||
|
def to_dict_list(self, nodes: List[UniqueNode]) -> List[Dict]:
|
||||||
|
"""Convert list of UniqueNode to JSON-serializable dicts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes: List of UniqueNode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"file_path": node.file_path,
|
||||||
|
"name": node.name,
|
||||||
|
"kind": node.kind,
|
||||||
|
"range": {
|
||||||
|
"start_line": node.range.start_line,
|
||||||
|
"start_character": node.range.start_character,
|
||||||
|
"end_line": node.range.end_line,
|
||||||
|
"end_character": node.range.end_character,
|
||||||
|
},
|
||||||
|
"min_depth": node.min_depth,
|
||||||
|
"occurrences": node.occurrences,
|
||||||
|
"path_count": len(node.paths),
|
||||||
|
"score": round(node.score, 4),
|
||||||
|
}
|
||||||
|
for node in nodes
|
||||||
|
]
|
||||||
277
codex-lens/build/lib/codexlens/search/binary_searcher.py
Normal file
277
codex-lens/build/lib/codexlens/search/binary_searcher.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Binary vector searcher for cascade search.
|
||||||
|
|
||||||
|
This module provides fast binary vector search using Hamming distance
|
||||||
|
for the first stage of cascade search (coarse filtering).
|
||||||
|
|
||||||
|
Supports two loading modes:
|
||||||
|
1. Memory-mapped file (preferred): Low memory footprint, OS-managed paging
|
||||||
|
2. Database loading (fallback): Loads all vectors into RAM
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Pre-computed popcount lookup table for vectorized Hamming distance
|
||||||
|
# Each byte value (0-255) maps to its bit count
|
||||||
|
_POPCOUNT_TABLE = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class BinarySearcher:
|
||||||
|
"""Fast binary vector search using Hamming distance.
|
||||||
|
|
||||||
|
This class implements the first stage of cascade search:
|
||||||
|
fast, approximate retrieval using binary vectors and Hamming distance.
|
||||||
|
|
||||||
|
The binary vectors are derived from dense embeddings by thresholding:
|
||||||
|
binary[i] = 1 if dense[i] > 0 else 0
|
||||||
|
|
||||||
|
Hamming distance between two binary vectors counts the number of
|
||||||
|
differing bits, which can be computed very efficiently using XOR
|
||||||
|
and population count.
|
||||||
|
|
||||||
|
Supports two loading modes:
|
||||||
|
- Memory-mapped file (preferred): Uses np.memmap for minimal RAM usage
|
||||||
|
- Database (fallback): Loads all vectors into memory from SQLite
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, index_root_or_meta_path: Path) -> None:
|
||||||
|
"""Initialize BinarySearcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_root_or_meta_path: Either:
|
||||||
|
- Path to index root directory (containing _binary_vectors.mmap)
|
||||||
|
- Path to _vectors_meta.db (legacy mode, loads from DB)
|
||||||
|
"""
|
||||||
|
path = Path(index_root_or_meta_path)
|
||||||
|
|
||||||
|
# Determine if this is an index root or a specific DB path
|
||||||
|
if path.suffix == '.db':
|
||||||
|
# Legacy mode: specific DB path
|
||||||
|
self.index_root = path.parent
|
||||||
|
self.meta_store_path = path
|
||||||
|
else:
|
||||||
|
# New mode: index root directory
|
||||||
|
self.index_root = path
|
||||||
|
self.meta_store_path = path / "_vectors_meta.db"
|
||||||
|
|
||||||
|
self._chunk_ids: Optional[np.ndarray] = None
|
||||||
|
self._binary_matrix: Optional[np.ndarray] = None
|
||||||
|
self._is_memmap = False
|
||||||
|
self._loaded = False
|
||||||
|
|
||||||
|
def load(self) -> bool:
|
||||||
|
"""Load binary vectors using memory-mapped file or database fallback.
|
||||||
|
|
||||||
|
Tries to load from memory-mapped file first (preferred for large indexes),
|
||||||
|
falls back to database loading if mmap file doesn't exist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if vectors were loaded successfully.
|
||||||
|
"""
|
||||||
|
if self._loaded:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Try memory-mapped file first (preferred)
|
||||||
|
mmap_path = self.index_root / "_binary_vectors.mmap"
|
||||||
|
meta_path = mmap_path.with_suffix('.meta.json')
|
||||||
|
|
||||||
|
if mmap_path.exists() and meta_path.exists():
|
||||||
|
try:
|
||||||
|
with open(meta_path, 'r') as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
shape = tuple(meta['shape'])
|
||||||
|
self._chunk_ids = np.array(meta['chunk_ids'], dtype=np.int64)
|
||||||
|
|
||||||
|
# Memory-map the binary matrix (read-only)
|
||||||
|
self._binary_matrix = np.memmap(
|
||||||
|
str(mmap_path),
|
||||||
|
dtype=np.uint8,
|
||||||
|
mode='r',
|
||||||
|
shape=shape
|
||||||
|
)
|
||||||
|
self._is_memmap = True
|
||||||
|
self._loaded = True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Memory-mapped %d binary vectors (%d bytes each)",
|
||||||
|
len(self._chunk_ids), shape[1]
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to load mmap binary vectors, falling back to DB: %s", e)
|
||||||
|
|
||||||
|
# Fallback: load from database
|
||||||
|
return self._load_from_db()
|
||||||
|
|
||||||
|
def _load_from_db(self) -> bool:
|
||||||
|
"""Load binary vectors from database (legacy/fallback mode).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if vectors were loaded successfully.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from codexlens.storage.vector_meta_store import VectorMetadataStore
|
||||||
|
|
||||||
|
with VectorMetadataStore(self.meta_store_path) as store:
|
||||||
|
rows = store.get_all_binary_vectors()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
logger.warning("No binary vectors found in %s", self.meta_store_path)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Convert to numpy arrays for fast computation
|
||||||
|
self._chunk_ids = np.array([r[0] for r in rows], dtype=np.int64)
|
||||||
|
|
||||||
|
# Unpack bytes to numpy array
|
||||||
|
binary_arrays = []
|
||||||
|
for _, vec_bytes in rows:
|
||||||
|
arr = np.frombuffer(vec_bytes, dtype=np.uint8)
|
||||||
|
binary_arrays.append(arr)
|
||||||
|
|
||||||
|
self._binary_matrix = np.vstack(binary_arrays)
|
||||||
|
self._is_memmap = False
|
||||||
|
self._loaded = True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded %d binary vectors from DB (%d bytes each)",
|
||||||
|
len(self._chunk_ids), self._binary_matrix.shape[1]
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to load binary vectors: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query_vector: np.ndarray,
|
||||||
|
top_k: int = 100
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""Search for similar vectors using Hamming distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_vector: Dense query vector (will be binarized).
|
||||||
|
top_k: Number of top results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (chunk_id, hamming_distance) tuples sorted by distance.
|
||||||
|
"""
|
||||||
|
if not self._loaded and not self.load():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Binarize query vector
|
||||||
|
query_binary = (query_vector > 0).astype(np.uint8)
|
||||||
|
query_packed = np.packbits(query_binary)
|
||||||
|
|
||||||
|
# Compute Hamming distances using XOR and popcount
|
||||||
|
# XOR gives 1 for differing bits
|
||||||
|
xor_result = np.bitwise_xor(self._binary_matrix, query_packed)
|
||||||
|
|
||||||
|
# Vectorized popcount using lookup table (orders of magnitude faster)
|
||||||
|
# Sum the bit counts for each byte across all columns
|
||||||
|
distances = np.sum(_POPCOUNT_TABLE[xor_result], axis=1, dtype=np.int32)
|
||||||
|
|
||||||
|
# Get top-k with smallest distances
|
||||||
|
if top_k >= len(distances):
|
||||||
|
top_indices = np.argsort(distances)
|
||||||
|
else:
|
||||||
|
# Partial sort for efficiency
|
||||||
|
top_indices = np.argpartition(distances, top_k)[:top_k]
|
||||||
|
top_indices = top_indices[np.argsort(distances[top_indices])]
|
||||||
|
|
||||||
|
results = [
|
||||||
|
(int(self._chunk_ids[i]), int(distances[i]))
|
||||||
|
for i in top_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_with_rerank(
|
||||||
|
self,
|
||||||
|
query_dense: np.ndarray,
|
||||||
|
dense_vectors: np.ndarray,
|
||||||
|
dense_chunk_ids: np.ndarray,
|
||||||
|
top_k: int = 10,
|
||||||
|
candidates: int = 100
|
||||||
|
) -> List[Tuple[int, float]]:
|
||||||
|
"""Two-stage cascade search: binary filter + dense rerank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_dense: Dense query vector.
|
||||||
|
dense_vectors: Dense vectors for reranking (from HNSW or stored).
|
||||||
|
dense_chunk_ids: Chunk IDs corresponding to dense_vectors.
|
||||||
|
top_k: Final number of results.
|
||||||
|
candidates: Number of candidates from binary search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (chunk_id, cosine_similarity) tuples.
|
||||||
|
"""
|
||||||
|
# Stage 1: Binary filtering
|
||||||
|
binary_results = self.search(query_dense, top_k=candidates)
|
||||||
|
if not binary_results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidate_ids = {r[0] for r in binary_results}
|
||||||
|
|
||||||
|
# Stage 2: Dense reranking
|
||||||
|
# Find indices of candidates in dense_vectors
|
||||||
|
candidate_mask = np.isin(dense_chunk_ids, list(candidate_ids))
|
||||||
|
candidate_indices = np.where(candidate_mask)[0]
|
||||||
|
|
||||||
|
if len(candidate_indices) == 0:
|
||||||
|
# Fallback: return binary results with normalized distance
|
||||||
|
max_dist = max(r[1] for r in binary_results) if binary_results else 1
|
||||||
|
return [(r[0], 1.0 - r[1] / max_dist) for r in binary_results[:top_k]]
|
||||||
|
|
||||||
|
# Compute cosine similarities for candidates
|
||||||
|
candidate_vectors = dense_vectors[candidate_indices]
|
||||||
|
candidate_ids_array = dense_chunk_ids[candidate_indices]
|
||||||
|
|
||||||
|
# Normalize vectors
|
||||||
|
query_norm = query_dense / (np.linalg.norm(query_dense) + 1e-8)
|
||||||
|
cand_norms = candidate_vectors / (
|
||||||
|
np.linalg.norm(candidate_vectors, axis=1, keepdims=True) + 1e-8
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cosine similarities
|
||||||
|
similarities = np.dot(cand_norms, query_norm)
|
||||||
|
|
||||||
|
# Sort by similarity (descending)
|
||||||
|
sorted_indices = np.argsort(-similarities)[:top_k]
|
||||||
|
|
||||||
|
results = [
|
||||||
|
(int(candidate_ids_array[i]), float(similarities[i]))
|
||||||
|
for i in sorted_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vector_count(self) -> int:
|
||||||
|
"""Get number of loaded binary vectors."""
|
||||||
|
return len(self._chunk_ids) if self._chunk_ids is not None else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_memmap(self) -> bool:
|
||||||
|
"""Check if using memory-mapped file (vs in-memory array)."""
|
||||||
|
return self._is_memmap
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear loaded vectors from memory."""
|
||||||
|
# For memmap, just delete the reference (OS will handle cleanup)
|
||||||
|
if self._is_memmap and self._binary_matrix is not None:
|
||||||
|
del self._binary_matrix
|
||||||
|
self._chunk_ids = None
|
||||||
|
self._binary_matrix = None
|
||||||
|
self._is_memmap = False
|
||||||
|
self._loaded = False
|
||||||
3268
codex-lens/build/lib/codexlens/search/chain_search.py
Normal file
3268
codex-lens/build/lib/codexlens/search/chain_search.py
Normal file
File diff suppressed because it is too large
Load Diff
124
codex-lens/build/lib/codexlens/search/clustering/__init__.py
Normal file
124
codex-lens/build/lib/codexlens/search/clustering/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""Clustering strategies for the staged hybrid search pipeline.
|
||||||
|
|
||||||
|
This module provides extensible clustering infrastructure for grouping
|
||||||
|
similar search results and selecting representative results.
|
||||||
|
|
||||||
|
Install with: pip install codexlens[clustering]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from codexlens.search.clustering import (
|
||||||
|
... CLUSTERING_AVAILABLE,
|
||||||
|
... ClusteringConfig,
|
||||||
|
... get_strategy,
|
||||||
|
... )
|
||||||
|
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||||
|
>>> # Auto-select best available strategy with fallback
|
||||||
|
>>> strategy = get_strategy("auto", config)
|
||||||
|
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||||
|
>>>
|
||||||
|
>>> # Or explicitly use a specific strategy
|
||||||
|
>>> if CLUSTERING_AVAILABLE:
|
||||||
|
... from codexlens.search.clustering import HDBSCANStrategy
|
||||||
|
... strategy = HDBSCANStrategy(config)
|
||||||
|
... representatives = strategy.fit_predict(embeddings, results)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Always export base classes and factory (no heavy dependencies)
|
||||||
|
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||||
|
from .factory import (
|
||||||
|
ClusteringStrategyFactory,
|
||||||
|
check_clustering_strategy_available,
|
||||||
|
get_strategy,
|
||||||
|
)
|
||||||
|
from .noop_strategy import NoOpStrategy
|
||||||
|
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||||
|
|
||||||
|
# Feature flag for clustering availability (hdbscan + sklearn)
|
||||||
|
CLUSTERING_AVAILABLE = False
|
||||||
|
HDBSCAN_AVAILABLE = False
|
||||||
|
DBSCAN_AVAILABLE = False
|
||||||
|
_import_error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
|
||||||
|
"""Detect if clustering dependencies are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
|
||||||
|
"""
|
||||||
|
hdbscan_ok = False
|
||||||
|
dbscan_ok = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import hdbscan # noqa: F401
|
||||||
|
hdbscan_ok = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sklearn.cluster import DBSCAN # noqa: F401
|
||||||
|
dbscan_ok = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
all_ok = hdbscan_ok and dbscan_ok
|
||||||
|
error = None
|
||||||
|
if not all_ok:
|
||||||
|
missing = []
|
||||||
|
if not hdbscan_ok:
|
||||||
|
missing.append("hdbscan")
|
||||||
|
if not dbscan_ok:
|
||||||
|
missing.append("scikit-learn")
|
||||||
|
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
|
||||||
|
|
||||||
|
return all_ok, hdbscan_ok, dbscan_ok, error
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize on module load
|
||||||
|
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
|
||||||
|
_detect_clustering_available()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_clustering_available() -> tuple[bool, str | None]:
|
||||||
|
"""Check if all clustering dependencies are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_available, error_message).
|
||||||
|
error_message is None if available, otherwise contains install instructions.
|
||||||
|
"""
|
||||||
|
return CLUSTERING_AVAILABLE, _import_error
|
||||||
|
|
||||||
|
|
||||||
|
# Conditionally export strategy implementations
|
||||||
|
__all__ = [
|
||||||
|
# Feature flags
|
||||||
|
"CLUSTERING_AVAILABLE",
|
||||||
|
"HDBSCAN_AVAILABLE",
|
||||||
|
"DBSCAN_AVAILABLE",
|
||||||
|
"check_clustering_available",
|
||||||
|
# Base classes
|
||||||
|
"BaseClusteringStrategy",
|
||||||
|
"ClusteringConfig",
|
||||||
|
# Factory
|
||||||
|
"ClusteringStrategyFactory",
|
||||||
|
"get_strategy",
|
||||||
|
"check_clustering_strategy_available",
|
||||||
|
# Always-available strategies
|
||||||
|
"NoOpStrategy",
|
||||||
|
"FrequencyStrategy",
|
||||||
|
"FrequencyConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Conditionally add strategy classes to __all__ and module namespace
|
||||||
|
if HDBSCAN_AVAILABLE:
|
||||||
|
from .hdbscan_strategy import HDBSCANStrategy
|
||||||
|
|
||||||
|
__all__.append("HDBSCANStrategy")
|
||||||
|
|
||||||
|
if DBSCAN_AVAILABLE:
|
||||||
|
from .dbscan_strategy import DBSCANStrategy
|
||||||
|
|
||||||
|
__all__.append("DBSCANStrategy")
|
||||||
153
codex-lens/build/lib/codexlens/search/clustering/base.py
Normal file
153
codex-lens/build/lib/codexlens/search/clustering/base.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Base classes for clustering strategies in the hybrid search pipeline.
|
||||||
|
|
||||||
|
This module defines the abstract base class for clustering strategies used
|
||||||
|
in the staged hybrid search pipeline. Strategies cluster search results
|
||||||
|
based on their embeddings and select representative results from each cluster.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
from codexlens.entities import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClusteringConfig:
|
||||||
|
"""Configuration parameters for clustering strategies.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
min_cluster_size: Minimum number of results to form a cluster.
|
||||||
|
HDBSCAN default is 5, but for search results 2-3 is often better.
|
||||||
|
min_samples: Number of samples in a neighborhood for a point to be
|
||||||
|
considered a core point. Lower values allow more clusters.
|
||||||
|
metric: Distance metric for clustering. Common options:
|
||||||
|
- 'euclidean': Standard L2 distance
|
||||||
|
- 'cosine': Cosine distance (1 - cosine_similarity)
|
||||||
|
- 'manhattan': L1 distance
|
||||||
|
cluster_selection_epsilon: Distance threshold for cluster selection.
|
||||||
|
Results within this distance may be merged into the same cluster.
|
||||||
|
allow_single_cluster: If True, allow all results to form one cluster.
|
||||||
|
Useful when results are very similar.
|
||||||
|
prediction_data: If True, generate prediction data for new points.
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_cluster_size: int = 3
|
||||||
|
min_samples: int = 2
|
||||||
|
metric: str = "cosine"
|
||||||
|
cluster_selection_epsilon: float = 0.0
|
||||||
|
allow_single_cluster: bool = True
|
||||||
|
prediction_data: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
if self.min_cluster_size < 2:
|
||||||
|
raise ValueError("min_cluster_size must be >= 2")
|
||||||
|
if self.min_samples < 1:
|
||||||
|
raise ValueError("min_samples must be >= 1")
|
||||||
|
if self.metric not in ("euclidean", "cosine", "manhattan"):
|
||||||
|
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
|
||||||
|
if self.cluster_selection_epsilon < 0:
|
||||||
|
raise ValueError("cluster_selection_epsilon must be >= 0")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClusteringStrategy(ABC):
|
||||||
|
"""Abstract base class for clustering strategies.
|
||||||
|
|
||||||
|
Clustering strategies are used in the staged hybrid search pipeline to
|
||||||
|
group similar search results and select representative results from each
|
||||||
|
cluster, reducing redundancy while maintaining diversity.
|
||||||
|
|
||||||
|
Subclasses must implement:
|
||||||
|
- cluster(): Group results into clusters based on embeddings
|
||||||
|
- select_representatives(): Choose best result(s) from each cluster
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||||
|
"""Initialize the clustering strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Clustering configuration. Uses defaults if not provided.
|
||||||
|
"""
|
||||||
|
self.config = config or ClusteringConfig()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cluster(
|
||||||
|
self,
|
||||||
|
embeddings: "np.ndarray",
|
||||||
|
results: List["SearchResult"],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Cluster search results based on their embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||||
|
containing the embedding vectors for each result.
|
||||||
|
results: List of SearchResult objects corresponding to embeddings.
|
||||||
|
Used for additional metadata during clustering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of clusters, where each cluster is a list of indices
|
||||||
|
into the results list. Results not assigned to any cluster
|
||||||
|
(noise points) should be returned as single-element clusters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> strategy = HDBSCANStrategy()
|
||||||
|
>>> clusters = strategy.cluster(embeddings, results)
|
||||||
|
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
|
||||||
|
>>> # Result indices 0, 2, 5 are in cluster 0
|
||||||
|
>>> # Result indices 1, 3 are in cluster 1
|
||||||
|
>>> # Result index 4 is a noise point (singleton cluster)
|
||||||
|
>>> # Result indices 6, 7, 8 are in cluster 2
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_representatives(
|
||||||
|
self,
|
||||||
|
clusters: List[List[int]],
|
||||||
|
results: List["SearchResult"],
|
||||||
|
embeddings: Optional["np.ndarray"] = None,
|
||||||
|
) -> List["SearchResult"]:
|
||||||
|
"""Select representative results from each cluster.
|
||||||
|
|
||||||
|
This method chooses the best result(s) from each cluster to include
|
||||||
|
in the final search results. The selection can be based on:
|
||||||
|
- Highest score within cluster
|
||||||
|
- Closest to cluster centroid
|
||||||
|
- Custom selection logic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clusters: List of clusters from cluster() method.
|
||||||
|
results: Original list of SearchResult objects.
|
||||||
|
embeddings: Optional embeddings array for centroid-based selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of representative SearchResult objects, one or more per cluster,
|
||||||
|
ordered by relevance (highest score first).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> representatives = strategy.select_representatives(clusters, results)
|
||||||
|
>>> # Returns best result from each cluster
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def fit_predict(
|
||||||
|
self,
|
||||||
|
embeddings: "np.ndarray",
|
||||||
|
results: List["SearchResult"],
|
||||||
|
) -> List["SearchResult"]:
|
||||||
|
"""Convenience method to cluster and select representatives in one call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||||
|
results: List of SearchResult objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of representative SearchResult objects.
|
||||||
|
"""
|
||||||
|
clusters = self.cluster(embeddings, results)
|
||||||
|
return self.select_representatives(clusters, results, embeddings)
|
||||||
@@ -0,0 +1,197 @@
|
|||||||
|
"""DBSCAN-based clustering strategy for search results.
|
||||||
|
|
||||||
|
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
|
||||||
|
is the fallback clustering strategy when HDBSCAN is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
from codexlens.entities import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class DBSCANStrategy(BaseClusteringStrategy):
|
||||||
|
"""DBSCAN-based clustering strategy.
|
||||||
|
|
||||||
|
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
|
||||||
|
DBSCAN requires an explicit eps parameter, which is auto-computed from the
|
||||||
|
distance distribution if not provided.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
|
||||||
|
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||||
|
>>> strategy = DBSCANStrategy(config)
|
||||||
|
>>> clusters = strategy.cluster(embeddings, results)
|
||||||
|
>>> representatives = strategy.select_representatives(clusters, results)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default eps percentile for auto-computation
|
||||||
|
DEFAULT_EPS_PERCENTILE: float = 15.0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[ClusteringConfig] = None,
|
||||||
|
eps: Optional[float] = None,
|
||||||
|
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize DBSCAN clustering strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Clustering configuration. Uses defaults if not provided.
|
||||||
|
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
|
||||||
|
from the distance distribution.
|
||||||
|
eps_percentile: Percentile of pairwise distances to use for
|
||||||
|
auto-computing eps. Default is 15th percentile.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If sklearn is not installed.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
self.eps = eps
|
||||||
|
self.eps_percentile = eps_percentile
|
||||||
|
|
||||||
|
# Validate sklearn is available
|
||||||
|
try:
|
||||||
|
from sklearn.cluster import DBSCAN # noqa: F401
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"scikit-learn package is required for DBSCANStrategy. "
|
||||||
|
"Install with: pip install codexlens[clustering]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def _compute_eps(self, embeddings: "np.ndarray") -> float:
|
||||||
|
"""Auto-compute eps from pairwise distance distribution.
|
||||||
|
|
||||||
|
Uses the specified percentile of pairwise distances as eps,
|
||||||
|
which typically captures local density well.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Computed eps value.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import pairwise_distances
|
||||||
|
|
||||||
|
# Compute pairwise distances
|
||||||
|
distances = pairwise_distances(embeddings, metric=self.config.metric)
|
||||||
|
|
||||||
|
# Get upper triangle (excluding diagonal)
|
||||||
|
upper_tri = distances[np.triu_indices_from(distances, k=1)]
|
||||||
|
|
||||||
|
if len(upper_tri) == 0:
|
||||||
|
# Only one point, return a default small eps
|
||||||
|
return 0.1
|
||||||
|
|
||||||
|
# Use percentile of distances as eps
|
||||||
|
eps = float(np.percentile(upper_tri, self.eps_percentile))
|
||||||
|
|
||||||
|
# Ensure eps is positive
|
||||||
|
return max(eps, 1e-6)
|
||||||
|
|
||||||
|
def cluster(
|
||||||
|
self,
|
||||||
|
embeddings: "np.ndarray",
|
||||||
|
results: List["SearchResult"],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Cluster search results using DBSCAN algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||||
|
containing the embedding vectors for each result.
|
||||||
|
results: List of SearchResult objects corresponding to embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of clusters, where each cluster is a list of indices
|
||||||
|
into the results list. Noise points are returned as singleton clusters.
|
||||||
|
"""
|
||||||
|
from sklearn.cluster import DBSCAN
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
n_results = len(results)
|
||||||
|
if n_results == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Handle edge case: single result
|
||||||
|
if n_results == 1:
|
||||||
|
return [[0]]
|
||||||
|
|
||||||
|
# Determine eps value
|
||||||
|
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
|
||||||
|
|
||||||
|
# Configure DBSCAN clusterer
|
||||||
|
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
|
||||||
|
clusterer = DBSCAN(
|
||||||
|
eps=eps,
|
||||||
|
min_samples=self.config.min_samples,
|
||||||
|
metric=self.config.metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fit and get cluster labels
|
||||||
|
# Labels: -1 = noise, 0+ = cluster index
|
||||||
|
labels = clusterer.fit_predict(embeddings)
|
||||||
|
|
||||||
|
# Group indices by cluster label
|
||||||
|
cluster_map: dict[int, list[int]] = {}
|
||||||
|
for idx, label in enumerate(labels):
|
||||||
|
if label not in cluster_map:
|
||||||
|
cluster_map[label] = []
|
||||||
|
cluster_map[label].append(idx)
|
||||||
|
|
||||||
|
# Build result: non-noise clusters first, then noise as singletons
|
||||||
|
clusters: List[List[int]] = []
|
||||||
|
|
||||||
|
# Add proper clusters (label >= 0)
|
||||||
|
for label in sorted(cluster_map.keys()):
|
||||||
|
if label >= 0:
|
||||||
|
clusters.append(cluster_map[label])
|
||||||
|
|
||||||
|
# Add noise points as singleton clusters (label == -1)
|
||||||
|
if -1 in cluster_map:
|
||||||
|
for idx in cluster_map[-1]:
|
||||||
|
clusters.append([idx])
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def select_representatives(
|
||||||
|
self,
|
||||||
|
clusters: List[List[int]],
|
||||||
|
results: List["SearchResult"],
|
||||||
|
embeddings: Optional["np.ndarray"] = None,
|
||||||
|
) -> List["SearchResult"]:
|
||||||
|
"""Select representative results from each cluster.
|
||||||
|
|
||||||
|
Selects the result with the highest score from each cluster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clusters: List of clusters from cluster() method.
|
||||||
|
results: Original list of SearchResult objects.
|
||||||
|
embeddings: Optional embeddings (not used in score-based selection).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of representative SearchResult objects, one per cluster,
|
||||||
|
ordered by score (highest first).
|
||||||
|
"""
|
||||||
|
if not clusters or not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
representatives: List["SearchResult"] = []
|
||||||
|
|
||||||
|
for cluster_indices in clusters:
|
||||||
|
if not cluster_indices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find the result with the highest score in this cluster
|
||||||
|
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||||
|
representatives.append(results[best_idx])
|
||||||
|
|
||||||
|
# Sort by score descending
|
||||||
|
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
|
||||||
|
return representatives
|
||||||
202
codex-lens/build/lib/codexlens/search/clustering/factory.py
Normal file
202
codex-lens/build/lib/codexlens/search/clustering/factory.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""Factory for creating clustering strategies.
|
||||||
|
|
||||||
|
Provides a unified interface for instantiating different clustering backends
|
||||||
|
with automatic fallback chain: hdbscan -> dbscan -> noop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||||
|
from .noop_strategy import NoOpStrategy
|
||||||
|
|
||||||
|
|
||||||
|
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
|
||||||
|
"""Check whether a specific clustering strategy can be used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy: Strategy name to check. Options:
|
||||||
|
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
|
||||||
|
- "dbscan": DBSCAN clustering (requires sklearn)
|
||||||
|
- "frequency": Frequency-based clustering (always available)
|
||||||
|
- "noop": No-op strategy (always available)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_available, error_message).
|
||||||
|
error_message is None if available, otherwise contains install instructions.
|
||||||
|
"""
|
||||||
|
strategy = (strategy or "").strip().lower()
|
||||||
|
|
||||||
|
if strategy == "hdbscan":
|
||||||
|
try:
|
||||||
|
import hdbscan # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
return False, (
|
||||||
|
"hdbscan package not available. "
|
||||||
|
"Install with: pip install codexlens[clustering]"
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
if strategy == "dbscan":
|
||||||
|
try:
|
||||||
|
from sklearn.cluster import DBSCAN # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
return False, (
|
||||||
|
"scikit-learn package not available. "
|
||||||
|
"Install with: pip install codexlens[clustering]"
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
if strategy == "frequency":
|
||||||
|
# Frequency strategy is always available (no external deps)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
if strategy == "noop":
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
return False, (
|
||||||
|
f"Invalid clustering strategy: {strategy}. "
|
||||||
|
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_strategy(
|
||||||
|
strategy: str = "hdbscan",
|
||||||
|
config: Optional[ClusteringConfig] = None,
|
||||||
|
*,
|
||||||
|
fallback: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseClusteringStrategy:
|
||||||
|
"""Factory function to create clustering strategy with fallback chain.
|
||||||
|
|
||||||
|
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy: Clustering strategy to use. Options:
|
||||||
|
- "hdbscan": HDBSCAN clustering (default, recommended)
|
||||||
|
- "dbscan": DBSCAN clustering (fallback)
|
||||||
|
- "frequency": Frequency-based clustering (groups by symbol occurrence)
|
||||||
|
- "noop": No-op strategy (returns all results ungrouped)
|
||||||
|
- "auto": Try hdbscan, then dbscan, then noop
|
||||||
|
config: Clustering configuration. Uses defaults if not provided.
|
||||||
|
For frequency strategy, pass FrequencyConfig for full control.
|
||||||
|
fallback: If True (default), automatically fall back to next strategy
|
||||||
|
in the chain when primary is unavailable. If False, raise ImportError
|
||||||
|
when requested strategy is unavailable.
|
||||||
|
**kwargs: Additional strategy-specific arguments.
|
||||||
|
For DBSCANStrategy: eps, eps_percentile
|
||||||
|
For FrequencyStrategy: group_by, min_frequency, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseClusteringStrategy: Configured clustering strategy instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If strategy is not recognized.
|
||||||
|
ImportError: If required dependencies are not installed and fallback=False.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
|
||||||
|
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||||
|
>>> # Auto-select best available strategy
|
||||||
|
>>> strategy = get_strategy("auto", config)
|
||||||
|
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
|
||||||
|
>>> strategy = get_strategy("hdbscan", config)
|
||||||
|
>>> # Use frequency-based strategy
|
||||||
|
>>> from codexlens.search.clustering import FrequencyConfig
|
||||||
|
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||||
|
>>> strategy = get_strategy("frequency", freq_config)
|
||||||
|
"""
|
||||||
|
strategy = (strategy or "").strip().lower()
|
||||||
|
|
||||||
|
# Handle "auto" - try strategies in order
|
||||||
|
if strategy == "auto":
|
||||||
|
return _get_best_available_strategy(config, **kwargs)
|
||||||
|
|
||||||
|
if strategy == "hdbscan":
|
||||||
|
ok, err = check_clustering_strategy_available("hdbscan")
|
||||||
|
if ok:
|
||||||
|
from .hdbscan_strategy import HDBSCANStrategy
|
||||||
|
return HDBSCANStrategy(config)
|
||||||
|
|
||||||
|
if fallback:
|
||||||
|
# Try dbscan fallback
|
||||||
|
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
|
||||||
|
if ok_dbscan:
|
||||||
|
from .dbscan_strategy import DBSCANStrategy
|
||||||
|
return DBSCANStrategy(config, **kwargs)
|
||||||
|
# Final fallback to noop
|
||||||
|
return NoOpStrategy(config)
|
||||||
|
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
if strategy == "dbscan":
|
||||||
|
ok, err = check_clustering_strategy_available("dbscan")
|
||||||
|
if ok:
|
||||||
|
from .dbscan_strategy import DBSCANStrategy
|
||||||
|
return DBSCANStrategy(config, **kwargs)
|
||||||
|
|
||||||
|
if fallback:
|
||||||
|
# Fallback to noop
|
||||||
|
return NoOpStrategy(config)
|
||||||
|
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
if strategy == "frequency":
|
||||||
|
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||||
|
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
|
||||||
|
if config is None or not isinstance(config, FrequencyConfig):
|
||||||
|
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
|
||||||
|
else:
|
||||||
|
freq_config = config
|
||||||
|
return FrequencyStrategy(freq_config)
|
||||||
|
|
||||||
|
if strategy == "noop":
|
||||||
|
return NoOpStrategy(config)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown clustering strategy: {strategy}. "
|
||||||
|
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_best_available_strategy(
|
||||||
|
config: Optional[ClusteringConfig] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseClusteringStrategy:
|
||||||
|
"""Get the best available clustering strategy.
|
||||||
|
|
||||||
|
Tries strategies in order: hdbscan -> dbscan -> noop
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Clustering configuration.
|
||||||
|
**kwargs: Additional strategy-specific arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Best available clustering strategy instance.
|
||||||
|
"""
|
||||||
|
# Try HDBSCAN first
|
||||||
|
ok, _ = check_clustering_strategy_available("hdbscan")
|
||||||
|
if ok:
|
||||||
|
from .hdbscan_strategy import HDBSCANStrategy
|
||||||
|
return HDBSCANStrategy(config)
|
||||||
|
|
||||||
|
# Try DBSCAN second
|
||||||
|
ok, _ = check_clustering_strategy_available("dbscan")
|
||||||
|
if ok:
|
||||||
|
from .dbscan_strategy import DBSCANStrategy
|
||||||
|
return DBSCANStrategy(config, **kwargs)
|
||||||
|
|
||||||
|
# Fallback to NoOp
|
||||||
|
return NoOpStrategy(config)
|
||||||
|
|
||||||
|
|
||||||
|
# Alias for backward compatibility
|
||||||
|
ClusteringStrategyFactory = type(
|
||||||
|
"ClusteringStrategyFactory",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"get_strategy": staticmethod(get_strategy),
|
||||||
|
"check_available": staticmethod(check_clustering_strategy_available),
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -0,0 +1,263 @@
|
|||||||
|
"""Frequency-based clustering strategy for search result deduplication.
|
||||||
|
|
||||||
|
This strategy groups search results by symbol/method name and prunes based on
|
||||||
|
occurrence frequency. High-frequency symbols (frequently referenced methods)
|
||||||
|
are considered more important and retained, while low-frequency results
|
||||||
|
(potentially noise) can be filtered out.
|
||||||
|
|
||||||
|
Use cases:
|
||||||
|
- Prioritize commonly called methods/functions
|
||||||
|
- Filter out one-off results that may be less relevant
|
||||||
|
- Deduplicate results pointing to the same symbol from different locations
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
|
||||||
|
|
||||||
|
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
from codexlens.entities import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FrequencyConfig(ClusteringConfig):
|
||||||
|
"""Configuration for frequency-based clustering strategy.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
group_by: Field to group results by for frequency counting.
|
||||||
|
- 'symbol': Group by symbol_name (default, for method/function dedup)
|
||||||
|
- 'file': Group by file path
|
||||||
|
- 'symbol_kind': Group by symbol type (function, class, etc.)
|
||||||
|
min_frequency: Minimum occurrence count to keep a result.
|
||||||
|
Results appearing less than this are considered noise and pruned.
|
||||||
|
max_representatives_per_group: Maximum results to keep per symbol group.
|
||||||
|
frequency_weight: How much to boost score based on frequency.
|
||||||
|
Final score = original_score * (1 + frequency_weight * log(frequency))
|
||||||
|
keep_mode: How to handle low-frequency results.
|
||||||
|
- 'filter': Remove results below min_frequency
|
||||||
|
- 'demote': Keep but lower their score ranking
|
||||||
|
"""
|
||||||
|
|
||||||
|
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
|
||||||
|
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
|
||||||
|
max_representatives_per_group: int = 3
|
||||||
|
frequency_weight: float = 0.1 # Boost factor for frequency
|
||||||
|
keep_mode: Literal["filter", "demote"] = "demote"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate configuration parameters."""
|
||||||
|
# Skip parent validation since we don't use HDBSCAN params
|
||||||
|
if self.min_frequency < 1:
|
||||||
|
raise ValueError("min_frequency must be >= 1")
|
||||||
|
if self.max_representatives_per_group < 1:
|
||||||
|
raise ValueError("max_representatives_per_group must be >= 1")
|
||||||
|
if self.frequency_weight < 0:
|
||||||
|
raise ValueError("frequency_weight must be >= 0")
|
||||||
|
if self.group_by not in ("symbol", "file", "symbol_kind"):
|
||||||
|
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
|
||||||
|
if self.keep_mode not in ("filter", "demote"):
|
||||||
|
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyStrategy(BaseClusteringStrategy):
|
||||||
|
"""Frequency-based clustering strategy for search result deduplication.
|
||||||
|
|
||||||
|
This strategy groups search results by symbol name (or file/kind) and:
|
||||||
|
1. Counts how many times each symbol appears in results
|
||||||
|
2. Higher frequency = more important (frequently referenced method)
|
||||||
|
3. Filters or demotes low-frequency results
|
||||||
|
4. Selects top representatives from each frequency group
|
||||||
|
|
||||||
|
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
|
||||||
|
- Does NOT require embeddings (works with metadata only)
|
||||||
|
- Is very fast (O(n) complexity)
|
||||||
|
- Is deterministic (no random initialization)
|
||||||
|
- Works well for symbol-level deduplication
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||||
|
>>> strategy = FrequencyStrategy(config)
|
||||||
|
>>> # Results with symbol "authenticate" appearing 5 times
|
||||||
|
>>> # will be prioritized over "helper_func" appearing once
|
||||||
|
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
|
||||||
|
"""Initialize the frequency strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Frequency configuration. Uses defaults if not provided.
|
||||||
|
"""
|
||||||
|
self.config: FrequencyConfig = config or FrequencyConfig()
|
||||||
|
|
||||||
|
def _get_group_key(self, result: "SearchResult") -> str:
|
||||||
|
"""Extract grouping key from a search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: SearchResult to extract key from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String key for grouping (symbol name, file path, or kind).
|
||||||
|
"""
|
||||||
|
if self.config.group_by == "symbol":
|
||||||
|
# Use symbol_name if available, otherwise fall back to file:line
|
||||||
|
symbol = getattr(result, "symbol_name", None)
|
||||||
|
if symbol:
|
||||||
|
return str(symbol)
|
||||||
|
# Fallback: use file path + start_line as pseudo-symbol
|
||||||
|
start_line = getattr(result, "start_line", 0) or 0
|
||||||
|
return f"{result.path}:{start_line}"
|
||||||
|
|
||||||
|
elif self.config.group_by == "file":
|
||||||
|
return str(result.path)
|
||||||
|
|
||||||
|
elif self.config.group_by == "symbol_kind":
|
||||||
|
kind = getattr(result, "symbol_kind", None)
|
||||||
|
return str(kind) if kind else "unknown"
|
||||||
|
|
||||||
|
return str(result.path) # Default fallback
|
||||||
|
|
||||||
|
def cluster(
|
||||||
|
self,
|
||||||
|
embeddings: "np.ndarray",
|
||||||
|
results: List["SearchResult"],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Group search results by frequency of occurrence.
|
||||||
|
|
||||||
|
Note: This method ignores embeddings and groups by metadata only.
|
||||||
|
The embeddings parameter is kept for interface compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: Ignored (kept for interface compatibility).
|
||||||
|
results: List of SearchResult objects to cluster.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of clusters (groups), where each cluster contains indices
|
||||||
|
of results with the same grouping key. Clusters are ordered by
|
||||||
|
frequency (highest frequency first).
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Group results by key
|
||||||
|
groups: Dict[str, List[int]] = defaultdict(list)
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
key = self._get_group_key(result)
|
||||||
|
groups[key].append(idx)
|
||||||
|
|
||||||
|
# Sort groups by frequency (descending) then by key (for stability)
|
||||||
|
sorted_groups = sorted(
|
||||||
|
groups.items(),
|
||||||
|
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to list of clusters
|
||||||
|
clusters = [indices for _, indices in sorted_groups]
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def select_representatives(
|
||||||
|
self,
|
||||||
|
clusters: List[List[int]],
|
||||||
|
results: List["SearchResult"],
|
||||||
|
embeddings: Optional["np.ndarray"] = None,
|
||||||
|
) -> List["SearchResult"]:
|
||||||
|
"""Select representative results based on frequency and score.
|
||||||
|
|
||||||
|
For each frequency group:
|
||||||
|
1. If frequency < min_frequency: filter or demote based on keep_mode
|
||||||
|
2. Sort by score within group
|
||||||
|
3. Apply frequency boost to scores
|
||||||
|
4. Select top N representatives
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clusters: List of clusters from cluster() method.
|
||||||
|
results: Original list of SearchResult objects.
|
||||||
|
embeddings: Optional embeddings (used for tie-breaking if provided).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of representative SearchResult objects, ordered by
|
||||||
|
frequency-adjusted score (highest first).
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
if not clusters or not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
representatives: List["SearchResult"] = []
|
||||||
|
demoted: List["SearchResult"] = []
|
||||||
|
|
||||||
|
for cluster_indices in clusters:
|
||||||
|
if not cluster_indices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
frequency = len(cluster_indices)
|
||||||
|
|
||||||
|
# Get results in this cluster, sorted by score
|
||||||
|
cluster_results = [results[i] for i in cluster_indices]
|
||||||
|
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||||
|
|
||||||
|
# Check frequency threshold
|
||||||
|
if frequency < self.config.min_frequency:
|
||||||
|
if self.config.keep_mode == "filter":
|
||||||
|
# Skip low-frequency results entirely
|
||||||
|
continue
|
||||||
|
else: # demote mode
|
||||||
|
# Keep but add to demoted list (lower priority)
|
||||||
|
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||||
|
demoted.append(result)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply frequency boost and select top representatives
|
||||||
|
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||||
|
# Calculate frequency-boosted score
|
||||||
|
original_score = getattr(result, "score", 0.0)
|
||||||
|
# log(frequency + 1) to handle frequency=1 case smoothly
|
||||||
|
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
|
||||||
|
boosted_score = original_score * frequency_boost
|
||||||
|
|
||||||
|
# Create new result with boosted score and frequency metadata
|
||||||
|
# Note: SearchResult might be immutable, so we preserve original
|
||||||
|
# and track boosted score in metadata
|
||||||
|
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
|
||||||
|
result.metadata["frequency"] = frequency
|
||||||
|
result.metadata["frequency_boosted_score"] = boosted_score
|
||||||
|
|
||||||
|
representatives.append(result)
|
||||||
|
|
||||||
|
# Sort representatives by boosted score (or original score as fallback)
|
||||||
|
def get_sort_score(r: "SearchResult") -> float:
|
||||||
|
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
|
||||||
|
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
|
||||||
|
return getattr(r, "score", 0.0)
|
||||||
|
|
||||||
|
representatives.sort(key=get_sort_score, reverse=True)
|
||||||
|
|
||||||
|
# Add demoted results at the end
|
||||||
|
if demoted:
|
||||||
|
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||||
|
representatives.extend(demoted)
|
||||||
|
|
||||||
|
return representatives
|
||||||
|
|
||||||
|
def fit_predict(
|
||||||
|
self,
|
||||||
|
embeddings: "np.ndarray",
|
||||||
|
results: List["SearchResult"],
|
||||||
|
) -> List["SearchResult"]:
|
||||||
|
"""Convenience method to cluster and select representatives in one call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: NumPy array (may be ignored for frequency-based clustering).
|
||||||
|
results: List of SearchResult objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of representative SearchResult objects.
|
||||||
|
"""
|
||||||
|
clusters = self.cluster(embeddings, results)
|
||||||
|
return self.select_representatives(clusters, results, embeddings)
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
"""HDBSCAN-based clustering strategy for search results.
|
||||||
|
|
||||||
|
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
|
||||||
|
is the primary clustering strategy for grouping similar search results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
from codexlens.entities import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class HDBSCANStrategy(BaseClusteringStrategy):
|
||||||
|
"""HDBSCAN-based clustering strategy.
|
||||||
|
|
||||||
|
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
|
||||||
|
HDBSCAN is preferred over DBSCAN because it:
|
||||||
|
- Automatically determines the number of clusters
|
||||||
|
- Handles varying density clusters well
|
||||||
|
- Identifies noise points (outliers) effectively
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
|
||||||
|
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||||
|
>>> strategy = HDBSCANStrategy(config)
|
||||||
|
>>> clusters = strategy.cluster(embeddings, results)
|
||||||
|
>>> representatives = strategy.select_representatives(clusters, results)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||||
|
"""Initialize HDBSCAN clustering strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Clustering configuration. Uses defaults if not provided.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If hdbscan package is not installed.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
# Validate hdbscan is available
|
||||||
|
try:
|
||||||
|
import hdbscan # noqa: F401
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"hdbscan package is required for HDBSCANStrategy. "
|
||||||
|
"Install with: pip install codexlens[clustering]"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def cluster(
|
||||||
|
self,
|
||||||
|
embeddings: "np.ndarray",
|
||||||
|
results: List["SearchResult"],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Cluster search results using HDBSCAN algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||||
|
containing the embedding vectors for each result.
|
||||||
|
results: List of SearchResult objects corresponding to embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of clusters, where each cluster is a list of indices
|
||||||
|
into the results list. Noise points are returned as singleton clusters.
|
||||||
|
"""
|
||||||
|
import hdbscan
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
n_results = len(results)
|
||||||
|
if n_results == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Handle edge case: fewer results than min_cluster_size
|
||||||
|
if n_results < self.config.min_cluster_size:
|
||||||
|
# Return each result as its own singleton cluster
|
||||||
|
return [[i] for i in range(n_results)]
|
||||||
|
|
||||||
|
# Configure HDBSCAN clusterer
|
||||||
|
clusterer = hdbscan.HDBSCAN(
|
||||||
|
min_cluster_size=self.config.min_cluster_size,
|
||||||
|
min_samples=self.config.min_samples,
|
||||||
|
metric=self.config.metric,
|
||||||
|
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
|
||||||
|
allow_single_cluster=self.config.allow_single_cluster,
|
||||||
|
prediction_data=self.config.prediction_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fit and get cluster labels
|
||||||
|
# Labels: -1 = noise, 0+ = cluster index
|
||||||
|
labels = clusterer.fit_predict(embeddings)
|
||||||
|
|
||||||
|
# Group indices by cluster label
|
||||||
|
cluster_map: dict[int, list[int]] = {}
|
||||||
|
for idx, label in enumerate(labels):
|
||||||
|
if label not in cluster_map:
|
||||||
|
cluster_map[label] = []
|
||||||
|
cluster_map[label].append(idx)
|
||||||
|
|
||||||
|
# Build result: non-noise clusters first, then noise as singletons
|
||||||
|
clusters: List[List[int]] = []
|
||||||
|
|
||||||
|
# Add proper clusters (label >= 0)
|
||||||
|
for label in sorted(cluster_map.keys()):
|
||||||
|
if label >= 0:
|
||||||
|
clusters.append(cluster_map[label])
|
||||||
|
|
||||||
|
# Add noise points as singleton clusters (label == -1)
|
||||||
|
if -1 in cluster_map:
|
||||||
|
for idx in cluster_map[-1]:
|
||||||
|
clusters.append([idx])
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def select_representatives(
|
||||||
|
self,
|
||||||
|
clusters: List[List[int]],
|
||||||
|
results: List["SearchResult"],
|
||||||
|
embeddings: Optional["np.ndarray"] = None,
|
||||||
|
) -> List["SearchResult"]:
|
||||||
|
"""Select representative results from each cluster.
|
||||||
|
|
||||||
|
Selects the result with the highest score from each cluster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clusters: List of clusters from cluster() method.
|
||||||
|
results: Original list of SearchResult objects.
|
||||||
|
embeddings: Optional embeddings (not used in score-based selection).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of representative SearchResult objects, one per cluster,
|
||||||
|
ordered by score (highest first).
|
||||||
|
"""
|
||||||
|
if not clusters or not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
representatives: List["SearchResult"] = []
|
||||||
|
|
||||||
|
for cluster_indices in clusters:
|
||||||
|
if not cluster_indices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find the result with the highest score in this cluster
|
||||||
|
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||||
|
representatives.append(results[best_idx])
|
||||||
|
|
||||||
|
# Sort by score descending
|
||||||
|
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
|
||||||
|
return representatives
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
"""No-op clustering strategy for search results.
|
||||||
|
|
||||||
|
NoOpStrategy returns all results ungrouped when clustering dependencies
|
||||||
|
are not available or clustering is disabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
from codexlens.entities import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpStrategy(BaseClusteringStrategy):
|
||||||
|
"""No-op clustering strategy that returns all results ungrouped.
|
||||||
|
|
||||||
|
This strategy is used as a final fallback when no clustering dependencies
|
||||||
|
are available, or when clustering is explicitly disabled. Each result
|
||||||
|
is treated as its own singleton cluster.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from codexlens.search.clustering import NoOpStrategy
|
||||||
|
>>> strategy = NoOpStrategy()
|
||||||
|
>>> clusters = strategy.cluster(embeddings, results)
|
||||||
|
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
|
||||||
|
>>> representatives = strategy.select_representatives(clusters, results)
|
||||||
|
>>> # Returns all results sorted by score
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||||
|
"""Initialize NoOp clustering strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Clustering configuration. Ignored for NoOpStrategy
|
||||||
|
but accepted for interface compatibility.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def cluster(
|
||||||
|
self,
|
||||||
|
embeddings: "np.ndarray",
|
||||||
|
results: List["SearchResult"],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Return each result as its own singleton cluster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||||
|
Not used but accepted for interface compatibility.
|
||||||
|
results: List of SearchResult objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of singleton clusters, one per result.
|
||||||
|
"""
|
||||||
|
return [[i] for i in range(len(results))]
|
||||||
|
|
||||||
|
def select_representatives(
|
||||||
|
self,
|
||||||
|
clusters: List[List[int]],
|
||||||
|
results: List["SearchResult"],
|
||||||
|
embeddings: Optional["np.ndarray"] = None,
|
||||||
|
) -> List["SearchResult"]:
|
||||||
|
"""Return all results sorted by score.
|
||||||
|
|
||||||
|
Since each cluster is a singleton, this effectively returns all
|
||||||
|
results sorted by score descending.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clusters: List of singleton clusters.
|
||||||
|
results: Original list of SearchResult objects.
|
||||||
|
embeddings: Optional embeddings (not used).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All SearchResult objects sorted by score (highest first).
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Return all results sorted by score
|
||||||
|
return sorted(results, key=lambda r: r.score, reverse=True)
|
||||||
171
codex-lens/build/lib/codexlens/search/enrichment.py
Normal file
171
codex-lens/build/lib/codexlens/search/enrichment.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
# codex-lens/src/codexlens/search/enrichment.py
|
||||||
|
"""Relationship enrichment for search results."""
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from codexlens.config import Config
|
||||||
|
from codexlens.entities import SearchResult
|
||||||
|
from codexlens.search.graph_expander import GraphExpander
|
||||||
|
from codexlens.storage.path_mapper import PathMapper
|
||||||
|
|
||||||
|
|
||||||
|
class RelationshipEnricher:
|
||||||
|
"""Enriches search results with code graph relationships."""
|
||||||
|
|
||||||
|
def __init__(self, index_path: Path):
|
||||||
|
"""Initialize with path to index database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to _index.db SQLite database
|
||||||
|
"""
|
||||||
|
self.index_path = index_path
|
||||||
|
self.db_conn: Optional[sqlite3.Connection] = None
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
def _connect(self) -> None:
|
||||||
|
"""Establish read-only database connection."""
|
||||||
|
if self.index_path.exists():
|
||||||
|
self.db_conn = sqlite3.connect(
|
||||||
|
f"file:{self.index_path}?mode=ro",
|
||||||
|
uri=True,
|
||||||
|
check_same_thread=False
|
||||||
|
)
|
||||||
|
self.db_conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
def enrich(self, results: List[Dict[str, Any]], limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
|
"""Add relationship data to search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of search result dictionaries
|
||||||
|
limit: Maximum number of results to enrich
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Results with relationships field added
|
||||||
|
"""
|
||||||
|
if not self.db_conn:
|
||||||
|
return results
|
||||||
|
|
||||||
|
for result in results[:limit]:
|
||||||
|
file_path = result.get('file') or result.get('path')
|
||||||
|
symbol_name = result.get('symbol')
|
||||||
|
result['relationships'] = self._find_relationships(file_path, symbol_name)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _find_relationships(self, file_path: Optional[str], symbol_name: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""Query relationships for a symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file containing the symbol
|
||||||
|
symbol_name: Name of the symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of relationship dictionaries with type, direction, target/source, file, line
|
||||||
|
"""
|
||||||
|
if not self.db_conn or not symbol_name:
|
||||||
|
return []
|
||||||
|
|
||||||
|
relationships = []
|
||||||
|
cursor = self.db_conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find symbol ID(s) by name and optionally file
|
||||||
|
if file_path:
|
||||||
|
cursor.execute(
|
||||||
|
'SELECT id FROM symbols WHERE name = ? AND file_path = ?',
|
||||||
|
(symbol_name, file_path)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor.execute('SELECT id FROM symbols WHERE name = ?', (symbol_name,))
|
||||||
|
|
||||||
|
symbol_ids = [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
if not symbol_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Query outgoing relationships (symbol is source)
|
||||||
|
placeholders = ','.join('?' * len(symbol_ids))
|
||||||
|
cursor.execute(f'''
|
||||||
|
SELECT sr.relationship_type, sr.target_symbol_fqn, sr.file_path, sr.line
|
||||||
|
FROM symbol_relationships sr
|
||||||
|
WHERE sr.source_symbol_id IN ({placeholders})
|
||||||
|
''', symbol_ids)
|
||||||
|
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
relationships.append({
|
||||||
|
'type': row[0],
|
||||||
|
'direction': 'outgoing',
|
||||||
|
'target': row[1],
|
||||||
|
'file': row[2],
|
||||||
|
'line': row[3],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Query incoming relationships (symbol is target)
|
||||||
|
# Match against symbol name or qualified name patterns
|
||||||
|
cursor.execute('''
|
||||||
|
SELECT sr.relationship_type, s.name AS source_name, sr.file_path, sr.line
|
||||||
|
FROM symbol_relationships sr
|
||||||
|
JOIN symbols s ON sr.source_symbol_id = s.id
|
||||||
|
WHERE sr.target_symbol_fqn = ? OR sr.target_symbol_fqn LIKE ?
|
||||||
|
''', (symbol_name, f'%.{symbol_name}'))
|
||||||
|
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
rel_type = row[0]
|
||||||
|
# Convert to incoming type
|
||||||
|
incoming_type = self._to_incoming_type(rel_type)
|
||||||
|
relationships.append({
|
||||||
|
'type': incoming_type,
|
||||||
|
'direction': 'incoming',
|
||||||
|
'source': row[1],
|
||||||
|
'file': row[2],
|
||||||
|
'line': row[3],
|
||||||
|
})
|
||||||
|
|
||||||
|
except sqlite3.Error:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return relationships
|
||||||
|
|
||||||
|
def _to_incoming_type(self, outgoing_type: str) -> str:
|
||||||
|
"""Convert outgoing relationship type to incoming type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outgoing_type: The outgoing relationship type (e.g., 'calls', 'imports')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Corresponding incoming type (e.g., 'called_by', 'imported_by')
|
||||||
|
"""
|
||||||
|
type_map = {
|
||||||
|
'calls': 'called_by',
|
||||||
|
'imports': 'imported_by',
|
||||||
|
'extends': 'extended_by',
|
||||||
|
}
|
||||||
|
return type_map.get(outgoing_type, f'{outgoing_type}_by')
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close database connection."""
|
||||||
|
if self.db_conn:
|
||||||
|
self.db_conn.close()
|
||||||
|
self.db_conn = None
|
||||||
|
|
||||||
|
def __enter__(self) -> 'RelationshipEnricher':
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class SearchEnrichmentPipeline:
|
||||||
|
"""Search post-processing pipeline (optional enrichments)."""
|
||||||
|
|
||||||
|
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._graph_expander = GraphExpander(mapper, config=config)
|
||||||
|
|
||||||
|
def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]:
|
||||||
|
"""Expand base results with related symbols when enabled in config."""
|
||||||
|
if self._config is None or not getattr(self._config, "enable_graph_expansion", False):
|
||||||
|
return []
|
||||||
|
|
||||||
|
depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2)
|
||||||
|
return self._graph_expander.expand(results, depth=depth)
|
||||||
264
codex-lens/build/lib/codexlens/search/graph_expander.py
Normal file
264
codex-lens/build/lib/codexlens/search/graph_expander.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
"""Graph expansion for search results using precomputed neighbors.
|
||||||
|
|
||||||
|
Expands top search results with related symbol definitions by traversing
|
||||||
|
precomputed N-hop neighbors stored in the per-directory index databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from codexlens.config import Config
|
||||||
|
from codexlens.entities import SearchResult
|
||||||
|
from codexlens.storage.path_mapper import PathMapper
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]:
|
||||||
|
return (result.path, result.symbol_name, result.start_line, result.end_line)
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]:
|
||||||
|
if content is None:
|
||||||
|
return None
|
||||||
|
if start_line is None or end_line is None:
|
||||||
|
return None
|
||||||
|
if start_line < 1 or end_line < start_line:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
start_idx = max(0, start_line - 1)
|
||||||
|
end_idx = min(len(lines), end_line)
|
||||||
|
if start_idx >= len(lines):
|
||||||
|
return None
|
||||||
|
return "\n".join(lines[start_idx:end_idx])
|
||||||
|
|
||||||
|
|
||||||
|
class GraphExpander:
|
||||||
|
"""Expands SearchResult lists with related symbols from the code graph."""
|
||||||
|
|
||||||
|
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||||
|
self._mapper = mapper
|
||||||
|
self._config = config
|
||||||
|
self._logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def expand(
|
||||||
|
self,
|
||||||
|
results: Sequence[SearchResult],
|
||||||
|
*,
|
||||||
|
depth: Optional[int] = None,
|
||||||
|
max_expand: int = 10,
|
||||||
|
max_related: int = 50,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Expand top results with related symbols.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Base ranked results.
|
||||||
|
depth: Maximum relationship depth to include (defaults to Config or 2).
|
||||||
|
max_expand: Only expand the top-N base results to bound cost.
|
||||||
|
max_related: Maximum related results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of related SearchResult objects with relationship_depth metadata.
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2
|
||||||
|
max_depth = int(depth if depth is not None else configured_depth)
|
||||||
|
if max_depth <= 0:
|
||||||
|
return []
|
||||||
|
max_depth = min(max_depth, 2)
|
||||||
|
|
||||||
|
expand_count = max(0, int(max_expand))
|
||||||
|
related_limit = max(0, int(max_related))
|
||||||
|
if expand_count == 0 or related_limit == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
seen = {_result_key(r) for r in results}
|
||||||
|
related_results: List[SearchResult] = []
|
||||||
|
conn_cache: Dict[Path, sqlite3.Connection] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
for base in list(results)[:expand_count]:
|
||||||
|
if len(related_results) >= related_limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not base.symbol_name or not base.path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
index_path = self._mapper.source_to_index_db(Path(base.path).parent)
|
||||||
|
conn = conn_cache.get(index_path)
|
||||||
|
if conn is None:
|
||||||
|
conn = self._connect_readonly(index_path)
|
||||||
|
if conn is None:
|
||||||
|
continue
|
||||||
|
conn_cache[index_path] = conn
|
||||||
|
|
||||||
|
source_ids = self._resolve_source_symbol_ids(
|
||||||
|
conn,
|
||||||
|
file_path=base.path,
|
||||||
|
symbol_name=base.symbol_name,
|
||||||
|
symbol_kind=base.symbol_kind,
|
||||||
|
)
|
||||||
|
if not source_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for source_id in source_ids:
|
||||||
|
neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit)
|
||||||
|
for neighbor_id, rel_depth in neighbors:
|
||||||
|
if len(related_results) >= related_limit:
|
||||||
|
break
|
||||||
|
row = self._get_symbol_details(conn, neighbor_id)
|
||||||
|
if row is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
path = str(row["full_path"])
|
||||||
|
symbol_name = str(row["name"])
|
||||||
|
symbol_kind = str(row["kind"])
|
||||||
|
start_line = int(row["start_line"]) if row["start_line"] is not None else None
|
||||||
|
end_line = int(row["end_line"]) if row["end_line"] is not None else None
|
||||||
|
content_block = _slice_content_block(
|
||||||
|
str(row["content"]) if row["content"] is not None else "",
|
||||||
|
start_line,
|
||||||
|
end_line,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = float(base.score) * (0.5 ** int(rel_depth))
|
||||||
|
candidate = SearchResult(
|
||||||
|
path=path,
|
||||||
|
score=max(0.0, score),
|
||||||
|
excerpt=None,
|
||||||
|
content=content_block,
|
||||||
|
start_line=start_line,
|
||||||
|
end_line=end_line,
|
||||||
|
symbol_name=symbol_name,
|
||||||
|
symbol_kind=symbol_kind,
|
||||||
|
metadata={"relationship_depth": int(rel_depth)},
|
||||||
|
)
|
||||||
|
|
||||||
|
key = _result_key(candidate)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
related_results.append(candidate)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for conn in conn_cache.values():
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return related_results
|
||||||
|
|
||||||
|
def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]:
|
||||||
|
try:
|
||||||
|
if not index_path.exists() or index_path.stat().st_size == 0:
|
||||||
|
return None
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_source_symbol_ids(
|
||||||
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
symbol_name: str,
|
||||||
|
symbol_kind: Optional[str],
|
||||||
|
) -> List[int]:
|
||||||
|
try:
|
||||||
|
if symbol_kind:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT s.id
|
||||||
|
FROM symbols s
|
||||||
|
JOIN files f ON f.id = s.file_id
|
||||||
|
WHERE f.full_path = ? AND s.name = ? AND s.kind = ?
|
||||||
|
""",
|
||||||
|
(file_path, symbol_name, symbol_kind),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT s.id
|
||||||
|
FROM symbols s
|
||||||
|
JOIN files f ON f.id = s.file_id
|
||||||
|
WHERE f.full_path = ? AND s.name = ?
|
||||||
|
""",
|
||||||
|
(file_path, symbol_name),
|
||||||
|
).fetchall()
|
||||||
|
except sqlite3.Error:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ids: List[int] = []
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
ids.append(int(row["id"]))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def _get_neighbors(
|
||||||
|
self,
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
source_symbol_id: int,
|
||||||
|
*,
|
||||||
|
max_depth: int,
|
||||||
|
limit: int,
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
try:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT neighbor_symbol_id, relationship_depth
|
||||||
|
FROM graph_neighbors
|
||||||
|
WHERE source_symbol_id = ? AND relationship_depth <= ?
|
||||||
|
ORDER BY relationship_depth ASC, neighbor_symbol_id ASC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(int(source_symbol_id), int(max_depth), int(limit)),
|
||||||
|
).fetchall()
|
||||||
|
except sqlite3.Error:
|
||||||
|
return []
|
||||||
|
|
||||||
|
neighbors: List[Tuple[int, int]] = []
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"])))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]:
|
||||||
|
try:
|
||||||
|
return conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
s.id,
|
||||||
|
s.name,
|
||||||
|
s.kind,
|
||||||
|
s.start_line,
|
||||||
|
s.end_line,
|
||||||
|
f.full_path,
|
||||||
|
f.content
|
||||||
|
FROM symbols s
|
||||||
|
JOIN files f ON f.id = s.file_id
|
||||||
|
WHERE s.id = ?
|
||||||
|
""",
|
||||||
|
(int(symbol_id),),
|
||||||
|
).fetchone()
|
||||||
|
except sqlite3.Error:
|
||||||
|
return None
|
||||||
|
|
||||||
1409
codex-lens/build/lib/codexlens/search/hybrid_search.py
Normal file
1409
codex-lens/build/lib/codexlens/search/hybrid_search.py
Normal file
File diff suppressed because it is too large
Load Diff
242
codex-lens/build/lib/codexlens/search/query_parser.py
Normal file
242
codex-lens/build/lib/codexlens/search/query_parser.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""Query preprocessing for CodexLens search.
|
||||||
|
|
||||||
|
Provides query expansion for better identifier matching:
|
||||||
|
- CamelCase splitting: UserAuth → User OR Auth
|
||||||
|
- snake_case splitting: user_auth → user OR auth
|
||||||
|
- Preserves original query for exact matching
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Set, List
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryParser:
|
||||||
|
"""Parser for preprocessing search queries before FTS5 execution.
|
||||||
|
|
||||||
|
Expands identifier-style queries (CamelCase, snake_case) into OR queries
|
||||||
|
to improve recall when searching for code symbols.
|
||||||
|
|
||||||
|
Example transformations:
|
||||||
|
- 'UserAuth' → 'UserAuth OR User OR Auth'
|
||||||
|
- 'user_auth' → 'user_auth OR user OR auth'
|
||||||
|
- 'getUserData' → 'getUserData OR get OR User OR Data'
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Patterns for identifier splitting
|
||||||
|
CAMEL_CASE_PATTERN = re.compile(r'([a-z])([A-Z])')
|
||||||
|
SNAKE_CASE_PATTERN = re.compile(r'_+')
|
||||||
|
KEBAB_CASE_PATTERN = re.compile(r'-+')
|
||||||
|
|
||||||
|
# Minimum token length to include in expansion (avoid noise from single chars)
|
||||||
|
MIN_TOKEN_LENGTH = 2
|
||||||
|
|
||||||
|
# All-caps acronyms pattern (e.g., HTTP, SQL, API)
|
||||||
|
ALL_CAPS_PATTERN = re.compile(r'^[A-Z]{2,}$')
|
||||||
|
|
||||||
|
def __init__(self, enable: bool = True, min_token_length: int = 2):
|
||||||
|
"""Initialize query parser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable: Whether to enable query preprocessing
|
||||||
|
min_token_length: Minimum token length to include in expansion
|
||||||
|
"""
|
||||||
|
self.enable = enable
|
||||||
|
self.min_token_length = min_token_length
|
||||||
|
|
||||||
|
def preprocess_query(self, query: str) -> str:
|
||||||
|
"""Preprocess query with identifier expansion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Original search query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expanded query with OR operator connecting original and split tokens
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> parser = QueryParser()
|
||||||
|
>>> parser.preprocess_query('UserAuth')
|
||||||
|
'UserAuth OR User OR Auth'
|
||||||
|
>>> parser.preprocess_query('get_user_data')
|
||||||
|
'get_user_data OR get OR user OR data'
|
||||||
|
"""
|
||||||
|
if not self.enable:
|
||||||
|
return query
|
||||||
|
|
||||||
|
query = query.strip()
|
||||||
|
if not query:
|
||||||
|
return query
|
||||||
|
|
||||||
|
# Extract tokens from query (handle multiple words/terms)
|
||||||
|
# For simple queries, just process the whole thing
|
||||||
|
# For complex FTS5 queries with operators, preserve structure
|
||||||
|
if self._is_simple_query(query):
|
||||||
|
return self._expand_simple_query(query)
|
||||||
|
else:
|
||||||
|
# Complex query with FTS5 operators, don't expand
|
||||||
|
log.debug(f"Skipping expansion for complex FTS5 query: {query}")
|
||||||
|
return query
|
||||||
|
|
||||||
|
def _is_simple_query(self, query: str) -> bool:
|
||||||
|
"""Check if query is simple (no FTS5 operators).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if query is simple (safe to expand), False otherwise
|
||||||
|
"""
|
||||||
|
# Check for FTS5 operators that indicate complex query
|
||||||
|
fts5_operators = ['OR', 'AND', 'NOT', 'NEAR', '*', '^', '"']
|
||||||
|
return not any(op in query for op in fts5_operators)
|
||||||
|
|
||||||
|
def _expand_simple_query(self, query: str) -> str:
|
||||||
|
"""Expand a simple query with identifier splitting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Simple search query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Expanded query with OR operators
|
||||||
|
"""
|
||||||
|
tokens: Set[str] = set()
|
||||||
|
|
||||||
|
# Always include original query
|
||||||
|
tokens.add(query)
|
||||||
|
|
||||||
|
# Split on whitespace first
|
||||||
|
words = query.split()
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
# Extract tokens from this word
|
||||||
|
word_tokens = self._extract_tokens(word)
|
||||||
|
tokens.update(word_tokens)
|
||||||
|
|
||||||
|
# Filter out short tokens and duplicates
|
||||||
|
filtered_tokens = [
|
||||||
|
t for t in tokens
|
||||||
|
if len(t) >= self.min_token_length
|
||||||
|
]
|
||||||
|
|
||||||
|
# Remove duplicates while preserving original query first
|
||||||
|
unique_tokens: List[str] = []
|
||||||
|
seen: Set[str] = set()
|
||||||
|
|
||||||
|
# Always put original query first
|
||||||
|
if query not in seen and len(query) >= self.min_token_length:
|
||||||
|
unique_tokens.append(query)
|
||||||
|
seen.add(query)
|
||||||
|
|
||||||
|
# Add other tokens
|
||||||
|
for token in filtered_tokens:
|
||||||
|
if token not in seen:
|
||||||
|
unique_tokens.append(token)
|
||||||
|
seen.add(token)
|
||||||
|
|
||||||
|
# Join with OR operator (only if we have multiple tokens)
|
||||||
|
if len(unique_tokens) > 1:
|
||||||
|
expanded = ' OR '.join(unique_tokens)
|
||||||
|
log.debug(f"Expanded query: '{query}' → '{expanded}'")
|
||||||
|
return expanded
|
||||||
|
else:
|
||||||
|
return query
|
||||||
|
|
||||||
|
def _extract_tokens(self, word: str) -> Set[str]:
|
||||||
|
"""Extract tokens from a single word using various splitting strategies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: Single word/identifier to split
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of extracted tokens
|
||||||
|
"""
|
||||||
|
tokens: Set[str] = set()
|
||||||
|
|
||||||
|
# Add original word
|
||||||
|
tokens.add(word)
|
||||||
|
|
||||||
|
# Handle all-caps acronyms (don't split)
|
||||||
|
if self.ALL_CAPS_PATTERN.match(word):
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# CamelCase splitting
|
||||||
|
camel_tokens = self._split_camel_case(word)
|
||||||
|
tokens.update(camel_tokens)
|
||||||
|
|
||||||
|
# snake_case splitting
|
||||||
|
snake_tokens = self._split_snake_case(word)
|
||||||
|
tokens.update(snake_tokens)
|
||||||
|
|
||||||
|
# kebab-case splitting
|
||||||
|
kebab_tokens = self._split_kebab_case(word)
|
||||||
|
tokens.update(kebab_tokens)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _split_camel_case(self, word: str) -> List[str]:
|
||||||
|
"""Split CamelCase identifier into tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: CamelCase identifier (e.g., 'getUserData')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tokens (e.g., ['get', 'User', 'Data'])
|
||||||
|
"""
|
||||||
|
# Insert space before uppercase letters preceded by lowercase
|
||||||
|
spaced = self.CAMEL_CASE_PATTERN.sub(r'\1 \2', word)
|
||||||
|
# Split on spaces and filter empty
|
||||||
|
return [t for t in spaced.split() if t]
|
||||||
|
|
||||||
|
def _split_snake_case(self, word: str) -> List[str]:
|
||||||
|
"""Split snake_case identifier into tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: snake_case identifier (e.g., 'get_user_data')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tokens (e.g., ['get', 'user', 'data'])
|
||||||
|
"""
|
||||||
|
# Split on underscores
|
||||||
|
return [t for t in self.SNAKE_CASE_PATTERN.split(word) if t]
|
||||||
|
|
||||||
|
def _split_kebab_case(self, word: str) -> List[str]:
|
||||||
|
"""Split kebab-case identifier into tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: kebab-case identifier (e.g., 'get-user-data')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tokens (e.g., ['get', 'user', 'data'])
|
||||||
|
"""
|
||||||
|
# Split on hyphens
|
||||||
|
return [t for t in self.KEBAB_CASE_PATTERN.split(word) if t]
|
||||||
|
|
||||||
|
|
||||||
|
# Global default parser instance
|
||||||
|
_default_parser = QueryParser(enable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_query(query: str, enable: bool = True) -> str:
|
||||||
|
"""Convenience function for query preprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Original search query
|
||||||
|
enable: Whether to enable preprocessing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed query with identifier expansion
|
||||||
|
"""
|
||||||
|
if not enable:
|
||||||
|
return query
|
||||||
|
|
||||||
|
return _default_parser.preprocess_query(query)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QueryParser",
|
||||||
|
"preprocess_query",
|
||||||
|
]
|
||||||
942
codex-lens/build/lib/codexlens/search/ranking.py
Normal file
942
codex-lens/build/lib/codexlens/search/ranking.py
Normal file
@@ -0,0 +1,942 @@
|
|||||||
|
"""Ranking algorithms for hybrid search result fusion.
|
||||||
|
|
||||||
|
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
|
||||||
|
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from codexlens.entities import SearchResult, AdditionalLocation
|
||||||
|
|
||||||
|
|
||||||
|
# Default RRF weights for SPLADE-based hybrid search
|
||||||
|
DEFAULT_WEIGHTS = {
|
||||||
|
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
|
||||||
|
"vector": 0.5,
|
||||||
|
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||||
|
}
|
||||||
|
|
||||||
|
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
||||||
|
FTS_FALLBACK_WEIGHTS = {
|
||||||
|
"exact": 0.25,
|
||||||
|
"fuzzy": 0.1,
|
||||||
|
"vector": 0.5,
|
||||||
|
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class QueryIntent(str, Enum):
|
||||||
|
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
||||||
|
|
||||||
|
KEYWORD = "keyword"
|
||||||
|
SEMANTIC = "semantic"
|
||||||
|
MIXED = "mixed"
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_weights(weights: Dict[str, float | None]) -> Dict[str, float | None]:
|
||||||
|
"""Normalize weights to sum to 1.0 (best-effort)."""
|
||||||
|
total = sum(float(v) for v in weights.values() if v is not None)
|
||||||
|
|
||||||
|
# NaN total: do not attempt to normalize (division would propagate NaNs).
|
||||||
|
if math.isnan(total):
|
||||||
|
return dict(weights)
|
||||||
|
|
||||||
|
# Infinite total: do not attempt to normalize (division yields 0 or NaN).
|
||||||
|
if not math.isfinite(total):
|
||||||
|
return dict(weights)
|
||||||
|
|
||||||
|
# Zero/negative total: do not attempt to normalize (invalid denominator).
|
||||||
|
if total <= 0:
|
||||||
|
return dict(weights)
|
||||||
|
|
||||||
|
return {k: (float(v) / total if v is not None else None) for k, v in weights.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_query_intent(query: str) -> QueryIntent:
|
||||||
|
"""Detect whether a query is code-like, natural-language, or mixed.
|
||||||
|
|
||||||
|
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
|
||||||
|
"""
|
||||||
|
trimmed = (query or "").strip()
|
||||||
|
if not trimmed:
|
||||||
|
return QueryIntent.MIXED
|
||||||
|
|
||||||
|
lower = trimmed.lower()
|
||||||
|
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
|
||||||
|
|
||||||
|
has_code_signals = bool(
|
||||||
|
re.search(r"(::|->|\.)", trimmed)
|
||||||
|
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
|
||||||
|
or re.search(r"\b\w+_\w+\b", trimmed)
|
||||||
|
or re.search(
|
||||||
|
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
|
||||||
|
lower,
|
||||||
|
flags=re.IGNORECASE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
has_natural_signals = bool(
|
||||||
|
word_count > 5
|
||||||
|
or "?" in trimmed
|
||||||
|
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
|
||||||
|
or re.search(
|
||||||
|
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
|
||||||
|
trimmed,
|
||||||
|
flags=re.IGNORECASE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_code_signals and has_natural_signals:
|
||||||
|
return QueryIntent.MIXED
|
||||||
|
if has_code_signals:
|
||||||
|
return QueryIntent.KEYWORD
|
||||||
|
if has_natural_signals:
|
||||||
|
return QueryIntent.SEMANTIC
|
||||||
|
return QueryIntent.MIXED
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_weights_by_intent(
|
||||||
|
intent: QueryIntent,
|
||||||
|
base_weights: Dict[str, float],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Adjust RRF weights based on query intent."""
|
||||||
|
# Check if using SPLADE or FTS mode
|
||||||
|
use_splade = "splade" in base_weights
|
||||||
|
|
||||||
|
if intent == QueryIntent.KEYWORD:
|
||||||
|
if use_splade:
|
||||||
|
target = {"splade": 0.6, "vector": 0.4}
|
||||||
|
else:
|
||||||
|
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||||
|
elif intent == QueryIntent.SEMANTIC:
|
||||||
|
if use_splade:
|
||||||
|
target = {"splade": 0.3, "vector": 0.7}
|
||||||
|
else:
|
||||||
|
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||||
|
else:
|
||||||
|
target = dict(base_weights)
|
||||||
|
|
||||||
|
# Filter to active backends
|
||||||
|
keys = list(base_weights.keys())
|
||||||
|
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||||
|
return normalize_weights(filtered)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rrf_weights(
|
||||||
|
query: str,
|
||||||
|
base_weights: Dict[str, float],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Compute adaptive RRF weights from query intent."""
|
||||||
|
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
|
||||||
|
|
||||||
|
|
||||||
|
# File extensions to category mapping for fast lookup
|
||||||
|
_EXT_TO_CATEGORY: Dict[str, str] = {
|
||||||
|
# Code extensions
|
||||||
|
".py": "code", ".js": "code", ".jsx": "code", ".ts": "code", ".tsx": "code",
|
||||||
|
".java": "code", ".go": "code", ".zig": "code", ".m": "code", ".mm": "code",
|
||||||
|
".c": "code", ".h": "code", ".cc": "code", ".cpp": "code", ".hpp": "code", ".cxx": "code",
|
||||||
|
".rs": "code",
|
||||||
|
# Doc extensions
|
||||||
|
".md": "doc", ".mdx": "doc", ".txt": "doc", ".rst": "doc",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_category(path: str) -> Optional[str]:
|
||||||
|
"""Get file category ('code' or 'doc') from path extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
'code', 'doc', or None if unknown
|
||||||
|
"""
|
||||||
|
ext = Path(path).suffix.lower()
|
||||||
|
return _EXT_TO_CATEGORY.get(ext)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_results_by_category(
|
||||||
|
results: List[SearchResult],
|
||||||
|
intent: QueryIntent,
|
||||||
|
allow_mixed: bool = True,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Filter results by category based on query intent.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- KEYWORD (code intent): Only return code files
|
||||||
|
- SEMANTIC (doc intent): Prefer docs, but allow code if allow_mixed=True
|
||||||
|
- MIXED: Return all results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of SearchResult objects
|
||||||
|
intent: Query intent from detect_query_intent()
|
||||||
|
allow_mixed: If True, SEMANTIC intent includes code files with lower priority
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered and re-ranked list of SearchResult objects
|
||||||
|
"""
|
||||||
|
if not results or intent == QueryIntent.MIXED:
|
||||||
|
return results
|
||||||
|
|
||||||
|
code_results = []
|
||||||
|
doc_results = []
|
||||||
|
unknown_results = []
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
category = get_file_category(r.path)
|
||||||
|
if category == "code":
|
||||||
|
code_results.append(r)
|
||||||
|
elif category == "doc":
|
||||||
|
doc_results.append(r)
|
||||||
|
else:
|
||||||
|
unknown_results.append(r)
|
||||||
|
|
||||||
|
if intent == QueryIntent.KEYWORD:
|
||||||
|
# Code intent: return only code files + unknown (might be code)
|
||||||
|
filtered = code_results + unknown_results
|
||||||
|
elif intent == QueryIntent.SEMANTIC:
|
||||||
|
if allow_mixed:
|
||||||
|
# Semantic intent with mixed: docs first, then code
|
||||||
|
filtered = doc_results + code_results + unknown_results
|
||||||
|
else:
|
||||||
|
# Semantic intent strict: only docs
|
||||||
|
filtered = doc_results + unknown_results
|
||||||
|
else:
|
||||||
|
filtered = results
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def simple_weighted_fusion(
|
||||||
|
results_map: Dict[str, List[SearchResult]],
|
||||||
|
weights: Dict[str, float] = None,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Combine search results using simple weighted sum of normalized scores.
|
||||||
|
|
||||||
|
This is an alternative to RRF that preserves score magnitude information.
|
||||||
|
Scores are min-max normalized per source before weighted combination.
|
||||||
|
|
||||||
|
Formula: score(d) = Σ weight_source * normalized_score_source(d)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||||
|
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||||
|
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||||
|
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects sorted by fused score (descending)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> fts_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
||||||
|
>>> vector_results = [SearchResult(path="b.py", score=0.85, excerpt="...")]
|
||||||
|
>>> results_map = {'exact': fts_results, 'vector': vector_results}
|
||||||
|
>>> fused = simple_weighted_fusion(results_map)
|
||||||
|
"""
|
||||||
|
if not results_map:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Default equal weights if not provided
|
||||||
|
if weights is None:
|
||||||
|
num_sources = len(results_map)
|
||||||
|
weights = {source: 1.0 / num_sources for source in results_map}
|
||||||
|
|
||||||
|
# Normalize weights to sum to 1.0
|
||||||
|
weight_sum = sum(weights.values())
|
||||||
|
if not math.isclose(weight_sum, 1.0, abs_tol=0.01) and weight_sum > 0:
|
||||||
|
weights = {source: w / weight_sum for source, w in weights.items()}
|
||||||
|
|
||||||
|
# Compute min-max normalization parameters per source
|
||||||
|
source_stats: Dict[str, tuple] = {}
|
||||||
|
for source_name, results in results_map.items():
|
||||||
|
if not results:
|
||||||
|
continue
|
||||||
|
scores = [r.score for r in results]
|
||||||
|
min_s, max_s = min(scores), max(scores)
|
||||||
|
source_stats[source_name] = (min_s, max_s)
|
||||||
|
|
||||||
|
def normalize_score(score: float, source: str) -> float:
|
||||||
|
"""Normalize score to [0, 1] range using min-max scaling."""
|
||||||
|
if source not in source_stats:
|
||||||
|
return 0.0
|
||||||
|
min_s, max_s = source_stats[source]
|
||||||
|
if max_s == min_s:
|
||||||
|
return 1.0 if score >= min_s else 0.0
|
||||||
|
return (score - min_s) / (max_s - min_s)
|
||||||
|
|
||||||
|
# Build unified result set with weighted scores
|
||||||
|
path_to_result: Dict[str, SearchResult] = {}
|
||||||
|
path_to_fusion_score: Dict[str, float] = {}
|
||||||
|
path_to_source_scores: Dict[str, Dict[str, float]] = {}
|
||||||
|
|
||||||
|
for source_name, results in results_map.items():
|
||||||
|
weight = weights.get(source_name, 0.0)
|
||||||
|
if weight == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
path = result.path
|
||||||
|
normalized = normalize_score(result.score, source_name)
|
||||||
|
contribution = weight * normalized
|
||||||
|
|
||||||
|
if path not in path_to_fusion_score:
|
||||||
|
path_to_fusion_score[path] = 0.0
|
||||||
|
path_to_result[path] = result
|
||||||
|
path_to_source_scores[path] = {}
|
||||||
|
|
||||||
|
path_to_fusion_score[path] += contribution
|
||||||
|
path_to_source_scores[path][source_name] = normalized
|
||||||
|
|
||||||
|
# Create final results with fusion scores
|
||||||
|
fused_results = []
|
||||||
|
for path, base_result in path_to_result.items():
|
||||||
|
fusion_score = path_to_fusion_score[path]
|
||||||
|
|
||||||
|
fused_result = SearchResult(
|
||||||
|
path=base_result.path,
|
||||||
|
score=fusion_score,
|
||||||
|
excerpt=base_result.excerpt,
|
||||||
|
content=base_result.content,
|
||||||
|
symbol=base_result.symbol,
|
||||||
|
chunk=base_result.chunk,
|
||||||
|
metadata={
|
||||||
|
**base_result.metadata,
|
||||||
|
"fusion_method": "simple_weighted",
|
||||||
|
"fusion_score": fusion_score,
|
||||||
|
"original_score": base_result.score,
|
||||||
|
"source_scores": path_to_source_scores[path],
|
||||||
|
},
|
||||||
|
start_line=base_result.start_line,
|
||||||
|
end_line=base_result.end_line,
|
||||||
|
symbol_name=base_result.symbol_name,
|
||||||
|
symbol_kind=base_result.symbol_kind,
|
||||||
|
)
|
||||||
|
fused_results.append(fused_result)
|
||||||
|
|
||||||
|
fused_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
return fused_results
|
||||||
|
|
||||||
|
|
||||||
|
def reciprocal_rank_fusion(
|
||||||
|
results_map: Dict[str, List[SearchResult]],
|
||||||
|
weights: Dict[str, float] = None,
|
||||||
|
k: int = 60,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Combine search results from multiple sources using Reciprocal Rank Fusion.
|
||||||
|
|
||||||
|
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
|
||||||
|
|
||||||
|
Supports three-way fusion with FTS, Vector, and SPLADE sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||||
|
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||||
|
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||||
|
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||||
|
Or: {'splade': 0.4, 'vector': 0.6}
|
||||||
|
k: Constant to avoid division by zero and control rank influence (default 60)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects sorted by fused score (descending)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> exact_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
||||||
|
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
|
||||||
|
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
|
||||||
|
>>> fused = reciprocal_rank_fusion(results_map)
|
||||||
|
|
||||||
|
# Three-way fusion with SPLADE
|
||||||
|
>>> results_map = {
|
||||||
|
... 'exact': exact_results,
|
||||||
|
... 'vector': vector_results,
|
||||||
|
... 'splade': splade_results
|
||||||
|
... }
|
||||||
|
>>> fused = reciprocal_rank_fusion(results_map, k=60)
|
||||||
|
"""
|
||||||
|
if not results_map:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Default equal weights if not provided
|
||||||
|
if weights is None:
|
||||||
|
num_sources = len(results_map)
|
||||||
|
weights = {source: 1.0 / num_sources for source in results_map}
|
||||||
|
|
||||||
|
# Validate weights sum to 1.0
|
||||||
|
weight_sum = sum(weights.values())
|
||||||
|
if not math.isclose(weight_sum, 1.0, abs_tol=0.01):
|
||||||
|
# Normalize weights to sum to 1.0
|
||||||
|
weights = {source: w / weight_sum for source, w in weights.items()}
|
||||||
|
|
||||||
|
# Build unified result set with RRF scores
|
||||||
|
path_to_result: Dict[str, SearchResult] = {}
|
||||||
|
path_to_fusion_score: Dict[str, float] = {}
|
||||||
|
path_to_source_ranks: Dict[str, Dict[str, int]] = {}
|
||||||
|
|
||||||
|
for source_name, results in results_map.items():
|
||||||
|
weight = weights.get(source_name, 0.0)
|
||||||
|
if weight == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for rank, result in enumerate(results, start=1):
|
||||||
|
path = result.path
|
||||||
|
rrf_contribution = weight / (k + rank)
|
||||||
|
|
||||||
|
# Initialize or accumulate fusion score
|
||||||
|
if path not in path_to_fusion_score:
|
||||||
|
path_to_fusion_score[path] = 0.0
|
||||||
|
path_to_result[path] = result
|
||||||
|
path_to_source_ranks[path] = {}
|
||||||
|
|
||||||
|
path_to_fusion_score[path] += rrf_contribution
|
||||||
|
path_to_source_ranks[path][source_name] = rank
|
||||||
|
|
||||||
|
# Create final results with fusion scores
|
||||||
|
fused_results = []
|
||||||
|
for path, base_result in path_to_result.items():
|
||||||
|
fusion_score = path_to_fusion_score[path]
|
||||||
|
|
||||||
|
# Create new SearchResult with fusion_score in metadata
|
||||||
|
fused_result = SearchResult(
|
||||||
|
path=base_result.path,
|
||||||
|
score=fusion_score,
|
||||||
|
excerpt=base_result.excerpt,
|
||||||
|
content=base_result.content,
|
||||||
|
symbol=base_result.symbol,
|
||||||
|
chunk=base_result.chunk,
|
||||||
|
metadata={
|
||||||
|
**base_result.metadata,
|
||||||
|
"fusion_method": "rrf",
|
||||||
|
"fusion_score": fusion_score,
|
||||||
|
"original_score": base_result.score,
|
||||||
|
"rrf_k": k,
|
||||||
|
"source_ranks": path_to_source_ranks[path],
|
||||||
|
},
|
||||||
|
start_line=base_result.start_line,
|
||||||
|
end_line=base_result.end_line,
|
||||||
|
symbol_name=base_result.symbol_name,
|
||||||
|
symbol_kind=base_result.symbol_kind,
|
||||||
|
)
|
||||||
|
fused_results.append(fused_result)
|
||||||
|
|
||||||
|
# Sort by fusion score descending
|
||||||
|
fused_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
|
||||||
|
return fused_results
|
||||||
|
|
||||||
|
|
||||||
|
def apply_symbol_boost(
|
||||||
|
results: List[SearchResult],
|
||||||
|
boost_factor: float = 1.5,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Boost fused scores for results that include an explicit symbol match.
|
||||||
|
|
||||||
|
The boost is multiplicative on the current result.score (typically the RRF fusion score).
|
||||||
|
When boosted, the original score is preserved in metadata["original_fusion_score"] and
|
||||||
|
metadata["boosted"] is set to True.
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if boost_factor <= 1.0:
|
||||||
|
# Still return new objects to follow immutable transformation pattern.
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
path=r.path,
|
||||||
|
score=r.score,
|
||||||
|
excerpt=r.excerpt,
|
||||||
|
content=r.content,
|
||||||
|
symbol=r.symbol,
|
||||||
|
chunk=r.chunk,
|
||||||
|
metadata={**r.metadata},
|
||||||
|
start_line=r.start_line,
|
||||||
|
end_line=r.end_line,
|
||||||
|
symbol_name=r.symbol_name,
|
||||||
|
symbol_kind=r.symbol_kind,
|
||||||
|
additional_locations=list(r.additional_locations),
|
||||||
|
)
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
|
||||||
|
boosted_results: List[SearchResult] = []
|
||||||
|
for result in results:
|
||||||
|
has_symbol = bool(result.symbol_name)
|
||||||
|
original_score = float(result.score)
|
||||||
|
boosted_score = original_score * boost_factor if has_symbol else original_score
|
||||||
|
|
||||||
|
metadata = {**result.metadata}
|
||||||
|
if has_symbol:
|
||||||
|
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
|
||||||
|
metadata["boosted"] = True
|
||||||
|
metadata["symbol_boost_factor"] = boost_factor
|
||||||
|
|
||||||
|
boosted_results.append(
|
||||||
|
SearchResult(
|
||||||
|
path=result.path,
|
||||||
|
score=boosted_score,
|
||||||
|
excerpt=result.excerpt,
|
||||||
|
content=result.content,
|
||||||
|
symbol=result.symbol,
|
||||||
|
chunk=result.chunk,
|
||||||
|
metadata=metadata,
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
symbol_name=result.symbol_name,
|
||||||
|
symbol_kind=result.symbol_kind,
|
||||||
|
additional_locations=list(result.additional_locations),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
boosted_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
return boosted_results
|
||||||
|
|
||||||
|
|
||||||
|
def rerank_results(
|
||||||
|
query: str,
|
||||||
|
results: List[SearchResult],
|
||||||
|
embedder: Any,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Re-rank results with embedding cosine similarity, combined with current score.
|
||||||
|
|
||||||
|
Combined score formula:
|
||||||
|
0.5 * rrf_score + 0.5 * cosine_similarity
|
||||||
|
|
||||||
|
If embedder is None or embedding fails, returns results as-is.
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if embedder is None or top_k <= 0:
|
||||||
|
return results
|
||||||
|
|
||||||
|
rerank_count = min(int(top_k), len(results))
|
||||||
|
|
||||||
|
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
||||||
|
# Defensive: handle mismatched lengths and zero vectors.
|
||||||
|
n = min(len(vec_a), len(vec_b))
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
dot = 0.0
|
||||||
|
norm_a = 0.0
|
||||||
|
norm_b = 0.0
|
||||||
|
for i in range(n):
|
||||||
|
a = float(vec_a[i])
|
||||||
|
b = float(vec_b[i])
|
||||||
|
dot += a * b
|
||||||
|
norm_a += a * a
|
||||||
|
norm_b += b * b
|
||||||
|
if norm_a <= 0.0 or norm_b <= 0.0:
|
||||||
|
return 0.0
|
||||||
|
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
|
||||||
|
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
|
||||||
|
return max(0.0, min(1.0, sim))
|
||||||
|
|
||||||
|
def text_for_embedding(r: SearchResult) -> str:
|
||||||
|
if r.excerpt and r.excerpt.strip():
|
||||||
|
return r.excerpt
|
||||||
|
if r.content and r.content.strip():
|
||||||
|
return r.content
|
||||||
|
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||||
|
return r.chunk.content
|
||||||
|
# Fallback: stable, non-empty text.
|
||||||
|
return r.symbol_name or r.path
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(embedder, "embed_single"):
|
||||||
|
query_vec = embedder.embed_single(query)
|
||||||
|
else:
|
||||||
|
query_vec = embedder.embed(query)[0]
|
||||||
|
|
||||||
|
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
|
||||||
|
doc_vecs = embedder.embed(doc_texts)
|
||||||
|
except Exception:
|
||||||
|
return results
|
||||||
|
|
||||||
|
reranked_results: List[SearchResult] = []
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
if idx < rerank_count:
|
||||||
|
rrf_score = float(result.score)
|
||||||
|
sim = cosine_similarity(query_vec, doc_vecs[idx])
|
||||||
|
combined_score = 0.5 * rrf_score + 0.5 * sim
|
||||||
|
|
||||||
|
reranked_results.append(
|
||||||
|
SearchResult(
|
||||||
|
path=result.path,
|
||||||
|
score=combined_score,
|
||||||
|
excerpt=result.excerpt,
|
||||||
|
content=result.content,
|
||||||
|
symbol=result.symbol,
|
||||||
|
chunk=result.chunk,
|
||||||
|
metadata={
|
||||||
|
**result.metadata,
|
||||||
|
"rrf_score": rrf_score,
|
||||||
|
"cosine_similarity": sim,
|
||||||
|
"reranked": True,
|
||||||
|
},
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
symbol_name=result.symbol_name,
|
||||||
|
symbol_kind=result.symbol_kind,
|
||||||
|
additional_locations=list(result.additional_locations),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Preserve remaining results without re-ranking, but keep immutability.
|
||||||
|
reranked_results.append(
|
||||||
|
SearchResult(
|
||||||
|
path=result.path,
|
||||||
|
score=result.score,
|
||||||
|
excerpt=result.excerpt,
|
||||||
|
content=result.content,
|
||||||
|
symbol=result.symbol,
|
||||||
|
chunk=result.chunk,
|
||||||
|
metadata={**result.metadata},
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
symbol_name=result.symbol_name,
|
||||||
|
symbol_kind=result.symbol_kind,
|
||||||
|
additional_locations=list(result.additional_locations),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
return reranked_results
|
||||||
|
|
||||||
|
|
||||||
|
def cross_encoder_rerank(
|
||||||
|
query: str,
|
||||||
|
results: List[SearchResult],
|
||||||
|
reranker: Any,
|
||||||
|
top_k: int = 50,
|
||||||
|
batch_size: int = 32,
|
||||||
|
chunk_type_weights: Optional[Dict[str, float]] = None,
|
||||||
|
test_file_penalty: float = 0.0,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Second-stage reranking using a cross-encoder model.
|
||||||
|
|
||||||
|
This function is dependency-agnostic: callers can pass any object that exposes
|
||||||
|
a compatible `score_pairs(pairs, batch_size=...)` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query string
|
||||||
|
results: List of search results to rerank
|
||||||
|
reranker: Cross-encoder model with score_pairs or predict method
|
||||||
|
top_k: Number of top results to rerank
|
||||||
|
batch_size: Batch size for reranking
|
||||||
|
chunk_type_weights: Optional weights for different chunk types.
|
||||||
|
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
|
||||||
|
test_file_penalty: Penalty applied to test files (0.0-1.0).
|
||||||
|
Example: 0.2 means test files get 20% score reduction
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if reranker is None or top_k <= 0:
|
||||||
|
return results
|
||||||
|
|
||||||
|
rerank_count = min(int(top_k), len(results))
|
||||||
|
|
||||||
|
def text_for_pair(r: SearchResult) -> str:
|
||||||
|
if r.excerpt and r.excerpt.strip():
|
||||||
|
return r.excerpt
|
||||||
|
if r.content and r.content.strip():
|
||||||
|
return r.content
|
||||||
|
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||||
|
return r.chunk.content
|
||||||
|
return r.symbol_name or r.path
|
||||||
|
|
||||||
|
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(reranker, "score_pairs"):
|
||||||
|
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
|
||||||
|
elif hasattr(reranker, "predict"):
|
||||||
|
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
|
||||||
|
else:
|
||||||
|
return results
|
||||||
|
except Exception:
|
||||||
|
return results
|
||||||
|
|
||||||
|
if not raw_scores or len(raw_scores) != rerank_count:
|
||||||
|
return results
|
||||||
|
|
||||||
|
scores = [float(s) for s in raw_scores]
|
||||||
|
min_s = min(scores)
|
||||||
|
max_s = max(scores)
|
||||||
|
|
||||||
|
def sigmoid(x: float) -> float:
|
||||||
|
# Clamp to keep exp() stable.
|
||||||
|
x = max(-50.0, min(50.0, x))
|
||||||
|
return 1.0 / (1.0 + math.exp(-x))
|
||||||
|
|
||||||
|
if 0.0 <= min_s and max_s <= 1.0:
|
||||||
|
probs = scores
|
||||||
|
else:
|
||||||
|
probs = [sigmoid(s) for s in scores]
|
||||||
|
|
||||||
|
reranked_results: List[SearchResult] = []
|
||||||
|
|
||||||
|
# Helper to detect test files
|
||||||
|
def is_test_file(path: str) -> bool:
|
||||||
|
if not path:
|
||||||
|
return False
|
||||||
|
basename = path.split("/")[-1].split("\\")[-1]
|
||||||
|
return (
|
||||||
|
basename.startswith("test_") or
|
||||||
|
basename.endswith("_test.py") or
|
||||||
|
basename.endswith(".test.ts") or
|
||||||
|
basename.endswith(".test.js") or
|
||||||
|
basename.endswith(".spec.ts") or
|
||||||
|
basename.endswith(".spec.js") or
|
||||||
|
"/tests/" in path or
|
||||||
|
"\\tests\\" in path or
|
||||||
|
"/test/" in path or
|
||||||
|
"\\test\\" in path
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
if idx < rerank_count:
|
||||||
|
prev_score = float(result.score)
|
||||||
|
ce_score = scores[idx]
|
||||||
|
ce_prob = probs[idx]
|
||||||
|
|
||||||
|
# Base combined score
|
||||||
|
combined_score = 0.5 * prev_score + 0.5 * ce_prob
|
||||||
|
|
||||||
|
# Apply chunk_type weight adjustment
|
||||||
|
if chunk_type_weights:
|
||||||
|
chunk_type = None
|
||||||
|
if result.chunk and hasattr(result.chunk, "metadata"):
|
||||||
|
chunk_type = result.chunk.metadata.get("chunk_type")
|
||||||
|
elif result.metadata:
|
||||||
|
chunk_type = result.metadata.get("chunk_type")
|
||||||
|
|
||||||
|
if chunk_type and chunk_type in chunk_type_weights:
|
||||||
|
weight = chunk_type_weights[chunk_type]
|
||||||
|
# Apply weight to CE contribution only
|
||||||
|
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
|
||||||
|
|
||||||
|
# Apply test file penalty
|
||||||
|
if test_file_penalty > 0 and is_test_file(result.path):
|
||||||
|
combined_score = combined_score * (1.0 - test_file_penalty)
|
||||||
|
|
||||||
|
reranked_results.append(
|
||||||
|
SearchResult(
|
||||||
|
path=result.path,
|
||||||
|
score=combined_score,
|
||||||
|
excerpt=result.excerpt,
|
||||||
|
content=result.content,
|
||||||
|
symbol=result.symbol,
|
||||||
|
chunk=result.chunk,
|
||||||
|
metadata={
|
||||||
|
**result.metadata,
|
||||||
|
"pre_cross_encoder_score": prev_score,
|
||||||
|
"cross_encoder_score": ce_score,
|
||||||
|
"cross_encoder_prob": ce_prob,
|
||||||
|
"cross_encoder_reranked": True,
|
||||||
|
},
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
symbol_name=result.symbol_name,
|
||||||
|
symbol_kind=result.symbol_kind,
|
||||||
|
additional_locations=list(result.additional_locations),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reranked_results.append(
|
||||||
|
SearchResult(
|
||||||
|
path=result.path,
|
||||||
|
score=result.score,
|
||||||
|
excerpt=result.excerpt,
|
||||||
|
content=result.content,
|
||||||
|
symbol=result.symbol,
|
||||||
|
chunk=result.chunk,
|
||||||
|
metadata={**result.metadata},
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
symbol_name=result.symbol_name,
|
||||||
|
symbol_kind=result.symbol_kind,
|
||||||
|
additional_locations=list(result.additional_locations),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
return reranked_results
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_bm25_score(score: float) -> float:
|
||||||
|
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
||||||
|
|
||||||
|
SQLite FTS5 returns negative BM25 scores (more negative = better match).
|
||||||
|
Uses sigmoid transformation for normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score: Raw BM25 score from SQLite (typically negative)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized score in range [0, 1]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> normalize_bm25_score(-10.5) # Good match
|
||||||
|
0.85
|
||||||
|
>>> normalize_bm25_score(-1.2) # Weak match
|
||||||
|
0.62
|
||||||
|
"""
|
||||||
|
# Take absolute value (BM25 is negative in SQLite)
|
||||||
|
abs_score = abs(score)
|
||||||
|
|
||||||
|
# Sigmoid transformation: 1 / (1 + e^(-x))
|
||||||
|
# Scale factor of 0.1 maps typical BM25 range (-20 to 0) to (0, 1)
|
||||||
|
normalized = 1.0 / (1.0 + math.exp(-abs_score * 0.1))
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def tag_search_source(results: List[SearchResult], source: str) -> List[SearchResult]:
|
||||||
|
"""Tag search results with their source for RRF tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of SearchResult objects
|
||||||
|
source: Source identifier ('exact', 'fuzzy', 'vector')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects with 'search_source' in metadata
|
||||||
|
"""
|
||||||
|
tagged_results = []
|
||||||
|
for result in results:
|
||||||
|
tagged_result = SearchResult(
|
||||||
|
path=result.path,
|
||||||
|
score=result.score,
|
||||||
|
excerpt=result.excerpt,
|
||||||
|
content=result.content,
|
||||||
|
symbol=result.symbol,
|
||||||
|
chunk=result.chunk,
|
||||||
|
metadata={**result.metadata, "search_source": source},
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
symbol_name=result.symbol_name,
|
||||||
|
symbol_kind=result.symbol_kind,
|
||||||
|
)
|
||||||
|
tagged_results.append(tagged_result)
|
||||||
|
|
||||||
|
return tagged_results
|
||||||
|
|
||||||
|
|
||||||
|
def group_similar_results(
|
||||||
|
results: List[SearchResult],
|
||||||
|
score_threshold_abs: float = 0.01,
|
||||||
|
content_field: str = "excerpt"
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Group search results by content and score similarity.
|
||||||
|
|
||||||
|
Groups results that have similar content and similar scores into a single
|
||||||
|
representative result, with other locations stored in additional_locations.
|
||||||
|
|
||||||
|
Algorithm:
|
||||||
|
1. Group results by content (using excerpt or content field)
|
||||||
|
2. Within each content group, create subgroups based on score similarity
|
||||||
|
3. Select highest-scoring result as representative for each subgroup
|
||||||
|
4. Store other results in subgroup as additional_locations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: A list of SearchResult objects (typically sorted by score)
|
||||||
|
score_threshold_abs: Absolute score difference to consider results similar.
|
||||||
|
Results with |score_a - score_b| <= threshold are grouped.
|
||||||
|
Default 0.01 is suitable for RRF fusion scores.
|
||||||
|
content_field: The field to use for content grouping ('excerpt' or 'content')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new list of SearchResult objects where similar items are grouped.
|
||||||
|
The list is sorted by score descending.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> results = [SearchResult(path="a.py", score=0.5, excerpt="def foo()"),
|
||||||
|
... SearchResult(path="b.py", score=0.5, excerpt="def foo()")]
|
||||||
|
>>> grouped = group_similar_results(results)
|
||||||
|
>>> len(grouped) # Two results merged into one
|
||||||
|
1
|
||||||
|
>>> len(grouped[0].additional_locations) # One additional location
|
||||||
|
1
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Group results by content
|
||||||
|
content_map: Dict[str, List[SearchResult]] = {}
|
||||||
|
unidentifiable_results: List[SearchResult] = []
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
key = getattr(r, content_field, None)
|
||||||
|
if key and key.strip():
|
||||||
|
content_map.setdefault(key, []).append(r)
|
||||||
|
else:
|
||||||
|
# Results without content can't be grouped by content
|
||||||
|
unidentifiable_results.append(r)
|
||||||
|
|
||||||
|
final_results: List[SearchResult] = []
|
||||||
|
|
||||||
|
# Process each content group
|
||||||
|
for content_group in content_map.values():
|
||||||
|
# Sort by score descending within group
|
||||||
|
content_group.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
|
||||||
|
while content_group:
|
||||||
|
# Take highest scoring as representative
|
||||||
|
representative = content_group.pop(0)
|
||||||
|
others_in_group = []
|
||||||
|
remaining_for_next_pass = []
|
||||||
|
|
||||||
|
# Find results with similar scores
|
||||||
|
for item in content_group:
|
||||||
|
if abs(representative.score - item.score) <= score_threshold_abs:
|
||||||
|
others_in_group.append(item)
|
||||||
|
else:
|
||||||
|
remaining_for_next_pass.append(item)
|
||||||
|
|
||||||
|
# Create grouped result with additional locations
|
||||||
|
if others_in_group:
|
||||||
|
# Build new result with additional_locations populated
|
||||||
|
grouped_result = SearchResult(
|
||||||
|
path=representative.path,
|
||||||
|
score=representative.score,
|
||||||
|
excerpt=representative.excerpt,
|
||||||
|
content=representative.content,
|
||||||
|
symbol=representative.symbol,
|
||||||
|
chunk=representative.chunk,
|
||||||
|
metadata={
|
||||||
|
**representative.metadata,
|
||||||
|
"grouped_count": len(others_in_group) + 1,
|
||||||
|
},
|
||||||
|
start_line=representative.start_line,
|
||||||
|
end_line=representative.end_line,
|
||||||
|
symbol_name=representative.symbol_name,
|
||||||
|
symbol_kind=representative.symbol_kind,
|
||||||
|
additional_locations=[
|
||||||
|
AdditionalLocation(
|
||||||
|
path=other.path,
|
||||||
|
score=other.score,
|
||||||
|
start_line=other.start_line,
|
||||||
|
end_line=other.end_line,
|
||||||
|
symbol_name=other.symbol_name,
|
||||||
|
) for other in others_in_group
|
||||||
|
],
|
||||||
|
)
|
||||||
|
final_results.append(grouped_result)
|
||||||
|
else:
|
||||||
|
final_results.append(representative)
|
||||||
|
|
||||||
|
content_group = remaining_for_next_pass
|
||||||
|
|
||||||
|
# Add ungroupable results
|
||||||
|
final_results.extend(unidentifiable_results)
|
||||||
|
|
||||||
|
# Sort final results by score descending
|
||||||
|
final_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
|
||||||
|
return final_results
|
||||||
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Optional semantic search module for CodexLens.
|
||||||
|
|
||||||
|
Install with: pip install codexlens[semantic]
|
||||||
|
Uses fastembed (ONNX-based, lightweight ~200MB)
|
||||||
|
|
||||||
|
GPU Acceleration:
|
||||||
|
- Automatic GPU detection and usage when available
|
||||||
|
- Supports CUDA (NVIDIA), TensorRT, DirectML (Windows), ROCm (AMD), CoreML (Apple)
|
||||||
|
- Install GPU support: pip install onnxruntime-gpu (NVIDIA) or onnxruntime-directml (Windows)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
SEMANTIC_AVAILABLE = False
|
||||||
|
SEMANTIC_BACKEND: str | None = None
|
||||||
|
GPU_AVAILABLE = False
|
||||||
|
LITELLM_AVAILABLE = False
|
||||||
|
_import_error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_backend() -> tuple[bool, str | None, bool, str | None]:
|
||||||
|
"""Detect if fastembed and GPU are available."""
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError as e:
|
||||||
|
return False, None, False, f"numpy not available: {e}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastembed import TextEmbedding
|
||||||
|
except ImportError:
|
||||||
|
return False, None, False, "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||||
|
|
||||||
|
# Check GPU availability
|
||||||
|
gpu_available = False
|
||||||
|
try:
|
||||||
|
from .gpu_support import is_gpu_available
|
||||||
|
gpu_available = is_gpu_available()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True, "fastembed", gpu_available, None
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize on module load
|
||||||
|
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, GPU_AVAILABLE, _import_error = _detect_backend()
|
||||||
|
|
||||||
|
|
||||||
|
def check_semantic_available() -> tuple[bool, str | None]:
|
||||||
|
"""Check if semantic search dependencies are available."""
|
||||||
|
return SEMANTIC_AVAILABLE, _import_error
|
||||||
|
|
||||||
|
|
||||||
|
def check_gpu_available() -> tuple[bool, str]:
|
||||||
|
"""Check if GPU acceleration is available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_available, status_message)
|
||||||
|
"""
|
||||||
|
if not SEMANTIC_AVAILABLE:
|
||||||
|
return False, "Semantic search not available"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .gpu_support import is_gpu_available, get_gpu_summary
|
||||||
|
if is_gpu_available():
|
||||||
|
return True, get_gpu_summary()
|
||||||
|
return False, "No GPU detected (using CPU)"
|
||||||
|
except ImportError:
|
||||||
|
return False, "GPU support module not available"
|
||||||
|
|
||||||
|
|
||||||
|
# Export embedder components
|
||||||
|
# BaseEmbedder is always available (abstract base class)
|
||||||
|
from .base import BaseEmbedder
|
||||||
|
|
||||||
|
# Factory function for creating embedders
|
||||||
|
from .factory import get_embedder as get_embedder_factory
|
||||||
|
|
||||||
|
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
|
||||||
|
try:
|
||||||
|
import ccw_litellm # noqa: F401
|
||||||
|
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||||
|
LITELLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
LiteLLMEmbedderWrapper = None
|
||||||
|
LITELLM_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
|
||||||
|
"""Check whether a specific embedding backend can be used.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
|
||||||
|
- "litellm" requires ccw-litellm to be installed in the same environment.
|
||||||
|
"""
|
||||||
|
backend = (backend or "").strip().lower()
|
||||||
|
if backend == "fastembed":
|
||||||
|
if SEMANTIC_AVAILABLE:
|
||||||
|
return True, None
|
||||||
|
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||||
|
if backend == "litellm":
|
||||||
|
if LITELLM_AVAILABLE:
|
||||||
|
return True, None
|
||||||
|
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
|
||||||
|
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SEMANTIC_AVAILABLE",
|
||||||
|
"SEMANTIC_BACKEND",
|
||||||
|
"GPU_AVAILABLE",
|
||||||
|
"LITELLM_AVAILABLE",
|
||||||
|
"check_semantic_available",
|
||||||
|
"is_embedding_backend_available",
|
||||||
|
"check_gpu_available",
|
||||||
|
"BaseEmbedder",
|
||||||
|
"get_embedder_factory",
|
||||||
|
"LiteLLMEmbedderWrapper",
|
||||||
|
]
|
||||||
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
File diff suppressed because it is too large
Load Diff
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""Base class for embedders.
|
||||||
|
|
||||||
|
Defines the interface that all embedders must implement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbedder(ABC):
|
||||||
|
"""Base class for all embedders.
|
||||||
|
|
||||||
|
All embedder implementations must inherit from this class and implement
|
||||||
|
the abstract methods to ensure a consistent interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
"""Return embedding dimensions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Dimension of the embedding vectors.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Return model name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Name or identifier of the underlying model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit for embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Maximum number of tokens that can be embedded at once.
|
||||||
|
Default is 8192 if not overridden by implementation.
|
||||||
|
"""
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||||
|
"""Embed texts to numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||||
|
"""
|
||||||
|
...
|
||||||
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
@@ -0,0 +1,821 @@
|
|||||||
|
"""Code chunking strategies for semantic search.
|
||||||
|
|
||||||
|
This module provides various chunking strategies for breaking down source code
|
||||||
|
into semantic chunks suitable for embedding and search.
|
||||||
|
|
||||||
|
Lightweight Mode:
|
||||||
|
The ChunkConfig supports a `skip_token_count` option for performance optimization.
|
||||||
|
When enabled, token counting uses a fast character-based estimation (char/4)
|
||||||
|
instead of expensive tiktoken encoding.
|
||||||
|
|
||||||
|
Use cases for lightweight mode:
|
||||||
|
- Large-scale indexing where speed is critical
|
||||||
|
- Scenarios where approximate token counts are acceptable
|
||||||
|
- Memory-constrained environments
|
||||||
|
- Initial prototyping and development
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Default mode (accurate tiktoken encoding)
|
||||||
|
config = ChunkConfig()
|
||||||
|
chunker = Chunker(config)
|
||||||
|
|
||||||
|
# Lightweight mode (fast char/4 estimation)
|
||||||
|
config = ChunkConfig(skip_token_count=True)
|
||||||
|
chunker = Chunker(config)
|
||||||
|
chunks = chunker.chunk_file(content, symbols, path, language)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from codexlens.entities import SemanticChunk, Symbol
|
||||||
|
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkConfig:
|
||||||
|
"""Configuration for chunking strategies."""
|
||||||
|
max_chunk_size: int = 1000 # Max characters per chunk
|
||||||
|
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
|
||||||
|
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||||
|
min_chunk_size: int = 50 # Minimum chunk size
|
||||||
|
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
|
||||||
|
strip_comments: bool = True # Remove comments from chunk content for embedding
|
||||||
|
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
|
||||||
|
preserve_original: bool = True # Store original content in metadata when stripping
|
||||||
|
|
||||||
|
|
||||||
|
class CommentStripper:
|
||||||
|
"""Remove comments from source code while preserving structure."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def strip_python_comments(content: str) -> str:
|
||||||
|
"""Strip Python comments (# style) but preserve docstrings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Python source code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Code with comments removed
|
||||||
|
"""
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
result_lines: List[str] = []
|
||||||
|
in_string = False
|
||||||
|
string_char = None
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
new_line = []
|
||||||
|
i = 0
|
||||||
|
while i < len(line):
|
||||||
|
char = line[i]
|
||||||
|
|
||||||
|
# Handle string literals
|
||||||
|
if char in ('"', "'") and not in_string:
|
||||||
|
# Check for triple quotes
|
||||||
|
if line[i:i+3] in ('"""', "'''"):
|
||||||
|
in_string = True
|
||||||
|
string_char = line[i:i+3]
|
||||||
|
new_line.append(line[i:i+3])
|
||||||
|
i += 3
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
in_string = True
|
||||||
|
string_char = char
|
||||||
|
elif in_string:
|
||||||
|
if string_char and len(string_char) == 3:
|
||||||
|
if line[i:i+3] == string_char:
|
||||||
|
in_string = False
|
||||||
|
new_line.append(line[i:i+3])
|
||||||
|
i += 3
|
||||||
|
string_char = None
|
||||||
|
continue
|
||||||
|
elif char == string_char:
|
||||||
|
# Check for escape
|
||||||
|
if i > 0 and line[i-1] != '\\':
|
||||||
|
in_string = False
|
||||||
|
string_char = None
|
||||||
|
|
||||||
|
# Handle comments (only outside strings)
|
||||||
|
if char == '#' and not in_string:
|
||||||
|
# Rest of line is comment, skip it
|
||||||
|
new_line.append('\n' if line.endswith('\n') else '')
|
||||||
|
break
|
||||||
|
|
||||||
|
new_line.append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
result_lines.append(''.join(new_line))
|
||||||
|
|
||||||
|
return ''.join(result_lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def strip_c_style_comments(content: str) -> str:
|
||||||
|
"""Strip C-style comments (// and /* */) from code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Source code with C-style comments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Code with comments removed
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
i = 0
|
||||||
|
in_string = False
|
||||||
|
string_char = None
|
||||||
|
in_multiline_comment = False
|
||||||
|
|
||||||
|
while i < len(content):
|
||||||
|
# Handle multi-line comment end
|
||||||
|
if in_multiline_comment:
|
||||||
|
if content[i:i+2] == '*/':
|
||||||
|
in_multiline_comment = False
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
char = content[i]
|
||||||
|
|
||||||
|
# Handle string literals
|
||||||
|
if char in ('"', "'", '`') and not in_string:
|
||||||
|
in_string = True
|
||||||
|
string_char = char
|
||||||
|
result.append(char)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
elif in_string:
|
||||||
|
result.append(char)
|
||||||
|
if char == string_char and (i == 0 or content[i-1] != '\\'):
|
||||||
|
in_string = False
|
||||||
|
string_char = None
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle comments
|
||||||
|
if content[i:i+2] == '//':
|
||||||
|
# Single line comment - skip to end of line
|
||||||
|
while i < len(content) and content[i] != '\n':
|
||||||
|
i += 1
|
||||||
|
if i < len(content):
|
||||||
|
result.append('\n')
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if content[i:i+2] == '/*':
|
||||||
|
in_multiline_comment = True
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def strip_comments(cls, content: str, language: str) -> str:
|
||||||
|
"""Strip comments based on language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Source code content
|
||||||
|
language: Programming language
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Code with comments removed
|
||||||
|
"""
|
||||||
|
if language == "python":
|
||||||
|
return cls.strip_python_comments(content)
|
||||||
|
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
|
||||||
|
return cls.strip_c_style_comments(content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
class DocstringStripper:
|
||||||
|
"""Remove docstrings from source code."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def strip_python_docstrings(content: str) -> str:
|
||||||
|
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Python source code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Code with docstrings removed
|
||||||
|
"""
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
result_lines: List[str] = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(lines):
|
||||||
|
line = lines[i]
|
||||||
|
stripped = line.strip()
|
||||||
|
|
||||||
|
# Check for docstring start
|
||||||
|
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||||
|
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||||
|
|
||||||
|
# Single line docstring
|
||||||
|
if stripped.count(quote_type) >= 2:
|
||||||
|
# Skip this line (docstring)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Multi-line docstring - skip until closing
|
||||||
|
i += 1
|
||||||
|
while i < len(lines):
|
||||||
|
if quote_type in lines[i]:
|
||||||
|
i += 1
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
result_lines.append(line)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ''.join(result_lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def strip_jsdoc_comments(content: str) -> str:
|
||||||
|
"""Strip JSDoc comments (/** ... */) from code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: JavaScript/TypeScript source code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Code with JSDoc comments removed
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
i = 0
|
||||||
|
in_jsdoc = False
|
||||||
|
|
||||||
|
while i < len(content):
|
||||||
|
if in_jsdoc:
|
||||||
|
if content[i:i+2] == '*/':
|
||||||
|
in_jsdoc = False
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for JSDoc start (/** but not /*)
|
||||||
|
if content[i:i+3] == '/**':
|
||||||
|
in_jsdoc = True
|
||||||
|
i += 3
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(content[i])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def strip_docstrings(cls, content: str, language: str) -> str:
|
||||||
|
"""Strip docstrings based on language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Source code content
|
||||||
|
language: Programming language
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Code with docstrings removed
|
||||||
|
"""
|
||||||
|
if language == "python":
|
||||||
|
return cls.strip_python_docstrings(content)
|
||||||
|
elif language in {"javascript", "typescript"}:
|
||||||
|
return cls.strip_jsdoc_comments(content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
class Chunker:
|
||||||
|
"""Chunk code files for semantic embedding."""
|
||||||
|
|
||||||
|
def __init__(self, config: ChunkConfig | None = None) -> None:
|
||||||
|
self.config = config or ChunkConfig()
|
||||||
|
self._tokenizer = get_default_tokenizer()
|
||||||
|
self._comment_stripper = CommentStripper()
|
||||||
|
self._docstring_stripper = DocstringStripper()
|
||||||
|
|
||||||
|
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
|
||||||
|
"""Process chunk content by stripping comments/docstrings if configured.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Original chunk content
|
||||||
|
language: Programming language
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (processed_content, original_content_if_preserved)
|
||||||
|
"""
|
||||||
|
original = content if self.config.preserve_original else None
|
||||||
|
processed = content
|
||||||
|
|
||||||
|
if self.config.strip_comments:
|
||||||
|
processed = self._comment_stripper.strip_comments(processed, language)
|
||||||
|
|
||||||
|
if self.config.strip_docstrings:
|
||||||
|
processed = self._docstring_stripper.strip_docstrings(processed, language)
|
||||||
|
|
||||||
|
# If nothing changed, don't store original
|
||||||
|
if processed == content:
|
||||||
|
original = None
|
||||||
|
|
||||||
|
return processed, original
|
||||||
|
|
||||||
|
def _estimate_token_count(self, text: str) -> int:
|
||||||
|
"""Estimate token count based on config.
|
||||||
|
|
||||||
|
If skip_token_count is True, uses character-based estimation (char/4).
|
||||||
|
Otherwise, uses accurate tiktoken encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if self.config.skip_token_count:
|
||||||
|
# Fast character-based estimation: ~4 chars per token
|
||||||
|
return max(1, len(text) // 4)
|
||||||
|
return self._tokenizer.count_tokens(text)
|
||||||
|
|
||||||
|
def chunk_by_symbol(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
symbols: List[Symbol],
|
||||||
|
file_path: str | Path,
|
||||||
|
language: str,
|
||||||
|
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||||
|
) -> List[SemanticChunk]:
|
||||||
|
"""Chunk code by extracted symbols (functions, classes).
|
||||||
|
|
||||||
|
Each symbol becomes one chunk with its full content.
|
||||||
|
Large symbols exceeding max_chunk_size are recursively split using sliding window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Source code content
|
||||||
|
symbols: List of extracted symbols
|
||||||
|
file_path: Path to source file
|
||||||
|
language: Programming language
|
||||||
|
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||||
|
"""
|
||||||
|
chunks: List[SemanticChunk] = []
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
|
||||||
|
for symbol in symbols:
|
||||||
|
start_line, end_line = symbol.range
|
||||||
|
# Convert to 0-indexed
|
||||||
|
start_idx = max(0, start_line - 1)
|
||||||
|
end_idx = min(len(lines), end_line)
|
||||||
|
|
||||||
|
chunk_content = "".join(lines[start_idx:end_idx])
|
||||||
|
if len(chunk_content.strip()) < self.config.min_chunk_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if symbol content exceeds max_chunk_size
|
||||||
|
if len(chunk_content) > self.config.max_chunk_size:
|
||||||
|
# Create line mapping for correct line number tracking
|
||||||
|
line_mapping = list(range(start_line, end_line + 1))
|
||||||
|
|
||||||
|
# Use sliding window to split large symbol
|
||||||
|
sub_chunks = self.chunk_sliding_window(
|
||||||
|
chunk_content,
|
||||||
|
file_path=file_path,
|
||||||
|
language=language,
|
||||||
|
line_mapping=line_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update sub_chunks with parent symbol metadata
|
||||||
|
for sub_chunk in sub_chunks:
|
||||||
|
sub_chunk.metadata["symbol_name"] = symbol.name
|
||||||
|
sub_chunk.metadata["symbol_kind"] = symbol.kind
|
||||||
|
sub_chunk.metadata["strategy"] = "symbol_split"
|
||||||
|
sub_chunk.metadata["chunk_type"] = "code"
|
||||||
|
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
|
||||||
|
|
||||||
|
chunks.extend(sub_chunks)
|
||||||
|
else:
|
||||||
|
# Process content (strip comments/docstrings if configured)
|
||||||
|
processed_content, original_content = self._process_content(chunk_content, language)
|
||||||
|
|
||||||
|
# Skip if processed content is too small
|
||||||
|
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate token count if not provided
|
||||||
|
token_count = None
|
||||||
|
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||||
|
token_count = symbol_token_counts[symbol.name]
|
||||||
|
else:
|
||||||
|
token_count = self._estimate_token_count(processed_content)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"file": str(file_path),
|
||||||
|
"language": language,
|
||||||
|
"symbol_name": symbol.name,
|
||||||
|
"symbol_kind": symbol.kind,
|
||||||
|
"start_line": start_line,
|
||||||
|
"end_line": end_line,
|
||||||
|
"strategy": "symbol",
|
||||||
|
"chunk_type": "code",
|
||||||
|
"token_count": token_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store original content if it was modified
|
||||||
|
if original_content is not None:
|
||||||
|
metadata["original_content"] = original_content
|
||||||
|
|
||||||
|
chunks.append(SemanticChunk(
|
||||||
|
content=processed_content,
|
||||||
|
embedding=None,
|
||||||
|
metadata=metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def chunk_sliding_window(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
file_path: str | Path,
|
||||||
|
language: str,
|
||||||
|
line_mapping: Optional[List[int]] = None,
|
||||||
|
) -> List[SemanticChunk]:
|
||||||
|
"""Chunk code using sliding window approach.
|
||||||
|
|
||||||
|
Used for files without clear symbol boundaries or very long functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Source code content
|
||||||
|
file_path: Path to source file
|
||||||
|
language: Programming language
|
||||||
|
line_mapping: Optional list mapping content line indices to original line numbers
|
||||||
|
(1-indexed). If provided, line_mapping[i] is the original line number
|
||||||
|
for the i-th line in content.
|
||||||
|
"""
|
||||||
|
chunks: List[SemanticChunk] = []
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Calculate lines per chunk based on average line length
|
||||||
|
avg_line_len = len(content) / max(len(lines), 1)
|
||||||
|
lines_per_chunk = max(10, int(self.config.max_chunk_size / max(avg_line_len, 1)))
|
||||||
|
overlap_lines = max(2, int(self.config.overlap / max(avg_line_len, 1)))
|
||||||
|
# Ensure overlap is less than chunk size to prevent infinite loop
|
||||||
|
overlap_lines = min(overlap_lines, lines_per_chunk - 1)
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
chunk_idx = 0
|
||||||
|
|
||||||
|
while start < len(lines):
|
||||||
|
end = min(start + lines_per_chunk, len(lines))
|
||||||
|
chunk_content = "".join(lines[start:end])
|
||||||
|
|
||||||
|
if len(chunk_content.strip()) >= self.config.min_chunk_size:
|
||||||
|
# Process content (strip comments/docstrings if configured)
|
||||||
|
processed_content, original_content = self._process_content(chunk_content, language)
|
||||||
|
|
||||||
|
# Skip if processed content is too small
|
||||||
|
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||||
|
# Move window forward
|
||||||
|
step = lines_per_chunk - overlap_lines
|
||||||
|
if step <= 0:
|
||||||
|
step = 1
|
||||||
|
start += step
|
||||||
|
continue
|
||||||
|
|
||||||
|
token_count = self._estimate_token_count(processed_content)
|
||||||
|
|
||||||
|
# Calculate correct line numbers
|
||||||
|
if line_mapping:
|
||||||
|
# Use line mapping to get original line numbers
|
||||||
|
start_line = line_mapping[start]
|
||||||
|
end_line = line_mapping[end - 1]
|
||||||
|
else:
|
||||||
|
# Default behavior: treat content as starting at line 1
|
||||||
|
start_line = start + 1
|
||||||
|
end_line = end
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"file": str(file_path),
|
||||||
|
"language": language,
|
||||||
|
"chunk_index": chunk_idx,
|
||||||
|
"start_line": start_line,
|
||||||
|
"end_line": end_line,
|
||||||
|
"strategy": "sliding_window",
|
||||||
|
"chunk_type": "code",
|
||||||
|
"token_count": token_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store original content if it was modified
|
||||||
|
if original_content is not None:
|
||||||
|
metadata["original_content"] = original_content
|
||||||
|
|
||||||
|
chunks.append(SemanticChunk(
|
||||||
|
content=processed_content,
|
||||||
|
embedding=None,
|
||||||
|
metadata=metadata
|
||||||
|
))
|
||||||
|
chunk_idx += 1
|
||||||
|
|
||||||
|
# Move window, accounting for overlap
|
||||||
|
step = lines_per_chunk - overlap_lines
|
||||||
|
if step <= 0:
|
||||||
|
step = 1 # Failsafe to prevent infinite loop
|
||||||
|
start += step
|
||||||
|
|
||||||
|
# Break if we've reached the end
|
||||||
|
if end >= len(lines):
|
||||||
|
break
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def chunk_file(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
symbols: List[Symbol],
|
||||||
|
file_path: str | Path,
|
||||||
|
language: str,
|
||||||
|
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||||
|
) -> List[SemanticChunk]:
|
||||||
|
"""Chunk a file using the best strategy.
|
||||||
|
|
||||||
|
Uses symbol-based chunking if symbols available,
|
||||||
|
falls back to sliding window for files without symbols.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Source code content
|
||||||
|
symbols: List of extracted symbols
|
||||||
|
file_path: Path to source file
|
||||||
|
language: Programming language
|
||||||
|
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||||
|
"""
|
||||||
|
if symbols:
|
||||||
|
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
|
||||||
|
return self.chunk_sliding_window(content, file_path, language)
|
||||||
|
|
||||||
|
class DocstringExtractor:
|
||||||
|
"""Extract docstrings from source code."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
|
||||||
|
"""Extract Python docstrings with their line ranges.
|
||||||
|
|
||||||
|
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||||
|
"""
|
||||||
|
docstrings: List[Tuple[str, int, int]] = []
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(lines):
|
||||||
|
line = lines[i]
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||||
|
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||||
|
start_line = i + 1
|
||||||
|
|
||||||
|
if stripped.count(quote_type) >= 2:
|
||||||
|
docstring_content = line
|
||||||
|
end_line = i + 1
|
||||||
|
docstrings.append((docstring_content, start_line, end_line))
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
docstring_lines = [line]
|
||||||
|
i += 1
|
||||||
|
while i < len(lines):
|
||||||
|
docstring_lines.append(lines[i])
|
||||||
|
if quote_type in lines[i]:
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
end_line = i + 1
|
||||||
|
docstring_content = "".join(docstring_lines)
|
||||||
|
docstrings.append((docstring_content, start_line, end_line))
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return docstrings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
|
||||||
|
"""Extract JSDoc comments with their line ranges.
|
||||||
|
|
||||||
|
Returns: List of (comment_content, start_line, end_line) tuples
|
||||||
|
"""
|
||||||
|
comments: List[Tuple[str, int, int]] = []
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(lines):
|
||||||
|
line = lines[i]
|
||||||
|
stripped = line.strip()
|
||||||
|
|
||||||
|
if stripped.startswith('/**'):
|
||||||
|
start_line = i + 1
|
||||||
|
comment_lines = [line]
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
while i < len(lines):
|
||||||
|
comment_lines.append(lines[i])
|
||||||
|
if '*/' in lines[i]:
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
end_line = i + 1
|
||||||
|
comment_content = "".join(comment_lines)
|
||||||
|
comments.append((comment_content, start_line, end_line))
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return comments
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_docstrings(
|
||||||
|
cls,
|
||||||
|
content: str,
|
||||||
|
language: str
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
"""Extract docstrings based on language.
|
||||||
|
|
||||||
|
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||||
|
"""
|
||||||
|
if language == "python":
|
||||||
|
return cls.extract_python_docstrings(content)
|
||||||
|
elif language in {"javascript", "typescript"}:
|
||||||
|
return cls.extract_jsdoc_comments(content)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class HybridChunker:
|
||||||
|
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
|
||||||
|
|
||||||
|
Composition-based strategy that:
|
||||||
|
1. Extracts docstrings as dedicated chunks
|
||||||
|
2. For remaining code, uses base chunker (symbol or sliding window)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_chunker: Chunker | None = None,
|
||||||
|
config: ChunkConfig | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Initialize hybrid chunker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_chunker: Chunker to use for non-docstring content
|
||||||
|
config: Configuration for chunking
|
||||||
|
"""
|
||||||
|
self.config = config or ChunkConfig()
|
||||||
|
self.base_chunker = base_chunker or Chunker(self.config)
|
||||||
|
self.docstring_extractor = DocstringExtractor()
|
||||||
|
|
||||||
|
def _get_excluded_line_ranges(
|
||||||
|
self,
|
||||||
|
docstrings: List[Tuple[str, int, int]]
|
||||||
|
) -> set[int]:
|
||||||
|
"""Get set of line numbers that are part of docstrings."""
|
||||||
|
excluded_lines: set[int] = set()
|
||||||
|
for _, start_line, end_line in docstrings:
|
||||||
|
for line_num in range(start_line, end_line + 1):
|
||||||
|
excluded_lines.add(line_num)
|
||||||
|
return excluded_lines
|
||||||
|
|
||||||
|
def _filter_symbols_outside_docstrings(
|
||||||
|
self,
|
||||||
|
symbols: List[Symbol],
|
||||||
|
excluded_lines: set[int]
|
||||||
|
) -> List[Symbol]:
|
||||||
|
"""Filter symbols to exclude those completely within docstrings."""
|
||||||
|
filtered: List[Symbol] = []
|
||||||
|
for symbol in symbols:
|
||||||
|
start_line, end_line = symbol.range
|
||||||
|
symbol_lines = set(range(start_line, end_line + 1))
|
||||||
|
if not symbol_lines.issubset(excluded_lines):
|
||||||
|
filtered.append(symbol)
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
def _find_parent_symbol(
|
||||||
|
self,
|
||||||
|
start_line: int,
|
||||||
|
end_line: int,
|
||||||
|
symbols: List[Symbol],
|
||||||
|
) -> Optional[Symbol]:
|
||||||
|
"""Find the smallest symbol range that fully contains a docstring span."""
|
||||||
|
candidates: List[Symbol] = []
|
||||||
|
for symbol in symbols:
|
||||||
|
sym_start, sym_end = symbol.range
|
||||||
|
if sym_start <= start_line and end_line <= sym_end:
|
||||||
|
candidates.append(symbol)
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
|
||||||
|
|
||||||
|
def chunk_file(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
symbols: List[Symbol],
|
||||||
|
file_path: str | Path,
|
||||||
|
language: str,
|
||||||
|
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||||
|
) -> List[SemanticChunk]:
|
||||||
|
"""Chunk file using hybrid strategy.
|
||||||
|
|
||||||
|
Extracts docstrings first, then chunks remaining code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Source code content
|
||||||
|
symbols: List of extracted symbols
|
||||||
|
file_path: Path to source file
|
||||||
|
language: Programming language
|
||||||
|
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||||
|
"""
|
||||||
|
chunks: List[SemanticChunk] = []
|
||||||
|
|
||||||
|
# Step 1: Extract docstrings as dedicated chunks
|
||||||
|
docstrings: List[Tuple[str, int, int]] = []
|
||||||
|
if language == "python":
|
||||||
|
# Fast path: avoid expensive docstring extraction if delimiters are absent.
|
||||||
|
if '"""' in content or "'''" in content:
|
||||||
|
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||||
|
elif language in {"javascript", "typescript"}:
|
||||||
|
if "/**" in content:
|
||||||
|
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||||
|
else:
|
||||||
|
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||||
|
|
||||||
|
# Fast path: no docstrings -> delegate to base chunker directly.
|
||||||
|
if not docstrings:
|
||||||
|
if symbols:
|
||||||
|
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||||
|
content, symbols, file_path, language, symbol_token_counts
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
|
||||||
|
|
||||||
|
for chunk in base_chunks:
|
||||||
|
chunk.metadata["strategy"] = "hybrid"
|
||||||
|
chunk.metadata["chunk_type"] = "code"
|
||||||
|
return base_chunks
|
||||||
|
|
||||||
|
for docstring_content, start_line, end_line in docstrings:
|
||||||
|
if len(docstring_content.strip()) >= self.config.min_chunk_size:
|
||||||
|
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
|
||||||
|
# Use base chunker's token estimation method
|
||||||
|
token_count = self.base_chunker._estimate_token_count(docstring_content)
|
||||||
|
metadata = {
|
||||||
|
"file": str(file_path),
|
||||||
|
"language": language,
|
||||||
|
"chunk_type": "docstring",
|
||||||
|
"start_line": start_line,
|
||||||
|
"end_line": end_line,
|
||||||
|
"strategy": "hybrid",
|
||||||
|
"token_count": token_count,
|
||||||
|
}
|
||||||
|
if parent_symbol is not None:
|
||||||
|
metadata["parent_symbol"] = parent_symbol.name
|
||||||
|
metadata["parent_symbol_kind"] = parent_symbol.kind
|
||||||
|
metadata["parent_symbol_range"] = parent_symbol.range
|
||||||
|
chunks.append(SemanticChunk(
|
||||||
|
content=docstring_content,
|
||||||
|
embedding=None,
|
||||||
|
metadata=metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
# Step 2: Get line ranges occupied by docstrings
|
||||||
|
excluded_lines = self._get_excluded_line_ranges(docstrings)
|
||||||
|
|
||||||
|
# Step 3: Filter symbols to exclude docstring-only ranges
|
||||||
|
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
|
||||||
|
|
||||||
|
# Step 4: Chunk remaining content using base chunker
|
||||||
|
if filtered_symbols:
|
||||||
|
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||||
|
content, filtered_symbols, file_path, language, symbol_token_counts
|
||||||
|
)
|
||||||
|
for chunk in base_chunks:
|
||||||
|
chunk.metadata["strategy"] = "hybrid"
|
||||||
|
chunk.metadata["chunk_type"] = "code"
|
||||||
|
chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
remaining_lines: List[str] = []
|
||||||
|
|
||||||
|
for i, line in enumerate(lines, start=1):
|
||||||
|
if i not in excluded_lines:
|
||||||
|
remaining_lines.append(line)
|
||||||
|
|
||||||
|
if remaining_lines:
|
||||||
|
remaining_content = "".join(remaining_lines)
|
||||||
|
if len(remaining_content.strip()) >= self.config.min_chunk_size:
|
||||||
|
base_chunks = self.base_chunker.chunk_sliding_window(
|
||||||
|
remaining_content, file_path, language
|
||||||
|
)
|
||||||
|
for chunk in base_chunks:
|
||||||
|
chunk.metadata["strategy"] = "hybrid"
|
||||||
|
chunk.metadata["chunk_type"] = "code"
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""Smart code extraction for complete code blocks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from codexlens.entities import SearchResult, Symbol
|
||||||
|
|
||||||
|
|
||||||
|
def extract_complete_code_block(
|
||||||
|
result: SearchResult,
|
||||||
|
source_file_path: Optional[str] = None,
|
||||||
|
context_lines: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Extract complete code block from a search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: SearchResult from semantic search.
|
||||||
|
source_file_path: Optional path to source file for re-reading.
|
||||||
|
context_lines: Additional lines of context to include above/below.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete code block as string.
|
||||||
|
"""
|
||||||
|
# If we have full content stored, use it
|
||||||
|
if result.content:
|
||||||
|
if context_lines == 0:
|
||||||
|
return result.content
|
||||||
|
# Need to add context, read from file
|
||||||
|
|
||||||
|
# Try to read from source file
|
||||||
|
file_path = source_file_path or result.path
|
||||||
|
if not file_path or not Path(file_path).exists():
|
||||||
|
# Fall back to excerpt
|
||||||
|
return result.excerpt or ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||||
|
lines = content.splitlines()
|
||||||
|
|
||||||
|
# Get line range
|
||||||
|
start_line = result.start_line or 1
|
||||||
|
end_line = result.end_line or len(lines)
|
||||||
|
|
||||||
|
# Add context
|
||||||
|
start_idx = max(0, start_line - 1 - context_lines)
|
||||||
|
end_idx = min(len(lines), end_line + context_lines)
|
||||||
|
|
||||||
|
return "\n".join(lines[start_idx:end_idx])
|
||||||
|
except Exception:
|
||||||
|
return result.excerpt or result.content or ""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_symbol_with_context(
|
||||||
|
file_path: str,
|
||||||
|
symbol: Symbol,
|
||||||
|
include_docstring: bool = True,
|
||||||
|
include_decorators: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Extract a symbol (function/class) with its docstring and decorators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to source file.
|
||||||
|
symbol: Symbol to extract.
|
||||||
|
include_docstring: Include docstring if present.
|
||||||
|
include_decorators: Include decorators/annotations above symbol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete symbol code with context.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||||
|
lines = content.splitlines()
|
||||||
|
|
||||||
|
start_line, end_line = symbol.range
|
||||||
|
start_idx = start_line - 1
|
||||||
|
end_idx = end_line
|
||||||
|
|
||||||
|
# Look for decorators above the symbol
|
||||||
|
if include_decorators and start_idx > 0:
|
||||||
|
decorator_start = start_idx
|
||||||
|
# Search backwards for decorators
|
||||||
|
i = start_idx - 1
|
||||||
|
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
|
||||||
|
line = lines[i].strip()
|
||||||
|
if line.startswith("@"):
|
||||||
|
decorator_start = i
|
||||||
|
i -= 1
|
||||||
|
elif line == "" or line.startswith("#"):
|
||||||
|
# Skip empty lines and comments, continue looking
|
||||||
|
i -= 1
|
||||||
|
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
|
||||||
|
# JavaScript/Java style comments
|
||||||
|
decorator_start = i
|
||||||
|
i -= 1
|
||||||
|
else:
|
||||||
|
# Found non-decorator, non-comment line, stop
|
||||||
|
break
|
||||||
|
start_idx = decorator_start
|
||||||
|
|
||||||
|
return "\n".join(lines[start_idx:end_idx])
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def format_search_result_code(
|
||||||
|
result: SearchResult,
|
||||||
|
max_lines: Optional[int] = None,
|
||||||
|
show_line_numbers: bool = True,
|
||||||
|
highlight_match: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Format search result code for display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: SearchResult to format.
|
||||||
|
max_lines: Maximum lines to show (None for all).
|
||||||
|
show_line_numbers: Include line numbers in output.
|
||||||
|
highlight_match: Add markers for matched region.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted code string.
|
||||||
|
"""
|
||||||
|
content = result.content or result.excerpt or ""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
truncated = False
|
||||||
|
if max_lines and len(lines) > max_lines:
|
||||||
|
lines = lines[:max_lines]
|
||||||
|
truncated = True
|
||||||
|
|
||||||
|
# Format with line numbers
|
||||||
|
if show_line_numbers:
|
||||||
|
start = result.start_line or 1
|
||||||
|
formatted_lines = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
line_num = start + i
|
||||||
|
formatted_lines.append(f"{line_num:4d} | {line}")
|
||||||
|
output = "\n".join(formatted_lines)
|
||||||
|
else:
|
||||||
|
output = "\n".join(lines)
|
||||||
|
|
||||||
|
if truncated:
|
||||||
|
output += "\n... (truncated)"
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_code_block_summary(result: SearchResult) -> str:
|
||||||
|
"""Get a concise summary of a code block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: SearchResult to summarize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary string like "function hello_world (lines 10-25)"
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if result.symbol_kind:
|
||||||
|
parts.append(result.symbol_kind)
|
||||||
|
|
||||||
|
if result.symbol_name:
|
||||||
|
parts.append(f"`{result.symbol_name}`")
|
||||||
|
elif result.excerpt:
|
||||||
|
# Extract first meaningful identifier
|
||||||
|
first_line = result.excerpt.split("\n")[0][:50]
|
||||||
|
parts.append(f'"{first_line}..."')
|
||||||
|
|
||||||
|
if result.start_line and result.end_line:
|
||||||
|
if result.start_line == result.end_line:
|
||||||
|
parts.append(f"(line {result.start_line})")
|
||||||
|
else:
|
||||||
|
parts.append(f"(lines {result.start_line}-{result.end_line})")
|
||||||
|
|
||||||
|
if result.path:
|
||||||
|
file_name = Path(result.path).name
|
||||||
|
parts.append(f"in {file_name}")
|
||||||
|
|
||||||
|
return " ".join(parts) if parts else "unknown code block"
|
||||||
|
|
||||||
|
|
||||||
|
class CodeBlockResult:
|
||||||
|
"""Enhanced search result with complete code block."""
|
||||||
|
|
||||||
|
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
|
||||||
|
self.result = result
|
||||||
|
self.source_path = source_path or result.path
|
||||||
|
self._full_code: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float:
|
||||||
|
return self.result.score
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self) -> str:
|
||||||
|
return self.result.path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file_name(self) -> str:
|
||||||
|
return Path(self.result.path).name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def symbol_name(self) -> Optional[str]:
|
||||||
|
return self.result.symbol_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def symbol_kind(self) -> Optional[str]:
|
||||||
|
return self.result.symbol_kind
|
||||||
|
|
||||||
|
@property
|
||||||
|
def line_range(self) -> Tuple[int, int]:
|
||||||
|
return (
|
||||||
|
self.result.start_line or 1,
|
||||||
|
self.result.end_line or 1
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_code(self) -> str:
|
||||||
|
"""Get full code block content."""
|
||||||
|
if self._full_code is None:
|
||||||
|
self._full_code = extract_complete_code_block(self.result, self.source_path)
|
||||||
|
return self._full_code
|
||||||
|
|
||||||
|
@property
|
||||||
|
def excerpt(self) -> str:
|
||||||
|
"""Get short excerpt."""
|
||||||
|
return self.result.excerpt or ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Get code block summary."""
|
||||||
|
return get_code_block_summary(self.result)
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
max_lines: Optional[int] = None,
|
||||||
|
show_line_numbers: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Format code for display."""
|
||||||
|
# Use full code if available
|
||||||
|
display_result = SearchResult(
|
||||||
|
path=self.result.path,
|
||||||
|
score=self.result.score,
|
||||||
|
content=self.full_code,
|
||||||
|
start_line=self.result.start_line,
|
||||||
|
end_line=self.result.end_line,
|
||||||
|
)
|
||||||
|
return format_search_result_code(
|
||||||
|
display_result,
|
||||||
|
max_lines=max_lines,
|
||||||
|
show_line_numbers=show_line_numbers
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
|
||||||
|
|
||||||
|
|
||||||
|
def enhance_search_results(
|
||||||
|
results: List[SearchResult],
|
||||||
|
) -> List[CodeBlockResult]:
|
||||||
|
"""Enhance search results with complete code block access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of SearchResult from semantic search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of CodeBlockResult with full code access.
|
||||||
|
"""
|
||||||
|
return [CodeBlockResult(r) for r in results]
|
||||||
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""Embedder for semantic code search using fastembed.
|
||||||
|
|
||||||
|
Supports GPU acceleration via ONNX execution providers (CUDA, TensorRT, DirectML, ROCm, CoreML).
|
||||||
|
GPU acceleration is automatic when available, with transparent CPU fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import SEMANTIC_AVAILABLE
|
||||||
|
from .base import BaseEmbedder
|
||||||
|
from .gpu_support import get_optimal_providers, is_gpu_available, get_gpu_summary, get_selected_device_id
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global embedder cache for singleton pattern
|
||||||
|
_embedder_cache: Dict[str, "Embedder"] = {}
|
||||||
|
_cache_lock = threading.RLock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
|
||||||
|
"""Get or create a cached Embedder instance (thread-safe singleton).
|
||||||
|
|
||||||
|
This function provides significant performance improvement by reusing
|
||||||
|
Embedder instances across multiple searches, avoiding repeated model
|
||||||
|
loading overhead (~0.8s per load).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile: Model profile ("fast", "code", "multilingual", "balanced")
|
||||||
|
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached Embedder instance for the given profile
|
||||||
|
"""
|
||||||
|
global _embedder_cache
|
||||||
|
|
||||||
|
# Cache key includes GPU preference to support mixed configurations
|
||||||
|
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
|
||||||
|
|
||||||
|
# All cache access is protected by _cache_lock to avoid races with
|
||||||
|
# clear_embedder_cache() during concurrent access.
|
||||||
|
with _cache_lock:
|
||||||
|
embedder = _embedder_cache.get(cache_key)
|
||||||
|
if embedder is not None:
|
||||||
|
return embedder
|
||||||
|
|
||||||
|
# Create new embedder and cache it
|
||||||
|
embedder = Embedder(profile=profile, use_gpu=use_gpu)
|
||||||
|
# Pre-load model to ensure it's ready
|
||||||
|
embedder._load_model()
|
||||||
|
_embedder_cache[cache_key] = embedder
|
||||||
|
|
||||||
|
# Log GPU status on first embedder creation
|
||||||
|
if use_gpu and is_gpu_available():
|
||||||
|
logger.info(f"Embedder initialized with GPU: {get_gpu_summary()}")
|
||||||
|
elif use_gpu:
|
||||||
|
logger.debug("GPU not available, using CPU for embeddings")
|
||||||
|
|
||||||
|
return embedder
|
||||||
|
|
||||||
|
|
||||||
|
def clear_embedder_cache() -> None:
|
||||||
|
"""Clear the embedder cache and release ONNX resources.
|
||||||
|
|
||||||
|
This method ensures proper cleanup of ONNX model resources to prevent
|
||||||
|
memory leaks when embedders are no longer needed.
|
||||||
|
"""
|
||||||
|
global _embedder_cache
|
||||||
|
with _cache_lock:
|
||||||
|
# Release ONNX resources before clearing cache
|
||||||
|
for embedder in _embedder_cache.values():
|
||||||
|
if embedder._model is not None:
|
||||||
|
del embedder._model
|
||||||
|
embedder._model = None
|
||||||
|
_embedder_cache.clear()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
class Embedder(BaseEmbedder):
|
||||||
|
"""Generate embeddings for code chunks using fastembed (ONNX-based).
|
||||||
|
|
||||||
|
Supported Model Profiles:
|
||||||
|
- fast: BAAI/bge-small-en-v1.5 (384 dim) - Fast, lightweight, English-optimized
|
||||||
|
- code: jinaai/jina-embeddings-v2-base-code (768 dim) - Code-optimized, best for programming languages
|
||||||
|
- multilingual: intfloat/multilingual-e5-large (1024 dim) - Multilingual + code support
|
||||||
|
- balanced: mixedbread-ai/mxbai-embed-large-v1 (1024 dim) - High accuracy, general purpose
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model profiles for different use cases
|
||||||
|
MODELS = {
|
||||||
|
"fast": "BAAI/bge-small-en-v1.5", # 384 dim - Fast, lightweight
|
||||||
|
"code": "jinaai/jina-embeddings-v2-base-code", # 768 dim - Code-optimized
|
||||||
|
"multilingual": "intfloat/multilingual-e5-large", # 1024 dim - Multilingual
|
||||||
|
"balanced": "mixedbread-ai/mxbai-embed-large-v1", # 1024 dim - High accuracy
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dimension mapping for each model
|
||||||
|
MODEL_DIMS = {
|
||||||
|
"BAAI/bge-small-en-v1.5": 384,
|
||||||
|
"jinaai/jina-embeddings-v2-base-code": 768,
|
||||||
|
"intfloat/multilingual-e5-large": 1024,
|
||||||
|
"mixedbread-ai/mxbai-embed-large-v1": 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default model (fast profile)
|
||||||
|
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
|
||||||
|
DEFAULT_PROFILE = "fast"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str | None = None,
|
||||||
|
profile: str | None = None,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
providers: List[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize embedder with model or profile.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Explicit model name (e.g., "jinaai/jina-embeddings-v2-base-code")
|
||||||
|
profile: Model profile shortcut ("fast", "code", "multilingual", "balanced")
|
||||||
|
If both provided, model_name takes precedence.
|
||||||
|
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||||
|
providers: Explicit ONNX providers list (overrides use_gpu if provided)
|
||||||
|
"""
|
||||||
|
if not SEMANTIC_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"Semantic search dependencies not available. "
|
||||||
|
"Install with: pip install codexlens[semantic]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve model name from profile or use explicit name
|
||||||
|
if model_name:
|
||||||
|
self._model_name = model_name
|
||||||
|
elif profile and profile in self.MODELS:
|
||||||
|
self._model_name = self.MODELS[profile]
|
||||||
|
else:
|
||||||
|
self._model_name = self.DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Configure ONNX execution providers with device_id options for GPU selection
|
||||||
|
# Using with_device_options=True ensures DirectML/CUDA device_id is passed correctly
|
||||||
|
if providers is not None:
|
||||||
|
self._providers = providers
|
||||||
|
else:
|
||||||
|
self._providers = get_optimal_providers(use_gpu=use_gpu, with_device_options=True)
|
||||||
|
|
||||||
|
self._use_gpu = use_gpu
|
||||||
|
self._model = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Get model name."""
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
"""Get embedding dimension for current model."""
|
||||||
|
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_tokens(self) -> int:
|
||||||
|
"""Get maximum token limit for current model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Maximum number of tokens based on model profile.
|
||||||
|
- fast: 512 (lightweight, optimized for speed)
|
||||||
|
- code: 8192 (code-optimized, larger context)
|
||||||
|
- multilingual: 512 (standard multilingual model)
|
||||||
|
- balanced: 512 (general purpose)
|
||||||
|
"""
|
||||||
|
# Determine profile from model name
|
||||||
|
profile = None
|
||||||
|
for prof, model in self.MODELS.items():
|
||||||
|
if model == self._model_name:
|
||||||
|
profile = prof
|
||||||
|
break
|
||||||
|
|
||||||
|
# Return token limit based on profile
|
||||||
|
if profile == "code":
|
||||||
|
return 8192
|
||||||
|
elif profile in ("fast", "multilingual", "balanced"):
|
||||||
|
return 512
|
||||||
|
else:
|
||||||
|
# Default for unknown models
|
||||||
|
return 512
|
||||||
|
|
||||||
|
@property
|
||||||
|
def providers(self) -> List[str]:
|
||||||
|
"""Get configured ONNX execution providers."""
|
||||||
|
return self._providers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_gpu_enabled(self) -> bool:
|
||||||
|
"""Check if GPU acceleration is enabled for this embedder."""
|
||||||
|
gpu_providers = {"CUDAExecutionProvider", "TensorrtExecutionProvider",
|
||||||
|
"DmlExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"}
|
||||||
|
# Handle both string providers and tuple providers (name, options)
|
||||||
|
for p in self._providers:
|
||||||
|
provider_name = p[0] if isinstance(p, tuple) else p
|
||||||
|
if provider_name in gpu_providers:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Lazy load the embedding model with configured providers."""
|
||||||
|
if self._model is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from fastembed import TextEmbedding
|
||||||
|
|
||||||
|
# providers already include device_id options via get_optimal_providers(with_device_options=True)
|
||||||
|
# DO NOT pass device_ids separately - fastembed ignores it when providers is specified
|
||||||
|
# See: fastembed/text/onnx_embedding.py - device_ids is only used with cuda=True
|
||||||
|
try:
|
||||||
|
self._model = TextEmbedding(
|
||||||
|
model_name=self.model_name,
|
||||||
|
providers=self._providers,
|
||||||
|
)
|
||||||
|
logger.debug(f"Model loaded with providers: {self._providers}")
|
||||||
|
except TypeError:
|
||||||
|
# Fallback for older fastembed versions without providers parameter
|
||||||
|
logger.warning(
|
||||||
|
"fastembed version doesn't support 'providers' parameter. "
|
||||||
|
"Upgrade fastembed for GPU acceleration: pip install --upgrade fastembed"
|
||||||
|
)
|
||||||
|
self._model = TextEmbedding(model_name=self.model_name)
|
||||||
|
|
||||||
|
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for one or more texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors (each is a list of floats).
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method converts numpy arrays to Python lists for backward compatibility.
|
||||||
|
For memory-efficient processing, use embed_to_numpy() instead.
|
||||||
|
"""
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
else:
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
embeddings = list(self._model.embed(texts))
|
||||||
|
return [emb.tolist() for emb in embeddings]
|
||||||
|
|
||||||
|
def embed_to_numpy(self, texts: str | Iterable[str], batch_size: Optional[int] = None) -> np.ndarray:
|
||||||
|
"""Generate embeddings for one or more texts (returns numpy arrays).
|
||||||
|
|
||||||
|
This method is more memory-efficient than embed() as it avoids converting
|
||||||
|
numpy arrays to Python lists, which can significantly reduce memory usage
|
||||||
|
during batch processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts to embed.
|
||||||
|
batch_size: Optional batch size for fastembed processing.
|
||||||
|
Larger values improve GPU utilization but use more memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray of shape (n_texts, embedding_dim) containing embeddings.
|
||||||
|
"""
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
else:
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
# Pass batch_size to fastembed for optimal GPU utilization
|
||||||
|
# Default batch_size in fastembed is 256, but larger values can improve throughput
|
||||||
|
if batch_size is not None:
|
||||||
|
embeddings = list(self._model.embed(texts, batch_size=batch_size))
|
||||||
|
else:
|
||||||
|
embeddings = list(self._model.embed(texts))
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
def embed_single(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding for a single text."""
|
||||||
|
return self.embed(text)[0]
|
||||||
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Factory for creating embedders.
|
||||||
|
|
||||||
|
Provides a unified interface for instantiating different embedder backends.
|
||||||
|
Includes caching to avoid repeated model loading overhead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from .base import BaseEmbedder
|
||||||
|
|
||||||
|
# Module-level cache for embedder instances
|
||||||
|
# Key: (backend, profile, model, use_gpu) -> embedder instance
|
||||||
|
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
|
||||||
|
_cache_lock = threading.Lock()
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedder(
|
||||||
|
backend: str = "fastembed",
|
||||||
|
profile: str = "code",
|
||||||
|
model: str = "default",
|
||||||
|
use_gpu: bool = True,
|
||||||
|
endpoints: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
strategy: str = "latency_aware",
|
||||||
|
cooldown: float = 60.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseEmbedder:
|
||||||
|
"""Factory function to create embedder based on backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Embedder backend to use. Options:
|
||||||
|
- "fastembed": Use fastembed (ONNX-based) embedder (default)
|
||||||
|
- "litellm": Use ccw-litellm embedder
|
||||||
|
profile: Model profile for fastembed backend ("fast", "code", "multilingual", "balanced")
|
||||||
|
Used only when backend="fastembed". Default: "code"
|
||||||
|
model: Model identifier for litellm backend.
|
||||||
|
Used only when backend="litellm". Default: "default"
|
||||||
|
use_gpu: Whether to use GPU acceleration when available (default: True).
|
||||||
|
Used only when backend="fastembed".
|
||||||
|
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
|
||||||
|
Each endpoint is a dict with keys: model, api_key, api_base, weight.
|
||||||
|
Used only when backend="litellm" and multiple endpoints provided.
|
||||||
|
strategy: Selection strategy for multi-endpoint mode:
|
||||||
|
"round_robin", "latency_aware", "weighted_random".
|
||||||
|
Default: "latency_aware"
|
||||||
|
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
|
||||||
|
**kwargs: Additional backend-specific arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseEmbedder: Configured embedder instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If backend is not recognized
|
||||||
|
ImportError: If required backend dependencies are not installed
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Create fastembed embedder with code profile:
|
||||||
|
>>> embedder = get_embedder(backend="fastembed", profile="code")
|
||||||
|
|
||||||
|
Create fastembed embedder with fast profile and CPU only:
|
||||||
|
>>> embedder = get_embedder(backend="fastembed", profile="fast", use_gpu=False)
|
||||||
|
|
||||||
|
Create litellm embedder:
|
||||||
|
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
|
||||||
|
|
||||||
|
Create rotational embedder with multiple endpoints:
|
||||||
|
>>> endpoints = [
|
||||||
|
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||||
|
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||||
|
... ]
|
||||||
|
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
||||||
|
"""
|
||||||
|
# Build cache key from immutable configuration
|
||||||
|
if backend == "fastembed":
|
||||||
|
cache_key = ("fastembed", profile, None, use_gpu)
|
||||||
|
elif backend == "litellm":
|
||||||
|
# For litellm, use model as part of cache key
|
||||||
|
# Multi-endpoint mode is not cached as it's more complex
|
||||||
|
if endpoints and len(endpoints) > 1:
|
||||||
|
cache_key = None # Skip cache for multi-endpoint
|
||||||
|
else:
|
||||||
|
effective_model = endpoints[0]["model"] if endpoints else model
|
||||||
|
cache_key = ("litellm", None, effective_model, None)
|
||||||
|
else:
|
||||||
|
cache_key = None
|
||||||
|
|
||||||
|
# Check cache first (thread-safe)
|
||||||
|
if cache_key is not None:
|
||||||
|
with _cache_lock:
|
||||||
|
if cache_key in _embedder_cache:
|
||||||
|
_logger.debug("Returning cached embedder for %s", cache_key)
|
||||||
|
return _embedder_cache[cache_key]
|
||||||
|
|
||||||
|
# Create new embedder instance
|
||||||
|
embedder: Optional[BaseEmbedder] = None
|
||||||
|
|
||||||
|
if backend == "fastembed":
|
||||||
|
from .embedder import Embedder
|
||||||
|
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||||
|
elif backend == "litellm":
|
||||||
|
# Check if multi-endpoint mode is requested
|
||||||
|
if endpoints and len(endpoints) > 1:
|
||||||
|
from .rotational_embedder import create_rotational_embedder
|
||||||
|
# Multi-endpoint is not cached
|
||||||
|
return create_rotational_embedder(
|
||||||
|
endpoints_config=endpoints,
|
||||||
|
strategy=strategy,
|
||||||
|
default_cooldown=cooldown,
|
||||||
|
)
|
||||||
|
elif endpoints and len(endpoints) == 1:
|
||||||
|
# Single endpoint in list - use it directly
|
||||||
|
ep = endpoints[0]
|
||||||
|
ep_kwargs = {**kwargs}
|
||||||
|
if "api_key" in ep:
|
||||||
|
ep_kwargs["api_key"] = ep["api_key"]
|
||||||
|
if "api_base" in ep:
|
||||||
|
ep_kwargs["api_base"] = ep["api_base"]
|
||||||
|
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||||
|
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||||
|
else:
|
||||||
|
# No endpoints list - use model parameter
|
||||||
|
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||||
|
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown backend: {backend}. "
|
||||||
|
f"Supported backends: 'fastembed', 'litellm'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache the embedder for future use (thread-safe)
|
||||||
|
if cache_key is not None and embedder is not None:
|
||||||
|
with _cache_lock:
|
||||||
|
# Double-check to avoid race condition
|
||||||
|
if cache_key not in _embedder_cache:
|
||||||
|
_embedder_cache[cache_key] = embedder
|
||||||
|
_logger.debug("Cached new embedder for %s", cache_key)
|
||||||
|
else:
|
||||||
|
# Another thread created it already, use that one
|
||||||
|
embedder = _embedder_cache[cache_key]
|
||||||
|
|
||||||
|
return embedder # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def clear_embedder_cache() -> int:
|
||||||
|
"""Clear the embedder cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of embedders cleared from cache
|
||||||
|
"""
|
||||||
|
with _cache_lock:
|
||||||
|
count = len(_embedder_cache)
|
||||||
|
_embedder_cache.clear()
|
||||||
|
_logger.debug("Cleared %d embedders from cache", count)
|
||||||
|
return count
|
||||||
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
"""GPU acceleration support for semantic embeddings.
|
||||||
|
|
||||||
|
This module provides GPU detection, initialization, and fallback handling
|
||||||
|
for ONNX-based embedding generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPUDevice:
|
||||||
|
"""Individual GPU device info."""
|
||||||
|
device_id: int
|
||||||
|
name: str
|
||||||
|
is_discrete: bool # True for discrete GPU (NVIDIA, AMD), False for integrated (Intel UHD)
|
||||||
|
vendor: str # "nvidia", "amd", "intel", "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPUInfo:
|
||||||
|
"""GPU availability and configuration info."""
|
||||||
|
|
||||||
|
gpu_available: bool = False
|
||||||
|
cuda_available: bool = False
|
||||||
|
gpu_count: int = 0
|
||||||
|
gpu_name: Optional[str] = None
|
||||||
|
onnx_providers: List[str] = None
|
||||||
|
devices: List[GPUDevice] = None # List of detected GPU devices
|
||||||
|
preferred_device_id: Optional[int] = None # Preferred GPU for embedding
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.onnx_providers is None:
|
||||||
|
self.onnx_providers = ["CPUExecutionProvider"]
|
||||||
|
if self.devices is None:
|
||||||
|
self.devices = []
|
||||||
|
|
||||||
|
|
||||||
|
_gpu_info_cache: Optional[GPUInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _enumerate_gpus() -> List[GPUDevice]:
|
||||||
|
"""Enumerate available GPU devices using WMI on Windows.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GPUDevice with device info, ordered by device_id.
|
||||||
|
"""
|
||||||
|
devices = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
# Use PowerShell to query GPU information via WMI
|
||||||
|
cmd = [
|
||||||
|
"powershell", "-NoProfile", "-Command",
|
||||||
|
"Get-WmiObject Win32_VideoController | Select-Object DeviceID, Name, AdapterCompatibility | ConvertTo-Json"
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||||
|
|
||||||
|
if result.returncode == 0 and result.stdout.strip():
|
||||||
|
import json
|
||||||
|
gpu_data = json.loads(result.stdout)
|
||||||
|
|
||||||
|
# Handle single GPU case (returns dict instead of list)
|
||||||
|
if isinstance(gpu_data, dict):
|
||||||
|
gpu_data = [gpu_data]
|
||||||
|
|
||||||
|
for idx, gpu in enumerate(gpu_data):
|
||||||
|
name = gpu.get("Name", "Unknown GPU")
|
||||||
|
compat = gpu.get("AdapterCompatibility", "").lower()
|
||||||
|
|
||||||
|
# Determine vendor
|
||||||
|
name_lower = name.lower()
|
||||||
|
if "nvidia" in name_lower or "nvidia" in compat:
|
||||||
|
vendor = "nvidia"
|
||||||
|
is_discrete = True
|
||||||
|
elif "amd" in name_lower or "radeon" in name_lower or "amd" in compat:
|
||||||
|
vendor = "amd"
|
||||||
|
is_discrete = True
|
||||||
|
elif "intel" in name_lower or "intel" in compat:
|
||||||
|
vendor = "intel"
|
||||||
|
# Intel UHD/Iris are integrated, Intel Arc is discrete
|
||||||
|
is_discrete = "arc" in name_lower
|
||||||
|
else:
|
||||||
|
vendor = "unknown"
|
||||||
|
is_discrete = False
|
||||||
|
|
||||||
|
devices.append(GPUDevice(
|
||||||
|
device_id=idx,
|
||||||
|
name=name,
|
||||||
|
is_discrete=is_discrete,
|
||||||
|
vendor=vendor
|
||||||
|
))
|
||||||
|
logger.debug(f"Detected GPU {idx}: {name} (vendor={vendor}, discrete={is_discrete})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"GPU enumeration failed: {e}")
|
||||||
|
|
||||||
|
return devices
|
||||||
|
|
||||||
|
|
||||||
|
def _get_preferred_device_id(devices: List[GPUDevice]) -> Optional[int]:
|
||||||
|
"""Determine the preferred GPU device_id for embedding.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1. NVIDIA discrete GPU (best DirectML/CUDA support)
|
||||||
|
2. AMD discrete GPU
|
||||||
|
3. Intel Arc (discrete)
|
||||||
|
4. Intel integrated (fallback)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
device_id of preferred GPU, or None to use default.
|
||||||
|
"""
|
||||||
|
if not devices:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Priority: NVIDIA > AMD > Intel Arc > Intel integrated
|
||||||
|
priority_order = [
|
||||||
|
("nvidia", True), # NVIDIA discrete
|
||||||
|
("amd", True), # AMD discrete
|
||||||
|
("intel", True), # Intel Arc (discrete)
|
||||||
|
("intel", False), # Intel integrated (fallback)
|
||||||
|
]
|
||||||
|
|
||||||
|
for target_vendor, target_discrete in priority_order:
|
||||||
|
for device in devices:
|
||||||
|
if device.vendor == target_vendor and device.is_discrete == target_discrete:
|
||||||
|
logger.info(f"Preferred GPU: {device.name} (device_id={device.device_id})")
|
||||||
|
return device.device_id
|
||||||
|
|
||||||
|
# If no match, use first device
|
||||||
|
if devices:
|
||||||
|
return devices[0].device_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def detect_gpu(force_refresh: bool = False) -> GPUInfo:
|
||||||
|
"""Detect available GPU resources for embedding acceleration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_refresh: If True, re-detect GPU even if cached.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GPUInfo with detection results.
|
||||||
|
"""
|
||||||
|
global _gpu_info_cache
|
||||||
|
|
||||||
|
if _gpu_info_cache is not None and not force_refresh:
|
||||||
|
return _gpu_info_cache
|
||||||
|
|
||||||
|
info = GPUInfo()
|
||||||
|
|
||||||
|
# Enumerate GPU devices first
|
||||||
|
info.devices = _enumerate_gpus()
|
||||||
|
info.gpu_count = len(info.devices)
|
||||||
|
if info.devices:
|
||||||
|
# Set preferred device (discrete GPU preferred over integrated)
|
||||||
|
info.preferred_device_id = _get_preferred_device_id(info.devices)
|
||||||
|
# Set gpu_name to preferred device name
|
||||||
|
for dev in info.devices:
|
||||||
|
if dev.device_id == info.preferred_device_id:
|
||||||
|
info.gpu_name = dev.name
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check PyTorch CUDA availability (most reliable detection)
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
info.cuda_available = True
|
||||||
|
info.gpu_available = True
|
||||||
|
info.gpu_count = torch.cuda.device_count()
|
||||||
|
if info.gpu_count > 0:
|
||||||
|
info.gpu_name = torch.cuda.get_device_name(0)
|
||||||
|
logger.debug(f"PyTorch CUDA detected: {info.gpu_count} GPU(s)")
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("PyTorch not available for GPU detection")
|
||||||
|
|
||||||
|
# Check ONNX Runtime providers with validation
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
available_providers = ort.get_available_providers()
|
||||||
|
|
||||||
|
# Build provider list with priority order
|
||||||
|
providers = []
|
||||||
|
|
||||||
|
# Test each provider to ensure it actually works
|
||||||
|
def test_provider(provider_name: str) -> bool:
|
||||||
|
"""Test if a provider actually works by creating a dummy session."""
|
||||||
|
try:
|
||||||
|
# Create a minimal ONNX model to test provider
|
||||||
|
import numpy as np
|
||||||
|
# Simple test: just check if provider can be instantiated
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
sess_options.log_severity_level = 4 # Suppress warnings
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# CUDA provider (NVIDIA GPU) - check if CUDA runtime is available
|
||||||
|
if "CUDAExecutionProvider" in available_providers:
|
||||||
|
# Verify CUDA is actually usable by checking for cuBLAS
|
||||||
|
cuda_works = False
|
||||||
|
try:
|
||||||
|
import ctypes
|
||||||
|
# Try to load cuBLAS to verify CUDA installation
|
||||||
|
try:
|
||||||
|
ctypes.CDLL("cublas64_12.dll")
|
||||||
|
cuda_works = True
|
||||||
|
except OSError:
|
||||||
|
try:
|
||||||
|
ctypes.CDLL("cublas64_11.dll")
|
||||||
|
cuda_works = True
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if cuda_works:
|
||||||
|
providers.append("CUDAExecutionProvider")
|
||||||
|
info.gpu_available = True
|
||||||
|
logger.debug("ONNX CUDAExecutionProvider available and working")
|
||||||
|
else:
|
||||||
|
logger.debug("ONNX CUDAExecutionProvider listed but CUDA runtime not found")
|
||||||
|
|
||||||
|
# TensorRT provider (optimized NVIDIA inference)
|
||||||
|
if "TensorrtExecutionProvider" in available_providers:
|
||||||
|
# TensorRT requires additional libraries, skip for now
|
||||||
|
logger.debug("ONNX TensorrtExecutionProvider available (requires TensorRT SDK)")
|
||||||
|
|
||||||
|
# DirectML provider (Windows GPU - AMD/Intel/NVIDIA)
|
||||||
|
if "DmlExecutionProvider" in available_providers:
|
||||||
|
providers.append("DmlExecutionProvider")
|
||||||
|
info.gpu_available = True
|
||||||
|
logger.debug("ONNX DmlExecutionProvider available (DirectML)")
|
||||||
|
|
||||||
|
# ROCm provider (AMD GPU on Linux)
|
||||||
|
if "ROCMExecutionProvider" in available_providers:
|
||||||
|
providers.append("ROCMExecutionProvider")
|
||||||
|
info.gpu_available = True
|
||||||
|
logger.debug("ONNX ROCMExecutionProvider available (AMD)")
|
||||||
|
|
||||||
|
# CoreML provider (Apple Silicon)
|
||||||
|
if "CoreMLExecutionProvider" in available_providers:
|
||||||
|
providers.append("CoreMLExecutionProvider")
|
||||||
|
info.gpu_available = True
|
||||||
|
logger.debug("ONNX CoreMLExecutionProvider available (Apple)")
|
||||||
|
|
||||||
|
# Always include CPU as fallback
|
||||||
|
providers.append("CPUExecutionProvider")
|
||||||
|
|
||||||
|
info.onnx_providers = providers
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("ONNX Runtime not available")
|
||||||
|
info.onnx_providers = ["CPUExecutionProvider"]
|
||||||
|
|
||||||
|
_gpu_info_cache = info
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimal_providers(use_gpu: bool = True, with_device_options: bool = False) -> list:
|
||||||
|
"""Get optimal ONNX execution providers based on availability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_gpu: If True, include GPU providers when available.
|
||||||
|
If False, force CPU-only execution.
|
||||||
|
with_device_options: If True, return providers as tuples with device_id options
|
||||||
|
for proper GPU device selection (required for DirectML).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of provider names or tuples (provider_name, options_dict) in priority order.
|
||||||
|
"""
|
||||||
|
if not use_gpu:
|
||||||
|
return ["CPUExecutionProvider"]
|
||||||
|
|
||||||
|
gpu_info = detect_gpu()
|
||||||
|
|
||||||
|
# Check if GPU was requested but not available - log warning
|
||||||
|
if not gpu_info.gpu_available:
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
available_providers = ort.get_available_providers()
|
||||||
|
except ImportError:
|
||||||
|
available_providers = []
|
||||||
|
logger.warning(
|
||||||
|
"GPU acceleration was requested, but no supported GPU provider (CUDA, DirectML) "
|
||||||
|
f"was found. Available providers: {available_providers}. Falling back to CPU."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Log which GPU provider is being used
|
||||||
|
gpu_providers = [p for p in gpu_info.onnx_providers if p != "CPUExecutionProvider"]
|
||||||
|
if gpu_providers:
|
||||||
|
logger.info(f"Using {gpu_providers[0]} for ONNX GPU acceleration")
|
||||||
|
|
||||||
|
if not with_device_options:
|
||||||
|
return gpu_info.onnx_providers
|
||||||
|
|
||||||
|
# Build providers with device_id options for GPU providers
|
||||||
|
device_id = get_selected_device_id()
|
||||||
|
providers = []
|
||||||
|
|
||||||
|
for provider in gpu_info.onnx_providers:
|
||||||
|
if provider == "DmlExecutionProvider" and device_id is not None:
|
||||||
|
# DirectML requires device_id in provider_options tuple
|
||||||
|
providers.append(("DmlExecutionProvider", {"device_id": device_id}))
|
||||||
|
logger.debug(f"DmlExecutionProvider configured with device_id={device_id}")
|
||||||
|
elif provider == "CUDAExecutionProvider" and device_id is not None:
|
||||||
|
# CUDA also supports device_id in provider_options
|
||||||
|
providers.append(("CUDAExecutionProvider", {"device_id": device_id}))
|
||||||
|
logger.debug(f"CUDAExecutionProvider configured with device_id={device_id}")
|
||||||
|
elif provider == "ROCMExecutionProvider" and device_id is not None:
|
||||||
|
# ROCm supports device_id
|
||||||
|
providers.append(("ROCMExecutionProvider", {"device_id": device_id}))
|
||||||
|
logger.debug(f"ROCMExecutionProvider configured with device_id={device_id}")
|
||||||
|
else:
|
||||||
|
# CPU and other providers don't need device_id
|
||||||
|
providers.append(provider)
|
||||||
|
|
||||||
|
return providers
|
||||||
|
|
||||||
|
|
||||||
|
def is_gpu_available() -> bool:
|
||||||
|
"""Check if any GPU acceleration is available."""
|
||||||
|
return detect_gpu().gpu_available
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_summary() -> str:
|
||||||
|
"""Get human-readable GPU status summary."""
|
||||||
|
info = detect_gpu()
|
||||||
|
|
||||||
|
if not info.gpu_available:
|
||||||
|
return "GPU: Not available (using CPU)"
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if info.gpu_name:
|
||||||
|
parts.append(f"GPU: {info.gpu_name}")
|
||||||
|
if info.gpu_count > 1:
|
||||||
|
parts.append(f"({info.gpu_count} devices)")
|
||||||
|
|
||||||
|
# Show active providers (excluding CPU fallback)
|
||||||
|
gpu_providers = [p for p in info.onnx_providers if p != "CPUExecutionProvider"]
|
||||||
|
if gpu_providers:
|
||||||
|
parts.append(f"Providers: {', '.join(gpu_providers)}")
|
||||||
|
|
||||||
|
return " | ".join(parts) if parts else "GPU: Available"
|
||||||
|
|
||||||
|
|
||||||
|
def clear_gpu_cache() -> None:
|
||||||
|
"""Clear cached GPU detection info."""
|
||||||
|
global _gpu_info_cache
|
||||||
|
_gpu_info_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
# User-selected device ID (overrides auto-detection)
|
||||||
|
_selected_device_id: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_devices() -> List[dict]:
|
||||||
|
"""Get list of available GPU devices for frontend selection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with device info for each GPU.
|
||||||
|
"""
|
||||||
|
info = detect_gpu()
|
||||||
|
devices = []
|
||||||
|
|
||||||
|
for dev in info.devices:
|
||||||
|
devices.append({
|
||||||
|
"device_id": dev.device_id,
|
||||||
|
"name": dev.name,
|
||||||
|
"vendor": dev.vendor,
|
||||||
|
"is_discrete": dev.is_discrete,
|
||||||
|
"is_preferred": dev.device_id == info.preferred_device_id,
|
||||||
|
"is_selected": dev.device_id == get_selected_device_id(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return devices
|
||||||
|
|
||||||
|
|
||||||
|
def get_selected_device_id() -> Optional[int]:
|
||||||
|
"""Get the user-selected GPU device_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User-selected device_id, or auto-detected preferred device_id if not set.
|
||||||
|
"""
|
||||||
|
global _selected_device_id
|
||||||
|
|
||||||
|
if _selected_device_id is not None:
|
||||||
|
return _selected_device_id
|
||||||
|
|
||||||
|
# Fall back to auto-detected preferred device
|
||||||
|
info = detect_gpu()
|
||||||
|
return info.preferred_device_id
|
||||||
|
|
||||||
|
|
||||||
|
def set_selected_device_id(device_id: Optional[int]) -> bool:
|
||||||
|
"""Set the GPU device_id to use for embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_id: GPU device_id to use, or None to use auto-detection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if device_id is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
global _selected_device_id
|
||||||
|
|
||||||
|
if device_id is None:
|
||||||
|
_selected_device_id = None
|
||||||
|
logger.info("GPU selection reset to auto-detection")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Validate device_id exists
|
||||||
|
info = detect_gpu()
|
||||||
|
valid_ids = [dev.device_id for dev in info.devices]
|
||||||
|
|
||||||
|
if device_id in valid_ids:
|
||||||
|
_selected_device_id = device_id
|
||||||
|
device_name = next((dev.name for dev in info.devices if dev.device_id == device_id), "Unknown")
|
||||||
|
logger.info(f"GPU selection set to device {device_id}: {device_name}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Invalid device_id {device_id}. Valid IDs: {valid_ids}")
|
||||||
|
return False
|
||||||
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""LiteLLM embedder wrapper for CodexLens.
|
||||||
|
|
||||||
|
Provides integration with ccw-litellm's LiteLLMEmbedder for embedding generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMEmbedderWrapper(BaseEmbedder):
|
||||||
|
"""Wrapper for ccw-litellm LiteLLMEmbedder.
|
||||||
|
|
||||||
|
This wrapper adapts the ccw-litellm LiteLLMEmbedder to the CodexLens
|
||||||
|
BaseEmbedder interface, enabling seamless integration with CodexLens
|
||||||
|
semantic search functionality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier for LiteLLM (default: "default")
|
||||||
|
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If ccw-litellm package is not installed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "default", **kwargs) -> None:
|
||||||
|
"""Initialize LiteLLM embedder wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model identifier for LiteLLM (default: "default")
|
||||||
|
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If ccw-litellm package is not installed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from ccw_litellm import LiteLLMEmbedder
|
||||||
|
self._embedder = LiteLLMEmbedder(model=model, **kwargs)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
"""Return embedding dimensions from LiteLLMEmbedder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Dimension of the embedding vectors.
|
||||||
|
"""
|
||||||
|
return self._embedder.dimensions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Return model name from LiteLLMEmbedder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Name or identifier of the underlying model.
|
||||||
|
"""
|
||||||
|
return self._embedder.model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit for the embedding model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Maximum number of tokens that can be embedded at once.
|
||||||
|
Reads from LiteLLM config's max_input_tokens property.
|
||||||
|
"""
|
||||||
|
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
|
||||||
|
if hasattr(self._embedder, 'max_input_tokens'):
|
||||||
|
return self._embedder.max_input_tokens
|
||||||
|
|
||||||
|
# Fallback: infer from model name
|
||||||
|
model_name_lower = self.model_name.lower()
|
||||||
|
|
||||||
|
# Large models (8B or "large" in name)
|
||||||
|
if '8b' in model_name_lower or 'large' in model_name_lower:
|
||||||
|
return 32768
|
||||||
|
|
||||||
|
# OpenAI text-embedding-3-* models
|
||||||
|
if 'text-embedding-3' in model_name_lower:
|
||||||
|
return 8191
|
||||||
|
|
||||||
|
# Default fallback
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
def _sanitize_text(self, text: str) -> str:
|
||||||
|
"""Sanitize text to work around ModelScope API routing bug.
|
||||||
|
|
||||||
|
ModelScope incorrectly routes text starting with lowercase 'import'
|
||||||
|
to an Ollama endpoint, causing failures. This adds a leading space
|
||||||
|
to work around the issue without affecting embedding quality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to sanitize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized text safe for embedding API.
|
||||||
|
"""
|
||||||
|
if text.startswith('import'):
|
||||||
|
return ' ' + text
|
||||||
|
return text
|
||||||
|
|
||||||
|
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||||
|
"""Embed texts to numpy array using LiteLLMEmbedder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts to embed.
|
||||||
|
**kwargs: Additional arguments (ignored for LiteLLM backend).
|
||||||
|
Accepts batch_size for API compatibility with fastembed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
else:
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
# Sanitize texts to avoid ModelScope routing bug
|
||||||
|
texts = [self._sanitize_text(t) for t in texts]
|
||||||
|
|
||||||
|
# LiteLLM handles batching internally, ignore batch_size parameter
|
||||||
|
return self._embedder.embed(texts)
|
||||||
|
|
||||||
|
def embed_single(self, text: str) -> list[float]:
|
||||||
|
"""Generate embedding for a single text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[float]: Embedding vector as a list of floats.
|
||||||
|
"""
|
||||||
|
# Sanitize text before embedding
|
||||||
|
sanitized = self._sanitize_text(text)
|
||||||
|
embedding = self._embedder.embed([sanitized])
|
||||||
|
return embedding[0].tolist()
|
||||||
|
|
||||||
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Reranker backends for second-stage search ranking.
|
||||||
|
|
||||||
|
This subpackage provides a unified interface and factory for different reranking
|
||||||
|
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .base import BaseReranker
|
||||||
|
from .factory import check_reranker_available, get_reranker
|
||||||
|
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
|
||||||
|
from .legacy import CrossEncoderReranker, check_cross_encoder_available
|
||||||
|
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseReranker",
|
||||||
|
"check_reranker_available",
|
||||||
|
"get_reranker",
|
||||||
|
"CrossEncoderReranker",
|
||||||
|
"check_cross_encoder_available",
|
||||||
|
"FastEmbedReranker",
|
||||||
|
"check_fastembed_reranker_available",
|
||||||
|
"ONNXReranker",
|
||||||
|
"check_onnx_reranker_available",
|
||||||
|
]
|
||||||
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
"""API-based reranker using a remote HTTP provider.
|
||||||
|
|
||||||
|
Supported providers:
|
||||||
|
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
|
||||||
|
- Cohere: https://api.cohere.ai/v1/rerank
|
||||||
|
- Jina: https://api.jina.ai/v1/rerank
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Mapping, Sequence
|
||||||
|
|
||||||
|
from .base import BaseReranker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
|
||||||
|
"""Get environment variable with .env file fallback."""
|
||||||
|
# Check os.environ first
|
||||||
|
if key in os.environ:
|
||||||
|
return os.environ[key]
|
||||||
|
|
||||||
|
# Try loading from .env files
|
||||||
|
try:
|
||||||
|
from codexlens.env_config import get_env
|
||||||
|
return get_env(key, workspace_root=workspace_root)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def check_httpx_available() -> tuple[bool, str | None]:
|
||||||
|
try:
|
||||||
|
import httpx # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return False, f"httpx not available: {exc}. Install with: pip install httpx"
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
class APIReranker(BaseReranker):
|
||||||
|
"""Reranker backed by a remote reranking HTTP API."""
|
||||||
|
|
||||||
|
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
|
||||||
|
"siliconflow": {
|
||||||
|
"api_base": "https://api.siliconflow.cn",
|
||||||
|
"endpoint": "/v1/rerank",
|
||||||
|
"default_model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
},
|
||||||
|
"cohere": {
|
||||||
|
"api_base": "https://api.cohere.ai",
|
||||||
|
"endpoint": "/v1/rerank",
|
||||||
|
"default_model": "rerank-english-v3.0",
|
||||||
|
},
|
||||||
|
"jina": {
|
||||||
|
"api_base": "https://api.jina.ai",
|
||||||
|
"endpoint": "/v1/rerank",
|
||||||
|
"default_model": "jina-reranker-v2-base-multilingual",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
provider: str = "siliconflow",
|
||||||
|
model_name: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
max_retries: int = 3,
|
||||||
|
backoff_base_s: float = 0.5,
|
||||||
|
backoff_max_s: float = 8.0,
|
||||||
|
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||||
|
workspace_root: Path | str | None = None,
|
||||||
|
max_input_tokens: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
ok, err = check_httpx_available()
|
||||||
|
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||||
|
|
||||||
|
self.provider = (provider or "").strip().lower()
|
||||||
|
if self.provider not in self._PROVIDER_DEFAULTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown reranker provider: {provider}. "
|
||||||
|
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
defaults = self._PROVIDER_DEFAULTS[self.provider]
|
||||||
|
|
||||||
|
# Load api_base from env with .env fallback
|
||||||
|
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
|
||||||
|
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
|
||||||
|
self.endpoint = defaults["endpoint"]
|
||||||
|
|
||||||
|
# Load model from env with .env fallback
|
||||||
|
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
|
||||||
|
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
|
||||||
|
if not self.model_name:
|
||||||
|
raise ValueError("model_name cannot be blank")
|
||||||
|
|
||||||
|
# Load API key from env with .env fallback
|
||||||
|
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
|
||||||
|
resolved_key = resolved_key.strip()
|
||||||
|
if not resolved_key:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing API key for reranker provider '{self.provider}'. "
|
||||||
|
f"Pass api_key=... or set ${env_api_key}."
|
||||||
|
)
|
||||||
|
self._api_key = resolved_key
|
||||||
|
|
||||||
|
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
|
||||||
|
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
|
||||||
|
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
|
||||||
|
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self._api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if self.provider == "cohere":
|
||||||
|
headers.setdefault("Cohere-Version", "2022-12-06")
|
||||||
|
|
||||||
|
self._client = httpx.Client(
|
||||||
|
base_url=self.api_base,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store max_input_tokens with model-aware defaults
|
||||||
|
if max_input_tokens is not None:
|
||||||
|
self._max_input_tokens = max_input_tokens
|
||||||
|
else:
|
||||||
|
# Infer from model name
|
||||||
|
model_lower = self.model_name.lower()
|
||||||
|
if '8b' in model_lower or 'large' in model_lower:
|
||||||
|
self._max_input_tokens = 32768
|
||||||
|
else:
|
||||||
|
self._max_input_tokens = 8192
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_input_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit for reranking."""
|
||||||
|
return self._max_input_tokens
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
try:
|
||||||
|
self._client.close()
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
return
|
||||||
|
|
||||||
|
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
|
||||||
|
if retry_after_s is not None and retry_after_s > 0:
|
||||||
|
time.sleep(min(float(retry_after_s), self.backoff_max_s))
|
||||||
|
return
|
||||||
|
|
||||||
|
exp = self.backoff_base_s * (2**attempt)
|
||||||
|
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
|
||||||
|
time.sleep(min(self.backoff_max_s, exp + jitter))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
|
||||||
|
value = (headers.get("Retry-After") or "").strip()
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_retry_status(status_code: int) -> bool:
|
||||||
|
return status_code == 429 or 500 <= status_code <= 599
|
||||||
|
|
||||||
|
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
response = self._client.post(self.endpoint, json=dict(payload))
|
||||||
|
except Exception as exc: # httpx is optional at import-time
|
||||||
|
last_exc = exc
|
||||||
|
if attempt < self.max_retries:
|
||||||
|
self._sleep_backoff(attempt)
|
||||||
|
continue
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Rerank request failed for provider '{self.provider}' after "
|
||||||
|
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
status = int(getattr(response, "status_code", 0) or 0)
|
||||||
|
if status >= 400:
|
||||||
|
body_preview = ""
|
||||||
|
try:
|
||||||
|
body_preview = (response.text or "").strip()
|
||||||
|
except Exception:
|
||||||
|
body_preview = ""
|
||||||
|
if len(body_preview) > 300:
|
||||||
|
body_preview = body_preview[:300] + "…"
|
||||||
|
|
||||||
|
if self._should_retry_status(status) and attempt < self.max_retries:
|
||||||
|
retry_after = self._parse_retry_after_seconds(response.headers)
|
||||||
|
logger.warning(
|
||||||
|
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
|
||||||
|
self.api_base,
|
||||||
|
self.endpoint,
|
||||||
|
status,
|
||||||
|
attempt + 1,
|
||||||
|
self.max_retries + 1,
|
||||||
|
)
|
||||||
|
self._sleep_backoff(attempt, retry_after_s=retry_after)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if status in {401, 403}:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
|
||||||
|
"Check your API key."
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
|
||||||
|
f"Response: {body_preview or '<empty>'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Rerank response from provider '{self.provider}' is not valid JSON: "
|
||||||
|
f"{type(exc).__name__}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Rerank response from provider '{self.provider}' must be a JSON object; "
|
||||||
|
f"got {type(data).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
|
||||||
|
if not isinstance(results, list):
|
||||||
|
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
|
||||||
|
|
||||||
|
scores: list[float] = [0.0 for _ in range(expected)]
|
||||||
|
filled = 0
|
||||||
|
|
||||||
|
for item in results:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
idx = item.get("index")
|
||||||
|
score = item.get("relevance_score", item.get("score"))
|
||||||
|
if idx is None or score is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
idx_int = int(idx)
|
||||||
|
score_f = float(score)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if 0 <= idx_int < expected:
|
||||||
|
scores[idx_int] = score_f
|
||||||
|
filled += 1
|
||||||
|
|
||||||
|
if filled != expected:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Rerank response contained {filled}/{expected} scored documents; "
|
||||||
|
"ensure top_n matches the number of documents."
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"query": query,
|
||||||
|
"documents": list(documents),
|
||||||
|
"top_n": len(documents),
|
||||||
|
"return_documents": False,
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _estimate_tokens(self, text: str) -> int:
|
||||||
|
"""Estimate token count using fast heuristic.
|
||||||
|
|
||||||
|
Uses len(text) // 4 as approximation (~4 chars per token for English).
|
||||||
|
Not perfectly accurate for all models/languages but sufficient for
|
||||||
|
batch sizing decisions where exact counts aren't critical.
|
||||||
|
"""
|
||||||
|
return len(text) // 4
|
||||||
|
|
||||||
|
def _create_token_aware_batches(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: Sequence[str],
|
||||||
|
) -> list[list[tuple[int, str]]]:
|
||||||
|
"""Split documents into batches that fit within token limits.
|
||||||
|
|
||||||
|
Uses 90% of max_input_tokens as safety margin.
|
||||||
|
Each batch includes the query tokens overhead.
|
||||||
|
"""
|
||||||
|
max_tokens = int(self._max_input_tokens * 0.9)
|
||||||
|
query_tokens = self._estimate_tokens(query)
|
||||||
|
|
||||||
|
batches: list[list[tuple[int, str]]] = []
|
||||||
|
current_batch: list[tuple[int, str]] = []
|
||||||
|
current_tokens = query_tokens # Start with query overhead
|
||||||
|
|
||||||
|
for idx, doc in enumerate(documents):
|
||||||
|
doc_tokens = self._estimate_tokens(doc)
|
||||||
|
|
||||||
|
# Warn if single document exceeds token limit (will be truncated by API)
|
||||||
|
if doc_tokens > max_tokens - query_tokens:
|
||||||
|
logger.warning(
|
||||||
|
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
|
||||||
|
f"(limit: {max_tokens - query_tokens} after query overhead). "
|
||||||
|
"Document will likely be truncated by the API."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If batch would exceed limit, start new batch
|
||||||
|
if current_tokens + doc_tokens > max_tokens and current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
current_batch = []
|
||||||
|
current_tokens = query_tokens
|
||||||
|
|
||||||
|
current_batch.append((idx, doc))
|
||||||
|
current_tokens += doc_tokens
|
||||||
|
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
|
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create token-aware batches
|
||||||
|
batches = self._create_token_aware_batches(query, documents)
|
||||||
|
|
||||||
|
if len(batches) == 1:
|
||||||
|
# Single batch - original behavior
|
||||||
|
payload = self._build_payload(query=query, documents=documents)
|
||||||
|
data = self._request_json(payload)
|
||||||
|
results = data.get("results")
|
||||||
|
return self._extract_scores_from_results(results, expected=len(documents))
|
||||||
|
|
||||||
|
# Multiple batches - process each and merge results
|
||||||
|
logger.info(
|
||||||
|
f"Splitting {len(documents)} documents into {len(batches)} batches "
|
||||||
|
f"(max_input_tokens: {self._max_input_tokens})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_scores: list[float] = [0.0] * len(documents)
|
||||||
|
|
||||||
|
for batch in batches:
|
||||||
|
batch_docs = [doc for _, doc in batch]
|
||||||
|
payload = self._build_payload(query=query, documents=batch_docs)
|
||||||
|
data = self._request_json(payload)
|
||||||
|
results = data.get("results")
|
||||||
|
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
|
||||||
|
|
||||||
|
# Map scores back to original indices
|
||||||
|
for (orig_idx, _), score in zip(batch, batch_scores):
|
||||||
|
all_scores[orig_idx] = score
|
||||||
|
|
||||||
|
return all_scores
|
||||||
|
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pairs: Sequence[tuple[str, str]],
|
||||||
|
*,
|
||||||
|
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
|
||||||
|
) -> list[float]:
|
||||||
|
if not pairs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
grouped: dict[str, list[tuple[int, str]]] = {}
|
||||||
|
for idx, (query, doc) in enumerate(pairs):
|
||||||
|
grouped.setdefault(str(query), []).append((idx, str(doc)))
|
||||||
|
|
||||||
|
scores: list[float] = [0.0 for _ in range(len(pairs))]
|
||||||
|
|
||||||
|
for query, items in grouped.items():
|
||||||
|
documents = [doc for _, doc in items]
|
||||||
|
query_scores = self._rerank_one_query(query=query, documents=documents)
|
||||||
|
for (orig_idx, _), score in zip(items, query_scores):
|
||||||
|
scores[orig_idx] = float(score)
|
||||||
|
|
||||||
|
return scores
|
||||||
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Base class for rerankers.
|
||||||
|
|
||||||
|
Defines the interface that all rerankers must implement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class BaseReranker(ABC):
|
||||||
|
"""Base class for all rerankers.
|
||||||
|
|
||||||
|
All reranker implementations must inherit from this class and implement
|
||||||
|
the abstract methods to ensure a consistent interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_input_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit for reranking.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Maximum number of tokens that can be processed at once.
|
||||||
|
Default is 8192 if not overridden by implementation.
|
||||||
|
"""
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pairs: Sequence[tuple[str, str]],
|
||||||
|
*,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> list[float]:
|
||||||
|
"""Score (query, doc) pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: Sequence of (query, doc) string pairs to score.
|
||||||
|
batch_size: Batch size for scoring.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (one per pair).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""Factory for creating rerankers.
|
||||||
|
|
||||||
|
Provides a unified interface for instantiating different reranker backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseReranker
|
||||||
|
|
||||||
|
|
||||||
|
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
|
||||||
|
"""Check whether a specific reranker backend can be used.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
|
||||||
|
- "onnx" redirects to "fastembed" for backward compatibility.
|
||||||
|
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
|
||||||
|
- "api" uses a remote reranking HTTP API (requires httpx).
|
||||||
|
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
|
||||||
|
"""
|
||||||
|
backend = (backend or "").strip().lower()
|
||||||
|
|
||||||
|
if backend == "legacy":
|
||||||
|
from .legacy import check_cross_encoder_available
|
||||||
|
|
||||||
|
return check_cross_encoder_available()
|
||||||
|
|
||||||
|
if backend == "fastembed":
|
||||||
|
from .fastembed_reranker import check_fastembed_reranker_available
|
||||||
|
|
||||||
|
return check_fastembed_reranker_available()
|
||||||
|
|
||||||
|
if backend == "onnx":
|
||||||
|
# Redirect to fastembed for backward compatibility
|
||||||
|
from .fastembed_reranker import check_fastembed_reranker_available
|
||||||
|
|
||||||
|
return check_fastembed_reranker_available()
|
||||||
|
|
||||||
|
if backend == "litellm":
|
||||||
|
try:
|
||||||
|
import ccw_litellm # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .litellm_reranker import LiteLLMReranker # noqa: F401
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
return False, f"LiteLLM reranker backend not available: {exc}"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
if backend == "api":
|
||||||
|
from .api_reranker import check_httpx_available
|
||||||
|
|
||||||
|
return check_httpx_available()
|
||||||
|
|
||||||
|
return False, (
|
||||||
|
f"Invalid reranker backend: {backend}. "
|
||||||
|
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_reranker(
|
||||||
|
backend: str = "fastembed",
|
||||||
|
model_name: str | None = None,
|
||||||
|
*,
|
||||||
|
device: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseReranker:
|
||||||
|
"""Factory function to create reranker based on backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Reranker backend to use. Options:
|
||||||
|
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
|
||||||
|
- "onnx": Redirects to fastembed for backward compatibility
|
||||||
|
- "api": HTTP API backend (remote providers)
|
||||||
|
- "litellm": LiteLLM backend (LLM-based, for API mode)
|
||||||
|
- "legacy": sentence-transformers CrossEncoder backend (optional)
|
||||||
|
model_name: Model identifier for model-based backends. Defaults depend on backend:
|
||||||
|
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
|
||||||
|
- onnx: (redirects to fastembed)
|
||||||
|
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
|
||||||
|
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||||
|
- litellm: default
|
||||||
|
device: Optional device string for backends that support it (legacy only).
|
||||||
|
**kwargs: Additional backend-specific arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseReranker: Configured reranker instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If backend is not recognized.
|
||||||
|
ImportError: If required backend dependencies are not installed or backend is unavailable.
|
||||||
|
"""
|
||||||
|
backend = (backend or "").strip().lower()
|
||||||
|
|
||||||
|
if backend == "fastembed":
|
||||||
|
ok, err = check_reranker_available("fastembed")
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
from .fastembed_reranker import FastEmbedReranker
|
||||||
|
|
||||||
|
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||||
|
_ = device # Device selection is managed via fastembed providers.
|
||||||
|
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||||
|
|
||||||
|
if backend == "onnx":
|
||||||
|
# Redirect to fastembed for backward compatibility
|
||||||
|
ok, err = check_reranker_available("fastembed")
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
from .fastembed_reranker import FastEmbedReranker
|
||||||
|
|
||||||
|
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||||
|
_ = device # Device selection is managed via fastembed providers.
|
||||||
|
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||||
|
|
||||||
|
if backend == "legacy":
|
||||||
|
ok, err = check_reranker_available("legacy")
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
from .legacy import CrossEncoderReranker
|
||||||
|
|
||||||
|
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
|
||||||
|
|
||||||
|
if backend == "litellm":
|
||||||
|
ok, err = check_reranker_available("litellm")
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
from .litellm_reranker import LiteLLMReranker
|
||||||
|
|
||||||
|
_ = device # Device selection is not applicable to remote LLM backends.
|
||||||
|
resolved_model_name = (model_name or "").strip() or "default"
|
||||||
|
return LiteLLMReranker(model=resolved_model_name, **kwargs)
|
||||||
|
|
||||||
|
if backend == "api":
|
||||||
|
ok, err = check_reranker_available("api")
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
from .api_reranker import APIReranker
|
||||||
|
|
||||||
|
_ = device # Device selection is not applicable to remote HTTP backends.
|
||||||
|
resolved_model_name = (model_name or "").strip() or None
|
||||||
|
return APIReranker(model_name=resolved_model_name, **kwargs)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
|
||||||
|
)
|
||||||
@@ -0,0 +1,257 @@
|
|||||||
|
"""FastEmbed-based reranker backend.
|
||||||
|
|
||||||
|
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
|
||||||
|
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
|
||||||
|
|
||||||
|
Install:
|
||||||
|
pip install fastembed>=0.4.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
from .base import BaseReranker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
|
||||||
|
"""Check whether fastembed reranker dependencies are available."""
|
||||||
|
try:
|
||||||
|
import fastembed # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"fastembed TextCrossEncoder not available: {exc}. "
|
||||||
|
"Upgrade with: pip install fastembed>=0.4.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
class FastEmbedReranker(BaseReranker):
|
||||||
|
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||||
|
|
||||||
|
# Alternative models supported by fastembed:
|
||||||
|
# - "BAAI/bge-reranker-base"
|
||||||
|
# - "BAAI/bge-reranker-large"
|
||||||
|
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str | None = None,
|
||||||
|
*,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
cache_dir: str | None = None,
|
||||||
|
threads: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize FastEmbed reranker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
|
||||||
|
use_gpu: Whether to use GPU acceleration when available.
|
||||||
|
cache_dir: Optional directory for caching downloaded models.
|
||||||
|
threads: Optional number of threads for ONNX Runtime.
|
||||||
|
"""
|
||||||
|
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||||
|
if not self.model_name:
|
||||||
|
raise ValueError("model_name cannot be blank")
|
||||||
|
|
||||||
|
self.use_gpu = bool(use_gpu)
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.threads = threads
|
||||||
|
|
||||||
|
self._encoder: Any | None = None
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Lazy-load the TextCrossEncoder model."""
|
||||||
|
if self._encoder is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
ok, err = check_fastembed_reranker_available()
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._encoder is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from fastembed.rerank.cross_encoder import TextCrossEncoder
|
||||||
|
|
||||||
|
# Determine providers based on GPU preference
|
||||||
|
providers: list[str] | None = None
|
||||||
|
if self.use_gpu:
|
||||||
|
try:
|
||||||
|
from ..gpu_support import get_optimal_providers
|
||||||
|
|
||||||
|
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
|
||||||
|
except Exception:
|
||||||
|
# Fallback: let fastembed decide
|
||||||
|
providers = None
|
||||||
|
|
||||||
|
# Build initialization kwargs
|
||||||
|
init_kwargs: dict[str, Any] = {}
|
||||||
|
if self.cache_dir:
|
||||||
|
init_kwargs["cache_dir"] = self.cache_dir
|
||||||
|
if self.threads is not None:
|
||||||
|
init_kwargs["threads"] = self.threads
|
||||||
|
if providers:
|
||||||
|
init_kwargs["providers"] = providers
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
|
||||||
|
self.model_name,
|
||||||
|
self.use_gpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._encoder = TextCrossEncoder(
|
||||||
|
model_name=self.model_name,
|
||||||
|
**init_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("FastEmbed reranker model loaded successfully")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sigmoid(x: float) -> float:
|
||||||
|
"""Numerically stable sigmoid function."""
|
||||||
|
if x < -709:
|
||||||
|
return 0.0
|
||||||
|
if x > 709:
|
||||||
|
return 1.0
|
||||||
|
import math
|
||||||
|
return 1.0 / (1.0 + math.exp(-x))
|
||||||
|
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pairs: Sequence[tuple[str, str]],
|
||||||
|
*,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> list[float]:
|
||||||
|
"""Score (query, doc) pairs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pairs: Sequence of (query, doc) string pairs to score.
|
||||||
|
batch_size: Batch size for scoring.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (one per pair), normalized to [0, 1] range.
|
||||||
|
"""
|
||||||
|
if not pairs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if self._encoder is None: # pragma: no cover - defensive
|
||||||
|
return []
|
||||||
|
|
||||||
|
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
|
||||||
|
# For batch scoring of multiple query-doc pairs, we need to process them.
|
||||||
|
# Group by query for efficiency when same query appears multiple times.
|
||||||
|
query_to_docs: dict[str, list[tuple[int, str]]] = {}
|
||||||
|
for idx, (query, doc) in enumerate(pairs):
|
||||||
|
if query not in query_to_docs:
|
||||||
|
query_to_docs[query] = []
|
||||||
|
query_to_docs[query].append((idx, doc))
|
||||||
|
|
||||||
|
# Score each query group
|
||||||
|
scores: list[float] = [0.0] * len(pairs)
|
||||||
|
|
||||||
|
for query, indexed_docs in query_to_docs.items():
|
||||||
|
docs = [doc for _, doc in indexed_docs]
|
||||||
|
indices = [idx for idx, _ in indexed_docs]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||||
|
raw_scores = list(
|
||||||
|
self._encoder.rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map scores back to original positions and normalize with sigmoid
|
||||||
|
for i, raw_score in enumerate(raw_scores):
|
||||||
|
if i < len(indices):
|
||||||
|
original_idx = indices[i]
|
||||||
|
# Normalize score to [0, 1] using stable sigmoid
|
||||||
|
scores[original_idx] = self._sigmoid(float(raw_score))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
|
||||||
|
# Leave scores as 0.0 for failed queries
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: Sequence[str],
|
||||||
|
*,
|
||||||
|
top_k: int | None = None,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> list[tuple[float, str, int]]:
|
||||||
|
"""Rerank documents for a single query.
|
||||||
|
|
||||||
|
This is a convenience method that provides results in ranked order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query string.
|
||||||
|
documents: List of documents to rerank.
|
||||||
|
top_k: Return only top K results. None returns all.
|
||||||
|
batch_size: Batch size for scoring.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (score, document, original_index) tuples, sorted by score descending.
|
||||||
|
"""
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if self._encoder is None: # pragma: no cover - defensive
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||||
|
raw_scores = list(
|
||||||
|
self._encoder.rerank(
|
||||||
|
query=query,
|
||||||
|
documents=list(documents),
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to our format: (normalized_score, document, original_index)
|
||||||
|
ranked = []
|
||||||
|
for idx, raw_score in enumerate(raw_scores):
|
||||||
|
if idx < len(documents):
|
||||||
|
# Normalize score to [0, 1] using stable sigmoid
|
||||||
|
normalized = self._sigmoid(float(raw_score))
|
||||||
|
ranked.append((normalized, documents[idx], idx))
|
||||||
|
|
||||||
|
# Sort by score descending
|
||||||
|
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
if top_k is not None and top_k > 0:
|
||||||
|
ranked = ranked[:top_k]
|
||||||
|
|
||||||
|
return ranked
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
|
||||||
|
return []
|
||||||
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""Legacy sentence-transformers cross-encoder reranker.
|
||||||
|
|
||||||
|
Install with: pip install codexlens[reranker-legacy]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import List, Sequence, Tuple
|
||||||
|
|
||||||
|
from .base import BaseReranker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentence_transformers import CrossEncoder as _CrossEncoder
|
||||||
|
|
||||||
|
CROSS_ENCODER_AVAILABLE = True
|
||||||
|
_import_error: str | None = None
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
_CrossEncoder = None # type: ignore[assignment]
|
||||||
|
CROSS_ENCODER_AVAILABLE = False
|
||||||
|
_import_error = str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
||||||
|
if CROSS_ENCODER_AVAILABLE:
|
||||||
|
return True, None
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
_import_error
|
||||||
|
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncoderReranker(BaseReranker):
|
||||||
|
"""Cross-encoder reranker with lazy model loading."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||||
|
self.model_name = (model_name or "").strip()
|
||||||
|
if not self.model_name:
|
||||||
|
raise ValueError("model_name cannot be blank")
|
||||||
|
|
||||||
|
self.device = (device or "").strip() or None
|
||||||
|
self._model = None
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
if self._model is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
ok, err = check_cross_encoder_available()
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._model is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.device:
|
||||||
|
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
|
||||||
|
else:
|
||||||
|
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pairs: Sequence[Tuple[str, str]],
|
||||||
|
*,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> List[float]:
|
||||||
|
"""Score (query, doc) pairs using the cross-encoder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (one per pair) in the model's native scale (usually logits).
|
||||||
|
"""
|
||||||
|
if not pairs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if self._model is None: # pragma: no cover - defensive
|
||||||
|
return []
|
||||||
|
|
||||||
|
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||||
|
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
||||||
|
return [float(s) for s in scores]
|
||||||
@@ -0,0 +1,214 @@
|
|||||||
|
"""Experimental LiteLLM reranker backend.
|
||||||
|
|
||||||
|
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
|
||||||
|
relevance of a single (query, document) pair per request.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- This backend is experimental and may be slow/expensive compared to local
|
||||||
|
rerankers.
|
||||||
|
- It relies on `ccw-litellm` for a unified LLM API across providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
from .base import BaseReranker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_score_to_unit_interval(score: float) -> float:
|
||||||
|
"""Coerce a numeric score into [0, 1].
|
||||||
|
|
||||||
|
The prompt asks for a float in [0, 1], but some models may respond with 0-10
|
||||||
|
or 0-100 scales. This function attempts a conservative normalization.
|
||||||
|
"""
|
||||||
|
if 0.0 <= score <= 1.0:
|
||||||
|
return score
|
||||||
|
if 0.0 <= score <= 10.0:
|
||||||
|
return score / 10.0
|
||||||
|
if 0.0 <= score <= 100.0:
|
||||||
|
return score / 100.0
|
||||||
|
return max(0.0, min(1.0, score))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_score(text: str) -> float | None:
|
||||||
|
"""Extract a numeric relevance score from an LLM response."""
|
||||||
|
content = (text or "").strip()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Prefer JSON if present.
|
||||||
|
if "{" in content and "}" in content:
|
||||||
|
try:
|
||||||
|
start = content.index("{")
|
||||||
|
end = content.rindex("}") + 1
|
||||||
|
payload = json.loads(content[start:end])
|
||||||
|
if isinstance(payload, dict) and "score" in payload:
|
||||||
|
return float(payload["score"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
match = _NUMBER_RE.search(content)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(match.group(0))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMReranker(BaseReranker):
|
||||||
|
"""Experimental reranker that uses a LiteLLM-compatible model.
|
||||||
|
|
||||||
|
This reranker scores each (query, doc) pair in isolation (single-pair mode)
|
||||||
|
to improve prompt reliability across providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a relevance scoring assistant.\n"
|
||||||
|
"Given a search query and a document snippet, output a single numeric "
|
||||||
|
"relevance score between 0 and 1.\n\n"
|
||||||
|
"Scoring guidance:\n"
|
||||||
|
"- 1.0: The document directly answers the query.\n"
|
||||||
|
"- 0.5: The document is partially relevant.\n"
|
||||||
|
"- 0.0: The document is unrelated.\n\n"
|
||||||
|
"Output requirements:\n"
|
||||||
|
"- Output ONLY the number (e.g., 0.73).\n"
|
||||||
|
"- Do not include any other text."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "default",
|
||||||
|
*,
|
||||||
|
requests_per_minute: float | None = None,
|
||||||
|
min_interval_seconds: float | None = None,
|
||||||
|
default_score: float = 0.0,
|
||||||
|
max_doc_chars: int = 8000,
|
||||||
|
**litellm_kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the reranker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name from ccw-litellm configuration (default: "default").
|
||||||
|
requests_per_minute: Optional rate limit in requests per minute.
|
||||||
|
min_interval_seconds: Optional minimum interval between requests. If set,
|
||||||
|
it takes precedence over requests_per_minute.
|
||||||
|
default_score: Score to use when an API call fails or parsing fails.
|
||||||
|
max_doc_chars: Maximum number of document characters to include in the prompt.
|
||||||
|
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If ccw-litellm is not installed.
|
||||||
|
ValueError: If model is blank.
|
||||||
|
"""
|
||||||
|
self.model_name = (model or "").strip()
|
||||||
|
if not self.model_name:
|
||||||
|
raise ValueError("model cannot be blank")
|
||||||
|
|
||||||
|
self.default_score = float(default_score)
|
||||||
|
|
||||||
|
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
|
||||||
|
|
||||||
|
if min_interval_seconds is not None:
|
||||||
|
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
|
||||||
|
elif requests_per_minute is not None and float(requests_per_minute) > 0:
|
||||||
|
self._min_interval_seconds = 60.0 / float(requests_per_minute)
|
||||||
|
else:
|
||||||
|
self._min_interval_seconds = 0.0
|
||||||
|
|
||||||
|
# Prefer deterministic output by default; allow overrides via kwargs.
|
||||||
|
litellm_kwargs = dict(litellm_kwargs)
|
||||||
|
litellm_kwargs.setdefault("temperature", 0.0)
|
||||||
|
litellm_kwargs.setdefault("max_tokens", 16)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ccw_litellm import ChatMessage, LiteLLMClient
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
raise ImportError(
|
||||||
|
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self._ChatMessage = ChatMessage
|
||||||
|
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
|
||||||
|
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._last_request_at = 0.0
|
||||||
|
|
||||||
|
def _sanitize_text(self, text: str) -> str:
|
||||||
|
# Keep consistent with LiteLLMEmbedderWrapper workaround.
|
||||||
|
if text.startswith("import"):
|
||||||
|
return " " + text
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _rate_limit(self) -> None:
|
||||||
|
if self._min_interval_seconds <= 0:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
now = time.monotonic()
|
||||||
|
elapsed = now - self._last_request_at
|
||||||
|
if elapsed < self._min_interval_seconds:
|
||||||
|
time.sleep(self._min_interval_seconds - elapsed)
|
||||||
|
self._last_request_at = time.monotonic()
|
||||||
|
|
||||||
|
def _build_user_prompt(self, query: str, doc: str) -> str:
|
||||||
|
sanitized_query = self._sanitize_text(query or "")
|
||||||
|
sanitized_doc = self._sanitize_text(doc or "")
|
||||||
|
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
|
||||||
|
sanitized_doc = sanitized_doc[: self.max_doc_chars]
|
||||||
|
|
||||||
|
return (
|
||||||
|
"Query:\n"
|
||||||
|
f"{sanitized_query}\n\n"
|
||||||
|
"Document:\n"
|
||||||
|
f"{sanitized_doc}\n\n"
|
||||||
|
"Return the relevance score (0 to 1) as a single number:"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _score_single_pair(self, query: str, doc: str) -> float:
|
||||||
|
messages = [
|
||||||
|
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
|
||||||
|
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._rate_limit()
|
||||||
|
response = self._client.chat(messages)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("LiteLLM reranker request failed: %s", exc)
|
||||||
|
return self.default_score
|
||||||
|
|
||||||
|
raw = getattr(response, "content", "") or ""
|
||||||
|
score = _extract_score(raw)
|
||||||
|
if score is None:
|
||||||
|
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
|
||||||
|
return self.default_score
|
||||||
|
return _coerce_score_to_unit_interval(float(score))
|
||||||
|
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pairs: Sequence[tuple[str, str]],
|
||||||
|
*,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> list[float]:
|
||||||
|
"""Score (query, doc) pairs with per-pair LLM calls."""
|
||||||
|
if not pairs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||||
|
|
||||||
|
scores: list[float] = []
|
||||||
|
for i in range(0, len(pairs), bs):
|
||||||
|
batch = pairs[i : i + bs]
|
||||||
|
for query, doc in batch:
|
||||||
|
scores.append(self._score_single_pair(query, doc))
|
||||||
|
return scores
|
||||||
@@ -0,0 +1,268 @@
|
|||||||
|
"""Optimum + ONNX Runtime reranker backend.
|
||||||
|
|
||||||
|
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
|
||||||
|
classification models. It is designed to run without requiring PyTorch at
|
||||||
|
runtime by using numpy tensors and ONNX Runtime execution providers.
|
||||||
|
|
||||||
|
Install (CPU):
|
||||||
|
pip install onnxruntime optimum[onnxruntime] transformers
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any, Iterable, Sequence
|
||||||
|
|
||||||
|
from .base import BaseReranker
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_onnx_reranker_available() -> tuple[bool, str | None]:
|
||||||
|
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
|
||||||
|
try:
|
||||||
|
import numpy # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||||
|
|
||||||
|
try:
|
||||||
|
import onnxruntime # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoTokenizer # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional dependency
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
|
||||||
|
for i in range(0, len(items), batch_size):
|
||||||
|
yield items[i : i + batch_size]
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXReranker(BaseReranker):
|
||||||
|
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str | None = None,
|
||||||
|
*,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
providers: list[Any] | None = None,
|
||||||
|
max_length: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||||
|
if not self.model_name:
|
||||||
|
raise ValueError("model_name cannot be blank")
|
||||||
|
|
||||||
|
self.use_gpu = bool(use_gpu)
|
||||||
|
self.providers = providers
|
||||||
|
|
||||||
|
self.max_length = int(max_length) if max_length is not None else None
|
||||||
|
|
||||||
|
self._tokenizer: Any | None = None
|
||||||
|
self._model: Any | None = None
|
||||||
|
self._model_input_names: set[str] | None = None
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
if self._model is not None and self._tokenizer is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
ok, err = check_onnx_reranker_available()
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._model is not None and self._tokenizer is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
from optimum.onnxruntime import ORTModelForSequenceClassification
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
if self.providers is None:
|
||||||
|
from ..gpu_support import get_optimal_providers
|
||||||
|
|
||||||
|
# Include device_id options for DirectML/CUDA selection when available.
|
||||||
|
self.providers = get_optimal_providers(
|
||||||
|
use_gpu=self.use_gpu, with_device_options=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some Optimum versions accept `providers`, others accept a single `provider`.
|
||||||
|
# Prefer passing the full providers list, with a conservative fallback.
|
||||||
|
model_kwargs: dict[str, Any] = {}
|
||||||
|
try:
|
||||||
|
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
|
||||||
|
if "providers" in params:
|
||||||
|
model_kwargs["providers"] = self.providers
|
||||||
|
elif "provider" in params:
|
||||||
|
provider_name = "CPUExecutionProvider"
|
||||||
|
if self.providers:
|
||||||
|
first = self.providers[0]
|
||||||
|
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||||
|
model_kwargs["provider"] = provider_name
|
||||||
|
except Exception:
|
||||||
|
model_kwargs = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._model = ORTModelForSequenceClassification.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
# Fallback for older Optimum versions: retry without provider arguments.
|
||||||
|
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||||
|
|
||||||
|
# Cache model input names to filter tokenizer outputs defensively.
|
||||||
|
input_names: set[str] | None = None
|
||||||
|
for attr in ("input_names", "model_input_names"):
|
||||||
|
names = getattr(self._model, attr, None)
|
||||||
|
if isinstance(names, (list, tuple)) and names:
|
||||||
|
input_names = {str(n) for n in names}
|
||||||
|
break
|
||||||
|
if input_names is None:
|
||||||
|
try:
|
||||||
|
session = getattr(self._model, "model", None)
|
||||||
|
if session is not None and hasattr(session, "get_inputs"):
|
||||||
|
input_names = {i.name for i in session.get_inputs()}
|
||||||
|
except Exception:
|
||||||
|
input_names = None
|
||||||
|
self._model_input_names = input_names
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sigmoid(x: "Any") -> "Any":
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
x = np.clip(x, -50.0, 50.0)
|
||||||
|
return 1.0 / (1.0 + np.exp(-x))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _select_relevance_logit(logits: "Any") -> "Any":
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
arr = np.asarray(logits)
|
||||||
|
if arr.ndim == 0:
|
||||||
|
return arr.reshape(1)
|
||||||
|
if arr.ndim == 1:
|
||||||
|
return arr
|
||||||
|
if arr.ndim >= 2:
|
||||||
|
# Common cases:
|
||||||
|
# - Regression: (batch, 1)
|
||||||
|
# - Binary classification: (batch, 2)
|
||||||
|
if arr.shape[-1] == 1:
|
||||||
|
return arr[..., 0]
|
||||||
|
if arr.shape[-1] == 2:
|
||||||
|
# Convert 2-logit softmax into a single logit via difference.
|
||||||
|
return arr[..., 1] - arr[..., 0]
|
||||||
|
return arr.max(axis=-1)
|
||||||
|
return arr.reshape(-1)
|
||||||
|
|
||||||
|
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
|
||||||
|
|
||||||
|
queries = [q for q, _ in batch]
|
||||||
|
docs = [d for _, d in batch]
|
||||||
|
|
||||||
|
tokenizer_kwargs: dict[str, Any] = {
|
||||||
|
"text": queries,
|
||||||
|
"text_pair": docs,
|
||||||
|
"padding": True,
|
||||||
|
"truncation": True,
|
||||||
|
"return_tensors": "np",
|
||||||
|
}
|
||||||
|
|
||||||
|
max_len = self.max_length
|
||||||
|
if max_len is None:
|
||||||
|
try:
|
||||||
|
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
|
||||||
|
if 0 < model_max < 10_000:
|
||||||
|
max_len = model_max
|
||||||
|
else:
|
||||||
|
max_len = 512
|
||||||
|
except Exception:
|
||||||
|
max_len = 512
|
||||||
|
if max_len is not None and max_len > 0:
|
||||||
|
tokenizer_kwargs["max_length"] = int(max_len)
|
||||||
|
|
||||||
|
encoded = self._tokenizer(**tokenizer_kwargs)
|
||||||
|
inputs = dict(encoded)
|
||||||
|
|
||||||
|
# Some models do not accept token_type_ids; filter to known input names if available.
|
||||||
|
if self._model_input_names:
|
||||||
|
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
|
||||||
|
if self._model is None:
|
||||||
|
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
|
||||||
|
|
||||||
|
outputs = self._model(**inputs)
|
||||||
|
if hasattr(outputs, "logits"):
|
||||||
|
return outputs.logits
|
||||||
|
if isinstance(outputs, dict) and "logits" in outputs:
|
||||||
|
return outputs["logits"]
|
||||||
|
if isinstance(outputs, (list, tuple)) and outputs:
|
||||||
|
return outputs[0]
|
||||||
|
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
|
||||||
|
|
||||||
|
def score_pairs(
|
||||||
|
self,
|
||||||
|
pairs: Sequence[tuple[str, str]],
|
||||||
|
*,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> list[float]:
|
||||||
|
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
|
||||||
|
if not pairs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
|
||||||
|
return []
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||||
|
scores: list[float] = []
|
||||||
|
|
||||||
|
for batch in _iter_batches(list(pairs), bs):
|
||||||
|
inputs = self._tokenize_batch(batch)
|
||||||
|
logits = self._forward_logits(inputs)
|
||||||
|
rel_logits = self._select_relevance_logit(logits)
|
||||||
|
probs = self._sigmoid(rel_logits)
|
||||||
|
probs = np.clip(probs, 0.0, 1.0)
|
||||||
|
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
|
||||||
|
|
||||||
|
if len(scores) != len(pairs):
|
||||||
|
logger.debug(
|
||||||
|
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
|
||||||
|
)
|
||||||
|
return scores[: len(pairs)]
|
||||||
|
|
||||||
|
return scores
|
||||||
434
codex-lens/build/lib/codexlens/semantic/rotational_embedder.py
Normal file
434
codex-lens/build/lib/codexlens/semantic/rotational_embedder.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
"""Rotational embedder for multi-endpoint API load balancing.
|
||||||
|
|
||||||
|
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
|
||||||
|
to maximize throughput while respecting rate limits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import BaseEmbedder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointStatus(Enum):
|
||||||
|
"""Status of an API endpoint."""
|
||||||
|
AVAILABLE = "available"
|
||||||
|
COOLING = "cooling" # Rate limited, temporarily unavailable
|
||||||
|
FAILED = "failed" # Permanent failure (auth error, etc.)
|
||||||
|
|
||||||
|
|
||||||
|
class SelectionStrategy(Enum):
|
||||||
|
"""Strategy for selecting endpoints."""
|
||||||
|
ROUND_ROBIN = "round_robin"
|
||||||
|
LATENCY_AWARE = "latency_aware"
|
||||||
|
WEIGHTED_RANDOM = "weighted_random"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EndpointConfig:
|
||||||
|
"""Configuration for a single API endpoint."""
|
||||||
|
model: str
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
api_base: Optional[str] = None
|
||||||
|
weight: float = 1.0 # Higher weight = more requests
|
||||||
|
max_concurrent: int = 4 # Max concurrent requests to this endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EndpointState:
|
||||||
|
"""Runtime state for an endpoint."""
|
||||||
|
config: EndpointConfig
|
||||||
|
embedder: Any = None # LiteLLMEmbedderWrapper instance
|
||||||
|
|
||||||
|
# Health metrics
|
||||||
|
status: EndpointStatus = EndpointStatus.AVAILABLE
|
||||||
|
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
|
||||||
|
|
||||||
|
# Performance metrics
|
||||||
|
total_requests: int = 0
|
||||||
|
total_failures: int = 0
|
||||||
|
avg_latency_ms: float = 0.0
|
||||||
|
last_latency_ms: float = 0.0
|
||||||
|
|
||||||
|
# Concurrency tracking
|
||||||
|
active_requests: int = 0
|
||||||
|
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""Check if endpoint is available for requests."""
|
||||||
|
if self.status == EndpointStatus.FAILED:
|
||||||
|
return False
|
||||||
|
if self.status == EndpointStatus.COOLING:
|
||||||
|
if time.time() >= self.cooldown_until:
|
||||||
|
self.status = EndpointStatus.AVAILABLE
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def set_cooldown(self, seconds: float) -> None:
|
||||||
|
"""Put endpoint in cooldown state."""
|
||||||
|
self.status = EndpointStatus.COOLING
|
||||||
|
self.cooldown_until = time.time() + seconds
|
||||||
|
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
|
||||||
|
|
||||||
|
def mark_failed(self) -> None:
|
||||||
|
"""Mark endpoint as permanently failed."""
|
||||||
|
self.status = EndpointStatus.FAILED
|
||||||
|
logger.error(f"Endpoint {self.config.model} marked as failed")
|
||||||
|
|
||||||
|
def record_success(self, latency_ms: float) -> None:
|
||||||
|
"""Record successful request."""
|
||||||
|
self.total_requests += 1
|
||||||
|
self.last_latency_ms = latency_ms
|
||||||
|
# Exponential moving average for latency
|
||||||
|
alpha = 0.3
|
||||||
|
if self.avg_latency_ms == 0:
|
||||||
|
self.avg_latency_ms = latency_ms
|
||||||
|
else:
|
||||||
|
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
|
||||||
|
|
||||||
|
def record_failure(self) -> None:
|
||||||
|
"""Record failed request."""
|
||||||
|
self.total_requests += 1
|
||||||
|
self.total_failures += 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def health_score(self) -> float:
|
||||||
|
"""Calculate health score (0-1) based on metrics."""
|
||||||
|
if not self.is_available():
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Base score from success rate
|
||||||
|
if self.total_requests > 0:
|
||||||
|
success_rate = 1 - (self.total_failures / self.total_requests)
|
||||||
|
else:
|
||||||
|
success_rate = 1.0
|
||||||
|
|
||||||
|
# Latency factor (faster = higher score)
|
||||||
|
# Normalize: 100ms = 1.0, 1000ms = 0.1
|
||||||
|
if self.avg_latency_ms > 0:
|
||||||
|
latency_factor = min(1.0, 100 / self.avg_latency_ms)
|
||||||
|
else:
|
||||||
|
latency_factor = 1.0
|
||||||
|
|
||||||
|
# Availability factor (less concurrent = more available)
|
||||||
|
if self.config.max_concurrent > 0:
|
||||||
|
availability = 1 - (self.active_requests / self.config.max_concurrent)
|
||||||
|
else:
|
||||||
|
availability = 1.0
|
||||||
|
|
||||||
|
# Combined score with weights
|
||||||
|
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
|
||||||
|
|
||||||
|
|
||||||
|
class RotationalEmbedder(BaseEmbedder):
|
||||||
|
"""Embedder that load balances across multiple API endpoints.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Intelligent endpoint selection based on latency and health
|
||||||
|
- Automatic failover on rate limits (429) and server errors
|
||||||
|
- Cooldown management to respect rate limits
|
||||||
|
- Thread-safe concurrent request handling
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoints: List of endpoint configurations
|
||||||
|
strategy: Selection strategy (default: latency_aware)
|
||||||
|
default_cooldown: Default cooldown seconds for rate limits (default: 60)
|
||||||
|
max_retries: Maximum retry attempts across all endpoints (default: 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoints: List[EndpointConfig],
|
||||||
|
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
|
||||||
|
default_cooldown: float = 60.0,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> None:
|
||||||
|
if not endpoints:
|
||||||
|
raise ValueError("At least one endpoint must be provided")
|
||||||
|
|
||||||
|
self.strategy = strategy
|
||||||
|
self.default_cooldown = default_cooldown
|
||||||
|
self.max_retries = max_retries
|
||||||
|
|
||||||
|
# Initialize endpoint states
|
||||||
|
self._endpoints: List[EndpointState] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._round_robin_index = 0
|
||||||
|
|
||||||
|
# Create embedder instances for each endpoint
|
||||||
|
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||||
|
|
||||||
|
for config in endpoints:
|
||||||
|
# Build kwargs for LiteLLMEmbedderWrapper
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
if config.api_key:
|
||||||
|
kwargs["api_key"] = config.api_key
|
||||||
|
if config.api_base:
|
||||||
|
kwargs["api_base"] = config.api_base
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
|
||||||
|
state = EndpointState(config=config, embedder=embedder)
|
||||||
|
self._endpoints.append(state)
|
||||||
|
logger.info(f"Initialized endpoint: {config.model}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
|
||||||
|
|
||||||
|
if not self._endpoints:
|
||||||
|
raise ValueError("Failed to initialize any endpoints")
|
||||||
|
|
||||||
|
# Cache embedding properties from first endpoint
|
||||||
|
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
|
||||||
|
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
|
||||||
|
self._max_tokens = self._endpoints[0].embedder.max_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_dim(self) -> int:
|
||||||
|
"""Return embedding dimensions."""
|
||||||
|
return self._embedding_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Return model name."""
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit."""
|
||||||
|
return self._max_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint_count(self) -> int:
|
||||||
|
"""Return number of configured endpoints."""
|
||||||
|
return len(self._endpoints)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_endpoint_count(self) -> int:
|
||||||
|
"""Return number of available endpoints."""
|
||||||
|
return sum(1 for ep in self._endpoints if ep.is_available())
|
||||||
|
|
||||||
|
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get statistics for all endpoints."""
|
||||||
|
stats = []
|
||||||
|
for ep in self._endpoints:
|
||||||
|
stats.append({
|
||||||
|
"model": ep.config.model,
|
||||||
|
"status": ep.status.value,
|
||||||
|
"total_requests": ep.total_requests,
|
||||||
|
"total_failures": ep.total_failures,
|
||||||
|
"avg_latency_ms": round(ep.avg_latency_ms, 2),
|
||||||
|
"health_score": round(ep.health_score, 3),
|
||||||
|
"active_requests": ep.active_requests,
|
||||||
|
})
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def _select_endpoint(self) -> Optional[EndpointState]:
|
||||||
|
"""Select best available endpoint based on strategy."""
|
||||||
|
available = [ep for ep in self._endpoints if ep.is_available()]
|
||||||
|
|
||||||
|
if not available:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.strategy == SelectionStrategy.ROUND_ROBIN:
|
||||||
|
with self._lock:
|
||||||
|
self._round_robin_index = (self._round_robin_index + 1) % len(available)
|
||||||
|
return available[self._round_robin_index]
|
||||||
|
|
||||||
|
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
|
||||||
|
# Sort by health score (descending) and pick top candidate
|
||||||
|
# Add small random factor to prevent thundering herd
|
||||||
|
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
|
||||||
|
scored.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return scored[0][0]
|
||||||
|
|
||||||
|
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
|
||||||
|
# Weighted random selection based on health scores
|
||||||
|
scores = [ep.health_score for ep in available]
|
||||||
|
total = sum(scores)
|
||||||
|
if total == 0:
|
||||||
|
return random.choice(available)
|
||||||
|
|
||||||
|
weights = [s / total for s in scores]
|
||||||
|
return random.choices(available, weights=weights, k=1)[0]
|
||||||
|
|
||||||
|
return available[0]
|
||||||
|
|
||||||
|
def _parse_retry_after(self, error: Exception) -> Optional[float]:
|
||||||
|
"""Extract Retry-After value from error if available."""
|
||||||
|
error_str = str(error)
|
||||||
|
|
||||||
|
# Try to find Retry-After in error message
|
||||||
|
import re
|
||||||
|
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
|
||||||
|
if match:
|
||||||
|
return float(match.group(1))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_rate_limit_error(self, error: Exception) -> bool:
|
||||||
|
"""Check if error is a rate limit error."""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
|
||||||
|
|
||||||
|
def _is_retryable_error(self, error: Exception) -> bool:
|
||||||
|
"""Check if error is retryable (not auth/config error)."""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
# Retryable errors
|
||||||
|
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
|
||||||
|
"timeout", "connection", "service unavailable"]):
|
||||||
|
return True
|
||||||
|
# Non-retryable errors (auth, config)
|
||||||
|
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
|
||||||
|
"unauthorized", "api key"]):
|
||||||
|
return False
|
||||||
|
# Default to retryable for unknown errors
|
||||||
|
return True
|
||||||
|
|
||||||
|
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||||
|
"""Embed texts using load-balanced endpoint selection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Single text or iterable of texts to embed.
|
||||||
|
**kwargs: Additional arguments passed to underlying embedder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If all endpoints fail after retries.
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
else:
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
tried_endpoints: set = set()
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
endpoint = self._select_endpoint()
|
||||||
|
|
||||||
|
if endpoint is None:
|
||||||
|
# All endpoints unavailable, wait for shortest cooldown
|
||||||
|
min_cooldown = min(
|
||||||
|
(ep.cooldown_until - time.time() for ep in self._endpoints
|
||||||
|
if ep.status == EndpointStatus.COOLING),
|
||||||
|
default=self.default_cooldown
|
||||||
|
)
|
||||||
|
if min_cooldown > 0 and attempt < self.max_retries:
|
||||||
|
wait_time = min(min_cooldown, 30) # Cap wait at 30s
|
||||||
|
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
# Track tried endpoints to avoid infinite loops
|
||||||
|
endpoint_id = id(endpoint)
|
||||||
|
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
|
||||||
|
# Already tried all endpoints
|
||||||
|
break
|
||||||
|
tried_endpoints.add(endpoint_id)
|
||||||
|
|
||||||
|
# Acquire slot
|
||||||
|
with endpoint.lock:
|
||||||
|
endpoint.active_requests += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
# Record success
|
||||||
|
endpoint.record_success(latency_ms)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
endpoint.record_failure()
|
||||||
|
|
||||||
|
if self._is_rate_limit_error(e):
|
||||||
|
# Rate limited - set cooldown
|
||||||
|
retry_after = self._parse_retry_after(e) or self.default_cooldown
|
||||||
|
endpoint.set_cooldown(retry_after)
|
||||||
|
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
|
||||||
|
f"cooling for {retry_after}s")
|
||||||
|
|
||||||
|
elif not self._is_retryable_error(e):
|
||||||
|
# Permanent failure (auth error, etc.)
|
||||||
|
endpoint.mark_failed()
|
||||||
|
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Temporary error - short cooldown
|
||||||
|
endpoint.set_cooldown(5.0)
|
||||||
|
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with endpoint.lock:
|
||||||
|
endpoint.active_requests -= 1
|
||||||
|
|
||||||
|
# All retries exhausted
|
||||||
|
available = self.available_endpoint_count
|
||||||
|
raise RuntimeError(
|
||||||
|
f"All embedding attempts failed after {self.max_retries + 1} tries. "
|
||||||
|
f"Available endpoints: {available}/{len(self._endpoints)}. "
|
||||||
|
f"Last error: {last_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_rotational_embedder(
|
||||||
|
endpoints_config: List[Dict[str, Any]],
|
||||||
|
strategy: str = "latency_aware",
|
||||||
|
default_cooldown: float = 60.0,
|
||||||
|
) -> RotationalEmbedder:
|
||||||
|
"""Factory function to create RotationalEmbedder from config dicts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoints_config: List of endpoint configuration dicts with keys:
|
||||||
|
- model: Model identifier (required)
|
||||||
|
- api_key: API key (optional)
|
||||||
|
- api_base: API base URL (optional)
|
||||||
|
- weight: Request weight (optional, default 1.0)
|
||||||
|
- max_concurrent: Max concurrent requests (optional, default 4)
|
||||||
|
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
|
||||||
|
default_cooldown: Default cooldown seconds for rate limits
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured RotationalEmbedder instance
|
||||||
|
|
||||||
|
Example config:
|
||||||
|
endpoints_config = [
|
||||||
|
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||||
|
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
endpoints = []
|
||||||
|
for cfg in endpoints_config:
|
||||||
|
endpoints.append(EndpointConfig(
|
||||||
|
model=cfg["model"],
|
||||||
|
api_key=cfg.get("api_key"),
|
||||||
|
api_base=cfg.get("api_base"),
|
||||||
|
weight=cfg.get("weight", 1.0),
|
||||||
|
max_concurrent=cfg.get("max_concurrent", 4),
|
||||||
|
))
|
||||||
|
|
||||||
|
strategy_enum = SelectionStrategy[strategy.upper()]
|
||||||
|
|
||||||
|
return RotationalEmbedder(
|
||||||
|
endpoints=endpoints,
|
||||||
|
strategy=strategy_enum,
|
||||||
|
default_cooldown=default_cooldown,
|
||||||
|
)
|
||||||
567
codex-lens/build/lib/codexlens/semantic/splade_encoder.py
Normal file
567
codex-lens/build/lib/codexlens/semantic/splade_encoder.py
Normal file
@@ -0,0 +1,567 @@
|
|||||||
|
"""ONNX-optimized SPLADE sparse encoder for code search.
|
||||||
|
|
||||||
|
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
|
||||||
|
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
|
||||||
|
that combine the interpretability of BM25 with neural relevance modeling.
|
||||||
|
|
||||||
|
Install (CPU):
|
||||||
|
pip install onnxruntime optimum[onnxruntime] transformers
|
||||||
|
|
||||||
|
Install (GPU):
|
||||||
|
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_splade_available() -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Check whether SPLADE dependencies are available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (available: bool, error_message: Optional[str])
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import numpy # noqa: F401
|
||||||
|
except ImportError as exc:
|
||||||
|
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||||
|
|
||||||
|
try:
|
||||||
|
import onnxruntime # noqa: F401
|
||||||
|
except ImportError as exc:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
|
||||||
|
except ImportError as exc:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoTokenizer # noqa: F401
|
||||||
|
except ImportError as exc:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
# Global cache for SPLADE encoders (singleton pattern)
|
||||||
|
_splade_cache: Dict[str, "SpladeEncoder"] = {}
|
||||||
|
_cache_lock = threading.RLock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_splade_encoder(
|
||||||
|
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
||||||
|
use_gpu: bool = True,
|
||||||
|
max_length: int = 512,
|
||||||
|
sparsity_threshold: float = 0.01,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> "SpladeEncoder":
|
||||||
|
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||||
|
|
||||||
|
This function provides significant performance improvement by reusing
|
||||||
|
SpladeEncoder instances across multiple searches, avoiding repeated model
|
||||||
|
loading overhead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||||
|
use_gpu: If True, use GPU acceleration when available
|
||||||
|
max_length: Maximum sequence length for tokenization
|
||||||
|
sparsity_threshold: Minimum weight to include in sparse vector
|
||||||
|
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached SpladeEncoder instance for the given configuration
|
||||||
|
"""
|
||||||
|
global _splade_cache
|
||||||
|
|
||||||
|
# Cache key includes all configuration parameters
|
||||||
|
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
|
||||||
|
|
||||||
|
with _cache_lock:
|
||||||
|
encoder = _splade_cache.get(cache_key)
|
||||||
|
if encoder is not None:
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
# Create new encoder and cache it
|
||||||
|
encoder = SpladeEncoder(
|
||||||
|
model_name=model_name,
|
||||||
|
use_gpu=use_gpu,
|
||||||
|
max_length=max_length,
|
||||||
|
sparsity_threshold=sparsity_threshold,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
)
|
||||||
|
# Pre-load model to ensure it's ready
|
||||||
|
encoder._load_model()
|
||||||
|
_splade_cache[cache_key] = encoder
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def clear_splade_cache() -> None:
|
||||||
|
"""Clear the SPLADE encoder cache and release ONNX resources.
|
||||||
|
|
||||||
|
This method ensures proper cleanup of ONNX model resources to prevent
|
||||||
|
memory leaks when encoders are no longer needed.
|
||||||
|
"""
|
||||||
|
global _splade_cache
|
||||||
|
with _cache_lock:
|
||||||
|
# Release ONNX resources before clearing cache
|
||||||
|
for encoder in _splade_cache.values():
|
||||||
|
if encoder._model is not None:
|
||||||
|
del encoder._model
|
||||||
|
encoder._model = None
|
||||||
|
if encoder._tokenizer is not None:
|
||||||
|
del encoder._tokenizer
|
||||||
|
encoder._tokenizer = None
|
||||||
|
_splade_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class SpladeEncoder:
|
||||||
|
"""ONNX-optimized SPLADE sparse encoder.
|
||||||
|
|
||||||
|
Produces sparse vectors with vocabulary-aligned dimensions.
|
||||||
|
Output: Dict[int, float] mapping token_id to weight.
|
||||||
|
|
||||||
|
SPLADE activation formula:
|
||||||
|
splade_repr = log(1 + ReLU(logits)) * attention_mask
|
||||||
|
splade_vec = max_pooling(splade_repr, axis=sequence_length)
|
||||||
|
|
||||||
|
References:
|
||||||
|
- SPLADE: https://arxiv.org/abs/2107.05720
|
||||||
|
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = DEFAULT_MODEL,
|
||||||
|
use_gpu: bool = True,
|
||||||
|
max_length: int = 512,
|
||||||
|
sparsity_threshold: float = 0.01,
|
||||||
|
providers: Optional[List[Any]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize SPLADE encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||||
|
use_gpu: If True, use GPU acceleration when available
|
||||||
|
max_length: Maximum sequence length for tokenization
|
||||||
|
sparsity_threshold: Minimum weight to include in sparse vector
|
||||||
|
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||||
|
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||||
|
"""
|
||||||
|
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||||
|
if not self.model_name:
|
||||||
|
raise ValueError("model_name cannot be blank")
|
||||||
|
|
||||||
|
self.use_gpu = bool(use_gpu)
|
||||||
|
self.max_length = int(max_length) if max_length > 0 else 512
|
||||||
|
self.sparsity_threshold = float(sparsity_threshold)
|
||||||
|
self.providers = providers
|
||||||
|
|
||||||
|
# Setup ONNX cache directory
|
||||||
|
if cache_dir:
|
||||||
|
self._cache_dir = Path(cache_dir)
|
||||||
|
else:
|
||||||
|
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
|
||||||
|
|
||||||
|
self._tokenizer: Any | None = None
|
||||||
|
self._model: Any | None = None
|
||||||
|
self._vocab_size: int | None = None
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _get_local_cache_path(self) -> Path:
|
||||||
|
"""Get local cache path for this model's ONNX files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the local ONNX cache directory for this model
|
||||||
|
"""
|
||||||
|
# Replace / with -- for filesystem-safe naming
|
||||||
|
safe_name = self.model_name.replace("/", "--")
|
||||||
|
return self._cache_dir / safe_name
|
||||||
|
|
||||||
|
def _load_model(self) -> None:
|
||||||
|
"""Lazy load ONNX model and tokenizer.
|
||||||
|
|
||||||
|
First checks local cache for ONNX model, falling back to
|
||||||
|
HuggingFace download and conversion if not cached.
|
||||||
|
"""
|
||||||
|
if self._model is not None and self._tokenizer is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
ok, err = check_splade_available()
|
||||||
|
if not ok:
|
||||||
|
raise ImportError(err)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._model is not None and self._tokenizer is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
from optimum.onnxruntime import ORTModelForMaskedLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
if self.providers is None:
|
||||||
|
from .gpu_support import get_optimal_providers, get_selected_device_id
|
||||||
|
|
||||||
|
# Get providers as pure string list (cache-friendly)
|
||||||
|
# NOTE: with_device_options=False to avoid tuple-based providers
|
||||||
|
# which break optimum's caching mechanism
|
||||||
|
self.providers = get_optimal_providers(
|
||||||
|
use_gpu=self.use_gpu, with_device_options=False
|
||||||
|
)
|
||||||
|
# Get device_id separately for provider_options
|
||||||
|
self._device_id = get_selected_device_id() if self.use_gpu else None
|
||||||
|
|
||||||
|
# Some Optimum versions accept `providers`, others accept a single `provider`
|
||||||
|
# Prefer passing the full providers list, with a conservative fallback
|
||||||
|
model_kwargs: dict[str, Any] = {}
|
||||||
|
try:
|
||||||
|
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
|
||||||
|
if "providers" in params:
|
||||||
|
model_kwargs["providers"] = self.providers
|
||||||
|
# Pass device_id via provider_options for GPU selection
|
||||||
|
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
|
||||||
|
# Build provider_options dict for each GPU provider
|
||||||
|
provider_options = {}
|
||||||
|
for p in self.providers:
|
||||||
|
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
|
||||||
|
provider_options[p] = {"device_id": self._device_id}
|
||||||
|
if provider_options:
|
||||||
|
model_kwargs["provider_options"] = provider_options
|
||||||
|
elif "provider" in params:
|
||||||
|
provider_name = "CPUExecutionProvider"
|
||||||
|
if self.providers:
|
||||||
|
first = self.providers[0]
|
||||||
|
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||||
|
model_kwargs["provider"] = provider_name
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to inspect ORTModel signature: {e}")
|
||||||
|
model_kwargs = {}
|
||||||
|
|
||||||
|
# Check for local ONNX cache first
|
||||||
|
local_cache = self._get_local_cache_path()
|
||||||
|
onnx_model_path = local_cache / "model.onnx"
|
||||||
|
|
||||||
|
if onnx_model_path.exists():
|
||||||
|
# Load from local cache
|
||||||
|
logger.info(f"Loading SPLADE from local cache: {local_cache}")
|
||||||
|
try:
|
||||||
|
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||||
|
str(local_cache),
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
str(local_cache), use_fast=True
|
||||||
|
)
|
||||||
|
self._vocab_size = len(self._tokenizer)
|
||||||
|
logger.info(
|
||||||
|
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load from cache, redownloading: {e}")
|
||||||
|
|
||||||
|
# Download and convert from HuggingFace
|
||||||
|
logger.info(f"Downloading SPLADE model: {self.model_name}")
|
||||||
|
try:
|
||||||
|
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
export=True, # Export to ONNX
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||||
|
except TypeError:
|
||||||
|
# Fallback for older Optimum versions: retry without provider arguments
|
||||||
|
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
export=True,
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Optimum version doesn't support provider parameters. "
|
||||||
|
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||||
|
|
||||||
|
# Cache vocabulary size
|
||||||
|
self._vocab_size = len(self._tokenizer)
|
||||||
|
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||||
|
|
||||||
|
# Save to local cache for future use
|
||||||
|
try:
|
||||||
|
local_cache.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._model.save_pretrained(str(local_cache))
|
||||||
|
self._tokenizer.save_pretrained(str(local_cache))
|
||||||
|
logger.info(f"SPLADE model cached to: {local_cache}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cache SPLADE model: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||||
|
"""Apply SPLADE activation function to model outputs.
|
||||||
|
|
||||||
|
Formula: log(1 + ReLU(logits)) * attention_mask
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Model output logits (batch, seq_len, vocab_size)
|
||||||
|
attention_mask: Attention mask (batch, seq_len)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SPLADE representations (batch, seq_len, vocab_size)
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# ReLU activation
|
||||||
|
relu_logits = np.maximum(0, logits)
|
||||||
|
|
||||||
|
# Log(1 + x) transformation
|
||||||
|
log_relu = np.log1p(relu_logits)
|
||||||
|
|
||||||
|
# Apply attention mask (expand to match vocab dimension)
|
||||||
|
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
|
||||||
|
mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
||||||
|
|
||||||
|
# Element-wise multiplication
|
||||||
|
splade_repr = log_relu * mask_expanded
|
||||||
|
|
||||||
|
return splade_repr
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _max_pooling(splade_repr: Any) -> Any:
|
||||||
|
"""Max pooling over sequence length dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pooled sparse vectors (batch, vocab_size)
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Max pooling over sequence dimension (axis=1)
|
||||||
|
return np.max(splade_repr, axis=1)
|
||||||
|
|
||||||
|
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
|
||||||
|
"""Convert dense vector to sparse dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dense_vec: Dense vector (vocab_size,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sparse dictionary {token_id: weight} with weights above threshold
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Find non-zero indices above threshold
|
||||||
|
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
|
||||||
|
|
||||||
|
# Create sparse dictionary
|
||||||
|
sparse_dict = {
|
||||||
|
int(idx): float(dense_vec[idx])
|
||||||
|
for idx in nonzero_indices
|
||||||
|
}
|
||||||
|
|
||||||
|
return sparse_dict
|
||||||
|
|
||||||
|
def warmup(self, text: str = "warmup query") -> None:
|
||||||
|
"""Warmup the encoder by running a dummy inference.
|
||||||
|
|
||||||
|
First-time model inference includes initialization overhead.
|
||||||
|
Call this method once before the first real search to avoid
|
||||||
|
latency spikes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Dummy text for warmup (default: "warmup query")
|
||||||
|
"""
|
||||||
|
logger.info("Warming up SPLADE encoder...")
|
||||||
|
# Trigger model loading and first inference
|
||||||
|
_ = self.encode_text(text)
|
||||||
|
logger.info("SPLADE encoder warmup complete")
|
||||||
|
|
||||||
|
def encode_text(self, text: str) -> Dict[int, float]:
|
||||||
|
"""Encode text to sparse vector {token_id: weight}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to encode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sparse vector as dictionary mapping token_id to weight
|
||||||
|
"""
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if self._model is None or self._tokenizer is None:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Tokenize input
|
||||||
|
encoded = self._tokenizer(
|
||||||
|
text,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward pass through model
|
||||||
|
outputs = self._model(**encoded)
|
||||||
|
|
||||||
|
# Extract logits
|
||||||
|
if hasattr(outputs, "logits"):
|
||||||
|
logits = outputs.logits
|
||||||
|
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||||
|
logits = outputs["logits"]
|
||||||
|
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||||
|
logits = outputs[0]
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unexpected model output format")
|
||||||
|
|
||||||
|
# Apply SPLADE activation
|
||||||
|
attention_mask = encoded["attention_mask"]
|
||||||
|
splade_repr = self._splade_activation(logits, attention_mask)
|
||||||
|
|
||||||
|
# Max pooling over sequence length
|
||||||
|
splade_vec = self._max_pooling(splade_repr)
|
||||||
|
|
||||||
|
# Convert to sparse dictionary (single item batch)
|
||||||
|
sparse_dict = self._to_sparse_dict(splade_vec[0])
|
||||||
|
|
||||||
|
return sparse_dict
|
||||||
|
|
||||||
|
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
|
||||||
|
"""Batch encode texts to sparse vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of input texts to encode
|
||||||
|
batch_size: Batch size for encoding (default: 32)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sparse vectors as dictionaries
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if self._model is None or self._tokenizer is None:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
results: List[Dict[int, float]] = []
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch_texts = texts[i:i + batch_size]
|
||||||
|
|
||||||
|
# Tokenize batch
|
||||||
|
encoded = self._tokenizer(
|
||||||
|
batch_texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward pass through model
|
||||||
|
outputs = self._model(**encoded)
|
||||||
|
|
||||||
|
# Extract logits
|
||||||
|
if hasattr(outputs, "logits"):
|
||||||
|
logits = outputs.logits
|
||||||
|
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||||
|
logits = outputs["logits"]
|
||||||
|
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||||
|
logits = outputs[0]
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unexpected model output format")
|
||||||
|
|
||||||
|
# Apply SPLADE activation
|
||||||
|
attention_mask = encoded["attention_mask"]
|
||||||
|
splade_repr = self._splade_activation(logits, attention_mask)
|
||||||
|
|
||||||
|
# Max pooling over sequence length
|
||||||
|
splade_vecs = self._max_pooling(splade_repr)
|
||||||
|
|
||||||
|
# Convert each vector to sparse dictionary
|
||||||
|
for vec in splade_vecs:
|
||||||
|
sparse_dict = self._to_sparse_dict(vec)
|
||||||
|
results.append(sparse_dict)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
"""Return vocabulary size (~30k for BERT-based models).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Vocabulary size (number of tokens in tokenizer)
|
||||||
|
"""
|
||||||
|
if self._vocab_size is not None:
|
||||||
|
return self._vocab_size
|
||||||
|
|
||||||
|
self._load_model()
|
||||||
|
return self._vocab_size or 0
|
||||||
|
|
||||||
|
def get_token(self, token_id: int) -> str:
|
||||||
|
"""Convert token_id to string (for debugging).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_id: Token ID to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token string
|
||||||
|
"""
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if self._tokenizer is None:
|
||||||
|
raise RuntimeError("Tokenizer not loaded")
|
||||||
|
|
||||||
|
return self._tokenizer.decode([token_id])
|
||||||
|
|
||||||
|
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
|
||||||
|
"""Get top-k tokens with highest weights from sparse vector.
|
||||||
|
|
||||||
|
Useful for debugging and understanding what the model is focusing on.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_vec: Sparse vector as {token_id: weight}
|
||||||
|
top_k: Number of top tokens to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (token_string, weight) tuples, sorted by weight descending
|
||||||
|
"""
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
if not sparse_vec:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort by weight descending
|
||||||
|
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# Take top-k and convert token_ids to strings
|
||||||
|
top_items = sorted_items[:top_k]
|
||||||
|
|
||||||
|
return [
|
||||||
|
(self.get_token(token_id), weight)
|
||||||
|
for token_id, weight in top_items
|
||||||
|
]
|
||||||
1278
codex-lens/build/lib/codexlens/semantic/vector_store.py
Normal file
1278
codex-lens/build/lib/codexlens/semantic/vector_store.py
Normal file
File diff suppressed because it is too large
Load Diff
32
codex-lens/build/lib/codexlens/storage/__init__.py
Normal file
32
codex-lens/build/lib/codexlens/storage/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Storage backends for CodexLens."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .sqlite_store import SQLiteStore
|
||||||
|
from .path_mapper import PathMapper
|
||||||
|
from .registry import RegistryStore, ProjectInfo, DirMapping
|
||||||
|
from .dir_index import DirIndexStore, SubdirLink, FileEntry
|
||||||
|
from .index_tree import IndexTreeBuilder, BuildResult, DirBuildResult
|
||||||
|
from .vector_meta_store import VectorMetadataStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Legacy (workspace-local)
|
||||||
|
"SQLiteStore",
|
||||||
|
# Path mapping
|
||||||
|
"PathMapper",
|
||||||
|
# Global registry
|
||||||
|
"RegistryStore",
|
||||||
|
"ProjectInfo",
|
||||||
|
"DirMapping",
|
||||||
|
# Directory index
|
||||||
|
"DirIndexStore",
|
||||||
|
"SubdirLink",
|
||||||
|
"FileEntry",
|
||||||
|
# Tree builder
|
||||||
|
"IndexTreeBuilder",
|
||||||
|
"BuildResult",
|
||||||
|
"DirBuildResult",
|
||||||
|
# Vector metadata
|
||||||
|
"VectorMetadataStore",
|
||||||
|
]
|
||||||
|
|
||||||
2358
codex-lens/build/lib/codexlens/storage/dir_index.py
Normal file
2358
codex-lens/build/lib/codexlens/storage/dir_index.py
Normal file
File diff suppressed because it is too large
Load Diff
32
codex-lens/build/lib/codexlens/storage/file_cache.py
Normal file
32
codex-lens/build/lib/codexlens/storage/file_cache.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Simple filesystem cache helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileCache:
|
||||||
|
"""Caches file mtimes for incremental indexing."""
|
||||||
|
|
||||||
|
cache_path: Path
|
||||||
|
|
||||||
|
def load_mtime(self, path: Path) -> Optional[float]:
|
||||||
|
try:
|
||||||
|
key = self._key_for(path)
|
||||||
|
record = (self.cache_path / key).read_text(encoding="utf-8")
|
||||||
|
return float(record)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def store_mtime(self, path: Path, mtime: float) -> None:
|
||||||
|
self.cache_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
key = self._key_for(path)
|
||||||
|
(self.cache_path / key).write_text(str(mtime), encoding="utf-8")
|
||||||
|
|
||||||
|
def _key_for(self, path: Path) -> str:
|
||||||
|
safe = str(path).replace(":", "_").replace("\\", "_").replace("/", "_")
|
||||||
|
return f"{safe}.mtime"
|
||||||
|
|
||||||
398
codex-lens/build/lib/codexlens/storage/global_index.py
Normal file
398
codex-lens/build/lib/codexlens/storage/global_index.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
"""Global cross-directory symbol index for fast lookups.
|
||||||
|
|
||||||
|
Stores symbols for an entire project in a single SQLite database so symbol search
|
||||||
|
does not require traversing every directory _index.db.
|
||||||
|
|
||||||
|
This index is updated incrementally during file indexing (delete+insert per file)
|
||||||
|
to avoid expensive batch rebuilds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from codexlens.entities import Symbol
|
||||||
|
from codexlens.errors import StorageError
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalSymbolIndex:
|
||||||
|
"""Project-wide symbol index with incremental updates."""
|
||||||
|
|
||||||
|
SCHEMA_VERSION = 1
|
||||||
|
DEFAULT_DB_NAME = "_global_symbols.db"
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path, project_id: int) -> None:
|
||||||
|
self.db_path = Path(db_path).resolve()
|
||||||
|
self.project_id = int(project_id)
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._conn: Optional[sqlite3.Connection] = None
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def initialize(self) -> None:
|
||||||
|
"""Create database and schema if not exists."""
|
||||||
|
with self._lock:
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
conn = self._get_connection()
|
||||||
|
|
||||||
|
current_version = self._get_schema_version(conn)
|
||||||
|
if current_version > self.SCHEMA_VERSION:
|
||||||
|
raise StorageError(
|
||||||
|
f"Database schema version {current_version} is newer than "
|
||||||
|
f"supported version {self.SCHEMA_VERSION}. "
|
||||||
|
f"Please update the application or use a compatible database.",
|
||||||
|
db_path=str(self.db_path),
|
||||||
|
operation="initialize",
|
||||||
|
details={
|
||||||
|
"current_version": current_version,
|
||||||
|
"supported_version": self.SCHEMA_VERSION,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_version == 0:
|
||||||
|
self._create_schema(conn)
|
||||||
|
self._set_schema_version(conn, self.SCHEMA_VERSION)
|
||||||
|
elif current_version < self.SCHEMA_VERSION:
|
||||||
|
self._apply_migrations(conn, current_version)
|
||||||
|
self._set_schema_version(conn, self.SCHEMA_VERSION)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close database connection."""
|
||||||
|
with self._lock:
|
||||||
|
if self._conn is not None:
|
||||||
|
try:
|
||||||
|
self._conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
|
def __enter__(self) -> "GlobalSymbolIndex":
|
||||||
|
self.initialize()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def add_symbol(self, symbol: Symbol, file_path: str | Path, index_path: str | Path) -> None:
|
||||||
|
"""Insert a single symbol (idempotent) for incremental updates."""
|
||||||
|
file_path_str = str(Path(file_path).resolve())
|
||||||
|
index_path_str = str(Path(index_path).resolve())
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO global_symbols(
|
||||||
|
project_id, symbol_name, symbol_kind,
|
||||||
|
file_path, start_line, end_line, index_path
|
||||||
|
)
|
||||||
|
VALUES(?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(
|
||||||
|
project_id, symbol_name, symbol_kind,
|
||||||
|
file_path, start_line, end_line
|
||||||
|
)
|
||||||
|
DO UPDATE SET
|
||||||
|
index_path=excluded.index_path
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
self.project_id,
|
||||||
|
symbol.name,
|
||||||
|
symbol.kind,
|
||||||
|
file_path_str,
|
||||||
|
symbol.range[0],
|
||||||
|
symbol.range[1],
|
||||||
|
index_path_str,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.DatabaseError as exc:
|
||||||
|
conn.rollback()
|
||||||
|
raise StorageError(
|
||||||
|
f"Failed to add symbol {symbol.name}: {exc}",
|
||||||
|
db_path=str(self.db_path),
|
||||||
|
operation="add_symbol",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def update_file_symbols(
|
||||||
|
self,
|
||||||
|
file_path: str | Path,
|
||||||
|
symbols: List[Symbol],
|
||||||
|
index_path: str | Path | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Replace all symbols for a file atomically (delete + insert)."""
|
||||||
|
file_path_str = str(Path(file_path).resolve())
|
||||||
|
|
||||||
|
index_path_str: Optional[str]
|
||||||
|
if index_path is not None:
|
||||||
|
index_path_str = str(Path(index_path).resolve())
|
||||||
|
else:
|
||||||
|
index_path_str = self._get_existing_index_path(file_path_str)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
conn.execute("BEGIN")
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM global_symbols WHERE project_id=? AND file_path=?",
|
||||||
|
(self.project_id, file_path_str),
|
||||||
|
)
|
||||||
|
|
||||||
|
if symbols:
|
||||||
|
if not index_path_str:
|
||||||
|
raise StorageError(
|
||||||
|
"index_path is required when inserting symbols for a new file",
|
||||||
|
db_path=str(self.db_path),
|
||||||
|
operation="update_file_symbols",
|
||||||
|
details={"file_path": file_path_str},
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
(
|
||||||
|
self.project_id,
|
||||||
|
s.name,
|
||||||
|
s.kind,
|
||||||
|
file_path_str,
|
||||||
|
s.range[0],
|
||||||
|
s.range[1],
|
||||||
|
index_path_str,
|
||||||
|
)
|
||||||
|
for s in symbols
|
||||||
|
]
|
||||||
|
conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO global_symbols(
|
||||||
|
project_id, symbol_name, symbol_kind,
|
||||||
|
file_path, start_line, end_line, index_path
|
||||||
|
)
|
||||||
|
VALUES(?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(
|
||||||
|
project_id, symbol_name, symbol_kind,
|
||||||
|
file_path, start_line, end_line
|
||||||
|
)
|
||||||
|
DO UPDATE SET
|
||||||
|
index_path=excluded.index_path
|
||||||
|
""",
|
||||||
|
rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.DatabaseError as exc:
|
||||||
|
conn.rollback()
|
||||||
|
raise StorageError(
|
||||||
|
f"Failed to update symbols for {file_path_str}: {exc}",
|
||||||
|
db_path=str(self.db_path),
|
||||||
|
operation="update_file_symbols",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def delete_file_symbols(self, file_path: str | Path) -> int:
|
||||||
|
"""Remove all symbols for a file. Returns number of rows deleted."""
|
||||||
|
file_path_str = str(Path(file_path).resolve())
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
cur = conn.execute(
|
||||||
|
"DELETE FROM global_symbols WHERE project_id=? AND file_path=?",
|
||||||
|
(self.project_id, file_path_str),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return int(cur.rowcount or 0)
|
||||||
|
except sqlite3.DatabaseError as exc:
|
||||||
|
conn.rollback()
|
||||||
|
raise StorageError(
|
||||||
|
f"Failed to delete symbols for {file_path_str}: {exc}",
|
||||||
|
db_path=str(self.db_path),
|
||||||
|
operation="delete_file_symbols",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
kind: Optional[str] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
prefix_mode: bool = True,
|
||||||
|
) -> List[Symbol]:
|
||||||
|
"""Search symbols and return full Symbol objects."""
|
||||||
|
if prefix_mode:
|
||||||
|
pattern = f"{name}%"
|
||||||
|
else:
|
||||||
|
pattern = f"%{name}%"
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
if kind:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||||
|
FROM global_symbols
|
||||||
|
WHERE project_id=? AND symbol_name LIKE ? AND symbol_kind=?
|
||||||
|
ORDER BY symbol_name
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(self.project_id, pattern, kind, limit),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||||
|
FROM global_symbols
|
||||||
|
WHERE project_id=? AND symbol_name LIKE ?
|
||||||
|
ORDER BY symbol_name
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(self.project_id, pattern, limit),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
Symbol(
|
||||||
|
name=row["symbol_name"],
|
||||||
|
kind=row["symbol_kind"],
|
||||||
|
range=(row["start_line"], row["end_line"]),
|
||||||
|
file=row["file_path"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def search_symbols(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
kind: Optional[str] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
prefix_mode: bool = True,
|
||||||
|
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||||
|
"""Search symbols and return only (file_path, (start_line, end_line))."""
|
||||||
|
symbols = self.search(name=name, kind=kind, limit=limit, prefix_mode=prefix_mode)
|
||||||
|
return [(s.file or "", s.range) for s in symbols]
|
||||||
|
|
||||||
|
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
|
||||||
|
"""Get all symbols in a specific file, sorted by start_line.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Full path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Symbol objects sorted by start_line
|
||||||
|
"""
|
||||||
|
file_path_str = str(Path(file_path).resolve())
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||||
|
FROM global_symbols
|
||||||
|
WHERE project_id=? AND file_path=?
|
||||||
|
ORDER BY start_line
|
||||||
|
""",
|
||||||
|
(self.project_id, file_path_str),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
Symbol(
|
||||||
|
name=row["symbol_name"],
|
||||||
|
kind=row["symbol_kind"],
|
||||||
|
range=(row["start_line"], row["end_line"]),
|
||||||
|
file=row["file_path"],
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_existing_index_path(self, file_path_str: str) -> Optional[str]:
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT index_path
|
||||||
|
FROM global_symbols
|
||||||
|
WHERE project_id=? AND file_path=?
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(self.project_id, file_path_str),
|
||||||
|
).fetchone()
|
||||||
|
return str(row["index_path"]) if row else None
|
||||||
|
|
||||||
|
def _get_schema_version(self, conn: sqlite3.Connection) -> int:
|
||||||
|
try:
|
||||||
|
row = conn.execute("PRAGMA user_version").fetchone()
|
||||||
|
return int(row[0]) if row else 0
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _set_schema_version(self, conn: sqlite3.Connection, version: int) -> None:
|
||||||
|
conn.execute(f"PRAGMA user_version = {int(version)}")
|
||||||
|
|
||||||
|
def _apply_migrations(self, conn: sqlite3.Connection, from_version: int) -> None:
|
||||||
|
# No migrations yet (v1).
|
||||||
|
_ = (conn, from_version)
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_connection(self) -> sqlite3.Connection:
|
||||||
|
if self._conn is None:
|
||||||
|
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
self._conn.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
self._conn.execute("PRAGMA mmap_size=30000000000")
|
||||||
|
return self._conn
|
||||||
|
|
||||||
|
def _create_schema(self, conn: sqlite3.Connection) -> None:
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS global_symbols (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
project_id INTEGER NOT NULL,
|
||||||
|
symbol_name TEXT NOT NULL,
|
||||||
|
symbol_kind TEXT NOT NULL,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
start_line INTEGER,
|
||||||
|
end_line INTEGER,
|
||||||
|
index_path TEXT NOT NULL,
|
||||||
|
UNIQUE(
|
||||||
|
project_id, symbol_name, symbol_kind,
|
||||||
|
file_path, start_line, end_line
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Required by optimization spec.
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_global_symbols_name_kind
|
||||||
|
ON global_symbols(symbol_name, symbol_kind)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Used by common queries (project-scoped name lookups).
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_name_kind
|
||||||
|
ON global_symbols(project_id, symbol_name, symbol_kind)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_file
|
||||||
|
ON global_symbols(project_id, file_path)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_index_path
|
||||||
|
ON global_symbols(project_id, index_path)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except sqlite3.DatabaseError as exc:
|
||||||
|
raise StorageError(
|
||||||
|
f"Failed to initialize global symbol schema: {exc}",
|
||||||
|
db_path=str(self.db_path),
|
||||||
|
operation="_create_schema",
|
||||||
|
) from exc
|
||||||
|
|
||||||
1064
codex-lens/build/lib/codexlens/storage/index_tree.py
Normal file
1064
codex-lens/build/lib/codexlens/storage/index_tree.py
Normal file
File diff suppressed because it is too large
Load Diff
136
codex-lens/build/lib/codexlens/storage/merkle_tree.py
Normal file
136
codex-lens/build/lib/codexlens/storage/merkle_tree.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""Merkle tree utilities for change detection.
|
||||||
|
|
||||||
|
This module provides a generic, file-system based Merkle tree implementation
|
||||||
|
that can be used to efficiently diff directory states.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def sha256_bytes(data: bytes) -> str:
|
||||||
|
return hashlib.sha256(data).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def sha256_text(text: str) -> str:
|
||||||
|
return sha256_bytes(text.encode("utf-8", errors="ignore"))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MerkleNode:
|
||||||
|
"""A Merkle node representing either a file (leaf) or directory (internal)."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
rel_path: str
|
||||||
|
hash: str
|
||||||
|
is_dir: bool
|
||||||
|
children: Dict[str, "MerkleNode"] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def iter_files(self) -> Iterable["MerkleNode"]:
|
||||||
|
if not self.is_dir:
|
||||||
|
yield self
|
||||||
|
return
|
||||||
|
for child in self.children.values():
|
||||||
|
yield from child.iter_files()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MerkleTree:
|
||||||
|
"""Merkle tree for a directory snapshot."""
|
||||||
|
|
||||||
|
root: MerkleNode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_from_directory(cls, root_dir: Path) -> "MerkleTree":
|
||||||
|
root_dir = Path(root_dir).resolve()
|
||||||
|
node = cls._build_node(root_dir, base=root_dir)
|
||||||
|
return cls(root=node)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_node(cls, path: Path, *, base: Path) -> MerkleNode:
|
||||||
|
if path.is_file():
|
||||||
|
rel = str(path.relative_to(base)).replace("\\", "/")
|
||||||
|
return MerkleNode(
|
||||||
|
name=path.name,
|
||||||
|
rel_path=rel,
|
||||||
|
hash=sha256_bytes(path.read_bytes()),
|
||||||
|
is_dir=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not path.is_dir():
|
||||||
|
rel = str(path.relative_to(base)).replace("\\", "/")
|
||||||
|
return MerkleNode(name=path.name, rel_path=rel, hash="", is_dir=False)
|
||||||
|
|
||||||
|
children: Dict[str, MerkleNode] = {}
|
||||||
|
for child in sorted(path.iterdir(), key=lambda p: p.name):
|
||||||
|
child_node = cls._build_node(child, base=base)
|
||||||
|
children[child_node.name] = child_node
|
||||||
|
|
||||||
|
items = [
|
||||||
|
f"{'d' if n.is_dir else 'f'}:{name}:{n.hash}"
|
||||||
|
for name, n in sorted(children.items(), key=lambda kv: kv[0])
|
||||||
|
]
|
||||||
|
dir_hash = sha256_text("\n".join(items))
|
||||||
|
|
||||||
|
rel_path = "." if path == base else str(path.relative_to(base)).replace("\\", "/")
|
||||||
|
return MerkleNode(
|
||||||
|
name="." if path == base else path.name,
|
||||||
|
rel_path=rel_path,
|
||||||
|
hash=dir_hash,
|
||||||
|
is_dir=True,
|
||||||
|
children=children,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_changed_files(old: Optional["MerkleTree"], new: Optional["MerkleTree"]) -> List[str]:
|
||||||
|
"""Find changed/added/removed files between two trees.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of relative file paths (POSIX-style separators).
|
||||||
|
"""
|
||||||
|
if old is None and new is None:
|
||||||
|
return []
|
||||||
|
if old is None:
|
||||||
|
return sorted({n.rel_path for n in new.root.iter_files()}) # type: ignore[union-attr]
|
||||||
|
if new is None:
|
||||||
|
return sorted({n.rel_path for n in old.root.iter_files()})
|
||||||
|
|
||||||
|
changed: set[str] = set()
|
||||||
|
|
||||||
|
def walk(old_node: Optional[MerkleNode], new_node: Optional[MerkleNode]) -> None:
|
||||||
|
if old_node is None and new_node is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if old_node is None and new_node is not None:
|
||||||
|
changed.update(n.rel_path for n in new_node.iter_files())
|
||||||
|
return
|
||||||
|
|
||||||
|
if new_node is None and old_node is not None:
|
||||||
|
changed.update(n.rel_path for n in old_node.iter_files())
|
||||||
|
return
|
||||||
|
|
||||||
|
assert old_node is not None and new_node is not None
|
||||||
|
|
||||||
|
if old_node.hash == new_node.hash:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not old_node.is_dir and not new_node.is_dir:
|
||||||
|
changed.add(new_node.rel_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
if old_node.is_dir != new_node.is_dir:
|
||||||
|
changed.update(n.rel_path for n in old_node.iter_files())
|
||||||
|
changed.update(n.rel_path for n in new_node.iter_files())
|
||||||
|
return
|
||||||
|
|
||||||
|
names = set(old_node.children.keys()) | set(new_node.children.keys())
|
||||||
|
for name in names:
|
||||||
|
walk(old_node.children.get(name), new_node.children.get(name))
|
||||||
|
|
||||||
|
walk(old.root, new.root)
|
||||||
|
return sorted(changed)
|
||||||
|
|
||||||
154
codex-lens/build/lib/codexlens/storage/migration_manager.py
Normal file
154
codex-lens/build/lib/codexlens/storage/migration_manager.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""
|
||||||
|
Manages database schema migrations.
|
||||||
|
|
||||||
|
This module provides a framework for applying versioned migrations to the SQLite
|
||||||
|
database. Migrations are discovered from the `codexlens.storage.migrations`
|
||||||
|
package and applied sequentially. The database schema version is tracked using
|
||||||
|
the `user_version` pragma.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import pkgutil
|
||||||
|
from pathlib import Path
|
||||||
|
from sqlite3 import Connection
|
||||||
|
from typing import List, NamedTuple
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(NamedTuple):
|
||||||
|
"""Represents a single database migration."""
|
||||||
|
|
||||||
|
version: int
|
||||||
|
name: str
|
||||||
|
upgrade: callable
|
||||||
|
|
||||||
|
|
||||||
|
def discover_migrations() -> List[Migration]:
|
||||||
|
"""
|
||||||
|
Discovers and returns a sorted list of database migrations.
|
||||||
|
|
||||||
|
Migrations are expected to be in the `codexlens.storage.migrations` package,
|
||||||
|
with filenames in the format `migration_XXX_description.py`, where XXX is
|
||||||
|
the version number. Each migration module must contain an `upgrade` function
|
||||||
|
that takes a `sqlite3.Connection` object as its argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of Migration objects, sorted by version.
|
||||||
|
"""
|
||||||
|
import codexlens.storage.migrations
|
||||||
|
|
||||||
|
migrations = []
|
||||||
|
package_path = Path(codexlens.storage.migrations.__file__).parent
|
||||||
|
|
||||||
|
for _, name, _ in pkgutil.iter_modules([str(package_path)]):
|
||||||
|
if name.startswith("migration_"):
|
||||||
|
try:
|
||||||
|
version = int(name.split("_")[1])
|
||||||
|
module = importlib.import_module(f"codexlens.storage.migrations.{name}")
|
||||||
|
if hasattr(module, "upgrade"):
|
||||||
|
migrations.append(
|
||||||
|
Migration(version=version, name=name, upgrade=module.upgrade)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.warning(f"Migration {name} is missing 'upgrade' function.")
|
||||||
|
except (ValueError, IndexError) as e:
|
||||||
|
log.warning(f"Could not parse migration name {name}: {e}")
|
||||||
|
except ImportError as e:
|
||||||
|
log.warning(f"Could not import migration {name}: {e}")
|
||||||
|
|
||||||
|
migrations.sort(key=lambda m: m.version)
|
||||||
|
return migrations
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationManager:
|
||||||
|
"""
|
||||||
|
Manages the application of migrations to a database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_conn: Connection):
|
||||||
|
"""
|
||||||
|
Initializes the MigrationManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
self.db_conn = db_conn
|
||||||
|
self.migrations = discover_migrations()
|
||||||
|
|
||||||
|
def get_current_version(self) -> int:
|
||||||
|
"""
|
||||||
|
Gets the current version of the database schema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current schema version number.
|
||||||
|
"""
|
||||||
|
return self.db_conn.execute("PRAGMA user_version").fetchone()[0]
|
||||||
|
|
||||||
|
def set_version(self, version: int):
|
||||||
|
"""
|
||||||
|
Sets the database schema version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: The version number to set.
|
||||||
|
"""
|
||||||
|
self.db_conn.execute(f"PRAGMA user_version = {version}")
|
||||||
|
log.info(f"Database schema version set to {version}")
|
||||||
|
|
||||||
|
def apply_migrations(self):
|
||||||
|
"""
|
||||||
|
Applies all pending migrations to the database.
|
||||||
|
|
||||||
|
This method checks the current database version and applies all
|
||||||
|
subsequent migrations in order. Each migration is applied within
|
||||||
|
a transaction, unless the migration manages its own transactions.
|
||||||
|
"""
|
||||||
|
current_version = self.get_current_version()
|
||||||
|
log.info(f"Current database schema version: {current_version}")
|
||||||
|
|
||||||
|
for migration in self.migrations:
|
||||||
|
if migration.version > current_version:
|
||||||
|
log.info(f"Applying migration {migration.version}: {migration.name}...")
|
||||||
|
try:
|
||||||
|
# Check if a transaction is already in progress
|
||||||
|
in_transaction = self.db_conn.in_transaction
|
||||||
|
|
||||||
|
# Only start transaction if not already in one
|
||||||
|
if not in_transaction:
|
||||||
|
self.db_conn.execute("BEGIN")
|
||||||
|
|
||||||
|
migration.upgrade(self.db_conn)
|
||||||
|
self.set_version(migration.version)
|
||||||
|
|
||||||
|
# Only commit if we started the transaction and it's still active
|
||||||
|
if not in_transaction and self.db_conn.in_transaction:
|
||||||
|
self.db_conn.execute("COMMIT")
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Successfully applied migration {migration.version}: {migration.name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to apply migration {migration.version}: {migration.name}. Error: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Try to rollback if transaction is active
|
||||||
|
try:
|
||||||
|
if self.db_conn.in_transaction:
|
||||||
|
self.db_conn.execute("ROLLBACK")
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore rollback errors
|
||||||
|
raise
|
||||||
|
|
||||||
|
latest_migration_version = self.migrations[-1].version if self.migrations else 0
|
||||||
|
if current_version < latest_migration_version:
|
||||||
|
# This case can be hit if migrations were applied but the loop was exited
|
||||||
|
# and set_version was not called for the last one for some reason.
|
||||||
|
# To be safe, we explicitly set the version to the latest known migration.
|
||||||
|
final_version = self.get_current_version()
|
||||||
|
if final_version != latest_migration_version:
|
||||||
|
log.warning(f"Database version ({final_version}) is not the latest migration version ({latest_migration_version}). This may indicate a problem.")
|
||||||
|
|
||||||
|
log.info("All pending migrations applied successfully.")
|
||||||
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
# This file makes the 'migrations' directory a Python package.
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Migration 001: Normalize keywords into separate tables.
|
||||||
|
|
||||||
|
This migration introduces two new tables, `keywords` and `file_keywords`, to
|
||||||
|
store semantic keywords in a normalized fashion. It then migrates the existing
|
||||||
|
keywords from the `semantic_data` JSON blob in the `files` table into these
|
||||||
|
new tables. This is intended to speed up keyword-based searches significantly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection):
|
||||||
|
"""
|
||||||
|
Applies the migration to normalize keywords.
|
||||||
|
|
||||||
|
- Creates `keywords` and `file_keywords` tables.
|
||||||
|
- Creates indexes for efficient querying.
|
||||||
|
- Migrates data from `files.semantic_data` to the new tables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Creating 'keywords' and 'file_keywords' tables...")
|
||||||
|
# Create a table to store unique keywords
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS keywords (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
keyword TEXT NOT NULL UNIQUE
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a join table to link files and keywords (many-to-many)
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS file_keywords (
|
||||||
|
file_id INTEGER NOT NULL,
|
||||||
|
keyword_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (file_id, keyword_id),
|
||||||
|
FOREIGN KEY (file_id) REFERENCES files (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (keyword_id) REFERENCES keywords (id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Creating indexes for new keyword tables...")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON keywords (keyword)")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_file_id ON file_keywords (file_id)")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_keyword_id ON file_keywords (keyword_id)")
|
||||||
|
|
||||||
|
log.info("Migrating existing keywords from 'semantic_metadata' table...")
|
||||||
|
|
||||||
|
# Check if semantic_metadata table exists before querying
|
||||||
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_metadata'")
|
||||||
|
if not cursor.fetchone():
|
||||||
|
log.info("No 'semantic_metadata' table found, skipping data migration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if 'keywords' column exists in semantic_metadata table
|
||||||
|
# (current schema may already use normalized tables without this column)
|
||||||
|
cursor.execute("PRAGMA table_info(semantic_metadata)")
|
||||||
|
columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
if "keywords" not in columns:
|
||||||
|
log.info("No 'keywords' column in semantic_metadata table, skipping data migration.")
|
||||||
|
return
|
||||||
|
|
||||||
|
cursor.execute("SELECT file_id, keywords FROM semantic_metadata WHERE keywords IS NOT NULL AND keywords != ''")
|
||||||
|
|
||||||
|
files_to_migrate = cursor.fetchall()
|
||||||
|
if not files_to_migrate:
|
||||||
|
log.info("No existing files with semantic metadata to migrate.")
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info(f"Found {len(files_to_migrate)} files with semantic metadata to migrate.")
|
||||||
|
|
||||||
|
for file_id, keywords_json in files_to_migrate:
|
||||||
|
if not keywords_json:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
keywords = json.loads(keywords_json)
|
||||||
|
|
||||||
|
if not isinstance(keywords, list):
|
||||||
|
log.warning(f"Keywords for file_id {file_id} is not a list, skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for keyword in keywords:
|
||||||
|
if not isinstance(keyword, str):
|
||||||
|
log.warning(f"Non-string keyword '{keyword}' found for file_id {file_id}, skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
keyword = keyword.strip()
|
||||||
|
if not keyword:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get or create keyword_id
|
||||||
|
cursor.execute("INSERT OR IGNORE INTO keywords (keyword) VALUES (?)", (keyword,))
|
||||||
|
cursor.execute("SELECT id FROM keywords WHERE keyword = ?", (keyword,))
|
||||||
|
keyword_id_result = cursor.fetchone()
|
||||||
|
|
||||||
|
if keyword_id_result:
|
||||||
|
keyword_id = keyword_id_result[0]
|
||||||
|
# Link file to keyword
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT OR IGNORE INTO file_keywords (file_id, keyword_id) VALUES (?, ?)",
|
||||||
|
(file_id, keyword_id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.error(f"Failed to retrieve or create keyword_id for keyword: {keyword}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
log.warning(f"Could not parse keywords for file_id {file_id}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"An unexpected error occurred during migration for file_id {file_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
log.info("Finished migrating keywords.")
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""
|
||||||
|
Migration 002: Add token_count and symbol_type to symbols table.
|
||||||
|
|
||||||
|
This migration adds token counting metadata to symbols for accurate chunk
|
||||||
|
splitting and performance optimization. It also adds symbol_type for better
|
||||||
|
filtering in searches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection):
|
||||||
|
"""
|
||||||
|
Applies the migration to add token metadata to symbols.
|
||||||
|
|
||||||
|
- Adds token_count column to symbols table
|
||||||
|
- Adds symbol_type column to symbols table (for future use)
|
||||||
|
- Creates index on symbol_type for efficient filtering
|
||||||
|
- Backfills existing symbols with NULL token_count (to be calculated lazily)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Adding token_count column to symbols table...")
|
||||||
|
try:
|
||||||
|
cursor.execute("ALTER TABLE symbols ADD COLUMN token_count INTEGER")
|
||||||
|
log.info("Successfully added token_count column.")
|
||||||
|
except Exception as e:
|
||||||
|
# Column might already exist
|
||||||
|
log.warning(f"Could not add token_count column (might already exist): {e}")
|
||||||
|
|
||||||
|
log.info("Adding symbol_type column to symbols table...")
|
||||||
|
try:
|
||||||
|
cursor.execute("ALTER TABLE symbols ADD COLUMN symbol_type TEXT")
|
||||||
|
log.info("Successfully added symbol_type column.")
|
||||||
|
except Exception as e:
|
||||||
|
# Column might already exist
|
||||||
|
log.warning(f"Could not add symbol_type column (might already exist): {e}")
|
||||||
|
|
||||||
|
log.info("Creating index on symbol_type for efficient filtering...")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
|
||||||
|
|
||||||
|
log.info("Migration 002 completed successfully.")
|
||||||
@@ -0,0 +1,232 @@
|
|||||||
|
"""
|
||||||
|
Migration 004: Add dual FTS tables for exact and fuzzy matching.
|
||||||
|
|
||||||
|
This migration introduces two FTS5 tables:
|
||||||
|
- files_fts_exact: Uses unicode61 tokenizer for exact token matching
|
||||||
|
- files_fts_fuzzy: Uses trigram tokenizer (or extended unicode61) for substring/fuzzy matching
|
||||||
|
|
||||||
|
Both tables are synchronized with the files table via triggers for automatic updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
from codexlens.storage.sqlite_utils import check_trigram_support, get_sqlite_version
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection):
|
||||||
|
"""
|
||||||
|
Applies the migration to add dual FTS tables.
|
||||||
|
|
||||||
|
- Drops old files_fts table and triggers
|
||||||
|
- Creates files_fts_exact with unicode61 tokenizer
|
||||||
|
- Creates files_fts_fuzzy with trigram or extended unicode61 tokenizer
|
||||||
|
- Creates synchronized triggers for both tables
|
||||||
|
- Rebuilds FTS indexes from files table
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check trigram support
|
||||||
|
has_trigram = check_trigram_support(db_conn)
|
||||||
|
version = get_sqlite_version(db_conn)
|
||||||
|
log.info(f"SQLite version: {'.'.join(map(str, version))}")
|
||||||
|
|
||||||
|
if has_trigram:
|
||||||
|
log.info("Trigram tokenizer available, using for fuzzy FTS table")
|
||||||
|
fuzzy_tokenizer = "trigram"
|
||||||
|
else:
|
||||||
|
log.warning(
|
||||||
|
f"Trigram tokenizer not available (requires SQLite >= 3.34), "
|
||||||
|
f"using extended unicode61 tokenizer for fuzzy matching"
|
||||||
|
)
|
||||||
|
fuzzy_tokenizer = "unicode61 tokenchars '_-.'"
|
||||||
|
|
||||||
|
# Start transaction
|
||||||
|
cursor.execute("BEGIN TRANSACTION")
|
||||||
|
|
||||||
|
# Check if files table has 'name' column (v2 schema doesn't have it)
|
||||||
|
cursor.execute("PRAGMA table_info(files)")
|
||||||
|
columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
if 'name' not in columns:
|
||||||
|
log.info("Adding 'name' column to files table (v2 schema upgrade)...")
|
||||||
|
# Add name column
|
||||||
|
cursor.execute("ALTER TABLE files ADD COLUMN name TEXT")
|
||||||
|
# Populate name from path (extract filename from last '/')
|
||||||
|
# Use Python to do the extraction since SQLite doesn't have reverse()
|
||||||
|
cursor.execute("SELECT rowid, path FROM files")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
for rowid, path in rows:
|
||||||
|
# Extract filename from path
|
||||||
|
name = path.split('/')[-1] if '/' in path else path
|
||||||
|
cursor.execute("UPDATE files SET name = ? WHERE rowid = ?", (name, rowid))
|
||||||
|
|
||||||
|
# Rename 'path' column to 'full_path' if needed
|
||||||
|
if 'path' in columns and 'full_path' not in columns:
|
||||||
|
log.info("Renaming 'path' to 'full_path' (v2 schema upgrade)...")
|
||||||
|
# Check if indexed_at column exists in v2 schema
|
||||||
|
has_indexed_at = 'indexed_at' in columns
|
||||||
|
has_mtime = 'mtime' in columns
|
||||||
|
|
||||||
|
# SQLite doesn't support RENAME COLUMN before 3.25, so use table recreation
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE files_new (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
full_path TEXT NOT NULL UNIQUE,
|
||||||
|
content TEXT,
|
||||||
|
language TEXT,
|
||||||
|
mtime REAL,
|
||||||
|
indexed_at TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Build INSERT statement based on available columns
|
||||||
|
# Note: v2 schema has no rowid (path is PRIMARY KEY), so use NULL for AUTOINCREMENT
|
||||||
|
if has_indexed_at and has_mtime:
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO files_new (name, full_path, content, language, mtime, indexed_at)
|
||||||
|
SELECT name, path, content, language, mtime, indexed_at FROM files
|
||||||
|
""")
|
||||||
|
elif has_indexed_at:
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO files_new (name, full_path, content, language, indexed_at)
|
||||||
|
SELECT name, path, content, language, indexed_at FROM files
|
||||||
|
""")
|
||||||
|
elif has_mtime:
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO files_new (name, full_path, content, language, mtime)
|
||||||
|
SELECT name, path, content, language, mtime FROM files
|
||||||
|
""")
|
||||||
|
else:
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO files_new (name, full_path, content, language)
|
||||||
|
SELECT name, path, content, language FROM files
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("DROP TABLE files")
|
||||||
|
cursor.execute("ALTER TABLE files_new RENAME TO files")
|
||||||
|
|
||||||
|
log.info("Dropping old FTS triggers and table...")
|
||||||
|
# Drop old triggers
|
||||||
|
cursor.execute("DROP TRIGGER IF EXISTS files_ai")
|
||||||
|
cursor.execute("DROP TRIGGER IF EXISTS files_ad")
|
||||||
|
cursor.execute("DROP TRIGGER IF EXISTS files_au")
|
||||||
|
|
||||||
|
# Drop old FTS table
|
||||||
|
cursor.execute("DROP TABLE IF EXISTS files_fts")
|
||||||
|
|
||||||
|
# Create exact FTS table (unicode61 with underscores/hyphens/dots as token chars)
|
||||||
|
# Note: tokenchars includes '.' to properly tokenize qualified names like PortRole.FLOW
|
||||||
|
log.info("Creating files_fts_exact table with unicode61 tokenizer...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE VIRTUAL TABLE files_fts_exact USING fts5(
|
||||||
|
name, full_path UNINDEXED, content,
|
||||||
|
content='files',
|
||||||
|
content_rowid='id',
|
||||||
|
tokenize="unicode61 tokenchars '_-.'"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create fuzzy FTS table (trigram or extended unicode61)
|
||||||
|
log.info(f"Creating files_fts_fuzzy table with {fuzzy_tokenizer} tokenizer...")
|
||||||
|
cursor.execute(
|
||||||
|
f"""
|
||||||
|
CREATE VIRTUAL TABLE files_fts_fuzzy USING fts5(
|
||||||
|
name, full_path UNINDEXED, content,
|
||||||
|
content='files',
|
||||||
|
content_rowid='id',
|
||||||
|
tokenize="{fuzzy_tokenizer}"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create synchronized triggers for files_fts_exact
|
||||||
|
log.info("Creating triggers for files_fts_exact...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER files_exact_ai AFTER INSERT ON files BEGIN
|
||||||
|
INSERT INTO files_fts_exact(rowid, name, full_path, content)
|
||||||
|
VALUES(new.id, new.name, new.full_path, new.content);
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER files_exact_ad AFTER DELETE ON files BEGIN
|
||||||
|
INSERT INTO files_fts_exact(files_fts_exact, rowid, name, full_path, content)
|
||||||
|
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER files_exact_au AFTER UPDATE ON files BEGIN
|
||||||
|
INSERT INTO files_fts_exact(files_fts_exact, rowid, name, full_path, content)
|
||||||
|
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||||
|
INSERT INTO files_fts_exact(rowid, name, full_path, content)
|
||||||
|
VALUES(new.id, new.name, new.full_path, new.content);
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create synchronized triggers for files_fts_fuzzy
|
||||||
|
log.info("Creating triggers for files_fts_fuzzy...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER files_fuzzy_ai AFTER INSERT ON files BEGIN
|
||||||
|
INSERT INTO files_fts_fuzzy(rowid, name, full_path, content)
|
||||||
|
VALUES(new.id, new.name, new.full_path, new.content);
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER files_fuzzy_ad AFTER DELETE ON files BEGIN
|
||||||
|
INSERT INTO files_fts_fuzzy(files_fts_fuzzy, rowid, name, full_path, content)
|
||||||
|
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER files_fuzzy_au AFTER UPDATE ON files BEGIN
|
||||||
|
INSERT INTO files_fts_fuzzy(files_fts_fuzzy, rowid, name, full_path, content)
|
||||||
|
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||||
|
INSERT INTO files_fts_fuzzy(rowid, name, full_path, content)
|
||||||
|
VALUES(new.id, new.name, new.full_path, new.content);
|
||||||
|
END
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rebuild FTS indexes from files table
|
||||||
|
log.info("Rebuilding FTS indexes from files table...")
|
||||||
|
cursor.execute("INSERT INTO files_fts_exact(files_fts_exact) VALUES('rebuild')")
|
||||||
|
cursor.execute("INSERT INTO files_fts_fuzzy(files_fts_fuzzy) VALUES('rebuild')")
|
||||||
|
|
||||||
|
# Commit transaction
|
||||||
|
cursor.execute("COMMIT")
|
||||||
|
log.info("Migration 004 completed successfully")
|
||||||
|
|
||||||
|
# Vacuum to reclaim space (outside transaction)
|
||||||
|
try:
|
||||||
|
log.info("Running VACUUM to reclaim space...")
|
||||||
|
cursor.execute("VACUUM")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"VACUUM failed (non-critical): {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Migration 004 failed: {e}")
|
||||||
|
try:
|
||||||
|
cursor.execute("ROLLBACK")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
"""
|
||||||
|
Migration 005: Remove unused and redundant database fields.
|
||||||
|
|
||||||
|
This migration removes four problematic fields identified by Gemini analysis:
|
||||||
|
|
||||||
|
1. **semantic_metadata.keywords** (deprecated - replaced by file_keywords table)
|
||||||
|
- Data: Migrated to normalized file_keywords table in migration 001
|
||||||
|
- Impact: Column now redundant, remove to prevent sync issues
|
||||||
|
|
||||||
|
2. **symbols.token_count** (unused - always NULL)
|
||||||
|
- Data: Never populated, always NULL
|
||||||
|
- Impact: No data loss, just removes unused column
|
||||||
|
|
||||||
|
3. **symbols.symbol_type** (redundant - duplicates kind)
|
||||||
|
- Data: Redundant with symbols.kind field
|
||||||
|
- Impact: No data loss, kind field contains same information
|
||||||
|
|
||||||
|
4. **subdirs.direct_files** (unused - never displayed)
|
||||||
|
- Data: Never used in queries or display logic
|
||||||
|
- Impact: No data loss, just removes unused column
|
||||||
|
|
||||||
|
Schema changes use table recreation pattern (SQLite best practice):
|
||||||
|
- Create new table without deprecated columns
|
||||||
|
- Copy data from old table
|
||||||
|
- Drop old table
|
||||||
|
- Rename new table
|
||||||
|
- Recreate indexes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection):
|
||||||
|
"""Remove unused and redundant fields from schema.
|
||||||
|
|
||||||
|
Note: Transaction management is handled by MigrationManager.
|
||||||
|
This migration should NOT start its own transaction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
# Step 1: Remove semantic_metadata.keywords (if column exists)
|
||||||
|
log.info("Checking semantic_metadata.keywords column...")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_metadata'"
|
||||||
|
)
|
||||||
|
if cursor.fetchone():
|
||||||
|
# Check if keywords column exists
|
||||||
|
cursor.execute("PRAGMA table_info(semantic_metadata)")
|
||||||
|
columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
if "keywords" in columns:
|
||||||
|
log.info("Removing semantic_metadata.keywords column...")
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE semantic_metadata_new (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
file_id INTEGER NOT NULL UNIQUE,
|
||||||
|
summary TEXT,
|
||||||
|
purpose TEXT,
|
||||||
|
llm_tool TEXT,
|
||||||
|
generated_at REAL,
|
||||||
|
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO semantic_metadata_new (id, file_id, summary, purpose, llm_tool, generated_at)
|
||||||
|
SELECT id, file_id, summary, purpose, llm_tool, generated_at
|
||||||
|
FROM semantic_metadata
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("DROP TABLE semantic_metadata")
|
||||||
|
cursor.execute("ALTER TABLE semantic_metadata_new RENAME TO semantic_metadata")
|
||||||
|
|
||||||
|
# Recreate index
|
||||||
|
cursor.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_semantic_file ON semantic_metadata(file_id)"
|
||||||
|
)
|
||||||
|
log.info("Removed semantic_metadata.keywords column")
|
||||||
|
else:
|
||||||
|
log.info("semantic_metadata.keywords column does not exist, skipping")
|
||||||
|
else:
|
||||||
|
log.info("semantic_metadata table does not exist, skipping")
|
||||||
|
|
||||||
|
# Step 2: Remove symbols.token_count and symbols.symbol_type (if columns exist)
|
||||||
|
log.info("Checking symbols.token_count and symbols.symbol_type columns...")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='symbols'"
|
||||||
|
)
|
||||||
|
if cursor.fetchone():
|
||||||
|
# Check if token_count or symbol_type columns exist
|
||||||
|
cursor.execute("PRAGMA table_info(symbols)")
|
||||||
|
columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
if "token_count" in columns or "symbol_type" in columns:
|
||||||
|
log.info("Removing symbols.token_count and symbols.symbol_type columns...")
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE symbols_new (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
file_id INTEGER NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
kind TEXT,
|
||||||
|
start_line INTEGER,
|
||||||
|
end_line INTEGER,
|
||||||
|
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO symbols_new (id, file_id, name, kind, start_line, end_line)
|
||||||
|
SELECT id, file_id, name, kind, start_line, end_line
|
||||||
|
FROM symbols
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("DROP TABLE symbols")
|
||||||
|
cursor.execute("ALTER TABLE symbols_new RENAME TO symbols")
|
||||||
|
|
||||||
|
# Recreate indexes
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_id)")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
|
||||||
|
log.info("Removed symbols.token_count and symbols.symbol_type columns")
|
||||||
|
else:
|
||||||
|
log.info("symbols.token_count/symbol_type columns do not exist, skipping")
|
||||||
|
else:
|
||||||
|
log.info("symbols table does not exist, skipping")
|
||||||
|
|
||||||
|
# Step 3: Remove subdirs.direct_files (if column exists)
|
||||||
|
log.info("Checking subdirs.direct_files column...")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='subdirs'"
|
||||||
|
)
|
||||||
|
if cursor.fetchone():
|
||||||
|
# Check if direct_files column exists
|
||||||
|
cursor.execute("PRAGMA table_info(subdirs)")
|
||||||
|
columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
if "direct_files" in columns:
|
||||||
|
log.info("Removing subdirs.direct_files column...")
|
||||||
|
cursor.execute("""
|
||||||
|
CREATE TABLE subdirs_new (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
index_path TEXT NOT NULL,
|
||||||
|
files_count INTEGER DEFAULT 0,
|
||||||
|
last_updated REAL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
INSERT INTO subdirs_new (id, name, index_path, files_count, last_updated)
|
||||||
|
SELECT id, name, index_path, files_count, last_updated
|
||||||
|
FROM subdirs
|
||||||
|
""")
|
||||||
|
|
||||||
|
cursor.execute("DROP TABLE subdirs")
|
||||||
|
cursor.execute("ALTER TABLE subdirs_new RENAME TO subdirs")
|
||||||
|
|
||||||
|
# Recreate index
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_subdirs_name ON subdirs(name)")
|
||||||
|
log.info("Removed subdirs.direct_files column")
|
||||||
|
else:
|
||||||
|
log.info("subdirs.direct_files column does not exist, skipping")
|
||||||
|
else:
|
||||||
|
log.info("subdirs table does not exist, skipping")
|
||||||
|
|
||||||
|
log.info("Migration 005 completed successfully")
|
||||||
|
|
||||||
|
# Vacuum to reclaim space (outside transaction, optional)
|
||||||
|
# Note: VACUUM cannot run inside a transaction, so we skip it here
|
||||||
|
# The caller can run VACUUM separately if desired
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade(db_conn: Connection):
|
||||||
|
"""Restore removed fields (data will be lost for keywords, token_count, symbol_type, direct_files).
|
||||||
|
|
||||||
|
This is a placeholder - true downgrade is not feasible as data is lost.
|
||||||
|
The migration is designed to be one-way since removed fields are unused/redundant.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
log.warning(
|
||||||
|
"Migration 005 downgrade not supported - removed fields are unused/redundant. "
|
||||||
|
"Data cannot be restored."
|
||||||
|
)
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Migration 005 downgrade not supported - this is a one-way migration"
|
||||||
|
)
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
Migration 006: Ensure relationship tables and indexes exist.
|
||||||
|
|
||||||
|
This migration is intentionally idempotent. It creates the `code_relationships`
|
||||||
|
table (used for graph visualization) and its indexes if missing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection) -> None:
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Ensuring code_relationships table exists...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
source_symbol_id INTEGER NOT NULL REFERENCES symbols (id) ON DELETE CASCADE,
|
||||||
|
target_qualified_name TEXT NOT NULL,
|
||||||
|
relationship_type TEXT NOT NULL,
|
||||||
|
source_line INTEGER NOT NULL,
|
||||||
|
target_file TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Ensuring relationship indexes exist...")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
|
||||||
|
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)")
|
||||||
|
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
Migration 007: Add precomputed graph neighbor table for search expansion.
|
||||||
|
|
||||||
|
Adds:
|
||||||
|
- graph_neighbors: cached N-hop neighbors between symbols (keyed by symbol ids)
|
||||||
|
|
||||||
|
This table is derived data (a cache) and is safe to rebuild at any time.
|
||||||
|
The migration is intentionally idempotent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection) -> None:
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Creating graph_neighbors table...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS graph_neighbors (
|
||||||
|
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||||
|
neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||||
|
relationship_depth INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (source_symbol_id, neighbor_symbol_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Creating indexes for graph_neighbors...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth
|
||||||
|
ON graph_neighbors(source_symbol_id, relationship_depth)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor
|
||||||
|
ON graph_neighbors(neighbor_symbol_id)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
"""
|
||||||
|
Migration 008: Add Merkle hash tables for content-based incremental indexing.
|
||||||
|
|
||||||
|
Adds:
|
||||||
|
- merkle_hashes: per-file SHA-256 hashes (keyed by file_id)
|
||||||
|
- merkle_state: directory-level root hash (single row, id=1)
|
||||||
|
|
||||||
|
Backfills merkle_hashes using the existing `files.content` column when available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection) -> None:
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Creating merkle_hashes table...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS merkle_hashes (
|
||||||
|
file_id INTEGER PRIMARY KEY REFERENCES files(id) ON DELETE CASCADE,
|
||||||
|
sha256 TEXT NOT NULL,
|
||||||
|
updated_at REAL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Creating merkle_state table...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS merkle_state (
|
||||||
|
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||||
|
root_hash TEXT,
|
||||||
|
updated_at REAL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backfill file hashes from stored content (best-effort).
|
||||||
|
try:
|
||||||
|
rows = cursor.execute("SELECT id, content FROM files").fetchall()
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("Unable to backfill merkle hashes (files table missing?): %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
inserts: list[tuple[int, str, float]] = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
file_id = int(row[0])
|
||||||
|
content = row[1]
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
digest = hashlib.sha256(str(content).encode("utf-8", errors="ignore")).hexdigest()
|
||||||
|
inserts.append((file_id, digest, now))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not inserts:
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info("Backfilling %d file hashes...", len(inserts))
|
||||||
|
cursor.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO merkle_hashes(file_id, sha256, updated_at)
|
||||||
|
VALUES(?, ?, ?)
|
||||||
|
ON CONFLICT(file_id) DO UPDATE SET
|
||||||
|
sha256=excluded.sha256,
|
||||||
|
updated_at=excluded.updated_at
|
||||||
|
""",
|
||||||
|
inserts,
|
||||||
|
)
|
||||||
|
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Migration 009: Add SPLADE sparse retrieval tables.
|
||||||
|
|
||||||
|
This migration introduces SPLADE (Sparse Lexical AnD Expansion) support:
|
||||||
|
- splade_metadata: Model configuration (model name, vocab size, ONNX path)
|
||||||
|
- splade_posting_list: Inverted index mapping token_id -> (chunk_id, weight)
|
||||||
|
|
||||||
|
The SPLADE tables are designed for efficient sparse vector retrieval:
|
||||||
|
- Token-based lookup for query expansion
|
||||||
|
- Chunk-based deletion for index maintenance
|
||||||
|
- Maintains backward compatibility with existing FTS tables
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection) -> None:
|
||||||
|
"""
|
||||||
|
Adds SPLADE tables for sparse retrieval.
|
||||||
|
|
||||||
|
Creates:
|
||||||
|
- splade_metadata: Stores model configuration and ONNX path
|
||||||
|
- splade_posting_list: Inverted index with token_id -> (chunk_id, weight) mappings
|
||||||
|
- Indexes for efficient token-based and chunk-based lookups
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Creating splade_metadata table...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS splade_metadata (
|
||||||
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||||
|
model_name TEXT NOT NULL,
|
||||||
|
vocab_size INTEGER NOT NULL,
|
||||||
|
onnx_path TEXT,
|
||||||
|
created_at REAL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Creating splade_posting_list table...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
||||||
|
token_id INTEGER NOT NULL,
|
||||||
|
chunk_id INTEGER NOT NULL,
|
||||||
|
weight REAL NOT NULL,
|
||||||
|
PRIMARY KEY (token_id, chunk_id),
|
||||||
|
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Creating indexes for splade_posting_list...")
|
||||||
|
# Index for efficient chunk-based lookups (deletion, updates)
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
||||||
|
ON splade_posting_list(chunk_id)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index for efficient term-based retrieval
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
||||||
|
ON splade_posting_list(token_id)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Migration 009 completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade(db_conn: Connection) -> None:
|
||||||
|
"""
|
||||||
|
Removes SPLADE tables.
|
||||||
|
|
||||||
|
Drops:
|
||||||
|
- splade_posting_list (and associated indexes)
|
||||||
|
- splade_metadata
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Dropping SPLADE indexes...")
|
||||||
|
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_chunk")
|
||||||
|
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_token")
|
||||||
|
|
||||||
|
log.info("Dropping splade_posting_list table...")
|
||||||
|
cursor.execute("DROP TABLE IF EXISTS splade_posting_list")
|
||||||
|
|
||||||
|
log.info("Dropping splade_metadata table...")
|
||||||
|
cursor.execute("DROP TABLE IF EXISTS splade_metadata")
|
||||||
|
|
||||||
|
log.info("Migration 009 downgrade completed successfully")
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
Migration 010: Add multi-vector storage support for cascade retrieval.
|
||||||
|
|
||||||
|
This migration introduces the chunks table with multi-vector support:
|
||||||
|
- chunks: Stores code chunks with multiple embedding types
|
||||||
|
- embedding: Original embedding for backward compatibility
|
||||||
|
- embedding_binary: 256-dim binary vector for coarse ranking (fast)
|
||||||
|
- embedding_dense: 2048-dim dense vector for fine ranking (precise)
|
||||||
|
|
||||||
|
The multi-vector architecture enables cascade retrieval:
|
||||||
|
1. First stage: Fast binary vector search for candidate retrieval
|
||||||
|
2. Second stage: Dense vector reranking for precision
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from sqlite3 import Connection
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(db_conn: Connection) -> None:
|
||||||
|
"""
|
||||||
|
Adds chunks table with multi-vector embedding columns.
|
||||||
|
|
||||||
|
Creates:
|
||||||
|
- chunks: Table for storing code chunks with multiple embedding types
|
||||||
|
- idx_chunks_file_path: Index for efficient file-based lookups
|
||||||
|
|
||||||
|
Also migrates existing chunks tables by adding new columns if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
# Check if chunks table already exists
|
||||||
|
table_exists = cursor.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'"
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if table_exists:
|
||||||
|
# Migrate existing table - add new columns if missing
|
||||||
|
log.info("chunks table exists, checking for missing columns...")
|
||||||
|
|
||||||
|
col_info = cursor.execute("PRAGMA table_info(chunks)").fetchall()
|
||||||
|
existing_columns = {row[1] for row in col_info}
|
||||||
|
|
||||||
|
if "embedding_binary" not in existing_columns:
|
||||||
|
log.info("Adding embedding_binary column to chunks table...")
|
||||||
|
cursor.execute(
|
||||||
|
"ALTER TABLE chunks ADD COLUMN embedding_binary BLOB"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "embedding_dense" not in existing_columns:
|
||||||
|
log.info("Adding embedding_dense column to chunks table...")
|
||||||
|
cursor.execute(
|
||||||
|
"ALTER TABLE chunks ADD COLUMN embedding_dense BLOB"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create new table with all columns
|
||||||
|
log.info("Creating chunks table with multi-vector support...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE chunks (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding BLOB,
|
||||||
|
embedding_binary BLOB,
|
||||||
|
embedding_dense BLOB,
|
||||||
|
metadata TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create index for file-based lookups
|
||||||
|
log.info("Creating index for chunks table...")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_file_path
|
||||||
|
ON chunks(file_path)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Migration 010 completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade(db_conn: Connection) -> None:
|
||||||
|
"""
|
||||||
|
Removes multi-vector columns from chunks table.
|
||||||
|
|
||||||
|
Note: This does not drop the chunks table entirely to preserve data.
|
||||||
|
Only the new columns added by this migration are removed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_conn: The SQLite database connection.
|
||||||
|
"""
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
|
||||||
|
log.info("Removing multi-vector columns from chunks table...")
|
||||||
|
|
||||||
|
# SQLite doesn't support DROP COLUMN directly in older versions
|
||||||
|
# We need to recreate the table without the columns
|
||||||
|
|
||||||
|
# Check if chunks table exists
|
||||||
|
table_exists = cursor.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks'"
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if not table_exists:
|
||||||
|
log.info("chunks table does not exist, nothing to downgrade")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the columns exist before trying to remove them
|
||||||
|
col_info = cursor.execute("PRAGMA table_info(chunks)").fetchall()
|
||||||
|
existing_columns = {row[1] for row in col_info}
|
||||||
|
|
||||||
|
needs_migration = (
|
||||||
|
"embedding_binary" in existing_columns or
|
||||||
|
"embedding_dense" in existing_columns
|
||||||
|
)
|
||||||
|
|
||||||
|
if not needs_migration:
|
||||||
|
log.info("Multi-vector columns not present, nothing to remove")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Recreate table without the new columns
|
||||||
|
log.info("Recreating chunks table without multi-vector columns...")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE chunks_backup (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
file_path TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding BLOB,
|
||||||
|
metadata TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO chunks_backup (id, file_path, content, embedding, metadata, created_at)
|
||||||
|
SELECT id, file_path, content, embedding, metadata, created_at FROM chunks
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute("DROP TABLE chunks")
|
||||||
|
cursor.execute("ALTER TABLE chunks_backup RENAME TO chunks")
|
||||||
|
|
||||||
|
# Recreate index
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_file_path
|
||||||
|
ON chunks(file_path)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Migration 010 downgrade completed successfully")
|
||||||
300
codex-lens/build/lib/codexlens/storage/path_mapper.py
Normal file
300
codex-lens/build/lib/codexlens/storage/path_mapper.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""Path mapping utilities for source paths and index paths.
|
||||||
|
|
||||||
|
This module provides bidirectional mapping between source code directories
|
||||||
|
and their corresponding index storage locations.
|
||||||
|
|
||||||
|
Storage Structure:
|
||||||
|
~/.codexlens/
|
||||||
|
├── registry.db # Global mapping table
|
||||||
|
└── indexes/
|
||||||
|
└── D/
|
||||||
|
└── Claude_dms3/
|
||||||
|
├── _index.db # Root directory index
|
||||||
|
└── src/
|
||||||
|
└── _index.db # src/ directory index
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def _get_configured_index_root() -> Path:
|
||||||
|
"""Get the index root from environment or config file.
|
||||||
|
|
||||||
|
Priority order:
|
||||||
|
1. CODEXLENS_INDEX_DIR environment variable
|
||||||
|
2. index_dir from ~/.codexlens/config.json
|
||||||
|
3. Default: ~/.codexlens/indexes
|
||||||
|
"""
|
||||||
|
env_override = os.getenv("CODEXLENS_INDEX_DIR")
|
||||||
|
if env_override:
|
||||||
|
return Path(env_override).expanduser().resolve()
|
||||||
|
|
||||||
|
config_file = Path.home() / ".codexlens" / "config.json"
|
||||||
|
if config_file.exists():
|
||||||
|
try:
|
||||||
|
cfg = json.loads(config_file.read_text(encoding="utf-8"))
|
||||||
|
if "index_dir" in cfg:
|
||||||
|
return Path(cfg["index_dir"]).expanduser().resolve()
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return Path.home() / ".codexlens" / "indexes"
|
||||||
|
|
||||||
|
|
||||||
|
class PathMapper:
|
||||||
|
"""Bidirectional mapping tool for source paths ↔ index paths.
|
||||||
|
|
||||||
|
Handles cross-platform path normalization and conversion between
|
||||||
|
source code directories and their index storage locations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
DEFAULT_INDEX_ROOT: Default root directory for all indexes
|
||||||
|
INDEX_DB_NAME: Standard name for index database files
|
||||||
|
index_root: Configured index root directory
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_INDEX_ROOT = _get_configured_index_root()
|
||||||
|
INDEX_DB_NAME = "_index.db"
|
||||||
|
|
||||||
|
def __init__(self, index_root: Optional[Path] = None):
|
||||||
|
"""Initialize PathMapper with optional custom index root.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_root: Custom index root directory. If None, uses DEFAULT_INDEX_ROOT.
|
||||||
|
"""
|
||||||
|
self.index_root = (index_root or self.DEFAULT_INDEX_ROOT).resolve()
|
||||||
|
|
||||||
|
def source_to_index_dir(self, source_path: Path) -> Path:
|
||||||
|
"""Convert source directory to its index directory path.
|
||||||
|
|
||||||
|
Maps a source code directory to where its index data should be stored.
|
||||||
|
The mapping preserves the directory structure but normalizes paths
|
||||||
|
for cross-platform compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory path to map
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Index directory path under index_root
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mapper = PathMapper()
|
||||||
|
>>> mapper.source_to_index_dir(Path("D:/Claude_dms3/src"))
|
||||||
|
PosixPath('/home/user/.codexlens/indexes/D/Claude_dms3/src')
|
||||||
|
|
||||||
|
>>> mapper.source_to_index_dir(Path("/home/user/project"))
|
||||||
|
PosixPath('/home/user/.codexlens/indexes/home/user/project')
|
||||||
|
"""
|
||||||
|
source_path = source_path.resolve()
|
||||||
|
normalized = self.normalize_path(source_path)
|
||||||
|
return self.index_root / normalized
|
||||||
|
|
||||||
|
def source_to_index_db(self, source_path: Path) -> Path:
|
||||||
|
"""Convert source directory to its index database file path.
|
||||||
|
|
||||||
|
Maps a source directory to the full path of its index database file,
|
||||||
|
including the standard INDEX_DB_NAME.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory path to map
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the index database file
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mapper = PathMapper()
|
||||||
|
>>> mapper.source_to_index_db(Path("D:/Claude_dms3/src"))
|
||||||
|
PosixPath('/home/user/.codexlens/indexes/D/Claude_dms3/src/_index.db')
|
||||||
|
"""
|
||||||
|
index_dir = self.source_to_index_dir(source_path)
|
||||||
|
return index_dir / self.INDEX_DB_NAME
|
||||||
|
|
||||||
|
def index_to_source(self, index_path: Path) -> Path:
|
||||||
|
"""Convert index path back to original source path.
|
||||||
|
|
||||||
|
Performs reverse mapping from an index storage location to the
|
||||||
|
original source directory. Handles both directory paths and
|
||||||
|
database file paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Index directory or database file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Original source directory path
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If index_path is not under index_root
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mapper = PathMapper()
|
||||||
|
>>> mapper.index_to_source(
|
||||||
|
... Path("~/.codexlens/indexes/D/Claude_dms3/src/_index.db")
|
||||||
|
... )
|
||||||
|
WindowsPath('D:/Claude_dms3/src')
|
||||||
|
|
||||||
|
>>> mapper.index_to_source(
|
||||||
|
... Path("~/.codexlens/indexes/D/Claude_dms3/src")
|
||||||
|
... )
|
||||||
|
WindowsPath('D:/Claude_dms3/src')
|
||||||
|
"""
|
||||||
|
index_path = index_path.resolve()
|
||||||
|
|
||||||
|
# Remove _index.db if present
|
||||||
|
if index_path.name == self.INDEX_DB_NAME:
|
||||||
|
index_path = index_path.parent
|
||||||
|
|
||||||
|
# Verify path is under index_root
|
||||||
|
try:
|
||||||
|
relative = index_path.relative_to(self.index_root)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Index path {index_path} is not under index root {self.index_root}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert normalized path back to source path
|
||||||
|
normalized_str = str(relative).replace("\\", "/")
|
||||||
|
return self.denormalize_path(normalized_str)
|
||||||
|
|
||||||
|
def get_project_root(self, source_path: Path) -> Path:
|
||||||
|
"""Find the project root directory (topmost indexed directory).
|
||||||
|
|
||||||
|
Walks up the directory tree to find the highest-level directory
|
||||||
|
that has an index database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory to start from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Project root directory path. Returns source_path itself if
|
||||||
|
no parent index is found.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mapper = PathMapper()
|
||||||
|
>>> mapper.get_project_root(Path("D:/Claude_dms3/src/codexlens"))
|
||||||
|
WindowsPath('D:/Claude_dms3')
|
||||||
|
"""
|
||||||
|
source_path = source_path.resolve()
|
||||||
|
current = source_path
|
||||||
|
project_root = source_path
|
||||||
|
|
||||||
|
# Walk up the tree
|
||||||
|
while current.parent != current: # Stop at filesystem root
|
||||||
|
parent_index_db = self.source_to_index_db(current.parent)
|
||||||
|
if parent_index_db.exists():
|
||||||
|
project_root = current.parent
|
||||||
|
current = current.parent
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return project_root
|
||||||
|
|
||||||
|
def get_relative_depth(self, source_path: Path, project_root: Path) -> int:
|
||||||
|
"""Calculate directory depth relative to project root.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Target directory path
|
||||||
|
project_root: Project root directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of directory levels from project_root to source_path
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If source_path is not under project_root
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mapper = PathMapper()
|
||||||
|
>>> mapper.get_relative_depth(
|
||||||
|
... Path("D:/Claude_dms3/src/codexlens"),
|
||||||
|
... Path("D:/Claude_dms3")
|
||||||
|
... )
|
||||||
|
2
|
||||||
|
"""
|
||||||
|
source_path = source_path.resolve()
|
||||||
|
project_root = project_root.resolve()
|
||||||
|
|
||||||
|
try:
|
||||||
|
relative = source_path.relative_to(project_root)
|
||||||
|
# Count path components
|
||||||
|
return len(relative.parts)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Source path {source_path} is not under project root {project_root}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def normalize_path(self, path: Path) -> str:
|
||||||
|
"""Normalize path to cross-platform storage format.
|
||||||
|
|
||||||
|
Converts OS-specific paths to a standardized format for storage:
|
||||||
|
- Windows: Removes drive colons (D: → D)
|
||||||
|
- Unix: Removes leading slash
|
||||||
|
- Uses forward slashes throughout
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized path string
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mapper = PathMapper()
|
||||||
|
>>> mapper.normalize_path(Path("D:/path/to/dir"))
|
||||||
|
'D/path/to/dir'
|
||||||
|
|
||||||
|
>>> mapper.normalize_path(Path("/home/user/path"))
|
||||||
|
'home/user/path'
|
||||||
|
"""
|
||||||
|
path = path.resolve()
|
||||||
|
path_str = str(path)
|
||||||
|
|
||||||
|
# Handle Windows paths with drive letters
|
||||||
|
if platform.system() == "Windows" and len(path.parts) > 0:
|
||||||
|
# Convert D:\path\to\dir → D/path/to/dir
|
||||||
|
drive = path.parts[0].replace(":", "") # D: → D
|
||||||
|
rest = Path(*path.parts[1:]) if len(path.parts) > 1 else Path()
|
||||||
|
normalized = f"{drive}/{rest}".replace("\\", "/")
|
||||||
|
return normalized.rstrip("/")
|
||||||
|
|
||||||
|
# Handle Unix paths
|
||||||
|
# /home/user/path → home/user/path
|
||||||
|
return path_str.lstrip("/").replace("\\", "/")
|
||||||
|
|
||||||
|
def denormalize_path(self, normalized: str) -> Path:
|
||||||
|
"""Convert normalized path back to OS-specific path.
|
||||||
|
|
||||||
|
Reverses the normalization process to restore OS-native path format:
|
||||||
|
- Windows: Adds drive colons (D → D:)
|
||||||
|
- Unix: Adds leading slash
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized: Normalized path string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OS-specific Path object
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mapper = PathMapper()
|
||||||
|
>>> mapper.denormalize_path("D/path/to/dir") # On Windows
|
||||||
|
WindowsPath('D:/path/to/dir')
|
||||||
|
|
||||||
|
>>> mapper.denormalize_path("home/user/path") # On Unix
|
||||||
|
PosixPath('/home/user/path')
|
||||||
|
"""
|
||||||
|
parts = normalized.split("/")
|
||||||
|
|
||||||
|
# Handle Windows paths
|
||||||
|
if platform.system() == "Windows" and len(parts) > 0:
|
||||||
|
# Check if first part is a drive letter
|
||||||
|
if len(parts[0]) == 1 and parts[0].isalpha():
|
||||||
|
# D/path/to/dir → D:/path/to/dir
|
||||||
|
drive = f"{parts[0]}:"
|
||||||
|
if len(parts) > 1:
|
||||||
|
return Path(drive) / Path(*parts[1:])
|
||||||
|
return Path(drive)
|
||||||
|
|
||||||
|
# Handle Unix paths or relative paths
|
||||||
|
# home/user/path → /home/user/path
|
||||||
|
return Path("/") / Path(*parts)
|
||||||
683
codex-lens/build/lib/codexlens/storage/registry.py
Normal file
683
codex-lens/build/lib/codexlens/storage/registry.py
Normal file
@@ -0,0 +1,683 @@
|
|||||||
|
"""Global project registry for CodexLens - SQLite storage."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from codexlens.errors import StorageError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProjectInfo:
|
||||||
|
"""Registered project information."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
source_root: Path
|
||||||
|
index_root: Path
|
||||||
|
created_at: float
|
||||||
|
last_indexed: float
|
||||||
|
total_files: int
|
||||||
|
total_dirs: int
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DirMapping:
|
||||||
|
"""Directory to index path mapping."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
project_id: int
|
||||||
|
source_path: Path
|
||||||
|
index_path: Path
|
||||||
|
depth: int
|
||||||
|
files_count: int
|
||||||
|
last_updated: float
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryStore:
|
||||||
|
"""Global project registry - SQLite storage.
|
||||||
|
|
||||||
|
Manages indexed projects and directory-to-index path mappings.
|
||||||
|
Thread-safe with connection pooling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_DB_PATH = Path.home() / ".codexlens" / "registry.db"
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path | None = None) -> None:
|
||||||
|
self.db_path = (db_path or self.DEFAULT_DB_PATH).resolve()
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._local = threading.local()
|
||||||
|
self._pool_lock = threading.Lock()
|
||||||
|
self._pool: Dict[int, sqlite3.Connection] = {}
|
||||||
|
self._pool_generation = 0
|
||||||
|
|
||||||
|
def _get_connection(self) -> sqlite3.Connection:
|
||||||
|
"""Get or create a thread-local database connection."""
|
||||||
|
thread_id = threading.get_ident()
|
||||||
|
if getattr(self._local, "generation", None) == self._pool_generation:
|
||||||
|
conn = getattr(self._local, "conn", None)
|
||||||
|
if conn is not None:
|
||||||
|
return conn
|
||||||
|
|
||||||
|
with self._pool_lock:
|
||||||
|
conn = self._pool.get(thread_id)
|
||||||
|
if conn is None:
|
||||||
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
self._pool[thread_id] = conn
|
||||||
|
|
||||||
|
self._local.conn = conn
|
||||||
|
self._local.generation = self._pool_generation
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close all pooled connections."""
|
||||||
|
with self._lock:
|
||||||
|
with self._pool_lock:
|
||||||
|
for conn in self._pool.values():
|
||||||
|
conn.close()
|
||||||
|
self._pool.clear()
|
||||||
|
self._pool_generation += 1
|
||||||
|
|
||||||
|
if hasattr(self._local, "conn"):
|
||||||
|
self._local.conn = None
|
||||||
|
if hasattr(self._local, "generation"):
|
||||||
|
self._local.generation = self._pool_generation
|
||||||
|
|
||||||
|
def __enter__(self) -> RegistryStore:
|
||||||
|
self.initialize()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def initialize(self) -> None:
|
||||||
|
"""Create database and schema."""
|
||||||
|
with self._lock:
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
conn = self._get_connection()
|
||||||
|
self._create_schema(conn)
|
||||||
|
|
||||||
|
def _create_schema(self, conn: sqlite3.Connection) -> None:
|
||||||
|
"""Create database schema."""
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
source_root TEXT UNIQUE NOT NULL,
|
||||||
|
index_root TEXT NOT NULL,
|
||||||
|
created_at REAL,
|
||||||
|
last_indexed REAL,
|
||||||
|
total_files INTEGER DEFAULT 0,
|
||||||
|
total_dirs INTEGER DEFAULT 0,
|
||||||
|
status TEXT DEFAULT 'active'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS dir_mapping (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
project_id INTEGER REFERENCES projects(id) ON DELETE CASCADE,
|
||||||
|
source_path TEXT NOT NULL,
|
||||||
|
index_path TEXT NOT NULL,
|
||||||
|
depth INTEGER,
|
||||||
|
files_count INTEGER DEFAULT 0,
|
||||||
|
last_updated REAL,
|
||||||
|
UNIQUE(source_path)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_dir_source ON dir_mapping(source_path)"
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_dir_project ON dir_mapping(project_id)"
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_project_source ON projects(source_root)"
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.DatabaseError as exc:
|
||||||
|
raise StorageError(f"Failed to initialize registry schema: {exc}") from exc
|
||||||
|
|
||||||
|
def _normalize_path_for_comparison(self, path: Path) -> str:
|
||||||
|
"""Normalize paths for comparisons and storage.
|
||||||
|
|
||||||
|
Windows paths are treated as case-insensitive, so normalize to lowercase.
|
||||||
|
Unix platforms preserve case sensitivity.
|
||||||
|
"""
|
||||||
|
path_str = str(path)
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
return path_str.lower()
|
||||||
|
return path_str
|
||||||
|
|
||||||
|
# === Project Operations ===
|
||||||
|
|
||||||
|
def register_project(self, source_root: Path, index_root: Path) -> ProjectInfo:
|
||||||
|
"""Register a new project or update existing one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_root: Source code root directory
|
||||||
|
index_root: Index storage root directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProjectInfo for the registered project
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_root_str = self._normalize_path_for_comparison(source_root.resolve())
|
||||||
|
index_root_str = str(index_root.resolve())
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO projects(source_root, index_root, created_at, last_indexed)
|
||||||
|
VALUES(?, ?, ?, ?)
|
||||||
|
ON CONFLICT(source_root) DO UPDATE SET
|
||||||
|
index_root=excluded.index_root,
|
||||||
|
last_indexed=excluded.last_indexed,
|
||||||
|
status='active'
|
||||||
|
""",
|
||||||
|
(source_root_str, index_root_str, now, now),
|
||||||
|
)
|
||||||
|
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM projects WHERE source_root=?", (source_root_str,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
raise StorageError(f"Failed to register project: {source_root}")
|
||||||
|
|
||||||
|
return self._row_to_project_info(row)
|
||||||
|
|
||||||
|
def unregister_project(self, source_root: Path) -> bool:
|
||||||
|
"""Remove a project registration (cascades to directory mappings).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_root: Source code root directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if project was removed, False if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_root_str = self._normalize_path_for_comparison(source_root.resolve())
|
||||||
|
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT id FROM projects WHERE source_root=?", (source_root_str,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return False
|
||||||
|
|
||||||
|
conn.execute("DELETE FROM projects WHERE source_root=?", (source_root_str,))
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_project(self, source_root: Path) -> Optional[ProjectInfo]:
|
||||||
|
"""Get project information by source root.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_root: Source code root directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProjectInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_root_str = self._normalize_path_for_comparison(source_root.resolve())
|
||||||
|
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM projects WHERE source_root=?", (source_root_str,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
return self._row_to_project_info(row) if row else None
|
||||||
|
|
||||||
|
def get_project_by_id(self, project_id: int) -> Optional[ProjectInfo]:
|
||||||
|
"""Get project information by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project database ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProjectInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM projects WHERE id=?", (project_id,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
return self._row_to_project_info(row) if row else None
|
||||||
|
|
||||||
|
def list_projects(self, status: Optional[str] = None) -> List[ProjectInfo]:
|
||||||
|
"""List all registered projects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Optional status filter ('active', 'stale', 'removed')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ProjectInfo objects
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
|
||||||
|
if status:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT * FROM projects WHERE status=? ORDER BY created_at DESC",
|
||||||
|
(status,),
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT * FROM projects ORDER BY created_at DESC"
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [self._row_to_project_info(row) for row in rows]
|
||||||
|
|
||||||
|
def update_project_stats(
|
||||||
|
self, source_root: Path, total_files: int, total_dirs: int
|
||||||
|
) -> None:
|
||||||
|
"""Update project statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_root: Source code root directory
|
||||||
|
total_files: Total number of indexed files
|
||||||
|
total_dirs: Total number of indexed directories
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_root_str = self._normalize_path_for_comparison(source_root.resolve())
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE projects
|
||||||
|
SET total_files=?, total_dirs=?, last_indexed=?
|
||||||
|
WHERE source_root=?
|
||||||
|
""",
|
||||||
|
(total_files, total_dirs, time.time(), source_root_str),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def set_project_status(self, source_root: Path, status: str) -> None:
|
||||||
|
"""Set project status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_root: Source code root directory
|
||||||
|
status: Status string ('active', 'stale', 'removed')
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_root_str = self._normalize_path_for_comparison(source_root.resolve())
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE projects SET status=? WHERE source_root=?",
|
||||||
|
(status, source_root_str),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# === Directory Mapping Operations ===
|
||||||
|
|
||||||
|
def register_dir(
|
||||||
|
self,
|
||||||
|
project_id: int,
|
||||||
|
source_path: Path,
|
||||||
|
index_path: Path,
|
||||||
|
depth: int,
|
||||||
|
files_count: int = 0,
|
||||||
|
) -> DirMapping:
|
||||||
|
"""Register a directory mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project database ID
|
||||||
|
source_path: Source directory path
|
||||||
|
index_path: Index database path
|
||||||
|
depth: Directory depth relative to project root
|
||||||
|
files_count: Number of files in directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DirMapping for the registered directory
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_path_str = self._normalize_path_for_comparison(source_path.resolve())
|
||||||
|
index_path_str = str(index_path.resolve())
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO dir_mapping(
|
||||||
|
project_id, source_path, index_path, depth, files_count, last_updated
|
||||||
|
)
|
||||||
|
VALUES(?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(source_path) DO UPDATE SET
|
||||||
|
index_path=excluded.index_path,
|
||||||
|
depth=excluded.depth,
|
||||||
|
files_count=excluded.files_count,
|
||||||
|
last_updated=excluded.last_updated
|
||||||
|
""",
|
||||||
|
(project_id, source_path_str, index_path_str, depth, files_count, now),
|
||||||
|
)
|
||||||
|
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM dir_mapping WHERE source_path=?", (source_path_str,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
raise StorageError(f"Failed to register directory: {source_path}")
|
||||||
|
|
||||||
|
return self._row_to_dir_mapping(row)
|
||||||
|
|
||||||
|
def unregister_dir(self, source_path: Path) -> bool:
|
||||||
|
"""Remove a directory mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if directory was removed, False if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_path_str = self._normalize_path_for_comparison(source_path.resolve())
|
||||||
|
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT id FROM dir_mapping WHERE source_path=?", (source_path_str,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return False
|
||||||
|
|
||||||
|
conn.execute("DELETE FROM dir_mapping WHERE source_path=?", (source_path_str,))
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def find_index_path(self, source_path: Path) -> Optional[Path]:
|
||||||
|
"""Find index path for a source directory (exact match).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Index path if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_path_str = self._normalize_path_for_comparison(source_path.resolve())
|
||||||
|
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT index_path FROM dir_mapping WHERE source_path=?",
|
||||||
|
(source_path_str,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
return Path(row["index_path"]) if row else None
|
||||||
|
|
||||||
|
def find_nearest_index(self, source_path: Path) -> Optional[DirMapping]:
|
||||||
|
"""Find nearest indexed ancestor directory.
|
||||||
|
|
||||||
|
Searches for the closest parent directory that has an index.
|
||||||
|
Useful for supporting subdirectory searches.
|
||||||
|
|
||||||
|
Optimized to use single database query instead of iterating through
|
||||||
|
each parent directory level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory or file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DirMapping for nearest ancestor, None if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_path_resolved = source_path.resolve()
|
||||||
|
|
||||||
|
# Build list of all parent paths from deepest to shallowest
|
||||||
|
paths_to_check = []
|
||||||
|
current = source_path_resolved
|
||||||
|
while True:
|
||||||
|
paths_to_check.append(self._normalize_path_for_comparison(current))
|
||||||
|
parent = current.parent
|
||||||
|
if parent == current: # Reached filesystem root
|
||||||
|
break
|
||||||
|
current = parent
|
||||||
|
|
||||||
|
if not paths_to_check:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Single query with WHERE IN, ordered by path length (longest = nearest)
|
||||||
|
placeholders = ','.join('?' * len(paths_to_check))
|
||||||
|
query = f"""
|
||||||
|
SELECT * FROM dir_mapping
|
||||||
|
WHERE source_path IN ({placeholders})
|
||||||
|
ORDER BY LENGTH(source_path) DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
row = conn.execute(query, paths_to_check).fetchone()
|
||||||
|
return self._row_to_dir_mapping(row) if row else None
|
||||||
|
|
||||||
|
def find_by_source_path(self, source_path: str) -> Optional[Dict[str, str]]:
|
||||||
|
"""Find project by source path (exact or nearest match).
|
||||||
|
|
||||||
|
Searches for a project whose source_root matches or contains
|
||||||
|
the given source_path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory path as string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with project info including 'index_root', or None if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
resolved_path = Path(source_path).resolve()
|
||||||
|
source_path_resolved = self._normalize_path_for_comparison(resolved_path)
|
||||||
|
|
||||||
|
# First try exact match on projects table
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM projects WHERE source_root=?", (source_path_resolved,)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return {
|
||||||
|
"id": str(row["id"]),
|
||||||
|
"source_root": row["source_root"],
|
||||||
|
"index_root": row["index_root"],
|
||||||
|
"status": row["status"] or "active",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try finding project that contains this path
|
||||||
|
# Build list of all parent paths
|
||||||
|
paths_to_check = []
|
||||||
|
current = resolved_path
|
||||||
|
while True:
|
||||||
|
paths_to_check.append(self._normalize_path_for_comparison(current))
|
||||||
|
parent = current.parent
|
||||||
|
if parent == current:
|
||||||
|
break
|
||||||
|
current = parent
|
||||||
|
|
||||||
|
if paths_to_check:
|
||||||
|
placeholders = ','.join('?' * len(paths_to_check))
|
||||||
|
query = f"""
|
||||||
|
SELECT * FROM projects
|
||||||
|
WHERE source_root IN ({placeholders})
|
||||||
|
ORDER BY LENGTH(source_root) DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
row = conn.execute(query, paths_to_check).fetchone()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return {
|
||||||
|
"id": str(row["id"]),
|
||||||
|
"source_root": row["source_root"],
|
||||||
|
"index_root": row["index_root"],
|
||||||
|
"status": row["status"] or "active",
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_project_dirs(self, project_id: int) -> List[DirMapping]:
|
||||||
|
"""Get all directory mappings for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project database ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DirMapping objects
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT * FROM dir_mapping WHERE project_id=? ORDER BY depth, source_path",
|
||||||
|
(project_id,),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [self._row_to_dir_mapping(row) for row in rows]
|
||||||
|
|
||||||
|
def get_subdirs(self, source_path: Path) -> List[DirMapping]:
|
||||||
|
"""Get direct subdirectory mappings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Parent directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DirMapping objects for direct children
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_path_str = self._normalize_path_for_comparison(source_path.resolve())
|
||||||
|
|
||||||
|
# First get the parent's depth
|
||||||
|
parent_row = conn.execute(
|
||||||
|
"SELECT depth, project_id FROM dir_mapping WHERE source_path=?",
|
||||||
|
(source_path_str,),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if not parent_row:
|
||||||
|
return []
|
||||||
|
|
||||||
|
parent_depth = int(parent_row["depth"])
|
||||||
|
project_id = int(parent_row["project_id"])
|
||||||
|
|
||||||
|
# Get all subdirs with depth = parent_depth + 1 and matching path prefix
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM dir_mapping
|
||||||
|
WHERE project_id=? AND depth=? AND source_path LIKE ?
|
||||||
|
ORDER BY source_path
|
||||||
|
""",
|
||||||
|
(project_id, parent_depth + 1, f"{source_path_str}%"),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [self._row_to_dir_mapping(row) for row in rows]
|
||||||
|
|
||||||
|
def update_dir_stats(self, source_path: Path, files_count: int) -> None:
|
||||||
|
"""Update directory statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Source directory path
|
||||||
|
files_count: Number of files in directory
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
source_path_str = self._normalize_path_for_comparison(source_path.resolve())
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE dir_mapping
|
||||||
|
SET files_count=?, last_updated=?
|
||||||
|
WHERE source_path=?
|
||||||
|
""",
|
||||||
|
(files_count, time.time(), source_path_str),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def update_index_paths(self, old_root: Path, new_root: Path) -> int:
|
||||||
|
"""Update all index paths after migration.
|
||||||
|
|
||||||
|
Replaces old_root prefix with new_root in all stored index paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_root: Old index root directory
|
||||||
|
new_root: New index root directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of paths updated
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
conn = self._get_connection()
|
||||||
|
old_root_str = str(old_root.resolve())
|
||||||
|
new_root_str = str(new_root.resolve())
|
||||||
|
updated = 0
|
||||||
|
|
||||||
|
# Update projects
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE projects
|
||||||
|
SET index_root = REPLACE(index_root, ?, ?)
|
||||||
|
WHERE index_root LIKE ?
|
||||||
|
""",
|
||||||
|
(old_root_str, new_root_str, f"{old_root_str}%"),
|
||||||
|
)
|
||||||
|
updated += conn.total_changes
|
||||||
|
|
||||||
|
# Update dir_mapping
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE dir_mapping
|
||||||
|
SET index_path = REPLACE(index_path, ?, ?)
|
||||||
|
WHERE index_path LIKE ?
|
||||||
|
""",
|
||||||
|
(old_root_str, new_root_str, f"{old_root_str}%"),
|
||||||
|
)
|
||||||
|
updated += conn.total_changes
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
return updated
|
||||||
|
|
||||||
|
# === Internal Methods ===
|
||||||
|
|
||||||
|
def _row_to_project_info(self, row: sqlite3.Row) -> ProjectInfo:
|
||||||
|
"""Convert database row to ProjectInfo."""
|
||||||
|
return ProjectInfo(
|
||||||
|
id=int(row["id"]),
|
||||||
|
source_root=Path(row["source_root"]),
|
||||||
|
index_root=Path(row["index_root"]),
|
||||||
|
created_at=float(row["created_at"]) if row["created_at"] else 0.0,
|
||||||
|
last_indexed=float(row["last_indexed"]) if row["last_indexed"] else 0.0,
|
||||||
|
total_files=int(row["total_files"]) if row["total_files"] else 0,
|
||||||
|
total_dirs=int(row["total_dirs"]) if row["total_dirs"] else 0,
|
||||||
|
status=str(row["status"]) if row["status"] else "active",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _row_to_dir_mapping(self, row: sqlite3.Row) -> DirMapping:
|
||||||
|
"""Convert database row to DirMapping."""
|
||||||
|
return DirMapping(
|
||||||
|
id=int(row["id"]),
|
||||||
|
project_id=int(row["project_id"]),
|
||||||
|
source_path=Path(row["source_path"]),
|
||||||
|
index_path=Path(row["index_path"]),
|
||||||
|
depth=int(row["depth"]) if row["depth"] is not None else 0,
|
||||||
|
files_count=int(row["files_count"]) if row["files_count"] else 0,
|
||||||
|
last_updated=float(row["last_updated"]) if row["last_updated"] else 0.0,
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user