mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-08 02:14:08 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88ff109ac4 | ||
|
|
261196a804 | ||
|
|
ea6cb8440f | ||
|
|
bf896342f4 | ||
|
|
f2b0a5bbc9 | ||
|
|
cf5fecd66d | ||
|
|
86d469ccc9 |
@@ -401,17 +401,19 @@ async function executeCommandChain(chain, analysis) {
|
||||
state.updated_at = new Date().toISOString();
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
|
||||
// Assemble prompt with previous results
|
||||
let prompt = `Task: ${analysis.goal}\n`;
|
||||
// Assemble prompt: Command first, then context
|
||||
let promptContent = formatCommand(cmd, state.execution_results, analysis);
|
||||
|
||||
// Build full prompt: Command → Task → Previous Results
|
||||
let prompt = `${promptContent}\n\nTask: ${analysis.goal}`;
|
||||
if (state.execution_results.length > 0) {
|
||||
prompt += '\nPrevious results:\n';
|
||||
prompt += '\n\nPrevious results:\n';
|
||||
state.execution_results.forEach(r => {
|
||||
if (r.session_id) {
|
||||
prompt += `- ${r.command}: ${r.session_id} (${r.artifacts?.join(', ') || 'completed'})\n`;
|
||||
}
|
||||
});
|
||||
}
|
||||
prompt += `\n${formatCommand(cmd, state.execution_results, analysis)}\n`;
|
||||
|
||||
// Record prompt used
|
||||
state.prompts_used.push({
|
||||
@@ -421,9 +423,12 @@ async function executeCommandChain(chain, analysis) {
|
||||
});
|
||||
|
||||
// Execute CLI command in background and stop
|
||||
// Format: ccw cli -p "PROMPT" --tool <tool> --mode <mode>
|
||||
// Note: -y is a command parameter INSIDE the prompt, not a ccw cli parameter
|
||||
// Example prompt: "/workflow:plan -y \"task description here\""
|
||||
try {
|
||||
const taskId = Bash(
|
||||
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write -y`,
|
||||
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write`,
|
||||
{ run_in_background: true }
|
||||
).task_id;
|
||||
|
||||
@@ -486,69 +491,71 @@ async function executeCommandChain(chain, analysis) {
|
||||
}
|
||||
|
||||
// Smart parameter assembly
|
||||
// Returns prompt content to be used with: ccw cli -p "RETURNED_VALUE" --tool claude --mode write
|
||||
function formatCommand(cmd, previousResults, analysis) {
|
||||
let line = cmd.command + ' --yes';
|
||||
// Format: /workflow:<command> -y <parameters>
|
||||
let prompt = `/workflow:${cmd.name} -y`;
|
||||
const name = cmd.name;
|
||||
|
||||
// Planning commands - take task description
|
||||
if (['lite-plan', 'plan', 'tdd-plan', 'multi-cli-plan'].includes(name)) {
|
||||
line += ` "${analysis.goal}"`;
|
||||
prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Lite execution - use --in-memory if plan exists
|
||||
} else if (name === 'lite-execute') {
|
||||
const hasPlan = previousResults.some(r => r.command.includes('plan'));
|
||||
line += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
|
||||
prompt += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
|
||||
|
||||
// Standard execution - resume from planning session
|
||||
} else if (name === 'execute') {
|
||||
const plan = previousResults.find(r => r.command.includes('plan'));
|
||||
if (plan?.session_id) line += ` --resume-session="${plan.session_id}"`;
|
||||
if (plan?.session_id) prompt += ` --resume-session="${plan.session_id}"`;
|
||||
|
||||
// Bug fix commands - take bug description
|
||||
} else if (['lite-fix', 'debug'].includes(name)) {
|
||||
line += ` "${analysis.goal}"`;
|
||||
prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Brainstorm - take topic description
|
||||
} else if (name === 'brainstorm:auto-parallel' || name === 'auto-parallel') {
|
||||
line += ` "${analysis.goal}"`;
|
||||
prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Test generation from session - needs source session
|
||||
} else if (name === 'test-gen') {
|
||||
const impl = previousResults.find(r =>
|
||||
r.command.includes('execute') || r.command.includes('lite-execute')
|
||||
);
|
||||
if (impl?.session_id) line += ` "${impl.session_id}"`;
|
||||
else line += ` "${analysis.goal}"`;
|
||||
if (impl?.session_id) prompt += ` "${impl.session_id}"`;
|
||||
else prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Test fix generation - session or description
|
||||
} else if (name === 'test-fix-gen') {
|
||||
const latest = previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) line += ` "${latest.session_id}"`;
|
||||
else line += ` "${analysis.goal}"`;
|
||||
if (latest?.session_id) prompt += ` "${latest.session_id}"`;
|
||||
else prompt += ` "${analysis.goal}"`;
|
||||
|
||||
// Review commands - take session or use latest
|
||||
} else if (name === 'review') {
|
||||
const latest = previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
|
||||
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||
|
||||
// Review fix - takes session from review
|
||||
} else if (name === 'review-fix') {
|
||||
const review = previousResults.find(r => r.command.includes('review'));
|
||||
const latest = review || previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
|
||||
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||
|
||||
// TDD verify - takes execution session
|
||||
} else if (name === 'tdd-verify') {
|
||||
const exec = previousResults.find(r => r.command.includes('execute'));
|
||||
if (exec?.session_id) line += ` --session="${exec.session_id}"`;
|
||||
if (exec?.session_id) prompt += ` --session="${exec.session_id}"`;
|
||||
|
||||
// Session-based commands (test-cycle, review-session, plan-verify)
|
||||
} else if (name.includes('test') || name.includes('review') || name.includes('verify')) {
|
||||
const latest = previousResults.filter(r => r.session_id).pop();
|
||||
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
|
||||
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
|
||||
}
|
||||
|
||||
return line;
|
||||
return prompt;
|
||||
}
|
||||
|
||||
// Hook callback: Called when background CLI completes
|
||||
@@ -663,12 +670,12 @@ function parseOutput(output) {
|
||||
{
|
||||
"index": 0,
|
||||
"command": "/workflow:plan",
|
||||
"prompt": "Task: Implement user registration...\n\n/workflow:plan --yes \"Implement user registration...\""
|
||||
"prompt": "/workflow:plan -y \"Implement user registration...\"\n\nTask: Implement user registration..."
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"command": "/workflow:execute",
|
||||
"prompt": "Task: Implement user registration...\n\nPrevious results:\n- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)\n\n/workflow:execute --yes --resume-session=\"WFS-plan-20250124\""
|
||||
"prompt": "/workflow:execute -y --resume-session=\"WFS-plan-20250124\"\n\nTask: Implement user registration\n\nPrevious results:\n- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -728,226 +735,68 @@ const cmd = registry.getCommand('lite-plan');
|
||||
// {name, command, description, argumentHint, allowedTools, filePath}
|
||||
```
|
||||
|
||||
## Execution Examples
|
||||
## Universal Prompt Template
|
||||
|
||||
### Simple Feature
|
||||
```
|
||||
Goal: Add API endpoint for user profile
|
||||
Scope: [api]
|
||||
Complexity: simple
|
||||
Constraints: []
|
||||
Task Type: feature
|
||||
### Standard Format
|
||||
|
||||
Pipeline (with Minimum Execution Units):
|
||||
需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过
|
||||
|
||||
Chain:
|
||||
# Unit 1: Quick Implementation
|
||||
1. /workflow:lite-plan --yes "Add API endpoint..."
|
||||
2. /workflow:lite-execute --yes --in-memory
|
||||
|
||||
# Unit 2: Test Validation
|
||||
3. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
||||
4. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
||||
```bash
|
||||
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
|
||||
```
|
||||
|
||||
### Complex Feature with Verification
|
||||
### Prompt Content Template
|
||||
|
||||
```
|
||||
Goal: Implement OAuth2 authentication system
|
||||
Scope: [auth, database, api, frontend]
|
||||
Complexity: complex
|
||||
Constraints: [no breaking changes]
|
||||
Task Type: feature
|
||||
/workflow:<command> -y <command_parameters>
|
||||
|
||||
Pipeline (with Minimum Execution Units):
|
||||
需求 →【plan → plan-verify】→ 验证计划 → execute → 代码
|
||||
→【review-session-cycle → review-fix】→ 修复代码
|
||||
→【test-fix-gen → test-cycle-execute】→ 测试通过
|
||||
Task: <task_description>
|
||||
|
||||
Chain:
|
||||
# Unit 1: Full Planning (plan + plan-verify)
|
||||
1. /workflow:plan --yes "Implement OAuth2..."
|
||||
2. /workflow:plan-verify --yes --session="WFS-xxx"
|
||||
|
||||
# Execution phase
|
||||
3. /workflow:execute --yes --resume-session="WFS-xxx"
|
||||
|
||||
# Unit 2: Code Review (review-session-cycle + review-fix)
|
||||
4. /workflow:review-session-cycle --yes --session="WFS-xxx"
|
||||
5. /workflow:review-fix --yes --session="WFS-xxx"
|
||||
|
||||
# Unit 3: Test Validation (test-fix-gen + test-cycle-execute)
|
||||
6. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
||||
7. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
||||
<optional_previous_results>
|
||||
```
|
||||
|
||||
### Quick Bug Fix
|
||||
```
|
||||
Goal: Fix login timeout issue
|
||||
Scope: [auth]
|
||||
Complexity: simple
|
||||
Constraints: [urgent]
|
||||
Task Type: bugfix
|
||||
### Template Variables
|
||||
|
||||
Pipeline:
|
||||
Bug报告 → lite-fix → 修复代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
||||
| Variable | Description | Examples |
|
||||
|----------|-------------|----------|
|
||||
| `<command>` | Workflow command name | `plan`, `lite-execute`, `test-cycle-execute` |
|
||||
| `-y` | Auto-confirm flag (inside prompt) | Always include for automation |
|
||||
| `<command_parameters>` | Command-specific parameters | Task description, session ID, flags |
|
||||
| `<task_description>` | Brief task description | "Implement user authentication", "Fix memory leak" |
|
||||
| `<optional_previous_results>` | Context from previous commands | "Previous results:\n- /workflow:plan: WFS-xxx" |
|
||||
|
||||
Chain:
|
||||
1. /workflow:lite-fix --yes "Fix login timeout..."
|
||||
2. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
||||
3. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
||||
### Command Parameter Patterns
|
||||
|
||||
| Command Type | Parameter Pattern | Example |
|
||||
|--------------|------------------|---------|
|
||||
| **Planning** | `"task description"` | `/workflow:plan -y "Implement OAuth2"` |
|
||||
| **Execution (with plan)** | `--resume-session="WFS-xxx"` | `/workflow:execute -y --resume-session="WFS-plan-001"` |
|
||||
| **Execution (standalone)** | `--in-memory` or `"task"` | `/workflow:lite-execute -y --in-memory` |
|
||||
| **Session-based** | `--session="WFS-xxx"` | `/workflow:test-fix-gen -y --session="WFS-impl-001"` |
|
||||
| **Fix/Debug** | `"problem description"` | `/workflow:lite-fix -y "Fix timeout bug"` |
|
||||
|
||||
### Complete Examples
|
||||
|
||||
**Planning Command**:
|
||||
```bash
|
||||
ccw cli -p '/workflow:plan -y "Implement user registration with email validation"
|
||||
|
||||
Task: Implement user registration' --tool claude --mode write
|
||||
```
|
||||
|
||||
### Skip Tests
|
||||
```
|
||||
Goal: Update documentation
|
||||
Scope: [docs]
|
||||
Complexity: simple
|
||||
Constraints: [skip-tests]
|
||||
Task Type: feature
|
||||
**Execution with Context**:
|
||||
```bash
|
||||
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
|
||||
|
||||
Pipeline:
|
||||
需求 → lite-plan → 计划 → lite-execute → 代码
|
||||
Task: Implement user registration
|
||||
|
||||
Chain:
|
||||
1. /workflow:lite-plan --yes "Update documentation..."
|
||||
2. /workflow:lite-execute --yes --in-memory
|
||||
Previous results:
|
||||
- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)' --tool claude --mode write
|
||||
```
|
||||
|
||||
### TDD Workflow
|
||||
```
|
||||
Goal: Implement user authentication with test-first approach
|
||||
Scope: [auth]
|
||||
Complexity: medium
|
||||
Constraints: [test-driven]
|
||||
Task Type: tdd
|
||||
**Standalone Lite Execution**:
|
||||
```bash
|
||||
ccw cli -p '/workflow:lite-fix -y "Fix login timeout in auth module"
|
||||
|
||||
Pipeline:
|
||||
需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify → TDD验证通过
|
||||
|
||||
Chain:
|
||||
1. /workflow:tdd-plan --yes "Implement user authentication..."
|
||||
2. /workflow:execute --yes --resume-session="WFS-xxx"
|
||||
3. /workflow:tdd-verify --yes --session="WFS-xxx"
|
||||
```
|
||||
|
||||
### Debug Workflow
|
||||
```
|
||||
Goal: Fix memory leak in WebSocket handler
|
||||
Scope: [websocket]
|
||||
Complexity: medium
|
||||
Constraints: [production-issue]
|
||||
Task Type: bugfix
|
||||
|
||||
Pipeline (快速修复):
|
||||
Bug报告 → lite-fix → 修复代码 → test-cycle-execute → 测试通过
|
||||
|
||||
Pipeline (系统调试):
|
||||
Bug报告 → debug → 调试日志 → 分析定位 → 修复
|
||||
|
||||
Chain:
|
||||
1. /workflow:lite-fix --yes "Fix memory leak in WebSocket..."
|
||||
2. /workflow:test-cycle-execute --yes --session="WFS-xxx"
|
||||
|
||||
OR (for hypothesis-driven debugging):
|
||||
1. /workflow:debug --yes "Memory leak in WebSocket handler..."
|
||||
```
|
||||
|
||||
### Test Fix Workflow
|
||||
```
|
||||
Goal: Fix failing authentication tests
|
||||
Scope: [auth, tests]
|
||||
Complexity: simple
|
||||
Constraints: []
|
||||
Task Type: test-fix
|
||||
|
||||
Pipeline:
|
||||
失败测试 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
||||
|
||||
Chain:
|
||||
1. /workflow:test-fix-gen --yes "WFS-auth-impl-001"
|
||||
2. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
||||
```
|
||||
|
||||
### Test Generation from Implementation
|
||||
```
|
||||
Goal: Generate comprehensive tests for completed user registration feature
|
||||
Scope: [auth, tests]
|
||||
Complexity: medium
|
||||
Constraints: []
|
||||
Task Type: test-gen
|
||||
|
||||
Pipeline (with Minimum Execution Units):
|
||||
代码/会话 →【test-gen → execute】→ 测试通过
|
||||
|
||||
Chain:
|
||||
# Unit: Test Generation (test-gen + execute)
|
||||
1. /workflow:test-gen --yes "WFS-registration-20250124"
|
||||
2. /workflow:execute --yes --session="WFS-test-registration"
|
||||
|
||||
Note: test-gen creates IMPL-001 (test generation) and IMPL-002 (test execution & fix)
|
||||
execute runs both tasks - this is a Minimum Execution Unit
|
||||
```
|
||||
|
||||
### Review + Fix Workflow
|
||||
```
|
||||
Goal: Code review of payment module
|
||||
Scope: [payment]
|
||||
Complexity: medium
|
||||
Constraints: []
|
||||
Task Type: review
|
||||
|
||||
Pipeline (with Minimum Execution Units):
|
||||
代码 →【review-session-cycle → review-fix】→ 修复代码
|
||||
→【test-fix-gen → test-cycle-execute】→ 测试通过
|
||||
|
||||
Chain:
|
||||
# Unit 1: Code Review (review-session-cycle + review-fix)
|
||||
1. /workflow:review-session-cycle --yes --session="WFS-payment-impl"
|
||||
2. /workflow:review-fix --yes --session="WFS-payment-impl"
|
||||
|
||||
# Unit 2: Test Validation (test-fix-gen + test-cycle-execute)
|
||||
3. /workflow:test-fix-gen --yes --session="WFS-payment-impl"
|
||||
4. /workflow:test-cycle-execute --yes --session="WFS-test-payment-impl"
|
||||
```
|
||||
|
||||
### Brainstorm Workflow (Uncertain Requirements)
|
||||
```
|
||||
Goal: Explore solutions for real-time notification system
|
||||
Scope: [notifications, architecture]
|
||||
Complexity: complex
|
||||
Constraints: []
|
||||
Task Type: brainstorm
|
||||
|
||||
Pipeline:
|
||||
探索主题 → brainstorm:auto-parallel → 分析结果 → plan → 详细计划
|
||||
→ plan-verify → 验证计划 → execute → 代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
||||
|
||||
Chain:
|
||||
1. /workflow:brainstorm:auto-parallel --yes "Explore solutions for real-time..."
|
||||
2. /workflow:plan --yes "Implement chosen notification approach..."
|
||||
3. /workflow:plan-verify --yes --session="WFS-xxx"
|
||||
4. /workflow:execute --yes --resume-session="WFS-xxx"
|
||||
5. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
||||
6. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
||||
```
|
||||
|
||||
### Multi-CLI Plan (Multi-Perspective Analysis)
|
||||
```
|
||||
Goal: Compare microservices vs monolith architecture
|
||||
Scope: [architecture]
|
||||
Complexity: complex
|
||||
Constraints: []
|
||||
Task Type: multi-cli
|
||||
|
||||
Pipeline:
|
||||
需求 → multi-cli-plan → 对比计划 → lite-execute → 代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
|
||||
|
||||
Chain:
|
||||
1. /workflow:multi-cli-plan --yes "Compare microservices vs monolith..."
|
||||
2. /workflow:lite-execute --yes --in-memory
|
||||
3. /workflow:test-fix-gen --yes --session="WFS-xxx"
|
||||
4. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
|
||||
Task: Fix login timeout' --tool claude --mode write
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
@@ -983,19 +832,76 @@ async function ccwCoordinator(taskDescription) {
|
||||
|
||||
## CLI Execution Model
|
||||
|
||||
**Serial Blocking**: Commands execute one-by-one. After launching CLI in background, orchestrator stops immediately and waits for hook callback.
|
||||
### CLI Invocation Format
|
||||
|
||||
**IMPORTANT**: The `ccw cli` command executes prompts through external tools. The format is:
|
||||
|
||||
```bash
|
||||
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `-p "PROMPT_CONTENT"`: The prompt content to execute (required)
|
||||
- `--tool <tool>`: CLI tool to use (e.g., `claude`, `gemini`, `qwen`)
|
||||
- `--mode <mode>`: Execution mode (`analysis` or `write`)
|
||||
|
||||
**Note**: `-y` is a **command parameter inside the prompt**, NOT a `ccw cli` parameter.
|
||||
|
||||
### Prompt Assembly
|
||||
|
||||
The prompt content MUST start with the workflow command, followed by task context:
|
||||
|
||||
```
|
||||
/workflow:<command> -y <parameters>
|
||||
|
||||
Task: <description>
|
||||
|
||||
<optional_context>
|
||||
```
|
||||
|
||||
**Examples**:
|
||||
```bash
|
||||
# Planning command
|
||||
ccw cli -p '/workflow:plan -y "Implement user registration feature"
|
||||
|
||||
Task: Implement user registration' --tool claude --mode write
|
||||
|
||||
# Execution command (with session reference)
|
||||
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
|
||||
|
||||
Task: Implement user registration
|
||||
|
||||
Previous results:
|
||||
- /workflow:plan: WFS-plan-20250124' --tool claude --mode write
|
||||
|
||||
# Lite execution (in-memory from previous plan)
|
||||
ccw cli -p '/workflow:lite-execute -y --in-memory
|
||||
|
||||
Task: Implement user registration' --tool claude --mode write
|
||||
```
|
||||
|
||||
### Serial Blocking
|
||||
|
||||
**CRITICAL**: Commands execute one-by-one. After launching CLI in background:
|
||||
1. Orchestrator stops immediately (`break`)
|
||||
2. Wait for hook callback - **DO NOT use TaskOutput polling**
|
||||
3. Hook callback triggers next command
|
||||
|
||||
**Prompt Structure**: Command must be first in prompt content
|
||||
|
||||
```javascript
|
||||
// Example: Execute command and stop
|
||||
const taskId = Bash(`ccw cli -p "..." --tool claude --mode write -y`, { run_in_background: true }).task_id;
|
||||
const prompt = '/workflow:plan -y "Implement user authentication"\n\nTask: Implement user auth system';
|
||||
const taskId = Bash(`ccw cli -p "${prompt}" --tool claude --mode write`, { run_in_background: true }).task_id;
|
||||
state.execution_results.push({ status: 'in-progress', task_id: taskId, ... });
|
||||
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
|
||||
break; // Stop, wait for hook callback
|
||||
break; // ⚠️ STOP HERE - DO NOT use TaskOutput polling
|
||||
|
||||
// Hook calls handleCliCompletion(sessionId, taskId, output) when done
|
||||
// Hook callback will call handleCliCompletion(sessionId, taskId, output) when done
|
||||
// → Updates state → Triggers next command via resumeChainExecution()
|
||||
```
|
||||
|
||||
|
||||
## Available Commands
|
||||
|
||||
All from `~/.claude/commands/workflow/`:
|
||||
@@ -1023,20 +929,20 @@ All from `~/.claude/commands/workflow/`:
|
||||
- **test-gen → execute**: 生成全面的测试套件,execute 执行生成和测试
|
||||
- **test-fix-gen → test-cycle-execute**: 针对特定问题生成修复任务,test-cycle-execute 迭代测试和修复直到通过
|
||||
|
||||
### Task Type Routing (Pipeline View)
|
||||
### Task Type Routing (Pipeline Summary)
|
||||
|
||||
**Note**: `【 】` marks Minimum Execution Units (最小执行单元) - these commands must execute together.
|
||||
|
||||
| Task Type | Pipeline |
|
||||
|-----------|----------|
|
||||
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
||||
| **feature** (complex) | 需求 →【plan → plan-verify】→ 验证计划 → execute → 代码 →【review-session-cycle → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
||||
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
||||
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify → TDD验证通过 |
|
||||
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
||||
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 |
|
||||
| **review** | 代码 →【review-session-cycle/review-module-cycle → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
||||
| **brainstorm** | 探索主题 → brainstorm:auto-parallel → 分析结果 →【plan → plan-verify】→ 验证计划 → execute → 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
||||
| **multi-cli** | 需求 → multi-cli-plan → 对比计划 → lite-execute → 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
|
||||
| Task Type | Pipeline | Minimum Units |
|
||||
|-----------|----------|---|
|
||||
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Quick Implementation + Test Validation |
|
||||
| **feature** (complex) | 需求 →【plan → plan-verify】→ validate → execute → 代码 → review → fix | Full Planning + Code Review + Testing |
|
||||
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Bug Fix + Test Validation |
|
||||
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify | TDD Planning + Execution |
|
||||
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Test Validation |
|
||||
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 | Test Generation + Execution |
|
||||
| **review** | 代码 →【review-* → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Code Review + Testing |
|
||||
| **brainstorm** | 探索主题 → brainstorm → 分析 →【plan → plan-verify】→ execute → test | Exploration + Planning + Execution |
|
||||
| **multi-cli** | 需求 → multi-cli-plan → 对比分析 → lite-execute → test | Multi-Perspective + Testing |
|
||||
|
||||
Use `CommandRegistry.getAllCommandsSummary()` to discover all commands dynamically.
|
||||
|
||||
43
README.md
43
README.md
@@ -263,6 +263,49 @@ Open Dashboard via `ccw view`, manage indexes and execute searches in **CodexLen
|
||||
|
||||
## 💻 CCW CLI Commands
|
||||
|
||||
### 🌟 Recommended Commands (Main Features)
|
||||
|
||||
<div align="center">
|
||||
<table>
|
||||
<tr><th>Command</th><th>Description</th><th>When to Use</th></tr>
|
||||
<tr>
|
||||
<td><b>/ccw</b></td>
|
||||
<td>Auto workflow orchestrator - analyzes intent, selects workflow level, executes command chain in main process</td>
|
||||
<td>✅ General tasks, auto workflow selection, quick development</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b>/ccw-coordinator</b></td>
|
||||
<td>Manual orchestrator - recommends command chains, executes via external CLI with state persistence</td>
|
||||
<td>🔧 Complex multi-step workflows, custom chains, resumable sessions</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
**Quick Examples**:
|
||||
|
||||
```bash
|
||||
# /ccw - Auto workflow selection (Main Process)
|
||||
/ccw "Add user authentication" # Auto-selects workflow based on intent
|
||||
/ccw "Fix memory leak in WebSocket" # Detects bugfix workflow
|
||||
/ccw "Implement with TDD" # Routes to TDD workflow
|
||||
|
||||
# /ccw-coordinator - Manual chain orchestration (External CLI)
|
||||
/ccw-coordinator "Implement OAuth2 system" # Analyzes → Recommends chain → User confirms → Executes
|
||||
```
|
||||
|
||||
**Key Differences**:
|
||||
|
||||
| Aspect | /ccw | /ccw-coordinator |
|
||||
|--------|------|------------------|
|
||||
| **Execution** | Main process (SlashCommand) | External CLI (background tasks) |
|
||||
| **Selection** | Auto intent-based | Manual chain confirmation |
|
||||
| **State** | TodoWrite tracking | Persistent state.json |
|
||||
| **Use Case** | General tasks, quick dev | Complex chains, resumable |
|
||||
|
||||
---
|
||||
|
||||
### Other CLI Commands
|
||||
|
||||
```bash
|
||||
ccw install # Install workflow files
|
||||
ccw view # Open dashboard
|
||||
|
||||
43
README_CN.md
43
README_CN.md
@@ -263,6 +263,49 @@ codexlens index /path/to/project
|
||||
|
||||
## 💻 CCW CLI 命令
|
||||
|
||||
### 🌟 推荐命令(核心功能)
|
||||
|
||||
<div align="center">
|
||||
<table>
|
||||
<tr><th>命令</th><th>说明</th><th>适用场景</th></tr>
|
||||
<tr>
|
||||
<td><b>/ccw</b></td>
|
||||
<td>自动工作流编排器 - 分析意图、自动选择工作流级别、在主进程中执行命令链</td>
|
||||
<td>✅ 通用任务、自动选择工作流、快速开发</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b>/ccw-coordinator</b></td>
|
||||
<td>手动编排器 - 推荐命令链、通过外部 CLI 执行、持久化状态</td>
|
||||
<td>🔧 复杂多步骤工作流、自定义链、可恢复会话</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
**快速示例**:
|
||||
|
||||
```bash
|
||||
# /ccw - 自动工作流选择(主进程)
|
||||
/ccw "添加用户认证" # 自动根据意图选择工作流
|
||||
/ccw "修复 WebSocket 中的内存泄漏" # 识别为 bugfix 工作流
|
||||
/ccw "使用 TDD 方式实现" # 路由到 TDD 工作流
|
||||
|
||||
# /ccw-coordinator - 手动链编排(外部 CLI)
|
||||
/ccw-coordinator "实现 OAuth2 系统" # 分析 → 推荐链 → 用户确认 → 执行
|
||||
```
|
||||
|
||||
**主要区别**:
|
||||
|
||||
| 方面 | /ccw | /ccw-coordinator |
|
||||
|------|------|------------------|
|
||||
| **执行方式** | 主进程(SlashCommand) | 外部 CLI(后台任务) |
|
||||
| **选择方式** | 自动基于意图识别 | 手动链确认 |
|
||||
| **状态管理** | TodoWrite 跟踪 | 持久化 state.json |
|
||||
| **适用场景** | 通用任务、快速开发 | 复杂链条、可恢复 |
|
||||
|
||||
---
|
||||
|
||||
### 其他 CLI 命令
|
||||
|
||||
```bash
|
||||
ccw install # 安装工作流文件
|
||||
ccw view # 打开 Dashboard
|
||||
|
||||
@@ -19,5 +19,5 @@
|
||||
"noEmit": false
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["src/templates/**/*", "node_modules", "dist"]
|
||||
"exclude": ["src/templates/**/*", "src/**/*.test.ts", "node_modules", "dist"]
|
||||
}
|
||||
|
||||
21
codex-lens/LICENSE
Normal file
21
codex-lens/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 CodexLens Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
59
codex-lens/README.md
Normal file
59
codex-lens/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# CodexLens
|
||||
|
||||
CodexLens is a multi-modal code analysis platform designed to provide comprehensive code understanding and analysis capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-language Support**: Analyze code in Python, JavaScript, TypeScript and more using Tree-sitter parsers
|
||||
- **Semantic Search**: Find relevant code snippets using semantic understanding with fastembed and HNSWLIB
|
||||
- **Code Parsing**: Advanced code structure parsing with tree-sitter
|
||||
- **Flexible Architecture**: Modular design for easy extension and customization
|
||||
|
||||
## Installation
|
||||
|
||||
### Basic Installation
|
||||
|
||||
```bash
|
||||
pip install codex-lens
|
||||
```
|
||||
|
||||
### With Semantic Search
|
||||
|
||||
```bash
|
||||
pip install codex-lens[semantic]
|
||||
```
|
||||
|
||||
### With GPU Acceleration (NVIDIA CUDA)
|
||||
|
||||
```bash
|
||||
pip install codex-lens[semantic-gpu]
|
||||
```
|
||||
|
||||
### With DirectML (Windows - NVIDIA/AMD/Intel)
|
||||
|
||||
```bash
|
||||
pip install codex-lens[semantic-directml]
|
||||
```
|
||||
|
||||
### With All Optional Features
|
||||
|
||||
```bash
|
||||
pip install codex-lens[full]
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python >= 3.10
|
||||
- See `pyproject.toml` for detailed dependency list
|
||||
|
||||
## Development
|
||||
|
||||
This project uses setuptools for building and packaging.
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
|
||||
## Authors
|
||||
|
||||
CodexLens Contributors
|
||||
28
codex-lens/build/lib/codexlens/__init__.py
Normal file
28
codex-lens/build/lib/codexlens/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""CodexLens package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import config, entities, errors
|
||||
from .config import Config
|
||||
from .entities import IndexedFile, SearchResult, SemanticChunk, Symbol
|
||||
from .errors import CodexLensError, ConfigError, ParseError, SearchError, StorageError
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"config",
|
||||
"entities",
|
||||
"errors",
|
||||
"Config",
|
||||
"IndexedFile",
|
||||
"SearchResult",
|
||||
"SemanticChunk",
|
||||
"Symbol",
|
||||
"CodexLensError",
|
||||
"ConfigError",
|
||||
"ParseError",
|
||||
"StorageError",
|
||||
"SearchError",
|
||||
]
|
||||
|
||||
14
codex-lens/build/lib/codexlens/__main__.py
Normal file
14
codex-lens/build/lib/codexlens/__main__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Module entrypoint for `python -m codexlens`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codexlens.cli import app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
88
codex-lens/build/lib/codexlens/api/__init__.py
Normal file
88
codex-lens/build/lib/codexlens/api/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Codexlens Public API Layer.
|
||||
|
||||
This module exports all public API functions and dataclasses for the
|
||||
codexlens LSP-like functionality.
|
||||
|
||||
Dataclasses (from models.py):
|
||||
- CallInfo: Call relationship information
|
||||
- MethodContext: Method context with call relationships
|
||||
- FileContextResult: File context result with method summaries
|
||||
- DefinitionResult: Definition lookup result
|
||||
- ReferenceResult: Reference lookup result
|
||||
- GroupedReferences: References grouped by definition
|
||||
- SymbolInfo: Symbol information for workspace search
|
||||
- HoverInfo: Hover information for a symbol
|
||||
- SemanticResult: Semantic search result
|
||||
|
||||
Utility functions (from utils.py):
|
||||
- resolve_project: Resolve and validate project root path
|
||||
- normalize_relationship_type: Normalize relationship type to canonical form
|
||||
- rank_by_proximity: Rank results by file path proximity
|
||||
|
||||
Example:
|
||||
>>> from codexlens.api import (
|
||||
... DefinitionResult,
|
||||
... resolve_project,
|
||||
... normalize_relationship_type
|
||||
... )
|
||||
>>> project = resolve_project("/path/to/project")
|
||||
>>> rel_type = normalize_relationship_type("calls")
|
||||
>>> print(rel_type)
|
||||
'call'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Dataclasses
|
||||
from .models import (
|
||||
CallInfo,
|
||||
MethodContext,
|
||||
FileContextResult,
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
SymbolInfo,
|
||||
HoverInfo,
|
||||
SemanticResult,
|
||||
)
|
||||
|
||||
# Utility functions
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
rank_by_proximity,
|
||||
rank_by_score,
|
||||
)
|
||||
|
||||
# API functions
|
||||
from .definition import find_definition
|
||||
from .symbols import workspace_symbols
|
||||
from .hover import get_hover
|
||||
from .file_context import file_context
|
||||
from .references import find_references
|
||||
from .semantic import semantic_search
|
||||
|
||||
__all__ = [
|
||||
# Dataclasses
|
||||
"CallInfo",
|
||||
"MethodContext",
|
||||
"FileContextResult",
|
||||
"DefinitionResult",
|
||||
"ReferenceResult",
|
||||
"GroupedReferences",
|
||||
"SymbolInfo",
|
||||
"HoverInfo",
|
||||
"SemanticResult",
|
||||
# Utility functions
|
||||
"resolve_project",
|
||||
"normalize_relationship_type",
|
||||
"rank_by_proximity",
|
||||
"rank_by_score",
|
||||
# API functions
|
||||
"find_definition",
|
||||
"workspace_symbols",
|
||||
"get_hover",
|
||||
"file_context",
|
||||
"find_references",
|
||||
"semantic_search",
|
||||
]
|
||||
126
codex-lens/build/lib/codexlens/api/definition.py
Normal file
126
codex-lens/build/lib/codexlens/api/definition.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""find_definition API implementation.
|
||||
|
||||
This module provides the find_definition() function for looking up
|
||||
symbol definitions with a 3-stage fallback strategy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import DefinitionResult
|
||||
from .utils import resolve_project, rank_by_proximity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_definition(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
file_context: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[DefinitionResult]:
|
||||
"""Find definition locations for a symbol.
|
||||
|
||||
Uses a 3-stage fallback strategy:
|
||||
1. Exact match with kind filter
|
||||
2. Exact match without kind filter
|
||||
3. Prefix match
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to find
|
||||
symbol_kind: Optional symbol kind filter (class, function, etc.)
|
||||
file_context: Optional file path for proximity ranking
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of DefinitionResult sorted by proximity if file_context provided
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Stage 1: Exact match with kind filter
|
||||
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 2: Exact match without kind (if kind was specified)
|
||||
if symbol_kind:
|
||||
results = _search_with_kind(global_index, symbol_name, None, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 3: Prefix match
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
if results:
|
||||
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
logger.debug(f"No definitions found for {symbol_name}")
|
||||
return []
|
||||
|
||||
|
||||
def _search_with_kind(
|
||||
global_index: GlobalSymbolIndex,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str],
|
||||
limit: int
|
||||
) -> List[Symbol]:
|
||||
"""Search for symbols with optional kind filter."""
|
||||
return global_index.search(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind,
|
||||
limit=limit,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
|
||||
def _rank_and_convert(
|
||||
symbols: List[Symbol],
|
||||
file_context: Optional[str]
|
||||
) -> List[DefinitionResult]:
|
||||
"""Convert symbols to DefinitionResult and rank by proximity."""
|
||||
results = [
|
||||
DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Could extract from file if needed
|
||||
container=None, # Could extract from parent symbol
|
||||
score=1.0
|
||||
)
|
||||
for sym in symbols
|
||||
]
|
||||
return rank_by_proximity(results, file_context)
|
||||
271
codex-lens/build/lib/codexlens/api/file_context.py
Normal file
271
codex-lens/build/lib/codexlens/api/file_context.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""file_context API implementation.
|
||||
|
||||
This module provides the file_context() function for retrieving
|
||||
method call graphs from a source file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.dir_index import DirIndexStore
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import (
|
||||
FileContextResult,
|
||||
MethodContext,
|
||||
CallInfo,
|
||||
)
|
||||
from .utils import resolve_project, normalize_relationship_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def file_context(
|
||||
project_root: str,
|
||||
file_path: str,
|
||||
include_calls: bool = True,
|
||||
include_callers: bool = True,
|
||||
max_depth: int = 1,
|
||||
format: str = "brief"
|
||||
) -> FileContextResult:
|
||||
"""Get method call context for a code file.
|
||||
|
||||
Retrieves all methods/functions in the file along with their
|
||||
outgoing calls and incoming callers.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
file_path: Path to the code file to analyze
|
||||
include_calls: Whether to include outgoing calls
|
||||
include_callers: Whether to include incoming callers
|
||||
max_depth: Call chain depth (V1 only supports 1)
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
FileContextResult with method contexts and summary
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
FileNotFoundError: If file does not exist
|
||||
ValueError: If max_depth > 1 (V1 limitation)
|
||||
"""
|
||||
# V1 limitation: only depth=1 supported
|
||||
if max_depth > 1:
|
||||
raise ValueError(
|
||||
f"max_depth > 1 not supported in V1. "
|
||||
f"Requested: {max_depth}, supported: 1"
|
||||
)
|
||||
|
||||
project_path = resolve_project(project_root)
|
||||
file_path_resolved = Path(file_path).resolve()
|
||||
|
||||
# Validate file exists
|
||||
if not file_path_resolved.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path_resolved}")
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Get all symbols in the file
|
||||
symbols = global_index.get_file_symbols(str(file_path_resolved))
|
||||
|
||||
# Filter to functions, methods, and classes
|
||||
method_symbols = [
|
||||
s for s in symbols
|
||||
if s.kind in ("function", "method", "class")
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
|
||||
|
||||
# Try to find dir_index for relationship queries
|
||||
dir_index = _find_dir_index(project_info, file_path_resolved)
|
||||
|
||||
# Build method contexts
|
||||
methods: List[MethodContext] = []
|
||||
outgoing_resolved = True
|
||||
incoming_resolved = True
|
||||
targets_resolved = True
|
||||
|
||||
for symbol in method_symbols:
|
||||
calls: List[CallInfo] = []
|
||||
callers: List[CallInfo] = []
|
||||
|
||||
if include_calls and dir_index:
|
||||
try:
|
||||
outgoing = dir_index.get_outgoing_calls(
|
||||
str(file_path_resolved),
|
||||
symbol.name
|
||||
)
|
||||
for target_name, rel_type, line, target_file in outgoing:
|
||||
calls.append(CallInfo(
|
||||
symbol_name=target_name,
|
||||
file_path=target_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
if target_file is None:
|
||||
targets_resolved = False
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get outgoing calls: {e}")
|
||||
outgoing_resolved = False
|
||||
|
||||
if include_callers and dir_index:
|
||||
try:
|
||||
incoming = dir_index.get_incoming_calls(symbol.name)
|
||||
for source_name, rel_type, line, source_file in incoming:
|
||||
callers.append(CallInfo(
|
||||
symbol_name=source_name,
|
||||
file_path=source_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get incoming calls: {e}")
|
||||
incoming_resolved = False
|
||||
|
||||
methods.append(MethodContext(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
signature=None, # Could extract from source
|
||||
calls=calls,
|
||||
callers=callers
|
||||
))
|
||||
|
||||
# Detect language from file extension
|
||||
language = _detect_language(file_path_resolved)
|
||||
|
||||
# Generate summary
|
||||
summary = _generate_summary(file_path_resolved, methods, format)
|
||||
|
||||
return FileContextResult(
|
||||
file_path=str(file_path_resolved),
|
||||
language=language,
|
||||
methods=methods,
|
||||
summary=summary,
|
||||
discovery_status={
|
||||
"outgoing_resolved": outgoing_resolved,
|
||||
"incoming_resolved": incoming_resolved,
|
||||
"targets_resolved": targets_resolved
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
|
||||
"""Find the dir_index that contains the file.
|
||||
|
||||
Args:
|
||||
project_info: Project information from registry
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
DirIndexStore if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Look for _index.db in file's directory or parent directories
|
||||
current = file_path.parent
|
||||
while current != current.parent:
|
||||
index_db = current / "_index.db"
|
||||
if index_db.exists():
|
||||
return DirIndexStore(str(index_db))
|
||||
|
||||
# Also check in project's index_root
|
||||
relative = current.relative_to(project_info.source_root)
|
||||
index_in_cache = project_info.index_root / relative / "_index.db"
|
||||
if index_in_cache.exists():
|
||||
return DirIndexStore(str(index_in_cache))
|
||||
|
||||
current = current.parent
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to find dir_index: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _detect_language(file_path: Path) -> str:
|
||||
"""Detect programming language from file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Language name
|
||||
"""
|
||||
ext_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".jsx": "javascript",
|
||||
".tsx": "typescript",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".java": "java",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
}
|
||||
return ext_map.get(file_path.suffix.lower(), "unknown")
|
||||
|
||||
|
||||
def _generate_summary(
|
||||
file_path: Path,
|
||||
methods: List[MethodContext],
|
||||
format: str
|
||||
) -> str:
|
||||
"""Generate human-readable summary of file context.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
methods: List of method contexts
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
Markdown-formatted summary
|
||||
"""
|
||||
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
|
||||
|
||||
for method in methods:
|
||||
start, end = method.line_range
|
||||
lines.append(f"### {method.name} (line {start}-{end})")
|
||||
|
||||
if method.calls:
|
||||
calls_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.calls
|
||||
)
|
||||
lines.append(f"- Calls: {calls_str}")
|
||||
|
||||
if method.callers:
|
||||
callers_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.callers
|
||||
)
|
||||
lines.append(f"- Called by: {callers_str}")
|
||||
|
||||
if not method.calls and not method.callers:
|
||||
lines.append("- (no call relationships)")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
148
codex-lens/build/lib/codexlens/api/hover.py
Normal file
148
codex-lens/build/lib/codexlens/api/hover.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""get_hover API implementation.
|
||||
|
||||
This module provides the get_hover() function for retrieving
|
||||
detailed hover information for symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import HoverInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hover(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
file_path: Optional[str] = None
|
||||
) -> Optional[HoverInfo]:
|
||||
"""Get detailed hover information for a symbol.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to look up
|
||||
file_path: Optional file path to disambiguate when symbol
|
||||
appears in multiple files
|
||||
|
||||
Returns:
|
||||
HoverInfo if symbol found, None otherwise
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search for the symbol
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=50,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
if not results:
|
||||
logger.debug(f"No hover info found for {symbol_name}")
|
||||
return None
|
||||
|
||||
# If file_path provided, filter to that file
|
||||
if file_path:
|
||||
file_path_resolved = str(Path(file_path).resolve())
|
||||
matching = [s for s in results if s.file == file_path_resolved]
|
||||
if matching:
|
||||
results = matching
|
||||
|
||||
# Take the first result
|
||||
symbol = results[0]
|
||||
|
||||
# Build hover info
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=_extract_signature(symbol),
|
||||
documentation=_extract_documentation(symbol),
|
||||
file_path=symbol.file or "",
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
type_info=_extract_type_info(symbol)
|
||||
)
|
||||
|
||||
|
||||
def _extract_signature(symbol: Symbol) -> str:
|
||||
"""Extract signature from symbol.
|
||||
|
||||
For now, generates a basic signature based on kind and name.
|
||||
In a full implementation, this would parse the actual source code.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract signature from
|
||||
|
||||
Returns:
|
||||
Signature string
|
||||
"""
|
||||
if symbol.kind == "function":
|
||||
return f"def {symbol.name}(...)"
|
||||
elif symbol.kind == "method":
|
||||
return f"def {symbol.name}(self, ...)"
|
||||
elif symbol.kind == "class":
|
||||
return f"class {symbol.name}"
|
||||
elif symbol.kind == "variable":
|
||||
return symbol.name
|
||||
elif symbol.kind == "constant":
|
||||
return f"{symbol.name} = ..."
|
||||
else:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
|
||||
def _extract_documentation(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract documentation from symbol.
|
||||
|
||||
In a full implementation, this would parse docstrings from source.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract documentation from
|
||||
|
||||
Returns:
|
||||
Documentation string if available, None otherwise
|
||||
"""
|
||||
# Would need to read source file and parse docstring
|
||||
# For V1, return None
|
||||
return None
|
||||
|
||||
|
||||
def _extract_type_info(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract type information from symbol.
|
||||
|
||||
In a full implementation, this would parse type annotations.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract type info from
|
||||
|
||||
Returns:
|
||||
Type info string if available, None otherwise
|
||||
"""
|
||||
# Would need to parse type annotations from source
|
||||
# For V1, return None
|
||||
return None
|
||||
281
codex-lens/build/lib/codexlens/api/models.py
Normal file
281
codex-lens/build/lib/codexlens/api/models.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""API dataclass definitions for codexlens LSP API.
|
||||
|
||||
This module defines all result dataclasses used by the public API layer,
|
||||
following the patterns established in mcp/schema.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.2: file_context dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class CallInfo:
|
||||
"""Call relationship information.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the called/calling symbol
|
||||
file_path: Target file path (may be None if unresolved)
|
||||
line: Line number of the call
|
||||
relationship: Type of relationship (call | import | inheritance)
|
||||
"""
|
||||
symbol_name: str
|
||||
file_path: Optional[str]
|
||||
line: int
|
||||
relationship: str # call | import | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodContext:
|
||||
"""Method context with call relationships.
|
||||
|
||||
Attributes:
|
||||
name: Method/function name
|
||||
kind: Symbol kind (function | method | class)
|
||||
line_range: Start and end line numbers
|
||||
signature: Function signature (if available)
|
||||
calls: List of outgoing calls
|
||||
callers: List of incoming calls
|
||||
"""
|
||||
name: str
|
||||
kind: str # function | method | class
|
||||
line_range: Tuple[int, int]
|
||||
signature: Optional[str]
|
||||
calls: List[CallInfo] = field(default_factory=list)
|
||||
callers: List[CallInfo] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"line_range": list(self.line_range),
|
||||
"calls": [c.to_dict() for c in self.calls],
|
||||
"callers": [c.to_dict() for c in self.callers],
|
||||
}
|
||||
if self.signature is not None:
|
||||
result["signature"] = self.signature
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileContextResult:
|
||||
"""File context result with method summaries.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the analyzed file
|
||||
language: Programming language
|
||||
methods: List of method contexts
|
||||
summary: Human-readable summary
|
||||
discovery_status: Status flags for call resolution
|
||||
"""
|
||||
file_path: str
|
||||
language: str
|
||||
methods: List[MethodContext]
|
||||
summary: str
|
||||
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
|
||||
"outgoing_resolved": False,
|
||||
"incoming_resolved": True,
|
||||
"targets_resolved": False
|
||||
})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"language": self.language,
|
||||
"methods": [m.to_dict() for m in self.methods],
|
||||
"summary": self.summary,
|
||||
"discovery_status": self.discovery_status,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.3: find_definition dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class DefinitionResult:
|
||||
"""Definition lookup result.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind (class, function, method, etc.)
|
||||
file_path: File where symbol is defined
|
||||
line: Start line number
|
||||
end_line: End line number
|
||||
signature: Symbol signature (if available)
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
end_line: int
|
||||
signature: Optional[str] = None
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.4: find_references dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class ReferenceResult:
|
||||
"""Reference lookup result.
|
||||
|
||||
Attributes:
|
||||
file_path: File containing the reference
|
||||
line: Line number
|
||||
column: Column number
|
||||
context_line: The line of code containing the reference
|
||||
relationship: Type of reference (call | import | type_annotation | inheritance)
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context_line: str
|
||||
relationship: str # call | import | type_annotation | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupedReferences:
|
||||
"""References grouped by definition.
|
||||
|
||||
Used when a symbol has multiple definitions (e.g., overloads).
|
||||
|
||||
Attributes:
|
||||
definition: The definition this group refers to
|
||||
references: List of references to this definition
|
||||
"""
|
||||
definition: DefinitionResult
|
||||
references: List[ReferenceResult] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"definition": self.definition.to_dict(),
|
||||
"references": [r.to_dict() for r in self.references],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.5: workspace_symbols dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Symbol information for workspace search.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.6: get_hover dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
signature: Symbol signature
|
||||
documentation: Documentation string (if available)
|
||||
file_path: File where symbol is defined
|
||||
line_range: Start and end line numbers
|
||||
type_info: Type information (if available)
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: Tuple[int, int]
|
||||
type_info: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"signature": self.signature,
|
||||
"file_path": self.file_path,
|
||||
"line_range": list(self.line_range),
|
||||
}
|
||||
if self.documentation is not None:
|
||||
result["documentation"] = self.documentation
|
||||
if self.type_info is not None:
|
||||
result["type_info"] = self.type_info
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.7: semantic_search dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SemanticResult:
|
||||
"""Semantic search result.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the matched symbol
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
vector_score: Vector similarity score (None if not available)
|
||||
structural_score: Structural match score (None if not available)
|
||||
fusion_score: Combined fusion score
|
||||
snippet: Code snippet
|
||||
match_reason: Explanation of why this matched (optional)
|
||||
"""
|
||||
symbol_name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
vector_score: Optional[float]
|
||||
structural_score: Optional[float]
|
||||
fusion_score: float
|
||||
snippet: str
|
||||
match_reason: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
345
codex-lens/build/lib/codexlens/api/references.py
Normal file
345
codex-lens/build/lib/codexlens/api/references.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Find references API for codexlens.
|
||||
|
||||
This module implements the find_references() function that wraps
|
||||
ChainSearchEngine.search_references() with grouped result structure
|
||||
for multi-definition symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from .models import (
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
)
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _read_line_from_file(file_path: str, line: int) -> str:
|
||||
"""Read a specific line from a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
line: Line number (1-based)
|
||||
|
||||
Returns:
|
||||
The line content, stripped of trailing whitespace.
|
||||
Returns empty string if file cannot be read or line doesn't exist.
|
||||
"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return ""
|
||||
|
||||
with path.open("r", encoding="utf-8", errors="replace") as f:
|
||||
for i, content in enumerate(f, 1):
|
||||
if i == line:
|
||||
return content.rstrip()
|
||||
return ""
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _transform_to_reference_result(
|
||||
raw_ref: "RawReferenceResult",
|
||||
) -> ReferenceResult:
|
||||
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
|
||||
|
||||
Args:
|
||||
raw_ref: Raw reference result from ChainSearchEngine
|
||||
|
||||
Returns:
|
||||
API ReferenceResult with context_line and normalized relationship
|
||||
"""
|
||||
# Read the actual line from the file
|
||||
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
|
||||
|
||||
# Normalize relationship type
|
||||
relationship = normalize_relationship_type(raw_ref.relationship_type)
|
||||
|
||||
return ReferenceResult(
|
||||
file_path=raw_ref.file_path,
|
||||
line=raw_ref.line,
|
||||
column=raw_ref.column,
|
||||
context_line=context_line,
|
||||
relationship=relationship,
|
||||
)
|
||||
|
||||
|
||||
def find_references(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
include_definition: bool = True,
|
||||
group_by_definition: bool = True,
|
||||
limit: int = 100,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Find all reference locations for a symbol.
|
||||
|
||||
Multi-definition case returns grouped results to resolve ambiguity.
|
||||
|
||||
This function wraps ChainSearchEngine.search_references() and groups
|
||||
the results by definition location. Each GroupedReferences contains
|
||||
a definition and all references that point to it.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory path
|
||||
symbol_name: Name of the symbol to find references for
|
||||
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
|
||||
include_definition: Whether to include the definition location
|
||||
in the result (default True)
|
||||
group_by_definition: Whether to group references by definition.
|
||||
If False, returns a single group with all references.
|
||||
(default True)
|
||||
limit: Maximum number of references to return (default 100)
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences. Each group contains:
|
||||
- definition: The DefinitionResult for this symbol definition
|
||||
- references: List of ReferenceResult pointing to this definition
|
||||
|
||||
Raises:
|
||||
ValueError: If project_root does not exist or is not a directory
|
||||
|
||||
Examples:
|
||||
>>> refs = find_references("/path/to/project", "authenticate")
|
||||
>>> for group in refs:
|
||||
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
|
||||
... for ref in group.references:
|
||||
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
|
||||
|
||||
Note:
|
||||
Reference relationship types are normalized:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
"""
|
||||
# Validate and resolve project root
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
# Initialize infrastructure
|
||||
config = Config()
|
||||
registry = RegistryStore()
|
||||
mapper = PathMapper(config.index_dir)
|
||||
|
||||
# Create chain search engine
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
|
||||
try:
|
||||
# Step 1: Find definitions for the symbol
|
||||
definitions: List[DefinitionResult] = []
|
||||
|
||||
if include_definition or group_by_definition:
|
||||
# Search for symbol definitions
|
||||
symbols = engine.search_symbols(
|
||||
name=symbol_name,
|
||||
source_path=project_path,
|
||||
kind=symbol_kind,
|
||||
)
|
||||
|
||||
# Convert Symbol to DefinitionResult
|
||||
for sym in symbols:
|
||||
# Only include exact name matches for definitions
|
||||
if sym.name != symbol_name:
|
||||
continue
|
||||
|
||||
# Optionally filter by kind
|
||||
if symbol_kind and sym.kind != symbol_kind:
|
||||
continue
|
||||
|
||||
definitions.append(DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Not available from Symbol
|
||||
container=None, # Not available from Symbol
|
||||
score=1.0,
|
||||
))
|
||||
|
||||
# Step 2: Get all references using ChainSearchEngine
|
||||
raw_references = engine.search_references(
|
||||
symbol_name=symbol_name,
|
||||
source_path=project_path,
|
||||
depth=-1,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Step 3: Transform raw references to API ReferenceResult
|
||||
api_references: List[ReferenceResult] = []
|
||||
for raw_ref in raw_references:
|
||||
api_ref = _transform_to_reference_result(raw_ref)
|
||||
api_references.append(api_ref)
|
||||
|
||||
# Step 4: Group references by definition
|
||||
if group_by_definition and definitions:
|
||||
return _group_references_by_definition(
|
||||
definitions=definitions,
|
||||
references=api_references,
|
||||
include_definition=include_definition,
|
||||
)
|
||||
else:
|
||||
# Return single group with placeholder definition or first definition
|
||||
if definitions:
|
||||
definition = definitions[0]
|
||||
else:
|
||||
# Create placeholder definition when no definition found
|
||||
definition = DefinitionResult(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path="",
|
||||
line=0,
|
||||
end_line=0,
|
||||
signature=None,
|
||||
container=None,
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
return [GroupedReferences(
|
||||
definition=definition,
|
||||
references=api_references,
|
||||
)]
|
||||
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
|
||||
def _group_references_by_definition(
|
||||
definitions: List[DefinitionResult],
|
||||
references: List[ReferenceResult],
|
||||
include_definition: bool = True,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Group references by their likely definition.
|
||||
|
||||
Uses file proximity heuristic to assign references to definitions.
|
||||
References in the same file or directory as a definition are
|
||||
assigned to that definition.
|
||||
|
||||
Args:
|
||||
definitions: List of definition locations
|
||||
references: List of reference locations
|
||||
include_definition: Whether to include definition in results
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences with references assigned to definitions
|
||||
"""
|
||||
import os
|
||||
|
||||
if not definitions:
|
||||
return []
|
||||
|
||||
if len(definitions) == 1:
|
||||
# Single definition - all references belong to it
|
||||
return [GroupedReferences(
|
||||
definition=definitions[0],
|
||||
references=references,
|
||||
)]
|
||||
|
||||
# Multiple definitions - group by proximity
|
||||
groups: Dict[int, List[ReferenceResult]] = {
|
||||
i: [] for i in range(len(definitions))
|
||||
}
|
||||
|
||||
for ref in references:
|
||||
# Find the closest definition by file proximity
|
||||
best_def_idx = 0
|
||||
best_score = -1
|
||||
|
||||
for i, defn in enumerate(definitions):
|
||||
score = _proximity_score(ref.file_path, defn.file_path)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_def_idx = i
|
||||
|
||||
groups[best_def_idx].append(ref)
|
||||
|
||||
# Build result groups
|
||||
result: List[GroupedReferences] = []
|
||||
for i, defn in enumerate(definitions):
|
||||
# Skip definitions with no references if not including definition itself
|
||||
if not include_definition and not groups[i]:
|
||||
continue
|
||||
|
||||
result.append(GroupedReferences(
|
||||
definition=defn,
|
||||
references=groups[i],
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _proximity_score(ref_path: str, def_path: str) -> int:
|
||||
"""Calculate proximity score between two file paths.
|
||||
|
||||
Args:
|
||||
ref_path: Reference file path
|
||||
def_path: Definition file path
|
||||
|
||||
Returns:
|
||||
Proximity score (higher = closer):
|
||||
- Same file: 1000
|
||||
- Same directory: 100
|
||||
- Otherwise: common path prefix length
|
||||
"""
|
||||
import os
|
||||
|
||||
if not ref_path or not def_path:
|
||||
return 0
|
||||
|
||||
# Normalize paths
|
||||
ref_path = os.path.normpath(ref_path)
|
||||
def_path = os.path.normpath(def_path)
|
||||
|
||||
# Same file
|
||||
if ref_path == def_path:
|
||||
return 1000
|
||||
|
||||
ref_dir = os.path.dirname(ref_path)
|
||||
def_dir = os.path.dirname(def_path)
|
||||
|
||||
# Same directory
|
||||
if ref_dir == def_dir:
|
||||
return 100
|
||||
|
||||
# Common path prefix
|
||||
try:
|
||||
common = os.path.commonpath([ref_path, def_path])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
|
||||
# Type alias for the raw reference from ChainSearchEngine
|
||||
class RawReferenceResult:
|
||||
"""Type stub for ChainSearchEngine.ReferenceResult.
|
||||
|
||||
This is only used for type hints and is replaced at runtime
|
||||
by the actual import.
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
471
codex-lens/build/lib/codexlens/api/semantic.py
Normal file
471
codex-lens/build/lib/codexlens/api/semantic.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Semantic search API with RRF fusion.
|
||||
|
||||
This module provides the semantic_search() function for combining
|
||||
vector, structural, and keyword search with configurable fusion strategies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from .models import SemanticResult
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def semantic_search(
|
||||
project_root: str,
|
||||
query: str,
|
||||
mode: str = "fusion",
|
||||
vector_weight: float = 0.5,
|
||||
structural_weight: float = 0.3,
|
||||
keyword_weight: float = 0.2,
|
||||
fusion_strategy: str = "rrf",
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
include_match_reason: bool = False,
|
||||
) -> List[SemanticResult]:
|
||||
"""Semantic search - combining vector and structural search.
|
||||
|
||||
This function provides a high-level API for semantic code search,
|
||||
combining vector similarity, structural (symbol + relationships),
|
||||
and keyword-based search methods with configurable fusion.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory
|
||||
query: Natural language query
|
||||
mode: Search mode
|
||||
- vector: Vector search only
|
||||
- structural: Structural search only (symbol + relationships)
|
||||
- fusion: Fusion search (default)
|
||||
vector_weight: Vector search weight [0, 1] (default 0.5)
|
||||
structural_weight: Structural search weight [0, 1] (default 0.3)
|
||||
keyword_weight: Keyword search weight [0, 1] (default 0.2)
|
||||
fusion_strategy: Fusion strategy (maps to chain_search.py)
|
||||
- rrf: Reciprocal Rank Fusion (recommended, default)
|
||||
- staged: Staged cascade -> staged_cascade_search
|
||||
- binary: Binary rerank cascade -> binary_cascade_search
|
||||
- hybrid: Hybrid cascade -> hybrid_cascade_search
|
||||
kind_filter: Symbol type filter (e.g., ["function", "class"])
|
||||
limit: Max return count (default 20)
|
||||
include_match_reason: Generate match reason (heuristic, not LLM)
|
||||
|
||||
Returns:
|
||||
Results sorted by fusion_score
|
||||
|
||||
Degradation:
|
||||
- No vector index: vector_score=None, uses FTS + structural search
|
||||
- No relationship data: structural_score=None, vector search only
|
||||
|
||||
Examples:
|
||||
>>> results = semantic_search(
|
||||
... "/path/to/project",
|
||||
... "authentication handler",
|
||||
... mode="fusion",
|
||||
... fusion_strategy="rrf"
|
||||
... )
|
||||
>>> for r in results:
|
||||
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
|
||||
"""
|
||||
# Validate and resolve project path
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Normalize weights to sum to 1.0
|
||||
total_weight = vector_weight + structural_weight + keyword_weight
|
||||
if total_weight > 0:
|
||||
vector_weight = vector_weight / total_weight
|
||||
structural_weight = structural_weight / total_weight
|
||||
keyword_weight = keyword_weight / total_weight
|
||||
else:
|
||||
# Default to equal weights if all zero
|
||||
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
|
||||
|
||||
# Initialize search infrastructure
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
except ImportError as exc:
|
||||
logger.error("Failed to import search dependencies: %s", exc)
|
||||
return []
|
||||
|
||||
# Load config
|
||||
config = Config.load()
|
||||
|
||||
# Get or create registry and mapper
|
||||
try:
|
||||
registry = RegistryStore.default()
|
||||
mapper = PathMapper(registry)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize search infrastructure: %s", exc)
|
||||
return []
|
||||
|
||||
# Build search options based on mode
|
||||
search_options = _build_search_options(
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Execute search based on fusion_strategy
|
||||
try:
|
||||
with ChainSearchEngine(registry, mapper, config=config) as engine:
|
||||
chain_result = _execute_search(
|
||||
engine=engine,
|
||||
query=query,
|
||||
source_path=project_path,
|
||||
fusion_strategy=fusion_strategy,
|
||||
options=search_options,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Search execution failed: %s", exc)
|
||||
return []
|
||||
|
||||
# Transform results to SemanticResult
|
||||
semantic_results = _transform_results(
|
||||
results=chain_result.results,
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
kind_filter=kind_filter,
|
||||
include_match_reason=include_match_reason,
|
||||
query=query,
|
||||
)
|
||||
|
||||
return semantic_results[:limit]
|
||||
|
||||
|
||||
def _build_search_options(
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
limit: int,
|
||||
) -> "SearchOptions":
|
||||
"""Build SearchOptions based on mode and weights.
|
||||
|
||||
Args:
|
||||
mode: Search mode (vector, structural, fusion)
|
||||
vector_weight: Vector search weight
|
||||
structural_weight: Structural search weight
|
||||
keyword_weight: Keyword search weight
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
Configured SearchOptions
|
||||
"""
|
||||
from codexlens.search.chain_search import SearchOptions
|
||||
|
||||
# Default options
|
||||
options = SearchOptions(
|
||||
total_limit=limit * 2, # Fetch extra for filtering
|
||||
limit_per_dir=limit,
|
||||
include_symbols=True, # Always include symbols for structural
|
||||
)
|
||||
|
||||
if mode == "vector":
|
||||
# Pure vector mode
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = True
|
||||
options.pure_vector = True
|
||||
options.enable_fuzzy = False
|
||||
elif mode == "structural":
|
||||
# Structural only - use FTS + symbols
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = False
|
||||
options.enable_fuzzy = True
|
||||
options.include_symbols = True
|
||||
else:
|
||||
# Fusion mode (default)
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = vector_weight > 0
|
||||
options.enable_fuzzy = keyword_weight > 0
|
||||
options.include_symbols = structural_weight > 0
|
||||
|
||||
# Set custom weights for RRF
|
||||
if options.enable_vector and keyword_weight > 0:
|
||||
options.hybrid_weights = {
|
||||
"vector": vector_weight,
|
||||
"exact": keyword_weight * 0.7,
|
||||
"fuzzy": keyword_weight * 0.3,
|
||||
}
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _execute_search(
|
||||
engine: "ChainSearchEngine",
|
||||
query: str,
|
||||
source_path: Path,
|
||||
fusion_strategy: str,
|
||||
options: "SearchOptions",
|
||||
limit: int,
|
||||
) -> "ChainSearchResult":
|
||||
"""Execute search using appropriate strategy.
|
||||
|
||||
Maps fusion_strategy to ChainSearchEngine methods:
|
||||
- rrf: Standard hybrid search with RRF fusion
|
||||
- staged: staged_cascade_search
|
||||
- binary: binary_cascade_search
|
||||
- hybrid: hybrid_cascade_search
|
||||
|
||||
Args:
|
||||
engine: ChainSearchEngine instance
|
||||
query: Search query
|
||||
source_path: Project root path
|
||||
fusion_strategy: Strategy name
|
||||
options: Search options
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
ChainSearchResult from the search
|
||||
"""
|
||||
from codexlens.search.chain_search import ChainSearchResult
|
||||
|
||||
if fusion_strategy == "staged":
|
||||
# Use staged cascade search (4-stage pipeline)
|
||||
return engine.staged_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "binary":
|
||||
# Use binary cascade search (binary coarse + dense fine)
|
||||
return engine.binary_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "hybrid":
|
||||
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
|
||||
return engine.hybrid_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
# Default: rrf - Standard search with RRF fusion
|
||||
return engine.search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
def _transform_results(
|
||||
results: List,
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
kind_filter: Optional[List[str]],
|
||||
include_match_reason: bool,
|
||||
query: str,
|
||||
) -> List[SemanticResult]:
|
||||
"""Transform ChainSearchEngine results to SemanticResult.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
mode: Search mode
|
||||
vector_weight: Vector weight used
|
||||
structural_weight: Structural weight used
|
||||
keyword_weight: Keyword weight used
|
||||
kind_filter: Optional symbol kind filter
|
||||
include_match_reason: Whether to generate match reasons
|
||||
query: Original query (for match reason generation)
|
||||
|
||||
Returns:
|
||||
List of SemanticResult objects
|
||||
"""
|
||||
semantic_results = []
|
||||
|
||||
for result in results:
|
||||
# Extract symbol info
|
||||
symbol_name = getattr(result, "symbol_name", None)
|
||||
symbol_kind = getattr(result, "symbol_kind", None)
|
||||
start_line = getattr(result, "start_line", None)
|
||||
|
||||
# Use symbol object if available
|
||||
if hasattr(result, "symbol") and result.symbol:
|
||||
symbol_name = symbol_name or result.symbol.name
|
||||
symbol_kind = symbol_kind or result.symbol.kind
|
||||
if hasattr(result.symbol, "range") and result.symbol.range:
|
||||
start_line = start_line or result.symbol.range[0]
|
||||
|
||||
# Filter by kind if specified
|
||||
if kind_filter and symbol_kind:
|
||||
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
|
||||
continue
|
||||
|
||||
# Determine scores based on mode and metadata
|
||||
metadata = getattr(result, "metadata", {}) or {}
|
||||
fusion_score = result.score
|
||||
|
||||
# Try to extract source scores from metadata
|
||||
source_scores = metadata.get("source_scores", {})
|
||||
vector_score: Optional[float] = None
|
||||
structural_score: Optional[float] = None
|
||||
|
||||
if mode == "vector":
|
||||
# In pure vector mode, the main score is the vector score
|
||||
vector_score = result.score
|
||||
structural_score = None
|
||||
elif mode == "structural":
|
||||
# In structural mode, no vector score
|
||||
vector_score = None
|
||||
structural_score = result.score
|
||||
else:
|
||||
# Fusion mode - try to extract individual scores
|
||||
if "vector" in source_scores:
|
||||
vector_score = source_scores["vector"]
|
||||
elif metadata.get("fusion_method") == "simple_weighted":
|
||||
# From weighted fusion
|
||||
vector_score = source_scores.get("vector")
|
||||
|
||||
# Structural score approximation (from exact/fuzzy FTS)
|
||||
fts_scores = []
|
||||
if "exact" in source_scores:
|
||||
fts_scores.append(source_scores["exact"])
|
||||
if "fuzzy" in source_scores:
|
||||
fts_scores.append(source_scores["fuzzy"])
|
||||
if "splade" in source_scores:
|
||||
fts_scores.append(source_scores["splade"])
|
||||
|
||||
if fts_scores:
|
||||
structural_score = max(fts_scores)
|
||||
|
||||
# Build snippet
|
||||
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
|
||||
if len(snippet) > 500:
|
||||
snippet = snippet[:500] + "..."
|
||||
|
||||
# Generate match reason if requested
|
||||
match_reason = None
|
||||
if include_match_reason:
|
||||
match_reason = _generate_match_reason(
|
||||
query=query,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
snippet=snippet,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
)
|
||||
|
||||
semantic_result = SemanticResult(
|
||||
symbol_name=symbol_name or Path(result.path).stem,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path=result.path,
|
||||
line=start_line or 1,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
fusion_score=fusion_score,
|
||||
snippet=snippet,
|
||||
match_reason=match_reason,
|
||||
)
|
||||
|
||||
semantic_results.append(semantic_result)
|
||||
|
||||
# Sort by fusion_score descending
|
||||
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
|
||||
|
||||
return semantic_results
|
||||
|
||||
|
||||
def _generate_match_reason(
|
||||
query: str,
|
||||
symbol_name: Optional[str],
|
||||
symbol_kind: Optional[str],
|
||||
snippet: str,
|
||||
vector_score: Optional[float],
|
||||
structural_score: Optional[float],
|
||||
) -> str:
|
||||
"""Generate human-readable match reason heuristically.
|
||||
|
||||
This is a simple heuristic-based approach, not LLM-powered.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
symbol_name: Symbol name if available
|
||||
symbol_kind: Symbol kind if available
|
||||
snippet: Code snippet
|
||||
vector_score: Vector similarity score
|
||||
structural_score: Structural match score
|
||||
|
||||
Returns:
|
||||
Human-readable explanation string
|
||||
"""
|
||||
reasons = []
|
||||
|
||||
# Check for direct name match
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
if symbol_name:
|
||||
name_lower = symbol_name.lower()
|
||||
# Direct substring match
|
||||
if query_lower in name_lower or name_lower in query_lower:
|
||||
reasons.append(f"Symbol name '{symbol_name}' matches query")
|
||||
# Word overlap
|
||||
name_words = set(_split_camel_case(symbol_name).lower().split())
|
||||
overlap = query_words & name_words
|
||||
if overlap and not reasons:
|
||||
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
|
||||
|
||||
# Check snippet for keyword matches
|
||||
snippet_lower = snippet.lower()
|
||||
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
|
||||
if matching_words and len(reasons) < 2:
|
||||
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
|
||||
|
||||
# Add score-based reasoning
|
||||
if vector_score is not None and vector_score > 0.7:
|
||||
reasons.append("High semantic similarity")
|
||||
elif vector_score is not None and vector_score > 0.5:
|
||||
reasons.append("Moderate semantic similarity")
|
||||
|
||||
if structural_score is not None and structural_score > 0.8:
|
||||
reasons.append("Strong structural match")
|
||||
|
||||
# Symbol kind context
|
||||
if symbol_kind and len(reasons) < 3:
|
||||
reasons.append(f"Matched {symbol_kind}")
|
||||
|
||||
if not reasons:
|
||||
reasons.append("Partial relevance based on content analysis")
|
||||
|
||||
return "; ".join(reasons[:3])
|
||||
|
||||
|
||||
def _split_camel_case(name: str) -> str:
|
||||
"""Split camelCase and PascalCase to words.
|
||||
|
||||
Args:
|
||||
name: Symbol name in camelCase or PascalCase
|
||||
|
||||
Returns:
|
||||
Space-separated words
|
||||
"""
|
||||
import re
|
||||
|
||||
# Insert space before uppercase letters
|
||||
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
# Insert space before uppercase followed by lowercase
|
||||
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
|
||||
# Replace underscores with spaces
|
||||
result = result.replace("_", " ")
|
||||
|
||||
return result
|
||||
146
codex-lens/build/lib/codexlens/api/symbols.py
Normal file
146
codex-lens/build/lib/codexlens/api/symbols.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""workspace_symbols API implementation.
|
||||
|
||||
This module provides the workspace_symbols() function for searching
|
||||
symbols across the entire workspace with prefix matching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import SymbolInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def workspace_symbols(
|
||||
project_root: str,
|
||||
query: str,
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
file_pattern: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[SymbolInfo]:
|
||||
"""Search for symbols across the entire workspace.
|
||||
|
||||
Uses prefix matching for efficient searching.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
query: Search query (prefix match)
|
||||
kind_filter: Optional list of symbol kinds to include
|
||||
(e.g., ["class", "function"])
|
||||
file_pattern: Optional glob pattern to filter by file path
|
||||
(e.g., "*.py", "src/**/*.ts")
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of SymbolInfo sorted by score
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project(project_path)
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search with prefix matching
|
||||
# If kind_filter has multiple kinds, we need to search for each
|
||||
all_results: List[Symbol] = []
|
||||
|
||||
if kind_filter and len(kind_filter) > 0:
|
||||
# Search for each kind separately
|
||||
for kind in kind_filter:
|
||||
results = global_index.search(
|
||||
name=query,
|
||||
kind=kind,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
all_results.extend(results)
|
||||
else:
|
||||
# Search without kind filter
|
||||
all_results = global_index.search(
|
||||
name=query,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
|
||||
|
||||
# Apply file pattern filter if specified
|
||||
if file_pattern:
|
||||
all_results = [
|
||||
sym for sym in all_results
|
||||
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
|
||||
]
|
||||
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
|
||||
|
||||
# Convert to SymbolInfo and sort by relevance
|
||||
symbols = [
|
||||
SymbolInfo(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
container=None, # Could extract from parent
|
||||
score=_calculate_score(sym.name, query)
|
||||
)
|
||||
for sym in all_results
|
||||
]
|
||||
|
||||
# Sort by score (exact matches first)
|
||||
symbols.sort(key=lambda s: s.score, reverse=True)
|
||||
|
||||
return symbols[:limit]
|
||||
|
||||
|
||||
def _calculate_score(symbol_name: str, query: str) -> float:
|
||||
"""Calculate relevance score for a symbol match.
|
||||
|
||||
Scoring:
|
||||
- Exact match: 1.0
|
||||
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
|
||||
- Case-insensitive match: 0.6
|
||||
|
||||
Args:
|
||||
symbol_name: The matched symbol name
|
||||
query: The search query
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
if symbol_name == query:
|
||||
return 1.0
|
||||
|
||||
if symbol_name.lower() == query.lower():
|
||||
return 0.9
|
||||
|
||||
if symbol_name.startswith(query):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.8 + 0.2 * ratio
|
||||
|
||||
if symbol_name.lower().startswith(query.lower()):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.6 + 0.2 * ratio
|
||||
|
||||
return 0.5
|
||||
153
codex-lens/build/lib/codexlens/api/utils.py
Normal file
153
codex-lens/build/lib/codexlens/api/utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Utility functions for the codexlens API.
|
||||
|
||||
This module provides helper functions for:
|
||||
- Project resolution
|
||||
- Relationship type normalization
|
||||
- Result ranking by proximity
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, TypeVar, Callable
|
||||
|
||||
from .models import DefinitionResult
|
||||
|
||||
|
||||
# Type variable for generic ranking
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def resolve_project(project_root: str) -> Path:
|
||||
"""Resolve and validate project root path.
|
||||
|
||||
Args:
|
||||
project_root: Path to project root (relative or absolute)
|
||||
|
||||
Returns:
|
||||
Resolved absolute Path
|
||||
|
||||
Raises:
|
||||
ValueError: If path does not exist or is not a directory
|
||||
"""
|
||||
path = Path(project_root).resolve()
|
||||
if not path.exists():
|
||||
raise ValueError(f"Project root does not exist: {path}")
|
||||
if not path.is_dir():
|
||||
raise ValueError(f"Project root is not a directory: {path}")
|
||||
return path
|
||||
|
||||
|
||||
# Relationship type normalization mapping
|
||||
_RELATIONSHIP_NORMALIZATION = {
|
||||
# Plural to singular
|
||||
"calls": "call",
|
||||
"imports": "import",
|
||||
"inherits": "inheritance",
|
||||
"uses": "use",
|
||||
# Already normalized (passthrough)
|
||||
"call": "call",
|
||||
"import": "import",
|
||||
"inheritance": "inheritance",
|
||||
"use": "use",
|
||||
"type_annotation": "type_annotation",
|
||||
}
|
||||
|
||||
|
||||
def normalize_relationship_type(relationship: str) -> str:
|
||||
"""Normalize relationship type to canonical form.
|
||||
|
||||
Converts plural forms and variations to standard singular forms:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
- 'uses' -> 'use'
|
||||
|
||||
Args:
|
||||
relationship: Raw relationship type string
|
||||
|
||||
Returns:
|
||||
Normalized relationship type
|
||||
|
||||
Examples:
|
||||
>>> normalize_relationship_type('calls')
|
||||
'call'
|
||||
>>> normalize_relationship_type('inherits')
|
||||
'inheritance'
|
||||
>>> normalize_relationship_type('call')
|
||||
'call'
|
||||
"""
|
||||
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
|
||||
|
||||
|
||||
def rank_by_proximity(
|
||||
results: List[DefinitionResult],
|
||||
file_context: Optional[str] = None
|
||||
) -> List[DefinitionResult]:
|
||||
"""Rank results by file path proximity to context.
|
||||
|
||||
V1 Implementation: Uses path-based proximity scoring.
|
||||
|
||||
Scoring algorithm:
|
||||
1. Same directory: highest score (100)
|
||||
2. Otherwise: length of common path prefix
|
||||
|
||||
Args:
|
||||
results: List of definition results to rank
|
||||
file_context: Reference file path for proximity calculation.
|
||||
If None, returns results unchanged.
|
||||
|
||||
Returns:
|
||||
Results sorted by proximity score (highest first)
|
||||
|
||||
Examples:
|
||||
>>> results = [
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/b/c.py", line=1, end_line=10),
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/x/y.py", line=1, end_line=10),
|
||||
... ]
|
||||
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
|
||||
>>> ranked[0].file_path
|
||||
'/a/b/c.py'
|
||||
"""
|
||||
if not file_context or not results:
|
||||
return results
|
||||
|
||||
def proximity_score(result: DefinitionResult) -> int:
|
||||
"""Calculate proximity score for a result."""
|
||||
result_dir = os.path.dirname(result.file_path)
|
||||
context_dir = os.path.dirname(file_context)
|
||||
|
||||
# Same directory gets highest score
|
||||
if result_dir == context_dir:
|
||||
return 100
|
||||
|
||||
# Otherwise, score by common path prefix length
|
||||
try:
|
||||
common = os.path.commonpath([result.file_path, file_context])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
return sorted(results, key=proximity_score, reverse=True)
|
||||
|
||||
|
||||
def rank_by_score(
|
||||
results: List[T],
|
||||
score_fn: Callable[[T], float],
|
||||
reverse: bool = True
|
||||
) -> List[T]:
|
||||
"""Generic ranking function by custom score.
|
||||
|
||||
Args:
|
||||
results: List of items to rank
|
||||
score_fn: Function to extract score from item
|
||||
reverse: If True, highest scores first (default)
|
||||
|
||||
Returns:
|
||||
Sorted list
|
||||
"""
|
||||
return sorted(results, key=score_fn, reverse=reverse)
|
||||
27
codex-lens/build/lib/codexlens/cli/__init__.py
Normal file
27
codex-lens/build/lib/codexlens/cli/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""CLI package for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Force UTF-8 encoding for Windows console
|
||||
# This ensures Chinese characters display correctly instead of GBK garbled text
|
||||
if sys.platform == "win32":
|
||||
# Set environment variable for Python I/O encoding
|
||||
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
|
||||
|
||||
# Reconfigure stdout/stderr to use UTF-8 if possible
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
# Fallback: some environments don't support reconfigure
|
||||
pass
|
||||
|
||||
from .commands import app
|
||||
|
||||
__all__ = ["app"]
|
||||
|
||||
4494
codex-lens/build/lib/codexlens/cli/commands.py
Normal file
4494
codex-lens/build/lib/codexlens/cli/commands.py
Normal file
File diff suppressed because it is too large
Load Diff
2001
codex-lens/build/lib/codexlens/cli/embedding_manager.py
Normal file
2001
codex-lens/build/lib/codexlens/cli/embedding_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
1026
codex-lens/build/lib/codexlens/cli/model_manager.py
Normal file
1026
codex-lens/build/lib/codexlens/cli/model_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
135
codex-lens/build/lib/codexlens/cli/output.py
Normal file
135
codex-lens/build/lib/codexlens/cli/output.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Rich and JSON output helpers for CodexLens CLI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Mapping, Sequence
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
|
||||
# Force UTF-8 encoding for Windows console to properly display Chinese text
|
||||
# Use force_terminal=True and legacy_windows=False to avoid GBK encoding issues
|
||||
console = Console(force_terminal=True, legacy_windows=False)
|
||||
|
||||
|
||||
def _to_jsonable(value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump()
|
||||
if is_dataclass(value):
|
||||
return asdict(value)
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
if isinstance(value, Mapping):
|
||||
return {k: _to_jsonable(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [_to_jsonable(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
|
||||
"""Print JSON output with optional additional fields.
|
||||
|
||||
Args:
|
||||
success: Whether the operation succeeded
|
||||
result: Result data (used when success=True)
|
||||
error: Error message (used when success=False)
|
||||
**kwargs: Additional fields to include in the payload (e.g., code, details)
|
||||
"""
|
||||
payload: dict[str, Any] = {"success": success}
|
||||
if success:
|
||||
payload["result"] = _to_jsonable(result)
|
||||
else:
|
||||
payload["error"] = error or "Unknown error"
|
||||
# Include additional error details if provided
|
||||
for key, value in kwargs.items():
|
||||
payload[key] = _to_jsonable(value)
|
||||
console.print_json(json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_search_results(
|
||||
results: Sequence[SearchResult], *, title: str = "Search Results", verbose: bool = False
|
||||
) -> None:
|
||||
"""Render search results with optional source tags in verbose mode.
|
||||
|
||||
Args:
|
||||
results: Search results to display
|
||||
title: Table title
|
||||
verbose: If True, show search source tags ([E], [F], [V]) and fusion scores
|
||||
"""
|
||||
table = Table(title=title, show_lines=False)
|
||||
|
||||
if verbose:
|
||||
# Verbose mode: show source tags
|
||||
table.add_column("Source", style="dim", width=6, justify="center")
|
||||
|
||||
table.add_column("Path", style="cyan", no_wrap=True)
|
||||
table.add_column("Score", style="magenta", justify="right")
|
||||
table.add_column("Excerpt", style="white")
|
||||
|
||||
for res in results:
|
||||
excerpt = res.excerpt or ""
|
||||
score_str = f"{res.score:.3f}"
|
||||
|
||||
if verbose:
|
||||
# Extract search source tag if available
|
||||
source = getattr(res, "search_source", None)
|
||||
source_tag = ""
|
||||
if source == "exact":
|
||||
source_tag = "[E]"
|
||||
elif source == "fuzzy":
|
||||
source_tag = "[F]"
|
||||
elif source == "vector":
|
||||
source_tag = "[V]"
|
||||
elif source == "fusion":
|
||||
source_tag = "[RRF]"
|
||||
table.add_row(source_tag, res.path, score_str, excerpt)
|
||||
else:
|
||||
table.add_row(res.path, score_str, excerpt)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def render_symbols(symbols: Sequence[Symbol], *, title: str = "Symbols") -> None:
|
||||
table = Table(title=title)
|
||||
table.add_column("Name", style="green")
|
||||
table.add_column("Kind", style="yellow")
|
||||
table.add_column("Range", style="white", justify="right")
|
||||
|
||||
for sym in symbols:
|
||||
start, end = sym.range
|
||||
table.add_row(sym.name, sym.kind, f"{start}-{end}")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def render_status(stats: Mapping[str, Any]) -> None:
|
||||
table = Table(title="Index Status")
|
||||
table.add_column("Metric", style="cyan")
|
||||
table.add_column("Value", style="white")
|
||||
|
||||
for key, value in stats.items():
|
||||
if isinstance(value, Mapping):
|
||||
value_text = ", ".join(f"{k}:{v}" for k, v in value.items())
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value_text = ", ".join(str(v) for v in value)
|
||||
else:
|
||||
value_text = str(value)
|
||||
table.add_row(str(key), value_text)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) -> None:
|
||||
header = Text.assemble(("File: ", "bold"), (path, "cyan"), (" Language: ", "bold"), (language, "green"))
|
||||
console.print(header)
|
||||
render_symbols(list(symbols), title="Discovered Symbols")
|
||||
|
||||
692
codex-lens/build/lib/codexlens/config.py
Normal file
692
codex-lens/build/lib/codexlens/config.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""Configuration system for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .errors import ConfigError
|
||||
|
||||
|
||||
# Workspace-local directory name
|
||||
WORKSPACE_DIR_NAME = ".codexlens"
|
||||
|
||||
# Settings file name
|
||||
SETTINGS_FILE_NAME = "settings.json"
|
||||
|
||||
# SPLADE index database name (centralized storage)
|
||||
SPLADE_DB_NAME = "_splade.db"
|
||||
|
||||
# Dense vector storage names (centralized storage)
|
||||
VECTORS_HNSW_NAME = "_vectors.hnsw"
|
||||
VECTORS_META_DB_NAME = "_vectors_meta.db"
|
||||
BINARY_VECTORS_MMAP_NAME = "_binary_vectors.mmap"
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_global_dir() -> Path:
|
||||
"""Get global CodexLens data directory."""
|
||||
env_override = os.getenv("CODEXLENS_DATA_DIR")
|
||||
if env_override:
|
||||
return Path(env_override).expanduser().resolve()
|
||||
return (Path.home() / ".codexlens").resolve()
|
||||
|
||||
|
||||
def find_workspace_root(start_path: Path) -> Optional[Path]:
|
||||
"""Find the workspace root by looking for .codexlens directory.
|
||||
|
||||
Searches from start_path upward to find an existing .codexlens directory.
|
||||
Returns None if not found.
|
||||
"""
|
||||
current = start_path.resolve()
|
||||
|
||||
# Search up to filesystem root
|
||||
while current != current.parent:
|
||||
workspace_dir = current / WORKSPACE_DIR_NAME
|
||||
if workspace_dir.is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
|
||||
# Check root as well
|
||||
workspace_dir = current / WORKSPACE_DIR_NAME
|
||||
if workspace_dir.is_dir():
|
||||
return current
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Runtime configuration for CodexLens.
|
||||
|
||||
- data_dir: Base directory for all persistent CodexLens data.
|
||||
- venv_path: Optional virtualenv used for language tooling.
|
||||
- supported_languages: Language IDs and their associated file extensions.
|
||||
- parsing_rules: Per-language parsing and chunking hints.
|
||||
"""
|
||||
|
||||
data_dir: Path = field(default_factory=_default_global_dir)
|
||||
venv_path: Path = field(default_factory=lambda: _default_global_dir() / "venv")
|
||||
supported_languages: Dict[str, Dict[str, Any]] = field(
|
||||
default_factory=lambda: {
|
||||
# Source code languages (category: "code")
|
||||
"python": {"extensions": [".py"], "tree_sitter_language": "python", "category": "code"},
|
||||
"javascript": {"extensions": [".js", ".jsx"], "tree_sitter_language": "javascript", "category": "code"},
|
||||
"typescript": {"extensions": [".ts", ".tsx"], "tree_sitter_language": "typescript", "category": "code"},
|
||||
"java": {"extensions": [".java"], "tree_sitter_language": "java", "category": "code"},
|
||||
"go": {"extensions": [".go"], "tree_sitter_language": "go", "category": "code"},
|
||||
"zig": {"extensions": [".zig"], "tree_sitter_language": "zig", "category": "code"},
|
||||
"objective-c": {"extensions": [".m", ".mm"], "tree_sitter_language": "objc", "category": "code"},
|
||||
"c": {"extensions": [".c", ".h"], "tree_sitter_language": "c", "category": "code"},
|
||||
"cpp": {"extensions": [".cc", ".cpp", ".hpp", ".cxx"], "tree_sitter_language": "cpp", "category": "code"},
|
||||
"rust": {"extensions": [".rs"], "tree_sitter_language": "rust", "category": "code"},
|
||||
}
|
||||
)
|
||||
parsing_rules: Dict[str, Dict[str, Any]] = field(
|
||||
default_factory=lambda: {
|
||||
"default": {
|
||||
"max_chunk_chars": 4000,
|
||||
"max_chunk_lines": 200,
|
||||
"overlap_lines": 20,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
llm_enabled: bool = False
|
||||
llm_tool: str = "gemini"
|
||||
llm_timeout_ms: int = 300000
|
||||
llm_batch_size: int = 5
|
||||
|
||||
# Hybrid chunker configuration
|
||||
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
|
||||
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
|
||||
|
||||
# Embedding configuration
|
||||
embedding_backend: str = "fastembed" # "fastembed" (local) or "litellm" (API)
|
||||
embedding_model: str = "code" # For fastembed: profile (fast/code/multilingual/balanced)
|
||||
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||
|
||||
# SPLADE sparse retrieval configuration
|
||||
enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead)
|
||||
splade_model: str = "naver/splade-cocondenser-ensembledistil"
|
||||
splade_threshold: float = 0.01 # Min weight to store in index
|
||||
splade_onnx_path: Optional[str] = None # Custom ONNX model path
|
||||
|
||||
# FTS fallback (disabled by default, available via --use-fts)
|
||||
use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled)
|
||||
|
||||
# Indexing/search optimizations
|
||||
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
|
||||
|
||||
# Graph expansion (search-time, uses precomputed neighbors)
|
||||
enable_graph_expansion: bool = False
|
||||
graph_expansion_depth: int = 2
|
||||
|
||||
# Optional search reranking (disabled by default)
|
||||
enable_reranking: bool = False
|
||||
reranking_top_k: int = 50
|
||||
symbol_boost_factor: float = 1.5
|
||||
|
||||
# Optional cross-encoder reranking (second stage; requires optional reranker deps)
|
||||
enable_cross_encoder_rerank: bool = False
|
||||
reranker_backend: str = "onnx"
|
||||
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
reranker_top_k: int = 50
|
||||
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
|
||||
reranker_chunk_type_weights: Optional[Dict[str, float]] = None # Weights for chunk types: {"code": 1.0, "docstring": 0.7}
|
||||
reranker_test_file_penalty: float = 0.0 # Penalty for test files (0.0-1.0, e.g., 0.2 = 20% reduction)
|
||||
|
||||
# Chunk stripping configuration (for semantic embedding)
|
||||
chunk_strip_comments: bool = True # Strip comments from code chunks
|
||||
chunk_strip_docstrings: bool = True # Strip docstrings from code chunks
|
||||
|
||||
# Cascade search configuration (two-stage retrieval)
|
||||
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
||||
cascade_coarse_k: int = 100 # Number of coarse candidates from first stage
|
||||
cascade_fine_k: int = 10 # Number of final results after reranking
|
||||
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
|
||||
|
||||
# Staged cascade search configuration (4-stage pipeline)
|
||||
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
|
||||
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
|
||||
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
|
||||
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
|
||||
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
|
||||
|
||||
# RRF fusion configuration
|
||||
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
|
||||
rrf_k: int = 60 # RRF constant (default 60)
|
||||
|
||||
# Category-based filtering to separate code/doc results
|
||||
enable_category_filter: bool = True # Enable code/doc result separation
|
||||
|
||||
# Multi-endpoint configuration for litellm backend
|
||||
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
|
||||
embedding_pool_enabled: bool = False # Enable high availability pool for embeddings
|
||||
embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
|
||||
embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
|
||||
|
||||
# Reranker multi-endpoint configuration
|
||||
reranker_pool_enabled: bool = False # Enable high availability pool for reranker
|
||||
reranker_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
|
||||
reranker_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
|
||||
|
||||
# API concurrency settings
|
||||
api_max_workers: int = 4 # Max concurrent API calls for embedding/reranking
|
||||
api_batch_size: int = 8 # Batch size for API requests
|
||||
api_batch_size_dynamic: bool = False # Enable dynamic batch size calculation
|
||||
api_batch_size_utilization_factor: float = 0.8 # Use 80% of model token capacity
|
||||
api_batch_size_max: int = 2048 # Absolute upper limit for batch size
|
||||
chars_per_token_estimate: int = 4 # Characters per token estimation ratio
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
try:
|
||||
self.data_dir = self.data_dir.expanduser().resolve()
|
||||
self.venv_path = self.venv_path.expanduser().resolve()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as exc:
|
||||
raise ConfigError(
|
||||
f"Permission denied initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||
f"[{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ConfigError(
|
||||
f"Filesystem error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||
f"[{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise ConfigError(
|
||||
f"Unexpected error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
|
||||
f"[{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
|
||||
@cached_property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Directory for transient caches."""
|
||||
return self.data_dir / "cache"
|
||||
|
||||
@cached_property
|
||||
def index_dir(self) -> Path:
|
||||
"""Directory where index artifacts are stored."""
|
||||
return self.data_dir / "index"
|
||||
|
||||
@cached_property
|
||||
def db_path(self) -> Path:
|
||||
"""Default SQLite index path."""
|
||||
return self.index_dir / "codexlens.db"
|
||||
|
||||
def ensure_runtime_dirs(self) -> None:
|
||||
"""Create standard runtime directories if missing."""
|
||||
for directory in (self.cache_dir, self.index_dir):
|
||||
try:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as exc:
|
||||
raise ConfigError(
|
||||
f"Permission denied creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ConfigError(
|
||||
f"Filesystem error creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise ConfigError(
|
||||
f"Unexpected error creating directory {directory} [{type(exc).__name__}]: {exc}"
|
||||
) from exc
|
||||
|
||||
def language_for_path(self, path: str | Path) -> str | None:
|
||||
"""Infer a supported language ID from a file path."""
|
||||
extension = Path(path).suffix.lower()
|
||||
for language_id, spec in self.supported_languages.items():
|
||||
extensions: List[str] = spec.get("extensions", [])
|
||||
if extension in extensions:
|
||||
return language_id
|
||||
return None
|
||||
|
||||
def category_for_path(self, path: str | Path) -> str | None:
|
||||
"""Get file category ('code' or 'doc') from a file path."""
|
||||
language = self.language_for_path(path)
|
||||
if language is None:
|
||||
return None
|
||||
spec = self.supported_languages.get(language, {})
|
||||
return spec.get("category")
|
||||
|
||||
def rules_for_language(self, language_id: str) -> Dict[str, Any]:
|
||||
"""Get parsing rules for a specific language, falling back to defaults."""
|
||||
return {**self.parsing_rules.get("default", {}), **self.parsing_rules.get(language_id, {})}
|
||||
|
||||
@cached_property
|
||||
def settings_path(self) -> Path:
|
||||
"""Path to the settings file."""
|
||||
return self.data_dir / SETTINGS_FILE_NAME
|
||||
|
||||
def save_settings(self) -> None:
|
||||
"""Save embedding and other settings to file."""
|
||||
embedding_config = {
|
||||
"backend": self.embedding_backend,
|
||||
"model": self.embedding_model,
|
||||
"use_gpu": self.embedding_use_gpu,
|
||||
"pool_enabled": self.embedding_pool_enabled,
|
||||
"strategy": self.embedding_strategy,
|
||||
"cooldown": self.embedding_cooldown,
|
||||
}
|
||||
# Include multi-endpoint config if present
|
||||
if self.embedding_endpoints:
|
||||
embedding_config["endpoints"] = self.embedding_endpoints
|
||||
|
||||
settings = {
|
||||
"embedding": embedding_config,
|
||||
"llm": {
|
||||
"enabled": self.llm_enabled,
|
||||
"tool": self.llm_tool,
|
||||
"timeout_ms": self.llm_timeout_ms,
|
||||
"batch_size": self.llm_batch_size,
|
||||
},
|
||||
"reranker": {
|
||||
"enabled": self.enable_cross_encoder_rerank,
|
||||
"backend": self.reranker_backend,
|
||||
"model": self.reranker_model,
|
||||
"top_k": self.reranker_top_k,
|
||||
"max_input_tokens": self.reranker_max_input_tokens,
|
||||
"pool_enabled": self.reranker_pool_enabled,
|
||||
"strategy": self.reranker_strategy,
|
||||
"cooldown": self.reranker_cooldown,
|
||||
},
|
||||
"cascade": {
|
||||
"strategy": self.cascade_strategy,
|
||||
"coarse_k": self.cascade_coarse_k,
|
||||
"fine_k": self.cascade_fine_k,
|
||||
},
|
||||
"api": {
|
||||
"max_workers": self.api_max_workers,
|
||||
"batch_size": self.api_batch_size,
|
||||
"batch_size_dynamic": self.api_batch_size_dynamic,
|
||||
"batch_size_utilization_factor": self.api_batch_size_utilization_factor,
|
||||
"batch_size_max": self.api_batch_size_max,
|
||||
"chars_per_token_estimate": self.chars_per_token_estimate,
|
||||
},
|
||||
}
|
||||
with open(self.settings_path, "w", encoding="utf-8") as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
def load_settings(self) -> None:
|
||||
"""Load settings from file if exists."""
|
||||
if not self.settings_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.settings_path, "r", encoding="utf-8") as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Load embedding settings
|
||||
embedding = settings.get("embedding", {})
|
||||
if "backend" in embedding:
|
||||
backend = embedding["backend"]
|
||||
# Support 'api' as alias for 'litellm'
|
||||
if backend == "api":
|
||||
backend = "litellm"
|
||||
if backend in {"fastembed", "litellm"}:
|
||||
self.embedding_backend = backend
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid embedding backend in %s: %r (expected 'fastembed' or 'litellm')",
|
||||
self.settings_path,
|
||||
embedding["backend"],
|
||||
)
|
||||
if "model" in embedding:
|
||||
self.embedding_model = embedding["model"]
|
||||
if "use_gpu" in embedding:
|
||||
self.embedding_use_gpu = embedding["use_gpu"]
|
||||
|
||||
# Load multi-endpoint configuration
|
||||
if "endpoints" in embedding:
|
||||
self.embedding_endpoints = embedding["endpoints"]
|
||||
if "pool_enabled" in embedding:
|
||||
self.embedding_pool_enabled = embedding["pool_enabled"]
|
||||
if "strategy" in embedding:
|
||||
self.embedding_strategy = embedding["strategy"]
|
||||
if "cooldown" in embedding:
|
||||
self.embedding_cooldown = embedding["cooldown"]
|
||||
|
||||
# Load LLM settings
|
||||
llm = settings.get("llm", {})
|
||||
if "enabled" in llm:
|
||||
self.llm_enabled = llm["enabled"]
|
||||
if "tool" in llm:
|
||||
self.llm_tool = llm["tool"]
|
||||
if "timeout_ms" in llm:
|
||||
self.llm_timeout_ms = llm["timeout_ms"]
|
||||
if "batch_size" in llm:
|
||||
self.llm_batch_size = llm["batch_size"]
|
||||
|
||||
# Load reranker settings
|
||||
reranker = settings.get("reranker", {})
|
||||
if "enabled" in reranker:
|
||||
self.enable_cross_encoder_rerank = reranker["enabled"]
|
||||
if "backend" in reranker:
|
||||
backend = reranker["backend"]
|
||||
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
|
||||
self.reranker_backend = backend
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid reranker backend in %s: %r (expected 'fastembed', 'onnx', 'api', 'litellm', or 'legacy')",
|
||||
self.settings_path,
|
||||
backend,
|
||||
)
|
||||
if "model" in reranker:
|
||||
self.reranker_model = reranker["model"]
|
||||
if "top_k" in reranker:
|
||||
self.reranker_top_k = reranker["top_k"]
|
||||
if "max_input_tokens" in reranker:
|
||||
self.reranker_max_input_tokens = reranker["max_input_tokens"]
|
||||
if "pool_enabled" in reranker:
|
||||
self.reranker_pool_enabled = reranker["pool_enabled"]
|
||||
if "strategy" in reranker:
|
||||
self.reranker_strategy = reranker["strategy"]
|
||||
if "cooldown" in reranker:
|
||||
self.reranker_cooldown = reranker["cooldown"]
|
||||
|
||||
# Load cascade settings
|
||||
cascade = settings.get("cascade", {})
|
||||
if "strategy" in cascade:
|
||||
strategy = cascade["strategy"]
|
||||
if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}:
|
||||
self.cascade_strategy = strategy
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')",
|
||||
self.settings_path,
|
||||
strategy,
|
||||
)
|
||||
if "coarse_k" in cascade:
|
||||
self.cascade_coarse_k = cascade["coarse_k"]
|
||||
if "fine_k" in cascade:
|
||||
self.cascade_fine_k = cascade["fine_k"]
|
||||
|
||||
# Load API settings
|
||||
api = settings.get("api", {})
|
||||
if "max_workers" in api:
|
||||
self.api_max_workers = api["max_workers"]
|
||||
if "batch_size" in api:
|
||||
self.api_batch_size = api["batch_size"]
|
||||
if "batch_size_dynamic" in api:
|
||||
self.api_batch_size_dynamic = api["batch_size_dynamic"]
|
||||
if "batch_size_utilization_factor" in api:
|
||||
self.api_batch_size_utilization_factor = api["batch_size_utilization_factor"]
|
||||
if "batch_size_max" in api:
|
||||
self.api_batch_size_max = api["batch_size_max"]
|
||||
if "chars_per_token_estimate" in api:
|
||||
self.chars_per_token_estimate = api["chars_per_token_estimate"]
|
||||
except Exception as exc:
|
||||
log.warning(
|
||||
"Failed to load settings from %s (%s): %s",
|
||||
self.settings_path,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Apply .env overrides (highest priority)
|
||||
self._apply_env_overrides()
|
||||
|
||||
def _apply_env_overrides(self) -> None:
|
||||
"""Apply environment variable overrides from .env file.
|
||||
|
||||
Priority: default → settings.json → .env (highest)
|
||||
|
||||
Supported variables (with or without CODEXLENS_ prefix):
|
||||
EMBEDDING_MODEL: Override embedding model/profile
|
||||
EMBEDDING_BACKEND: Override embedding backend (fastembed/litellm)
|
||||
EMBEDDING_POOL_ENABLED: Enable embedding high availability pool
|
||||
EMBEDDING_STRATEGY: Load balance strategy for embedding
|
||||
EMBEDDING_COOLDOWN: Rate limit cooldown for embedding
|
||||
RERANKER_MODEL: Override reranker model
|
||||
RERANKER_BACKEND: Override reranker backend
|
||||
RERANKER_ENABLED: Override reranker enabled state (true/false)
|
||||
RERANKER_POOL_ENABLED: Enable reranker high availability pool
|
||||
RERANKER_STRATEGY: Load balance strategy for reranker
|
||||
RERANKER_COOLDOWN: Rate limit cooldown for reranker
|
||||
"""
|
||||
from .env_config import load_global_env
|
||||
|
||||
env_vars = load_global_env()
|
||||
if not env_vars:
|
||||
return
|
||||
|
||||
def get_env(key: str) -> str | None:
|
||||
"""Get env var with or without CODEXLENS_ prefix."""
|
||||
# Check prefixed version first (Dashboard format), then unprefixed
|
||||
return env_vars.get(f"CODEXLENS_{key}") or env_vars.get(key)
|
||||
|
||||
# Embedding overrides
|
||||
embedding_model = get_env("EMBEDDING_MODEL")
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
log.debug("Overriding embedding_model from .env: %s", self.embedding_model)
|
||||
|
||||
embedding_backend = get_env("EMBEDDING_BACKEND")
|
||||
if embedding_backend:
|
||||
backend = embedding_backend.lower()
|
||||
# Support 'api' as alias for 'litellm'
|
||||
if backend == "api":
|
||||
backend = "litellm"
|
||||
if backend in {"fastembed", "litellm"}:
|
||||
self.embedding_backend = backend
|
||||
log.debug("Overriding embedding_backend from .env: %s", backend)
|
||||
else:
|
||||
log.warning("Invalid EMBEDDING_BACKEND in .env: %r", embedding_backend)
|
||||
|
||||
embedding_pool = get_env("EMBEDDING_POOL_ENABLED")
|
||||
if embedding_pool:
|
||||
value = embedding_pool.lower()
|
||||
self.embedding_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding embedding_pool_enabled from .env: %s", self.embedding_pool_enabled)
|
||||
|
||||
embedding_strategy = get_env("EMBEDDING_STRATEGY")
|
||||
if embedding_strategy:
|
||||
strategy = embedding_strategy.lower()
|
||||
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||
self.embedding_strategy = strategy
|
||||
log.debug("Overriding embedding_strategy from .env: %s", strategy)
|
||||
else:
|
||||
log.warning("Invalid EMBEDDING_STRATEGY in .env: %r", embedding_strategy)
|
||||
|
||||
embedding_cooldown = get_env("EMBEDDING_COOLDOWN")
|
||||
if embedding_cooldown:
|
||||
try:
|
||||
self.embedding_cooldown = float(embedding_cooldown)
|
||||
log.debug("Overriding embedding_cooldown from .env: %s", self.embedding_cooldown)
|
||||
except ValueError:
|
||||
log.warning("Invalid EMBEDDING_COOLDOWN in .env: %r", embedding_cooldown)
|
||||
|
||||
# Reranker overrides
|
||||
reranker_model = get_env("RERANKER_MODEL")
|
||||
if reranker_model:
|
||||
self.reranker_model = reranker_model
|
||||
log.debug("Overriding reranker_model from .env: %s", self.reranker_model)
|
||||
|
||||
reranker_backend = get_env("RERANKER_BACKEND")
|
||||
if reranker_backend:
|
||||
backend = reranker_backend.lower()
|
||||
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
|
||||
self.reranker_backend = backend
|
||||
log.debug("Overriding reranker_backend from .env: %s", backend)
|
||||
else:
|
||||
log.warning("Invalid RERANKER_BACKEND in .env: %r", reranker_backend)
|
||||
|
||||
reranker_enabled = get_env("RERANKER_ENABLED")
|
||||
if reranker_enabled:
|
||||
value = reranker_enabled.lower()
|
||||
self.enable_cross_encoder_rerank = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding reranker_enabled from .env: %s", self.enable_cross_encoder_rerank)
|
||||
|
||||
reranker_pool = get_env("RERANKER_POOL_ENABLED")
|
||||
if reranker_pool:
|
||||
value = reranker_pool.lower()
|
||||
self.reranker_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding reranker_pool_enabled from .env: %s", self.reranker_pool_enabled)
|
||||
|
||||
reranker_strategy = get_env("RERANKER_STRATEGY")
|
||||
if reranker_strategy:
|
||||
strategy = reranker_strategy.lower()
|
||||
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||
self.reranker_strategy = strategy
|
||||
log.debug("Overriding reranker_strategy from .env: %s", strategy)
|
||||
else:
|
||||
log.warning("Invalid RERANKER_STRATEGY in .env: %r", reranker_strategy)
|
||||
|
||||
reranker_cooldown = get_env("RERANKER_COOLDOWN")
|
||||
if reranker_cooldown:
|
||||
try:
|
||||
self.reranker_cooldown = float(reranker_cooldown)
|
||||
log.debug("Overriding reranker_cooldown from .env: %s", self.reranker_cooldown)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_COOLDOWN in .env: %r", reranker_cooldown)
|
||||
|
||||
reranker_max_tokens = get_env("RERANKER_MAX_INPUT_TOKENS")
|
||||
if reranker_max_tokens:
|
||||
try:
|
||||
self.reranker_max_input_tokens = int(reranker_max_tokens)
|
||||
log.debug("Overriding reranker_max_input_tokens from .env: %s", self.reranker_max_input_tokens)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
|
||||
|
||||
# Reranker tuning from environment
|
||||
test_penalty = get_env("RERANKER_TEST_FILE_PENALTY")
|
||||
if test_penalty:
|
||||
try:
|
||||
self.reranker_test_file_penalty = float(test_penalty)
|
||||
log.debug("Overriding reranker_test_file_penalty from .env: %s", self.reranker_test_file_penalty)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_TEST_FILE_PENALTY in .env: %r", test_penalty)
|
||||
|
||||
docstring_weight = get_env("RERANKER_DOCSTRING_WEIGHT")
|
||||
if docstring_weight:
|
||||
try:
|
||||
weight = float(docstring_weight)
|
||||
self.reranker_chunk_type_weights = {"code": 1.0, "docstring": weight}
|
||||
log.debug("Overriding reranker docstring weight from .env: %s", weight)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_DOCSTRING_WEIGHT in .env: %r", docstring_weight)
|
||||
|
||||
# Chunk stripping from environment
|
||||
strip_comments = get_env("CHUNK_STRIP_COMMENTS")
|
||||
if strip_comments:
|
||||
self.chunk_strip_comments = strip_comments.lower() in ("true", "1", "yes")
|
||||
log.debug("Overriding chunk_strip_comments from .env: %s", self.chunk_strip_comments)
|
||||
|
||||
strip_docstrings = get_env("CHUNK_STRIP_DOCSTRINGS")
|
||||
if strip_docstrings:
|
||||
self.chunk_strip_docstrings = strip_docstrings.lower() in ("true", "1", "yes")
|
||||
log.debug("Overriding chunk_strip_docstrings from .env: %s", self.chunk_strip_docstrings)
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "Config":
|
||||
"""Load config with settings from file."""
|
||||
config = cls()
|
||||
config.load_settings()
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceConfig:
|
||||
"""Workspace-local configuration for CodexLens.
|
||||
|
||||
Stores index data in project/.codexlens/ directory.
|
||||
"""
|
||||
|
||||
workspace_root: Path
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.workspace_root = Path(self.workspace_root).resolve()
|
||||
|
||||
@property
|
||||
def codexlens_dir(self) -> Path:
|
||||
"""The .codexlens directory in workspace root."""
|
||||
return self.workspace_root / WORKSPACE_DIR_NAME
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""SQLite index path for this workspace."""
|
||||
return self.codexlens_dir / "index.db"
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Cache directory for this workspace."""
|
||||
return self.codexlens_dir / "cache"
|
||||
|
||||
@property
|
||||
def env_path(self) -> Path:
|
||||
"""Path to workspace .env file."""
|
||||
return self.codexlens_dir / ".env"
|
||||
|
||||
def load_env(self, *, override: bool = False) -> int:
|
||||
"""Load .env file and apply to os.environ.
|
||||
|
||||
Args:
|
||||
override: If True, override existing environment variables
|
||||
|
||||
Returns:
|
||||
Number of variables applied
|
||||
"""
|
||||
from .env_config import apply_workspace_env
|
||||
return apply_workspace_env(self.workspace_root, override=override)
|
||||
|
||||
def get_api_config(self, prefix: str) -> dict:
|
||||
"""Get API configuration from environment.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||
|
||||
Returns:
|
||||
Dictionary with api_key, api_base, model, etc.
|
||||
"""
|
||||
from .env_config import get_api_config
|
||||
return get_api_config(prefix, workspace_root=self.workspace_root)
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Create the .codexlens directory structure."""
|
||||
try:
|
||||
self.codexlens_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create .gitignore to exclude cache but keep index
|
||||
gitignore_path = self.codexlens_dir / ".gitignore"
|
||||
if not gitignore_path.exists():
|
||||
gitignore_path.write_text(
|
||||
"# CodexLens workspace data\n"
|
||||
"cache/\n"
|
||||
"*.log\n"
|
||||
".env\n" # Exclude .env from git
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if workspace is already initialized."""
|
||||
return self.codexlens_dir.is_dir() and self.db_path.exists()
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path: Path) -> Optional["WorkspaceConfig"]:
|
||||
"""Create WorkspaceConfig from a path by finding workspace root.
|
||||
|
||||
Returns None if no workspace found.
|
||||
"""
|
||||
root = find_workspace_root(path)
|
||||
if root is None:
|
||||
return None
|
||||
return cls(workspace_root=root)
|
||||
|
||||
@classmethod
|
||||
def create_at(cls, path: Path) -> "WorkspaceConfig":
|
||||
"""Create a new workspace at the given path."""
|
||||
config = cls(workspace_root=path)
|
||||
config.initialize()
|
||||
return config
|
||||
128
codex-lens/build/lib/codexlens/entities.py
Normal file
128
codex-lens/build/lib/codexlens/entities.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Pydantic entity models for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class Symbol(BaseModel):
|
||||
"""A code symbol discovered in a file."""
|
||||
|
||||
name: str = Field(..., min_length=1)
|
||||
kind: str = Field(..., min_length=1)
|
||||
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
||||
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
|
||||
|
||||
@field_validator("range")
|
||||
@classmethod
|
||||
def validate_range(cls, value: Tuple[int, int]) -> Tuple[int, int]:
|
||||
if len(value) != 2:
|
||||
raise ValueError("range must be a (start_line, end_line) tuple")
|
||||
start_line, end_line = value
|
||||
if start_line < 1 or end_line < 1:
|
||||
raise ValueError("range lines must be >= 1")
|
||||
if end_line < start_line:
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
return value
|
||||
|
||||
|
||||
class SemanticChunk(BaseModel):
|
||||
"""A semantically meaningful chunk of content, optionally embedded."""
|
||||
|
||||
content: str = Field(..., min_length=1)
|
||||
embedding: Optional[List[float]] = Field(default=None, description="Vector embedding for semantic search")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
id: Optional[int] = Field(default=None, description="Database row ID")
|
||||
file_path: Optional[str] = Field(default=None, description="Source file path")
|
||||
|
||||
@field_validator("embedding")
|
||||
@classmethod
|
||||
def validate_embedding(cls, value: Optional[List[float]]) -> Optional[List[float]]:
|
||||
if value is None:
|
||||
return value
|
||||
if not value:
|
||||
raise ValueError("embedding cannot be empty when provided")
|
||||
norm = math.sqrt(sum(x * x for x in value))
|
||||
epsilon = 1e-10
|
||||
if norm < epsilon:
|
||||
raise ValueError("embedding cannot be a zero vector")
|
||||
return value
|
||||
|
||||
|
||||
class IndexedFile(BaseModel):
|
||||
"""An indexed source file with symbols and optional semantic chunks."""
|
||||
|
||||
path: str = Field(..., min_length=1)
|
||||
language: str = Field(..., min_length=1)
|
||||
symbols: List[Symbol] = Field(default_factory=list)
|
||||
chunks: List[SemanticChunk] = Field(default_factory=list)
|
||||
relationships: List["CodeRelationship"] = Field(default_factory=list)
|
||||
|
||||
@field_validator("path", "language")
|
||||
@classmethod
|
||||
def strip_and_validate_nonempty(cls, value: str) -> str:
|
||||
cleaned = value.strip()
|
||||
if not cleaned:
|
||||
raise ValueError("value cannot be blank")
|
||||
return cleaned
|
||||
|
||||
|
||||
class RelationshipType(str, Enum):
|
||||
"""Types of code relationships."""
|
||||
CALL = "calls"
|
||||
INHERITS = "inherits"
|
||||
IMPORTS = "imports"
|
||||
|
||||
|
||||
class CodeRelationship(BaseModel):
|
||||
"""A relationship between code symbols (e.g., function calls, inheritance)."""
|
||||
|
||||
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
|
||||
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
|
||||
relationship_type: RelationshipType = Field(..., description="Type of relationship (call, inherits, etc.)")
|
||||
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
|
||||
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
|
||||
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
|
||||
|
||||
|
||||
class AdditionalLocation(BaseModel):
|
||||
"""A pointer to another location where a similar result was found.
|
||||
|
||||
Used for grouping search results with similar scores and content,
|
||||
where the primary result is stored in SearchResult and secondary
|
||||
locations are stored in this model.
|
||||
"""
|
||||
|
||||
path: str = Field(..., min_length=1)
|
||||
score: float = Field(..., ge=0.0)
|
||||
start_line: Optional[int] = Field(default=None, description="Start line of the result (1-based)")
|
||||
end_line: Optional[int] = Field(default=None, description="End line of the result (1-based)")
|
||||
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol")
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""A unified search result for lexical or semantic search."""
|
||||
|
||||
path: str = Field(..., min_length=1)
|
||||
score: float = Field(..., ge=0.0)
|
||||
excerpt: Optional[str] = None
|
||||
content: Optional[str] = Field(default=None, description="Full content of matched code block")
|
||||
symbol: Optional[Symbol] = None
|
||||
chunk: Optional[SemanticChunk] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# Additional context for complete code blocks
|
||||
start_line: Optional[int] = Field(default=None, description="Start line of code block (1-based)")
|
||||
end_line: Optional[int] = Field(default=None, description="End line of code block (1-based)")
|
||||
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol/function/class")
|
||||
symbol_kind: Optional[str] = Field(default=None, description="Kind of symbol (function/class/method)")
|
||||
|
||||
# Field for grouping similar results
|
||||
additional_locations: List["AdditionalLocation"] = Field(
|
||||
default_factory=list,
|
||||
description="Other locations for grouped results with similar scores and content."
|
||||
)
|
||||
304
codex-lens/build/lib/codexlens/env_config.py
Normal file
304
codex-lens/build/lib/codexlens/env_config.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Environment configuration loader for CodexLens.
|
||||
|
||||
Loads .env files from workspace .codexlens directory with fallback to project root.
|
||||
Provides unified access to API configurations.
|
||||
|
||||
Priority order:
|
||||
1. Environment variables (already set)
|
||||
2. .codexlens/.env (workspace-local)
|
||||
3. .env (project root)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Supported environment variables with descriptions
|
||||
ENV_VARS = {
|
||||
# Reranker configuration (overrides settings.json)
|
||||
"RERANKER_MODEL": "Reranker model name (overrides settings.json)",
|
||||
"RERANKER_BACKEND": "Reranker backend: fastembed, onnx, api, litellm, legacy",
|
||||
"RERANKER_ENABLED": "Enable reranker: true/false",
|
||||
"RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)",
|
||||
"RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)",
|
||||
"RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina",
|
||||
"RERANKER_POOL_ENABLED": "Enable reranker high availability pool: true/false",
|
||||
"RERANKER_STRATEGY": "Reranker load balance strategy: round_robin, latency_aware, weighted_random",
|
||||
"RERANKER_COOLDOWN": "Reranker rate limit cooldown in seconds",
|
||||
# Embedding configuration (overrides settings.json)
|
||||
"EMBEDDING_MODEL": "Embedding model/profile name (overrides settings.json)",
|
||||
"EMBEDDING_BACKEND": "Embedding backend: fastembed, litellm",
|
||||
"EMBEDDING_API_KEY": "API key for embedding service",
|
||||
"EMBEDDING_API_BASE": "Base URL for embedding API",
|
||||
"EMBEDDING_POOL_ENABLED": "Enable embedding high availability pool: true/false",
|
||||
"EMBEDDING_STRATEGY": "Embedding load balance strategy: round_robin, latency_aware, weighted_random",
|
||||
"EMBEDDING_COOLDOWN": "Embedding rate limit cooldown in seconds",
|
||||
# LiteLLM configuration
|
||||
"LITELLM_API_KEY": "API key for LiteLLM",
|
||||
"LITELLM_API_BASE": "Base URL for LiteLLM",
|
||||
"LITELLM_MODEL": "LiteLLM model name",
|
||||
# General configuration
|
||||
"CODEXLENS_DATA_DIR": "Custom data directory path",
|
||||
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
|
||||
# Chunking configuration
|
||||
"CHUNK_STRIP_COMMENTS": "Strip comments from code chunks for embedding: true/false (default: true)",
|
||||
"CHUNK_STRIP_DOCSTRINGS": "Strip docstrings from code chunks for embedding: true/false (default: true)",
|
||||
# Reranker tuning
|
||||
"RERANKER_TEST_FILE_PENALTY": "Penalty for test files in reranking: 0.0-1.0 (default: 0.0)",
|
||||
"RERANKER_DOCSTRING_WEIGHT": "Weight for docstring chunks in reranking: 0.0-1.0 (default: 1.0)",
|
||||
}
|
||||
|
||||
|
||||
def _parse_env_line(line: str) -> tuple[str, str] | None:
|
||||
"""Parse a single .env line, returning (key, value) or None."""
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith("#"):
|
||||
return None
|
||||
|
||||
# Handle export prefix
|
||||
if line.startswith("export "):
|
||||
line = line[7:].strip()
|
||||
|
||||
# Split on first =
|
||||
if "=" not in line:
|
||||
return None
|
||||
|
||||
key, _, value = line.partition("=")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Remove surrounding quotes
|
||||
if len(value) >= 2:
|
||||
if (value.startswith('"') and value.endswith('"')) or \
|
||||
(value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1]
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
def load_env_file(env_path: Path) -> Dict[str, str]:
|
||||
"""Load environment variables from a .env file.
|
||||
|
||||
Args:
|
||||
env_path: Path to .env file
|
||||
|
||||
Returns:
|
||||
Dictionary of environment variables
|
||||
"""
|
||||
if not env_path.is_file():
|
||||
return {}
|
||||
|
||||
env_vars: Dict[str, str] = {}
|
||||
|
||||
try:
|
||||
content = env_path.read_text(encoding="utf-8")
|
||||
for line in content.splitlines():
|
||||
result = _parse_env_line(line)
|
||||
if result:
|
||||
key, value = result
|
||||
env_vars[key] = value
|
||||
except Exception as exc:
|
||||
log.warning("Failed to load .env file %s: %s", env_path, exc)
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def _get_global_data_dir() -> Path:
|
||||
"""Get global CodexLens data directory."""
|
||||
env_override = os.environ.get("CODEXLENS_DATA_DIR")
|
||||
if env_override:
|
||||
return Path(env_override).expanduser().resolve()
|
||||
return (Path.home() / ".codexlens").resolve()
|
||||
|
||||
|
||||
def load_global_env() -> Dict[str, str]:
|
||||
"""Load environment variables from global ~/.codexlens/.env file.
|
||||
|
||||
Returns:
|
||||
Dictionary of environment variables from global config
|
||||
"""
|
||||
global_env_path = _get_global_data_dir() / ".env"
|
||||
if global_env_path.is_file():
|
||||
env_vars = load_env_file(global_env_path)
|
||||
log.debug("Loaded %d vars from global %s", len(env_vars), global_env_path)
|
||||
return env_vars
|
||||
return {}
|
||||
|
||||
|
||||
def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]:
|
||||
"""Load environment variables from workspace .env files.
|
||||
|
||||
Priority (later overrides earlier):
|
||||
1. Global ~/.codexlens/.env (lowest priority)
|
||||
2. Project root .env
|
||||
3. .codexlens/.env (highest priority)
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace root directory. If None, uses current directory.
|
||||
|
||||
Returns:
|
||||
Merged dictionary of environment variables
|
||||
"""
|
||||
if workspace_root is None:
|
||||
workspace_root = Path.cwd()
|
||||
|
||||
workspace_root = Path(workspace_root).resolve()
|
||||
|
||||
env_vars: Dict[str, str] = {}
|
||||
|
||||
# Load from global ~/.codexlens/.env (lowest priority)
|
||||
global_vars = load_global_env()
|
||||
if global_vars:
|
||||
env_vars.update(global_vars)
|
||||
|
||||
# Load from project root .env (medium priority)
|
||||
root_env = workspace_root / ".env"
|
||||
if root_env.is_file():
|
||||
loaded = load_env_file(root_env)
|
||||
env_vars.update(loaded)
|
||||
log.debug("Loaded %d vars from %s", len(loaded), root_env)
|
||||
|
||||
# Load from .codexlens/.env (highest priority)
|
||||
codexlens_env = workspace_root / ".codexlens" / ".env"
|
||||
if codexlens_env.is_file():
|
||||
loaded = load_env_file(codexlens_env)
|
||||
env_vars.update(loaded)
|
||||
log.debug("Loaded %d vars from %s", len(loaded), codexlens_env)
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int:
|
||||
"""Load .env files and apply to os.environ.
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace root directory
|
||||
override: If True, override existing environment variables
|
||||
|
||||
Returns:
|
||||
Number of variables applied
|
||||
"""
|
||||
env_vars = load_workspace_env(workspace_root)
|
||||
applied = 0
|
||||
|
||||
for key, value in env_vars.items():
|
||||
if override or key not in os.environ:
|
||||
os.environ[key] = value
|
||||
applied += 1
|
||||
log.debug("Applied env var: %s", key)
|
||||
|
||||
return applied
|
||||
|
||||
|
||||
def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback.
|
||||
|
||||
Priority:
|
||||
1. os.environ (already set)
|
||||
2. .codexlens/.env
|
||||
3. .env
|
||||
4. default value
|
||||
|
||||
Args:
|
||||
key: Environment variable name
|
||||
default: Default value if not found
|
||||
workspace_root: Workspace root for .env file lookup
|
||||
|
||||
Returns:
|
||||
Value or default
|
||||
"""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Load from .env files
|
||||
env_vars = load_workspace_env(workspace_root)
|
||||
if key in env_vars:
|
||||
return env_vars[key]
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_api_config(
|
||||
prefix: str,
|
||||
*,
|
||||
workspace_root: Path | None = None,
|
||||
defaults: Dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get API configuration from environment.
|
||||
|
||||
Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||
workspace_root: Workspace root for .env file lookup
|
||||
defaults: Default values
|
||||
|
||||
Returns:
|
||||
Dictionary with api_key, api_base, model, etc.
|
||||
"""
|
||||
defaults = defaults or {}
|
||||
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
# Standard API config fields
|
||||
field_mapping = {
|
||||
"api_key": f"{prefix}_API_KEY",
|
||||
"api_base": f"{prefix}_API_BASE",
|
||||
"model": f"{prefix}_MODEL",
|
||||
"provider": f"{prefix}_PROVIDER",
|
||||
"timeout": f"{prefix}_TIMEOUT",
|
||||
}
|
||||
|
||||
for field, env_key in field_mapping.items():
|
||||
value = get_env(env_key, workspace_root=workspace_root)
|
||||
if value is not None:
|
||||
# Type conversion for specific fields
|
||||
if field == "timeout":
|
||||
try:
|
||||
config[field] = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
config[field] = value
|
||||
elif field in defaults:
|
||||
config[field] = defaults[field]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def generate_env_example() -> str:
|
||||
"""Generate .env.example content with all supported variables.
|
||||
|
||||
Returns:
|
||||
String content for .env.example file
|
||||
"""
|
||||
lines = [
|
||||
"# CodexLens Environment Configuration",
|
||||
"# Copy this file to .codexlens/.env and fill in your values",
|
||||
"",
|
||||
]
|
||||
|
||||
# Group by prefix
|
||||
groups: Dict[str, list] = {}
|
||||
for key, desc in ENV_VARS.items():
|
||||
prefix = key.split("_")[0]
|
||||
if prefix not in groups:
|
||||
groups[prefix] = []
|
||||
groups[prefix].append((key, desc))
|
||||
|
||||
for prefix, items in groups.items():
|
||||
lines.append(f"# {prefix} Configuration")
|
||||
for key, desc in items:
|
||||
lines.append(f"# {desc}")
|
||||
lines.append(f"# {key}=")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
59
codex-lens/build/lib/codexlens/errors.py
Normal file
59
codex-lens/build/lib/codexlens/errors.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""CodexLens exception hierarchy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class CodexLensError(Exception):
|
||||
"""Base class for all CodexLens errors."""
|
||||
|
||||
|
||||
class ConfigError(CodexLensError):
|
||||
"""Raised when configuration is invalid or cannot be loaded."""
|
||||
|
||||
|
||||
class ParseError(CodexLensError):
|
||||
"""Raised when parsing or indexing a file fails."""
|
||||
|
||||
|
||||
class StorageError(CodexLensError):
|
||||
"""Raised when reading/writing index storage fails.
|
||||
|
||||
Attributes:
|
||||
message: Human-readable error description
|
||||
db_path: Path to the database file (if applicable)
|
||||
operation: The operation that failed (e.g., 'query', 'initialize', 'migrate')
|
||||
details: Additional context for debugging
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
db_path: str | None = None,
|
||||
operation: str | None = None,
|
||||
details: dict | None = None
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.db_path = db_path
|
||||
self.operation = operation
|
||||
self.details = details or {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [self.message]
|
||||
if self.db_path:
|
||||
parts.append(f"[db: {self.db_path}]")
|
||||
if self.operation:
|
||||
parts.append(f"[op: {self.operation}]")
|
||||
if self.details:
|
||||
detail_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
|
||||
parts.append(f"[{detail_str}]")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
class SearchError(CodexLensError):
|
||||
"""Raised when a search operation fails."""
|
||||
|
||||
|
||||
class IndexNotFoundError(CodexLensError):
|
||||
"""Raised when a project's index cannot be found."""
|
||||
|
||||
28
codex-lens/build/lib/codexlens/hybrid_search/__init__.py
Normal file
28
codex-lens/build/lib/codexlens/hybrid_search/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Hybrid Search data structures for CodexLens.
|
||||
|
||||
This module provides core data structures for hybrid search:
|
||||
- CodeSymbolNode: Graph node representing a code symbol
|
||||
- CodeAssociationGraph: Graph of code relationships
|
||||
- SearchResultCluster: Clustered search results
|
||||
- Range: Position range in source files
|
||||
- CallHierarchyItem: LSP call hierarchy item
|
||||
|
||||
Note: The search engine is in codexlens.search.hybrid_search
|
||||
LSP-based expansion is in codexlens.lsp module
|
||||
"""
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeAssociationGraph,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
SearchResultCluster,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallHierarchyItem",
|
||||
"CodeAssociationGraph",
|
||||
"CodeSymbolNode",
|
||||
"Range",
|
||||
"SearchResultCluster",
|
||||
]
|
||||
602
codex-lens/build/lib/codexlens/hybrid_search/data_structures.py
Normal file
602
codex-lens/build/lib/codexlens/hybrid_search/data_structures.py
Normal file
@@ -0,0 +1,602 @@
|
||||
"""Core data structures for the hybrid search system.
|
||||
|
||||
This module defines the fundamental data structures used throughout the
|
||||
hybrid search pipeline, including code symbol representations, association
|
||||
graphs, and clustered search results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import networkx as nx
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
"""Position range within a source file.
|
||||
|
||||
Attributes:
|
||||
start_line: Starting line number (0-based).
|
||||
start_character: Starting character offset within the line.
|
||||
end_line: Ending line number (0-based).
|
||||
end_character: Ending character offset within the line.
|
||||
"""
|
||||
|
||||
start_line: int
|
||||
start_character: int
|
||||
end_line: int
|
||||
end_character: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate range values."""
|
||||
if self.start_line < 0:
|
||||
raise ValueError("start_line must be >= 0")
|
||||
if self.start_character < 0:
|
||||
raise ValueError("start_character must be >= 0")
|
||||
if self.end_line < 0:
|
||||
raise ValueError("end_line must be >= 0")
|
||||
if self.end_character < 0:
|
||||
raise ValueError("end_character must be >= 0")
|
||||
if self.end_line < self.start_line:
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
if self.end_line == self.start_line and self.end_character < self.start_character:
|
||||
raise ValueError("end_character must be >= start_character on the same line")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"start": {"line": self.start_line, "character": self.start_character},
|
||||
"end": {"line": self.end_line, "character": self.end_character},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Range:
|
||||
"""Create Range from dictionary representation."""
|
||||
return cls(
|
||||
start_line=data["start"]["line"],
|
||||
start_character=data["start"]["character"],
|
||||
end_line=data["end"]["line"],
|
||||
end_character=data["end"]["character"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
|
||||
"""Create Range from LSP Range object.
|
||||
|
||||
LSP Range format:
|
||||
{"start": {"line": int, "character": int},
|
||||
"end": {"line": int, "character": int}}
|
||||
"""
|
||||
return cls(
|
||||
start_line=lsp_range["start"]["line"],
|
||||
start_character=lsp_range["start"]["character"],
|
||||
end_line=lsp_range["end"]["line"],
|
||||
end_character=lsp_range["end"]["character"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallHierarchyItem:
|
||||
"""LSP CallHierarchyItem for representing callers/callees.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name (function, method, class name).
|
||||
kind: Symbol kind (function, method, class, etc.).
|
||||
file_path: Absolute file path where the symbol is defined.
|
||||
range: Position range in the source file.
|
||||
detail: Optional additional detail about the symbol.
|
||||
"""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
detail: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
}
|
||||
if self.detail:
|
||||
result["detail"] = self.detail
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from dictionary representation."""
|
||||
return cls(
|
||||
name=data["name"],
|
||||
kind=data["kind"],
|
||||
file_path=data["file_path"],
|
||||
range=Range.from_dict(data["range"]),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeSymbolNode:
|
||||
"""Graph node representing a code symbol.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier in format 'file_path:name:line'.
|
||||
name: Symbol name (function, class, variable name).
|
||||
kind: Symbol kind (function, class, method, variable, etc.).
|
||||
file_path: Absolute file path where symbol is defined.
|
||||
range: Start/end position in the source file.
|
||||
embedding: Optional vector embedding for semantic search.
|
||||
raw_code: Raw source code of the symbol.
|
||||
docstring: Documentation string (if available).
|
||||
score: Ranking score (used during reranking).
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
embedding: Optional[List[float]] = None
|
||||
raw_code: str = ""
|
||||
docstring: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate required fields."""
|
||||
if not self.id:
|
||||
raise ValueError("id cannot be empty")
|
||||
if not self.name:
|
||||
raise ValueError("name cannot be empty")
|
||||
if not self.kind:
|
||||
raise ValueError("kind cannot be empty")
|
||||
if not self.file_path:
|
||||
raise ValueError("file_path cannot be empty")
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on unique ID."""
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on unique ID."""
|
||||
if not isinstance(other, CodeSymbolNode):
|
||||
return False
|
||||
return self.id == other.id
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
"score": self.score,
|
||||
}
|
||||
if self.raw_code:
|
||||
result["raw_code"] = self.raw_code
|
||||
if self.docstring:
|
||||
result["docstring"] = self.docstring
|
||||
# Exclude embedding from serialization (too large for JSON responses)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
|
||||
"""Create CodeSymbolNode from dictionary representation."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
kind=data["kind"],
|
||||
file_path=data["file_path"],
|
||||
range=Range.from_dict(data["range"]),
|
||||
embedding=data.get("embedding"),
|
||||
raw_code=data.get("raw_code", ""),
|
||||
docstring=data.get("docstring", ""),
|
||||
score=data.get("score", 0.0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp_location(
|
||||
cls,
|
||||
uri: str,
|
||||
name: str,
|
||||
kind: str,
|
||||
lsp_range: Dict[str, Any],
|
||||
raw_code: str = "",
|
||||
docstring: str = "",
|
||||
) -> CodeSymbolNode:
|
||||
"""Create CodeSymbolNode from LSP location data.
|
||||
|
||||
Args:
|
||||
uri: File URI (file:// prefix will be stripped).
|
||||
name: Symbol name.
|
||||
kind: Symbol kind.
|
||||
lsp_range: LSP Range object.
|
||||
raw_code: Optional raw source code.
|
||||
docstring: Optional documentation string.
|
||||
|
||||
Returns:
|
||||
New CodeSymbolNode instance.
|
||||
"""
|
||||
# Strip file:// prefix if present
|
||||
file_path = uri
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
# Handle Windows paths (file:///C:/...)
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
range_obj = Range.from_lsp_range(lsp_range)
|
||||
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
|
||||
|
||||
return cls(
|
||||
id=symbol_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=range_obj,
|
||||
raw_code=raw_code,
|
||||
docstring=docstring,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_id(cls, file_path: str, name: str, line: int) -> str:
|
||||
"""Generate a unique symbol ID.
|
||||
|
||||
Args:
|
||||
file_path: Absolute file path.
|
||||
name: Symbol name.
|
||||
line: Start line number.
|
||||
|
||||
Returns:
|
||||
Unique ID string in format 'file_path:name:line'.
|
||||
"""
|
||||
return f"{file_path}:{name}:{line}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeAssociationGraph:
|
||||
"""Graph of code relationships between symbols.
|
||||
|
||||
This graph represents the association between code symbols discovered
|
||||
through LSP queries (references, call hierarchy, etc.).
|
||||
|
||||
Attributes:
|
||||
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
|
||||
edges: List of (from_id, to_id, relationship_type) tuples.
|
||||
relationship_type: 'calls', 'references', 'inherits', 'imports'.
|
||||
"""
|
||||
|
||||
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
|
||||
edges: List[Tuple[str, str, str]] = field(default_factory=list)
|
||||
|
||||
def add_node(self, node: CodeSymbolNode) -> None:
|
||||
"""Add a node to the graph.
|
||||
|
||||
Args:
|
||||
node: CodeSymbolNode to add. If a node with the same ID exists,
|
||||
it will be replaced.
|
||||
"""
|
||||
self.nodes[node.id] = node
|
||||
|
||||
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||
"""Add an edge to the graph.
|
||||
|
||||
Args:
|
||||
from_id: Source node ID.
|
||||
to_id: Target node ID.
|
||||
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
|
||||
|
||||
Raises:
|
||||
ValueError: If from_id or to_id not in graph nodes.
|
||||
"""
|
||||
if from_id not in self.nodes:
|
||||
raise ValueError(f"Source node '{from_id}' not found in graph")
|
||||
if to_id not in self.nodes:
|
||||
raise ValueError(f"Target node '{to_id}' not found in graph")
|
||||
|
||||
edge = (from_id, to_id, rel_type)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||
"""Add an edge without validating node existence.
|
||||
|
||||
Use this method during bulk graph construction where nodes may be
|
||||
added after edges, or when performance is critical.
|
||||
|
||||
Args:
|
||||
from_id: Source node ID.
|
||||
to_id: Target node ID.
|
||||
rel_type: Relationship type.
|
||||
"""
|
||||
edge = (from_id, to_id, rel_type)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
|
||||
"""Get a node by ID.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to look up.
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode if found, None otherwise.
|
||||
"""
|
||||
return self.nodes.get(node_id)
|
||||
|
||||
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||
"""Get neighboring nodes connected by outgoing edges.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to find neighbors for.
|
||||
rel_type: Optional filter by relationship type.
|
||||
|
||||
Returns:
|
||||
List of neighboring CodeSymbolNode objects.
|
||||
"""
|
||||
neighbors = []
|
||||
for from_id, to_id, edge_rel in self.edges:
|
||||
if from_id == node_id:
|
||||
if rel_type is None or edge_rel == rel_type:
|
||||
node = self.nodes.get(to_id)
|
||||
if node:
|
||||
neighbors.append(node)
|
||||
return neighbors
|
||||
|
||||
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||
"""Get nodes connected by incoming edges.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to find incoming connections for.
|
||||
rel_type: Optional filter by relationship type.
|
||||
|
||||
Returns:
|
||||
List of CodeSymbolNode objects with edges pointing to node_id.
|
||||
"""
|
||||
incoming = []
|
||||
for from_id, to_id, edge_rel in self.edges:
|
||||
if to_id == node_id:
|
||||
if rel_type is None or edge_rel == rel_type:
|
||||
node = self.nodes.get(from_id)
|
||||
if node:
|
||||
incoming.append(node)
|
||||
return incoming
|
||||
|
||||
def to_networkx(self) -> "nx.DiGraph":
|
||||
"""Convert to NetworkX DiGraph for graph algorithms.
|
||||
|
||||
Returns:
|
||||
NetworkX directed graph with nodes and edges.
|
||||
|
||||
Raises:
|
||||
ImportError: If networkx is not installed.
|
||||
"""
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"networkx is required for graph algorithms. "
|
||||
"Install with: pip install networkx"
|
||||
)
|
||||
|
||||
graph = nx.DiGraph()
|
||||
|
||||
# Add nodes with attributes
|
||||
for node_id, node in self.nodes.items():
|
||||
graph.add_node(
|
||||
node_id,
|
||||
name=node.name,
|
||||
kind=node.kind,
|
||||
file_path=node.file_path,
|
||||
score=node.score,
|
||||
)
|
||||
|
||||
# Add edges with relationship type
|
||||
for from_id, to_id, rel_type in self.edges:
|
||||
graph.add_edge(from_id, to_id, relationship=rel_type)
|
||||
|
||||
return graph
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'nodes' and 'edges' keys.
|
||||
"""
|
||||
return {
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"edges": [
|
||||
{"from": from_id, "to": to_id, "relationship": rel_type}
|
||||
for from_id, to_id, rel_type in self.edges
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
|
||||
"""Create CodeAssociationGraph from dictionary representation.
|
||||
|
||||
Args:
|
||||
data: Dictionary with 'nodes' and 'edges' keys.
|
||||
|
||||
Returns:
|
||||
New CodeAssociationGraph instance.
|
||||
"""
|
||||
graph = cls()
|
||||
|
||||
# Load nodes
|
||||
for node_id, node_data in data.get("nodes", {}).items():
|
||||
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
|
||||
|
||||
# Load edges
|
||||
for edge_data in data.get("edges", []):
|
||||
graph.edges.append((
|
||||
edge_data["from"],
|
||||
edge_data["to"],
|
||||
edge_data["relationship"],
|
||||
))
|
||||
|
||||
return graph
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of nodes in the graph."""
|
||||
return len(self.nodes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResultCluster:
|
||||
"""Clustered search result containing related code symbols.
|
||||
|
||||
Search results are grouped into clusters based on graph community
|
||||
detection or embedding similarity. Each cluster represents a
|
||||
conceptually related group of code symbols.
|
||||
|
||||
Attributes:
|
||||
cluster_id: Unique cluster identifier.
|
||||
score: Cluster relevance score (max of symbol scores).
|
||||
title: Human-readable cluster title/summary.
|
||||
symbols: List of CodeSymbolNode in this cluster.
|
||||
metadata: Additional cluster metadata.
|
||||
"""
|
||||
|
||||
cluster_id: str
|
||||
score: float
|
||||
title: str
|
||||
symbols: List[CodeSymbolNode] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate cluster fields."""
|
||||
if not self.cluster_id:
|
||||
raise ValueError("cluster_id cannot be empty")
|
||||
if self.score < 0:
|
||||
raise ValueError("score must be >= 0")
|
||||
|
||||
def add_symbol(self, symbol: CodeSymbolNode) -> None:
|
||||
"""Add a symbol to the cluster.
|
||||
|
||||
Args:
|
||||
symbol: CodeSymbolNode to add.
|
||||
"""
|
||||
self.symbols.append(symbol)
|
||||
|
||||
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
|
||||
"""Get top N symbols by score.
|
||||
|
||||
Args:
|
||||
n: Number of symbols to return.
|
||||
|
||||
Returns:
|
||||
List of top N CodeSymbolNode objects sorted by score descending.
|
||||
"""
|
||||
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
|
||||
return sorted_symbols[:n]
|
||||
|
||||
def update_score(self) -> None:
|
||||
"""Update cluster score to max of symbol scores."""
|
||||
if self.symbols:
|
||||
self.score = max(s.score for s in self.symbols)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the cluster.
|
||||
"""
|
||||
return {
|
||||
"cluster_id": self.cluster_id,
|
||||
"score": self.score,
|
||||
"title": self.title,
|
||||
"symbols": [s.to_dict() for s in self.symbols],
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
|
||||
"""Create SearchResultCluster from dictionary representation.
|
||||
|
||||
Args:
|
||||
data: Dictionary with cluster data.
|
||||
|
||||
Returns:
|
||||
New SearchResultCluster instance.
|
||||
"""
|
||||
return cls(
|
||||
cluster_id=data["cluster_id"],
|
||||
score=data["score"],
|
||||
title=data["title"],
|
||||
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of symbols in the cluster."""
|
||||
return len(self.symbols)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallHierarchyItem:
|
||||
"""LSP CallHierarchyItem for representing callers/callees.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name (function, method, etc.).
|
||||
kind: Symbol kind (function, method, etc.).
|
||||
file_path: Absolute file path.
|
||||
range: Position range in the file.
|
||||
detail: Optional additional detail (e.g., signature).
|
||||
"""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
detail: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
}
|
||||
if self.detail:
|
||||
result["detail"] = self.detail
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from dictionary representation."""
|
||||
return cls(
|
||||
name=data.get("name", "unknown"),
|
||||
kind=data.get("kind", "unknown"),
|
||||
file_path=data.get("file_path", data.get("uri", "")),
|
||||
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from LSP response format.
|
||||
|
||||
LSP uses 0-based line numbers and 'character' instead of 'char'.
|
||||
"""
|
||||
uri = data.get("uri", data.get("file_path", ""))
|
||||
# Strip file:// prefix
|
||||
file_path = uri
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
return cls(
|
||||
name=data.get("name", "unknown"),
|
||||
kind=str(data.get("kind", "unknown")),
|
||||
file_path=file_path,
|
||||
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
26
codex-lens/build/lib/codexlens/indexing/__init__.py
Normal file
26
codex-lens/build/lib/codexlens/indexing/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Code indexing and symbol extraction."""
|
||||
from codexlens.indexing.symbol_extractor import SymbolExtractor
|
||||
from codexlens.indexing.embedding import (
|
||||
BinaryEmbeddingBackend,
|
||||
DenseEmbeddingBackend,
|
||||
CascadeEmbeddingBackend,
|
||||
get_cascade_embedder,
|
||||
binarize_embedding,
|
||||
pack_binary_embedding,
|
||||
unpack_binary_embedding,
|
||||
hamming_distance,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SymbolExtractor",
|
||||
# Cascade embedding backends
|
||||
"BinaryEmbeddingBackend",
|
||||
"DenseEmbeddingBackend",
|
||||
"CascadeEmbeddingBackend",
|
||||
"get_cascade_embedder",
|
||||
# Utility functions
|
||||
"binarize_embedding",
|
||||
"pack_binary_embedding",
|
||||
"unpack_binary_embedding",
|
||||
"hamming_distance",
|
||||
]
|
||||
582
codex-lens/build/lib/codexlens/indexing/embedding.py
Normal file
582
codex-lens/build/lib/codexlens/indexing/embedding.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Multi-type embedding backends for cascade retrieval.
|
||||
|
||||
This module provides embedding backends optimized for cascade retrieval:
|
||||
1. BinaryEmbeddingBackend - Fast coarse filtering with binary vectors
|
||||
2. DenseEmbeddingBackend - High-precision dense vectors for reranking
|
||||
3. CascadeEmbeddingBackend - Combined binary + dense for two-stage retrieval
|
||||
|
||||
Cascade retrieval workflow:
|
||||
1. Binary search (fast, ~32 bytes/vector) -> top-K candidates
|
||||
2. Dense rerank (precise, ~8KB/vector) -> final results
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from codexlens.semantic.base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utility Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def binarize_embedding(embedding: np.ndarray) -> np.ndarray:
|
||||
"""Convert float embedding to binary vector.
|
||||
|
||||
Applies sign-based quantization: values > 0 become 1, values <= 0 become 0.
|
||||
|
||||
Args:
|
||||
embedding: Float32 embedding of any dimension
|
||||
|
||||
Returns:
|
||||
Binary vector (uint8 with values 0 or 1) of same dimension
|
||||
"""
|
||||
return (embedding > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def pack_binary_embedding(binary_vector: np.ndarray) -> bytes:
|
||||
"""Pack binary vector into compact bytes format.
|
||||
|
||||
Packs 8 binary values into each byte for storage efficiency.
|
||||
For a 256-dim binary vector, output is 32 bytes.
|
||||
|
||||
Args:
|
||||
binary_vector: Binary vector (uint8 with values 0 or 1)
|
||||
|
||||
Returns:
|
||||
Packed bytes (length = ceil(dim / 8))
|
||||
"""
|
||||
# Ensure vector length is multiple of 8 by padding if needed
|
||||
dim = len(binary_vector)
|
||||
padded_dim = ((dim + 7) // 8) * 8
|
||||
if padded_dim > dim:
|
||||
padded = np.zeros(padded_dim, dtype=np.uint8)
|
||||
padded[:dim] = binary_vector
|
||||
binary_vector = padded
|
||||
|
||||
# Pack 8 bits per byte
|
||||
packed = np.packbits(binary_vector)
|
||||
return packed.tobytes()
|
||||
|
||||
|
||||
def unpack_binary_embedding(packed_bytes: bytes, dim: int = 256) -> np.ndarray:
|
||||
"""Unpack bytes back to binary vector.
|
||||
|
||||
Args:
|
||||
packed_bytes: Packed binary data
|
||||
dim: Original vector dimension (default: 256)
|
||||
|
||||
Returns:
|
||||
Binary vector (uint8 with values 0 or 1)
|
||||
"""
|
||||
unpacked = np.unpackbits(np.frombuffer(packed_bytes, dtype=np.uint8))
|
||||
return unpacked[:dim]
|
||||
|
||||
|
||||
def hamming_distance(a: bytes, b: bytes) -> int:
|
||||
"""Compute Hamming distance between two packed binary vectors.
|
||||
|
||||
Uses XOR and popcount for efficient distance computation.
|
||||
|
||||
Args:
|
||||
a: First packed binary vector
|
||||
b: Second packed binary vector
|
||||
|
||||
Returns:
|
||||
Hamming distance (number of differing bits)
|
||||
"""
|
||||
a_arr = np.frombuffer(a, dtype=np.uint8)
|
||||
b_arr = np.frombuffer(b, dtype=np.uint8)
|
||||
xor = np.bitwise_xor(a_arr, b_arr)
|
||||
return int(np.unpackbits(xor).sum())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Binary Embedding Backend
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BinaryEmbeddingBackend(BaseEmbedder):
|
||||
"""Generate 256-dimensional binary embeddings for fast coarse retrieval.
|
||||
|
||||
Uses a lightweight embedding model and applies sign-based quantization
|
||||
to produce compact binary vectors (32 bytes per embedding).
|
||||
|
||||
Suitable for:
|
||||
- First-stage candidate retrieval
|
||||
- Hamming distance-based similarity search
|
||||
- Memory-constrained environments
|
||||
|
||||
Model: sentence-transformers/all-MiniLM-L6-v2 (384 dim) -> quantized to 256 bits
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, fast
|
||||
BINARY_DIM = 256
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
) -> None:
|
||||
"""Initialize binary embedding backend.
|
||||
|
||||
Args:
|
||||
model_name: Base embedding model name. Defaults to BAAI/bge-small-en-v1.5
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
"""
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Semantic search dependencies not available. "
|
||||
"Install with: pip install codexlens[semantic]"
|
||||
)
|
||||
|
||||
self._model_name = model_name or self.DEFAULT_MODEL
|
||||
self._use_gpu = use_gpu
|
||||
self._model = None
|
||||
|
||||
# Projection matrix for dimension reduction (lazily initialized)
|
||||
self._projection_matrix: Optional[np.ndarray] = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return binary embedding dimension (256)."""
|
||||
return self.BINARY_DIM
|
||||
|
||||
@property
|
||||
def packed_bytes(self) -> int:
|
||||
"""Return packed bytes size (32 bytes for 256 bits)."""
|
||||
return self.BINARY_DIM // 8
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from codexlens.semantic.gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
|
||||
try:
|
||||
self._model = TextEmbedding(
|
||||
model_name=self._model_name,
|
||||
providers=providers,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for older fastembed versions
|
||||
self._model = TextEmbedding(model_name=self._model_name)
|
||||
|
||||
logger.debug(f"BinaryEmbeddingBackend loaded model: {self._model_name}")
|
||||
|
||||
def _get_projection_matrix(self, input_dim: int) -> np.ndarray:
|
||||
"""Get or create projection matrix for dimension reduction.
|
||||
|
||||
Uses random projection with fixed seed for reproducibility.
|
||||
|
||||
Args:
|
||||
input_dim: Input embedding dimension from base model
|
||||
|
||||
Returns:
|
||||
Projection matrix of shape (input_dim, BINARY_DIM)
|
||||
"""
|
||||
if self._projection_matrix is not None:
|
||||
return self._projection_matrix
|
||||
|
||||
# Fixed seed for reproducibility across sessions
|
||||
rng = np.random.RandomState(42)
|
||||
# Gaussian random projection
|
||||
self._projection_matrix = rng.randn(input_dim, self.BINARY_DIM).astype(np.float32)
|
||||
# Normalize columns for consistent scale
|
||||
norms = np.linalg.norm(self._projection_matrix, axis=0, keepdims=True)
|
||||
self._projection_matrix /= (norms + 1e-8)
|
||||
|
||||
return self._projection_matrix
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate binary embeddings as numpy array.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Binary embeddings of shape (n_texts, 256) with values 0 or 1
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Get base float embeddings
|
||||
float_embeddings = np.array(list(self._model.embed(texts)))
|
||||
input_dim = float_embeddings.shape[1]
|
||||
|
||||
# Project to target dimension if needed
|
||||
if input_dim != self.BINARY_DIM:
|
||||
projection = self._get_projection_matrix(input_dim)
|
||||
float_embeddings = float_embeddings @ projection
|
||||
|
||||
# Binarize
|
||||
return binarize_embedding(float_embeddings)
|
||||
|
||||
def embed_packed(self, texts: str | Iterable[str]) -> List[bytes]:
|
||||
"""Generate packed binary embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
List of packed bytes (32 bytes each for 256-dim)
|
||||
"""
|
||||
binary = self.embed_to_numpy(texts)
|
||||
return [pack_binary_embedding(vec) for vec in binary]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dense Embedding Backend
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DenseEmbeddingBackend(BaseEmbedder):
|
||||
"""Generate high-dimensional dense embeddings for precise reranking.
|
||||
|
||||
Uses large embedding models to produce 2048-dimensional float32 vectors
|
||||
for maximum retrieval quality.
|
||||
|
||||
Suitable for:
|
||||
- Second-stage reranking
|
||||
- High-precision similarity search
|
||||
- Quality-critical applications
|
||||
|
||||
Model: BAAI/bge-large-en-v1.5 (1024 dim) with optional expansion
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, use small for testing
|
||||
TARGET_DIM = 768 # Reduced target for faster testing
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
expand_dim: bool = True,
|
||||
) -> None:
|
||||
"""Initialize dense embedding backend.
|
||||
|
||||
Args:
|
||||
model_name: Dense embedding model name. Defaults to BAAI/bge-large-en-v1.5
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
expand_dim: If True, expand embeddings to TARGET_DIM using learned expansion
|
||||
"""
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Semantic search dependencies not available. "
|
||||
"Install with: pip install codexlens[semantic]"
|
||||
)
|
||||
|
||||
self._model_name = model_name or self.DEFAULT_MODEL
|
||||
self._use_gpu = use_gpu
|
||||
self._expand_dim = expand_dim
|
||||
self._model = None
|
||||
self._native_dim: Optional[int] = None
|
||||
|
||||
# Expansion matrix for dimension expansion (lazily initialized)
|
||||
self._expansion_matrix: Optional[np.ndarray] = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimension.
|
||||
|
||||
Returns TARGET_DIM if expand_dim is True, otherwise native model dimension.
|
||||
"""
|
||||
if self._expand_dim:
|
||||
return self.TARGET_DIM
|
||||
# Return cached native dim or estimate based on model
|
||||
if self._native_dim is not None:
|
||||
return self._native_dim
|
||||
# Model dimension estimates
|
||||
model_dims = {
|
||||
"BAAI/bge-large-en-v1.5": 1024,
|
||||
"BAAI/bge-base-en-v1.5": 768,
|
||||
"BAAI/bge-small-en-v1.5": 384,
|
||||
"intfloat/multilingual-e5-large": 1024,
|
||||
}
|
||||
return model_dims.get(self._model_name, 1024)
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit."""
|
||||
return 512 # Conservative default for large models
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load the embedding model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from codexlens.semantic.gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
|
||||
try:
|
||||
self._model = TextEmbedding(
|
||||
model_name=self._model_name,
|
||||
providers=providers,
|
||||
)
|
||||
except TypeError:
|
||||
self._model = TextEmbedding(model_name=self._model_name)
|
||||
|
||||
logger.debug(f"DenseEmbeddingBackend loaded model: {self._model_name}")
|
||||
|
||||
def _get_expansion_matrix(self, input_dim: int) -> np.ndarray:
|
||||
"""Get or create expansion matrix for dimension expansion.
|
||||
|
||||
Uses random orthogonal projection for information-preserving expansion.
|
||||
|
||||
Args:
|
||||
input_dim: Input embedding dimension from base model
|
||||
|
||||
Returns:
|
||||
Expansion matrix of shape (input_dim, TARGET_DIM)
|
||||
"""
|
||||
if self._expansion_matrix is not None:
|
||||
return self._expansion_matrix
|
||||
|
||||
# Fixed seed for reproducibility
|
||||
rng = np.random.RandomState(123)
|
||||
|
||||
# Create semi-orthogonal expansion matrix
|
||||
# First input_dim columns form identity-like structure
|
||||
self._expansion_matrix = np.zeros((input_dim, self.TARGET_DIM), dtype=np.float32)
|
||||
|
||||
# Copy original dimensions
|
||||
copy_dim = min(input_dim, self.TARGET_DIM)
|
||||
self._expansion_matrix[:copy_dim, :copy_dim] = np.eye(copy_dim, dtype=np.float32)
|
||||
|
||||
# Fill remaining with random projections
|
||||
if self.TARGET_DIM > input_dim:
|
||||
random_part = rng.randn(input_dim, self.TARGET_DIM - input_dim).astype(np.float32)
|
||||
# Normalize
|
||||
norms = np.linalg.norm(random_part, axis=0, keepdims=True)
|
||||
random_part /= (norms + 1e-8)
|
||||
self._expansion_matrix[:, input_dim:] = random_part
|
||||
|
||||
return self._expansion_matrix
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate dense embeddings as numpy array.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Dense embeddings of shape (n_texts, TARGET_DIM) as float32
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Get base float embeddings
|
||||
float_embeddings = np.array(list(self._model.embed(texts)), dtype=np.float32)
|
||||
self._native_dim = float_embeddings.shape[1]
|
||||
|
||||
# Expand to target dimension if needed
|
||||
if self._expand_dim and self._native_dim < self.TARGET_DIM:
|
||||
expansion = self._get_expansion_matrix(self._native_dim)
|
||||
float_embeddings = float_embeddings @ expansion
|
||||
|
||||
return float_embeddings
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cascade Embedding Backend
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CascadeEmbeddingBackend(BaseEmbedder):
|
||||
"""Combined binary + dense embedding backend for cascade retrieval.
|
||||
|
||||
Generates both binary (for fast coarse filtering) and dense (for precise
|
||||
reranking) embeddings in a single pass, optimized for two-stage retrieval.
|
||||
|
||||
Cascade workflow:
|
||||
1. encode_cascade() returns (binary_embeddings, dense_embeddings)
|
||||
2. Binary search: Use Hamming distance on binary vectors -> top-K candidates
|
||||
3. Dense rerank: Use cosine similarity on dense vectors -> final results
|
||||
|
||||
Memory efficiency:
|
||||
- Binary: 32 bytes per vector (256 bits)
|
||||
- Dense: 8192 bytes per vector (2048 x float32)
|
||||
- Total: ~8KB per document for full cascade support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
binary_model: Optional[str] = None,
|
||||
dense_model: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
) -> None:
|
||||
"""Initialize cascade embedding backend.
|
||||
|
||||
Args:
|
||||
binary_model: Model for binary embeddings. Defaults to BAAI/bge-small-en-v1.5
|
||||
dense_model: Model for dense embeddings. Defaults to BAAI/bge-large-en-v1.5
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
"""
|
||||
self._binary_backend = BinaryEmbeddingBackend(
|
||||
model_name=binary_model,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
self._dense_backend = DenseEmbeddingBackend(
|
||||
model_name=dense_model,
|
||||
use_gpu=use_gpu,
|
||||
expand_dim=True,
|
||||
)
|
||||
self._use_gpu = use_gpu
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model names for both backends."""
|
||||
return f"cascade({self._binary_backend.model_name}, {self._dense_backend.model_name})"
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return dense embedding dimension (for compatibility)."""
|
||||
return self._dense_backend.embedding_dim
|
||||
|
||||
@property
|
||||
def binary_dim(self) -> int:
|
||||
"""Return binary embedding dimension."""
|
||||
return self._binary_backend.embedding_dim
|
||||
|
||||
@property
|
||||
def dense_dim(self) -> int:
|
||||
"""Return dense embedding dimension."""
|
||||
return self._dense_backend.embedding_dim
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate dense embeddings (for BaseEmbedder compatibility).
|
||||
|
||||
For cascade embeddings, use encode_cascade() instead.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Dense embeddings of shape (n_texts, dense_dim)
|
||||
"""
|
||||
return self._dense_backend.embed_to_numpy(texts)
|
||||
|
||||
def encode_cascade(
|
||||
self,
|
||||
texts: str | Iterable[str],
|
||||
batch_size: int = 32,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Generate both binary and dense embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- binary_embeddings: Shape (n_texts, 256), uint8 values 0/1
|
||||
- dense_embeddings: Shape (n_texts, 2048), float32
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
binary_embeddings = self._binary_backend.embed_to_numpy(texts)
|
||||
dense_embeddings = self._dense_backend.embed_to_numpy(texts)
|
||||
|
||||
return binary_embeddings, dense_embeddings
|
||||
|
||||
def encode_binary(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate only binary embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Binary embeddings of shape (n_texts, 256)
|
||||
"""
|
||||
return self._binary_backend.embed_to_numpy(texts)
|
||||
|
||||
def encode_dense(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Generate only dense embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
Dense embeddings of shape (n_texts, 2048)
|
||||
"""
|
||||
return self._dense_backend.embed_to_numpy(texts)
|
||||
|
||||
def encode_binary_packed(self, texts: str | Iterable[str]) -> List[bytes]:
|
||||
"""Generate packed binary embeddings.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts
|
||||
|
||||
Returns:
|
||||
List of packed bytes (32 bytes each)
|
||||
"""
|
||||
return self._binary_backend.embed_packed(texts)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Factory Function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_cascade_embedder(
|
||||
binary_model: Optional[str] = None,
|
||||
dense_model: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
) -> CascadeEmbeddingBackend:
|
||||
"""Factory function to create a cascade embedder.
|
||||
|
||||
Args:
|
||||
binary_model: Model for binary embeddings (default: BAAI/bge-small-en-v1.5)
|
||||
dense_model: Model for dense embeddings (default: BAAI/bge-large-en-v1.5)
|
||||
use_gpu: Whether to use GPU acceleration
|
||||
|
||||
Returns:
|
||||
Configured CascadeEmbeddingBackend instance
|
||||
|
||||
Example:
|
||||
>>> embedder = get_cascade_embedder()
|
||||
>>> binary, dense = embedder.encode_cascade(["hello world"])
|
||||
>>> binary.shape # (1, 256)
|
||||
>>> dense.shape # (1, 2048)
|
||||
"""
|
||||
return CascadeEmbeddingBackend(
|
||||
binary_model=binary_model,
|
||||
dense_model=dense_model,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
277
codex-lens/build/lib/codexlens/indexing/symbol_extractor.py
Normal file
277
codex-lens/build/lib/codexlens/indexing/symbol_extractor.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Symbol and relationship extraction from source code."""
|
||||
import re
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
except Exception: # pragma: no cover - optional dependency / platform variance
|
||||
TreeSitterSymbolParser = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class SymbolExtractor:
|
||||
"""Extract symbols and relationships from source code using regex patterns."""
|
||||
|
||||
# Pattern definitions for different languages
|
||||
PATTERNS = {
|
||||
'python': {
|
||||
'function': r'^(?:async\s+)?def\s+(\w+)\s*\(',
|
||||
'class': r'^class\s+(\w+)\s*[:\(]',
|
||||
'import': r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)',
|
||||
'call': r'(?<![.\w])(\w+)\s*\(',
|
||||
},
|
||||
'typescript': {
|
||||
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*[<\(]',
|
||||
'class': r'(?:export\s+)?class\s+(\w+)',
|
||||
'import': r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
|
||||
'call': r'(?<![.\w])(\w+)\s*[<\(]',
|
||||
},
|
||||
'javascript': {
|
||||
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(',
|
||||
'class': r'(?:export\s+)?class\s+(\w+)',
|
||||
'import': r"(?:import|require)\s*\(?['\"]([^'\"]+)['\"]",
|
||||
'call': r'(?<![.\w])(\w+)\s*\(',
|
||||
}
|
||||
}
|
||||
|
||||
LANGUAGE_MAP = {
|
||||
'.py': 'python',
|
||||
'.ts': 'typescript',
|
||||
'.tsx': 'typescript',
|
||||
'.js': 'javascript',
|
||||
'.jsx': 'javascript',
|
||||
}
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.db_conn: Optional[sqlite3.Connection] = None
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Connect to database and ensure schema exists."""
|
||||
self.db_conn = sqlite3.connect(str(self.db_path))
|
||||
self._ensure_tables()
|
||||
|
||||
def __enter__(self) -> "SymbolExtractor":
|
||||
"""Context manager entry: connect to database."""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Context manager exit: close database connection."""
|
||||
self.close()
|
||||
|
||||
def _ensure_tables(self) -> None:
|
||||
"""Create symbols and relationships tables if they don't exist."""
|
||||
if not self.db_conn:
|
||||
return
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
# Create symbols table with qualified_name
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS symbols (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
qualified_name TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL,
|
||||
UNIQUE(file_path, name, start_line)
|
||||
)
|
||||
''')
|
||||
|
||||
# Create relationships table with target_symbol_fqn
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS symbol_relationships (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_symbol_fqn TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
line INTEGER,
|
||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols(id) ON DELETE CASCADE
|
||||
)
|
||||
''')
|
||||
|
||||
# Create performance indexes
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_path)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_source ON symbol_relationships(source_symbol_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_target ON symbol_relationships(target_symbol_fqn)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_type ON symbol_relationships(relationship_type)')
|
||||
|
||||
self.db_conn.commit()
|
||||
|
||||
def extract_from_file(self, file_path: Path, content: str) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Extract symbols and relationships from file content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file
|
||||
content: File content as string
|
||||
|
||||
Returns:
|
||||
Tuple of (symbols, relationships) where:
|
||||
- symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
|
||||
- relationships: List of relationship dicts with source_scope, target, type, file_path, line
|
||||
"""
|
||||
ext = file_path.suffix.lower()
|
||||
lang = self.LANGUAGE_MAP.get(ext)
|
||||
|
||||
if not lang or lang not in self.PATTERNS:
|
||||
return [], []
|
||||
|
||||
patterns = self.PATTERNS[lang]
|
||||
symbols = []
|
||||
relationships: List[Dict] = []
|
||||
lines = content.split('\n')
|
||||
|
||||
current_scope = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
# Extract function/class definitions
|
||||
for kind in ['function', 'class']:
|
||||
if kind in patterns:
|
||||
match = re.search(patterns[kind], line)
|
||||
if match:
|
||||
name = match.group(1)
|
||||
qualified_name = f"{file_path.stem}.{name}"
|
||||
symbols.append({
|
||||
'qualified_name': qualified_name,
|
||||
'name': name,
|
||||
'kind': kind,
|
||||
'file_path': str(file_path),
|
||||
'start_line': line_num,
|
||||
'end_line': line_num, # Simplified - would need proper parsing for actual end
|
||||
})
|
||||
current_scope = name
|
||||
|
||||
if TreeSitterSymbolParser is not None:
|
||||
try:
|
||||
ts_parser = TreeSitterSymbolParser(lang, file_path)
|
||||
if ts_parser.is_available():
|
||||
indexed = ts_parser.parse(content, file_path)
|
||||
if indexed is not None and indexed.relationships:
|
||||
relationships = [
|
||||
{
|
||||
"source_scope": r.source_symbol,
|
||||
"target": r.target_symbol,
|
||||
"type": r.relationship_type.value,
|
||||
"file_path": str(file_path),
|
||||
"line": r.source_line,
|
||||
}
|
||||
for r in indexed.relationships
|
||||
]
|
||||
except Exception:
|
||||
relationships = []
|
||||
|
||||
# Regex fallback for relationships (when tree-sitter is unavailable)
|
||||
if not relationships:
|
||||
current_scope = None
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
for kind in ['function', 'class']:
|
||||
if kind in patterns:
|
||||
match = re.search(patterns[kind], line)
|
||||
if match:
|
||||
current_scope = match.group(1)
|
||||
|
||||
# Extract imports
|
||||
if 'import' in patterns:
|
||||
match = re.search(patterns['import'], line)
|
||||
if match:
|
||||
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
if import_target and current_scope:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': import_target.strip(),
|
||||
'type': 'imports',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
|
||||
# Extract function calls (simplified)
|
||||
if 'call' in patterns and current_scope:
|
||||
for match in re.finditer(patterns['call'], line):
|
||||
call_name = match.group(1)
|
||||
# Skip common keywords and the current function
|
||||
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': call_name,
|
||||
'type': 'calls',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
|
||||
return symbols, relationships
|
||||
|
||||
def save_symbols(self, symbols: List[Dict]) -> Dict[str, int]:
|
||||
"""Save symbols to database and return name->id mapping.
|
||||
|
||||
Args:
|
||||
symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol name to database id
|
||||
"""
|
||||
if not self.db_conn or not symbols:
|
||||
return {}
|
||||
|
||||
cursor = self.db_conn.cursor()
|
||||
name_to_id = {}
|
||||
|
||||
for sym in symbols:
|
||||
try:
|
||||
cursor.execute('''
|
||||
INSERT OR IGNORE INTO symbols
|
||||
(qualified_name, name, kind, file_path, start_line, end_line)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (sym['qualified_name'], sym['name'], sym['kind'],
|
||||
sym['file_path'], sym['start_line'], sym['end_line']))
|
||||
|
||||
# Get the id
|
||||
cursor.execute('''
|
||||
SELECT id FROM symbols
|
||||
WHERE file_path = ? AND name = ? AND start_line = ?
|
||||
''', (sym['file_path'], sym['name'], sym['start_line']))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
name_to_id[sym['name']] = row[0]
|
||||
except sqlite3.Error:
|
||||
continue
|
||||
|
||||
self.db_conn.commit()
|
||||
return name_to_id
|
||||
|
||||
def save_relationships(self, relationships: List[Dict], name_to_id: Dict[str, int]) -> None:
|
||||
"""Save relationships to database.
|
||||
|
||||
Args:
|
||||
relationships: List of relationship dicts with source_scope, target, type, file_path, line
|
||||
name_to_id: Dictionary mapping symbol names to database ids
|
||||
"""
|
||||
if not self.db_conn or not relationships:
|
||||
return
|
||||
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
for rel in relationships:
|
||||
source_id = name_to_id.get(rel['source_scope'])
|
||||
if source_id:
|
||||
try:
|
||||
cursor.execute('''
|
||||
INSERT INTO symbol_relationships
|
||||
(source_symbol_id, target_symbol_fqn, relationship_type, file_path, line)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', (source_id, rel['target'], rel['type'], rel['file_path'], rel['line']))
|
||||
except sqlite3.Error:
|
||||
continue
|
||||
|
||||
self.db_conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self.db_conn:
|
||||
self.db_conn.close()
|
||||
self.db_conn = None
|
||||
34
codex-lens/build/lib/codexlens/lsp/__init__.py
Normal file
34
codex-lens/build/lib/codexlens/lsp/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""LSP module for real-time language server integration.
|
||||
|
||||
This module provides:
|
||||
- LspBridge: HTTP bridge to VSCode language servers
|
||||
- LspGraphBuilder: Build code association graphs via LSP
|
||||
- Location: Position in a source file
|
||||
|
||||
Example:
|
||||
>>> from codexlens.lsp import LspBridge, LspGraphBuilder
|
||||
>>>
|
||||
>>> async with LspBridge() as bridge:
|
||||
... refs = await bridge.get_references(symbol)
|
||||
... graph = await LspGraphBuilder().build_from_seeds(seeds, bridge)
|
||||
"""
|
||||
|
||||
from codexlens.lsp.lsp_bridge import (
|
||||
CacheEntry,
|
||||
Location,
|
||||
LspBridge,
|
||||
)
|
||||
from codexlens.lsp.lsp_graph_builder import (
|
||||
LspGraphBuilder,
|
||||
)
|
||||
|
||||
# Alias for backward compatibility
|
||||
GraphBuilder = LspGraphBuilder
|
||||
|
||||
__all__ = [
|
||||
"CacheEntry",
|
||||
"GraphBuilder",
|
||||
"Location",
|
||||
"LspBridge",
|
||||
"LspGraphBuilder",
|
||||
]
|
||||
551
codex-lens/build/lib/codexlens/lsp/handlers.py
Normal file
551
codex-lens/build/lib/codexlens/lsp/handlers.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""LSP request handlers for codex-lens.
|
||||
|
||||
This module contains handlers for LSP requests:
|
||||
- textDocument/definition
|
||||
- textDocument/completion
|
||||
- workspace/symbol
|
||||
- textDocument/didSave
|
||||
- textDocument/hover
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
from codexlens.lsp.server import server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Symbol kind mapping from codex-lens to LSP
|
||||
SYMBOL_KIND_MAP = {
|
||||
"class": lsp.SymbolKind.Class,
|
||||
"function": lsp.SymbolKind.Function,
|
||||
"method": lsp.SymbolKind.Method,
|
||||
"variable": lsp.SymbolKind.Variable,
|
||||
"constant": lsp.SymbolKind.Constant,
|
||||
"property": lsp.SymbolKind.Property,
|
||||
"field": lsp.SymbolKind.Field,
|
||||
"interface": lsp.SymbolKind.Interface,
|
||||
"module": lsp.SymbolKind.Module,
|
||||
"namespace": lsp.SymbolKind.Namespace,
|
||||
"package": lsp.SymbolKind.Package,
|
||||
"enum": lsp.SymbolKind.Enum,
|
||||
"enum_member": lsp.SymbolKind.EnumMember,
|
||||
"struct": lsp.SymbolKind.Struct,
|
||||
"type": lsp.SymbolKind.TypeParameter,
|
||||
"type_alias": lsp.SymbolKind.TypeParameter,
|
||||
}
|
||||
|
||||
# Completion kind mapping from codex-lens to LSP
|
||||
COMPLETION_KIND_MAP = {
|
||||
"class": lsp.CompletionItemKind.Class,
|
||||
"function": lsp.CompletionItemKind.Function,
|
||||
"method": lsp.CompletionItemKind.Method,
|
||||
"variable": lsp.CompletionItemKind.Variable,
|
||||
"constant": lsp.CompletionItemKind.Constant,
|
||||
"property": lsp.CompletionItemKind.Property,
|
||||
"field": lsp.CompletionItemKind.Field,
|
||||
"interface": lsp.CompletionItemKind.Interface,
|
||||
"module": lsp.CompletionItemKind.Module,
|
||||
"enum": lsp.CompletionItemKind.Enum,
|
||||
"enum_member": lsp.CompletionItemKind.EnumMember,
|
||||
"struct": lsp.CompletionItemKind.Struct,
|
||||
"type": lsp.CompletionItemKind.TypeParameter,
|
||||
"type_alias": lsp.CompletionItemKind.TypeParameter,
|
||||
}
|
||||
|
||||
|
||||
def _path_to_uri(path: Union[str, Path]) -> str:
|
||||
"""Convert a file path to a URI.
|
||||
|
||||
Args:
|
||||
path: File path (string or Path object)
|
||||
|
||||
Returns:
|
||||
File URI string
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
# Handle Windows paths
|
||||
if path_str.startswith("/"):
|
||||
return f"file://{quote(path_str)}"
|
||||
else:
|
||||
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
|
||||
|
||||
|
||||
def _uri_to_path(uri: str) -> Path:
|
||||
"""Convert a URI to a file path.
|
||||
|
||||
Args:
|
||||
uri: File URI string
|
||||
|
||||
Returns:
|
||||
Path object
|
||||
"""
|
||||
path = uri.replace("file:///", "").replace("file://", "")
|
||||
return Path(unquote(path))
|
||||
|
||||
|
||||
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
|
||||
"""Extract the word at the given position in the document.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Word at position, or None if no word found
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return None
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
return None
|
||||
|
||||
# Find word boundaries
|
||||
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
|
||||
for match in word_pattern.finditer(line_text):
|
||||
if match.start() <= character <= match.end():
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
|
||||
"""Extract the incomplete word prefix at the given position.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Prefix string (may be empty)
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return ""
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
character = len(line_text)
|
||||
|
||||
# Extract text before cursor
|
||||
before_cursor = line_text[:character]
|
||||
|
||||
# Find the start of the current word
|
||||
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
|
||||
"""Convert a codex-lens Symbol to an LSP Location.
|
||||
|
||||
Args:
|
||||
symbol: codex-lens Symbol object
|
||||
|
||||
Returns:
|
||||
LSP Location, or None if symbol has no file
|
||||
"""
|
||||
if not symbol.file:
|
||||
return None
|
||||
|
||||
# LSP uses 0-based lines, codex-lens uses 1-based
|
||||
start_line = max(0, symbol.range[0] - 1)
|
||||
end_line = max(0, symbol.range[1] - 1)
|
||||
|
||||
return lsp.Location(
|
||||
uri=_path_to_uri(symbol.file),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(line=start_line, character=0),
|
||||
end=lsp.Position(line=end_line, character=0),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
|
||||
"""Map codex-lens symbol kind to LSP SymbolKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP SymbolKind
|
||||
"""
|
||||
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
|
||||
|
||||
|
||||
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
|
||||
"""Map codex-lens symbol kind to LSP CompletionItemKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP CompletionItemKind
|
||||
"""
|
||||
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LSP Request Handlers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
|
||||
def lsp_definition(
|
||||
params: lsp.DefinitionParams,
|
||||
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
|
||||
"""Handle textDocument/definition request.
|
||||
|
||||
Finds the definition of the symbol at the cursor position.
|
||||
"""
|
||||
if not server.global_index:
|
||||
logger.debug("No global index available for definition lookup")
|
||||
return None
|
||||
|
||||
# Get document
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
# Get word at position
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
logger.debug("No word found at position")
|
||||
return None
|
||||
|
||||
logger.debug("Looking up definition for: %s", word)
|
||||
|
||||
# Search for exact symbol match
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=False, # Exact match preferred
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
if not exact_matches:
|
||||
# Fall back to prefix search
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=True,
|
||||
)
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
if not exact_matches:
|
||||
logger.debug("No definition found for: %s", word)
|
||||
return None
|
||||
|
||||
# Convert to LSP locations
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
if len(locations) == 1:
|
||||
return locations[0]
|
||||
elif locations:
|
||||
return locations
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error looking up definition: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
|
||||
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
|
||||
"""Handle textDocument/references request.
|
||||
|
||||
Finds all references to the symbol at the cursor position using
|
||||
the code_relationships table for accurate call-site tracking.
|
||||
Falls back to same-name symbol search if search_engine is unavailable.
|
||||
"""
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Finding references for: %s", word)
|
||||
|
||||
try:
|
||||
# Try using search_engine.search_references() for accurate reference tracking
|
||||
if server.search_engine and server.workspace_root:
|
||||
references = server.search_engine.search_references(
|
||||
symbol_name=word,
|
||||
source_path=server.workspace_root,
|
||||
limit=200,
|
||||
)
|
||||
|
||||
if references:
|
||||
locations = []
|
||||
for ref in references:
|
||||
locations.append(
|
||||
lsp.Location(
|
||||
uri=_path_to_uri(ref.file_path),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column,
|
||||
),
|
||||
end=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column + len(word),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return locations if locations else None
|
||||
|
||||
# Fallback: search for symbols with same name using global_index
|
||||
if server.global_index:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=100,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact matches
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
return locations if locations else None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error finding references: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
|
||||
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
|
||||
"""Handle textDocument/completion request.
|
||||
|
||||
Provides code completion suggestions based on indexed symbols.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
prefix = _get_prefix_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not prefix or len(prefix) < 2:
|
||||
# Require at least 2 characters for completion
|
||||
return None
|
||||
|
||||
logger.debug("Completing prefix: %s", prefix)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=prefix,
|
||||
limit=50,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
# Convert to completion items
|
||||
items = []
|
||||
seen_names = set()
|
||||
|
||||
for sym in symbols:
|
||||
if sym.name in seen_names:
|
||||
continue
|
||||
seen_names.add(sym.name)
|
||||
|
||||
items.append(
|
||||
lsp.CompletionItem(
|
||||
label=sym.name,
|
||||
kind=_symbol_kind_to_completion_kind(sym.kind),
|
||||
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
|
||||
sort_text=sym.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return lsp.CompletionList(
|
||||
is_incomplete=len(symbols) >= 50,
|
||||
items=items,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting completions: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
|
||||
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
|
||||
"""Handle textDocument/hover request.
|
||||
|
||||
Provides hover information for the symbol at the cursor position
|
||||
using HoverProvider for rich symbol information including
|
||||
signature, documentation, and location.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Hover for: %s", word)
|
||||
|
||||
try:
|
||||
# Use HoverProvider for rich symbol information
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
provider = HoverProvider(server.global_index, server.registry)
|
||||
info = provider.get_hover_info(word)
|
||||
|
||||
if not info:
|
||||
return None
|
||||
|
||||
# Format as markdown with signature and location
|
||||
content = provider.format_hover_markdown(info)
|
||||
|
||||
return lsp.Hover(
|
||||
contents=lsp.MarkupContent(
|
||||
kind=lsp.MarkupKind.Markdown,
|
||||
value=content,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting hover info: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.WORKSPACE_SYMBOL)
|
||||
def lsp_workspace_symbol(
|
||||
params: lsp.WorkspaceSymbolParams,
|
||||
) -> Optional[List[lsp.SymbolInformation]]:
|
||||
"""Handle workspace/symbol request.
|
||||
|
||||
Searches for symbols across the workspace.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
query = params.query
|
||||
if not query or len(query) < 2:
|
||||
return None
|
||||
|
||||
logger.debug("Workspace symbol search: %s", query)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=query,
|
||||
limit=100,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
result = []
|
||||
for sym in symbols:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
result.append(
|
||||
lsp.SymbolInformation(
|
||||
name=sym.name,
|
||||
kind=_symbol_kind_to_lsp(sym.kind),
|
||||
location=loc,
|
||||
container_name=Path(sym.file).parent.name if sym.file else None,
|
||||
)
|
||||
)
|
||||
|
||||
return result if result else None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error searching workspace symbols: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
|
||||
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didSave notification.
|
||||
|
||||
Triggers incremental re-indexing of the saved file.
|
||||
Note: Full incremental indexing requires WatcherManager integration,
|
||||
which is planned for Phase 2.
|
||||
"""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.info("File saved: %s", file_path)
|
||||
|
||||
# Phase 1: Just log the save event
|
||||
# Phase 2 will integrate with WatcherManager for incremental indexing
|
||||
# if server.watcher_manager:
|
||||
# server.watcher_manager.trigger_reindex(file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
|
||||
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didOpen notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File opened: %s", file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
|
||||
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didClose notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File closed: %s", file_path)
|
||||
834
codex-lens/build/lib/codexlens/lsp/lsp_bridge.py
Normal file
834
codex-lens/build/lib/codexlens/lsp/lsp_bridge.py
Normal file
@@ -0,0 +1,834 @@
|
||||
"""LspBridge service for real-time LSP communication with caching.
|
||||
|
||||
This module provides a bridge to communicate with language servers either via:
|
||||
1. Standalone LSP Manager (direct subprocess communication - default)
|
||||
2. VSCode Bridge extension (HTTP-based, legacy mode)
|
||||
|
||||
Features:
|
||||
- Direct communication with language servers (no VSCode dependency)
|
||||
- Cache with TTL and file modification time invalidation
|
||||
- Graceful error handling with empty results on failure
|
||||
- Support for definition, references, hover, and call hierarchy
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
|
||||
# Check for optional dependencies
|
||||
try:
|
||||
import aiohttp
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Location:
|
||||
"""A location in a source file (LSP response format)."""
|
||||
|
||||
file_path: str
|
||||
line: int
|
||||
character: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"line": self.line,
|
||||
"character": self.character,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_lsp_response(cls, data: Dict[str, Any]) -> "Location":
|
||||
"""Create Location from LSP response format.
|
||||
|
||||
Handles both direct format and VSCode URI format.
|
||||
"""
|
||||
# Handle VSCode URI format (file:///path/to/file)
|
||||
uri = data.get("uri", data.get("file_path", ""))
|
||||
if uri.startswith("file:///"):
|
||||
# Windows: file:///C:/path -> C:/path
|
||||
# Unix: file:///path -> /path
|
||||
file_path = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||
elif uri.startswith("file://"):
|
||||
file_path = uri[7:]
|
||||
else:
|
||||
file_path = uri
|
||||
|
||||
# Get position from range or direct fields
|
||||
if "range" in data:
|
||||
range_data = data["range"]
|
||||
start = range_data.get("start", {})
|
||||
line = start.get("line", 0) + 1 # LSP is 0-based, convert to 1-based
|
||||
character = start.get("character", 0) + 1
|
||||
else:
|
||||
line = data.get("line", 1)
|
||||
character = data.get("character", 1)
|
||||
|
||||
return cls(file_path=file_path, line=line, character=character)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached LSP response with expiration metadata.
|
||||
|
||||
Attributes:
|
||||
data: The cached response data
|
||||
file_mtime: File modification time when cached (for invalidation)
|
||||
cached_at: Unix timestamp when entry was cached
|
||||
"""
|
||||
|
||||
data: Any
|
||||
file_mtime: float
|
||||
cached_at: float
|
||||
|
||||
|
||||
class LspBridge:
|
||||
"""Bridge for real-time LSP communication with language servers.
|
||||
|
||||
By default, uses StandaloneLspManager to directly spawn and communicate
|
||||
with language servers via JSON-RPC over stdio. No VSCode dependency required.
|
||||
|
||||
For legacy mode, can use VSCode Bridge HTTP server (set use_vscode_bridge=True).
|
||||
|
||||
Features:
|
||||
- Direct language server communication (default)
|
||||
- Response caching with TTL and file modification invalidation
|
||||
- Timeout handling
|
||||
- Graceful error handling returning empty results
|
||||
|
||||
Example:
|
||||
# Default: standalone mode (no VSCode needed)
|
||||
async with LspBridge() as bridge:
|
||||
refs = await bridge.get_references(symbol)
|
||||
definition = await bridge.get_definition(symbol)
|
||||
|
||||
# Legacy: VSCode Bridge mode
|
||||
async with LspBridge(use_vscode_bridge=True) as bridge:
|
||||
refs = await bridge.get_references(symbol)
|
||||
"""
|
||||
|
||||
DEFAULT_BRIDGE_URL = "http://127.0.0.1:3457"
|
||||
DEFAULT_TIMEOUT = 30.0 # seconds (increased for standalone mode)
|
||||
DEFAULT_CACHE_TTL = 300 # 5 minutes
|
||||
DEFAULT_MAX_CACHE_SIZE = 1000 # Maximum cache entries
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bridge_url: str = DEFAULT_BRIDGE_URL,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
cache_ttl: int = DEFAULT_CACHE_TTL,
|
||||
max_cache_size: int = DEFAULT_MAX_CACHE_SIZE,
|
||||
use_vscode_bridge: bool = False,
|
||||
workspace_root: Optional[str] = None,
|
||||
config_file: Optional[str] = None,
|
||||
):
|
||||
"""Initialize LspBridge.
|
||||
|
||||
Args:
|
||||
bridge_url: URL of the VSCode Bridge HTTP server (legacy mode only)
|
||||
timeout: Request timeout in seconds
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
max_cache_size: Maximum number of cache entries (LRU eviction)
|
||||
use_vscode_bridge: If True, use VSCode Bridge HTTP mode (requires aiohttp)
|
||||
workspace_root: Root directory for standalone LSP manager
|
||||
config_file: Path to lsp-servers.json configuration file
|
||||
"""
|
||||
self.bridge_url = bridge_url
|
||||
self.timeout = timeout
|
||||
self.cache_ttl = cache_ttl
|
||||
self.max_cache_size = max_cache_size
|
||||
self.use_vscode_bridge = use_vscode_bridge
|
||||
self.workspace_root = workspace_root
|
||||
self.config_file = config_file
|
||||
|
||||
self.cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
|
||||
# VSCode Bridge mode (legacy)
|
||||
self._session: Optional["aiohttp.ClientSession"] = None
|
||||
|
||||
# Standalone mode (default)
|
||||
self._manager: Optional["StandaloneLspManager"] = None
|
||||
self._manager_started = False
|
||||
|
||||
# Validate dependencies
|
||||
if use_vscode_bridge and not HAS_AIOHTTP:
|
||||
raise ImportError(
|
||||
"aiohttp is required for VSCode Bridge mode: pip install aiohttp"
|
||||
)
|
||||
|
||||
async def _ensure_manager(self) -> "StandaloneLspManager":
|
||||
"""Ensure standalone LSP manager is started."""
|
||||
if self._manager is None:
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
self._manager = StandaloneLspManager(
|
||||
workspace_root=self.workspace_root,
|
||||
config_file=self.config_file,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if not self._manager_started:
|
||||
await self._manager.start()
|
||||
self._manager_started = True
|
||||
|
||||
return self._manager
|
||||
|
||||
async def _get_session(self) -> "aiohttp.ClientSession":
|
||||
"""Get or create the aiohttp session (VSCode Bridge mode only)."""
|
||||
if not HAS_AIOHTTP:
|
||||
raise ImportError("aiohttp required for VSCode Bridge mode")
|
||||
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connections and cleanup resources."""
|
||||
# Close VSCode Bridge session
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
# Stop standalone manager
|
||||
if self._manager and self._manager_started:
|
||||
await self._manager.stop()
|
||||
self._manager_started = False
|
||||
|
||||
def _get_file_mtime(self, file_path: str) -> float:
|
||||
"""Get file modification time, or 0 if file doesn't exist."""
|
||||
try:
|
||||
return os.path.getmtime(file_path)
|
||||
except OSError:
|
||||
return 0.0
|
||||
|
||||
def _is_cached(self, cache_key: str, file_path: str) -> bool:
|
||||
"""Check if cache entry is valid.
|
||||
|
||||
Cache is invalid if:
|
||||
- Entry doesn't exist
|
||||
- TTL has expired
|
||||
- File has been modified since caching
|
||||
|
||||
Args:
|
||||
cache_key: The cache key to check
|
||||
file_path: Path to source file for mtime check
|
||||
|
||||
Returns:
|
||||
True if cache is valid and can be used
|
||||
"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
entry = self.cache[cache_key]
|
||||
now = time.time()
|
||||
|
||||
# Check TTL
|
||||
if now - entry.cached_at > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return False
|
||||
|
||||
# Check file modification time
|
||||
current_mtime = self._get_file_mtime(file_path)
|
||||
if current_mtime != entry.file_mtime:
|
||||
del self.cache[cache_key]
|
||||
return False
|
||||
|
||||
# Move to end on access (LRU behavior)
|
||||
self.cache.move_to_end(cache_key)
|
||||
return True
|
||||
|
||||
def _cache(self, key: str, file_path: str, data: Any) -> None:
|
||||
"""Store data in cache with LRU eviction.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
file_path: Path to source file (for mtime tracking)
|
||||
data: Data to cache
|
||||
"""
|
||||
# Remove oldest entries if at capacity
|
||||
while len(self.cache) >= self.max_cache_size:
|
||||
self.cache.popitem(last=False) # Remove oldest (FIFO order)
|
||||
|
||||
# Move to end if key exists (update access order)
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
self.cache[key] = CacheEntry(
|
||||
data=data,
|
||||
file_mtime=self._get_file_mtime(file_path),
|
||||
cached_at=time.time(),
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
self.cache.clear()
|
||||
|
||||
async def _request_vscode_bridge(self, action: str, params: Dict[str, Any]) -> Any:
|
||||
"""Make HTTP request to VSCode Bridge (legacy mode).
|
||||
|
||||
Args:
|
||||
action: The endpoint/action name (e.g., "get_definition")
|
||||
params: Request parameters
|
||||
|
||||
Returns:
|
||||
Response data on success, None on failure
|
||||
"""
|
||||
url = f"{self.bridge_url}/{action}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=params) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
if data.get("success") is False:
|
||||
return None
|
||||
|
||||
return data.get("result")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_references(self, symbol: CodeSymbolNode) -> List[Location]:
|
||||
"""Get all references to a symbol via real-time LSP.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to find references for
|
||||
|
||||
Returns:
|
||||
List of Location objects where the symbol is referenced.
|
||||
Returns empty list on error or timeout.
|
||||
"""
|
||||
cache_key = f"refs:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
locations: List[Location] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_references", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
# Don't cache on connection error (result is None)
|
||||
if result is None:
|
||||
return locations
|
||||
|
||||
if isinstance(result, list):
|
||||
for item in result:
|
||||
try:
|
||||
locations.append(Location.from_lsp_response(item))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_references(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
for item in result:
|
||||
try:
|
||||
locations.append(Location.from_lsp_response(item))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
|
||||
self._cache(cache_key, symbol.file_path, locations)
|
||||
return locations
|
||||
|
||||
async def get_definition(self, symbol: CodeSymbolNode) -> Optional[Location]:
|
||||
"""Get symbol definition location.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to find definition for
|
||||
|
||||
Returns:
|
||||
Location of the definition, or None if not found
|
||||
"""
|
||||
cache_key = f"def:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
location: Optional[Location] = None
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_definition", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result:
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
try:
|
||||
location = Location.from_lsp_response(result[0])
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
elif isinstance(result, dict):
|
||||
try:
|
||||
location = Location.from_lsp_response(result)
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_definition(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
if result:
|
||||
try:
|
||||
location = Location.from_lsp_response(result)
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
|
||||
self._cache(cache_key, symbol.file_path, location)
|
||||
return location
|
||||
|
||||
async def get_call_hierarchy(self, symbol: CodeSymbolNode) -> List[CallHierarchyItem]:
|
||||
"""Get incoming/outgoing calls for a symbol.
|
||||
|
||||
If call hierarchy is not supported by the language server,
|
||||
falls back to using references.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to get call hierarchy for
|
||||
|
||||
Returns:
|
||||
List of CallHierarchyItem representing callers/callees.
|
||||
Returns empty list on error or if not supported.
|
||||
"""
|
||||
cache_key = f"calls:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
items: List[CallHierarchyItem] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_call_hierarchy", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result is None:
|
||||
# Fallback: use references
|
||||
refs = await self.get_references(symbol)
|
||||
for ref in refs:
|
||||
items.append(CallHierarchyItem(
|
||||
name=f"caller@{ref.line}",
|
||||
kind="reference",
|
||||
file_path=ref.file_path,
|
||||
range=Range(
|
||||
start_line=ref.line,
|
||||
start_character=ref.character,
|
||||
end_line=ref.line,
|
||||
end_character=ref.character,
|
||||
),
|
||||
detail="Inferred from reference",
|
||||
))
|
||||
elif isinstance(result, list):
|
||||
for item in result:
|
||||
try:
|
||||
range_data = item.get("range", {})
|
||||
start = range_data.get("start", {})
|
||||
end = range_data.get("end", {})
|
||||
|
||||
items.append(CallHierarchyItem(
|
||||
name=item.get("name", "unknown"),
|
||||
kind=item.get("kind", "unknown"),
|
||||
file_path=item.get("file_path", item.get("uri", "")),
|
||||
range=Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
),
|
||||
detail=item.get("detail"),
|
||||
))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
|
||||
# Try to get call hierarchy items
|
||||
hierarchy_items = await manager.get_call_hierarchy_items(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
if hierarchy_items:
|
||||
# Get incoming calls for each item
|
||||
for h_item in hierarchy_items:
|
||||
incoming = await manager.get_incoming_calls(h_item)
|
||||
for call in incoming:
|
||||
from_item = call.get("from", {})
|
||||
range_data = from_item.get("range", {})
|
||||
start = range_data.get("start", {})
|
||||
end = range_data.get("end", {})
|
||||
|
||||
# Parse URI
|
||||
uri = from_item.get("uri", "")
|
||||
if uri.startswith("file:///"):
|
||||
fp = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||
elif uri.startswith("file://"):
|
||||
fp = uri[7:]
|
||||
else:
|
||||
fp = uri
|
||||
|
||||
items.append(CallHierarchyItem(
|
||||
name=from_item.get("name", "unknown"),
|
||||
kind=str(from_item.get("kind", "unknown")),
|
||||
file_path=fp,
|
||||
range=Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
),
|
||||
detail=from_item.get("detail"),
|
||||
))
|
||||
else:
|
||||
# Fallback: use references
|
||||
refs = await self.get_references(symbol)
|
||||
for ref in refs:
|
||||
items.append(CallHierarchyItem(
|
||||
name=f"caller@{ref.line}",
|
||||
kind="reference",
|
||||
file_path=ref.file_path,
|
||||
range=Range(
|
||||
start_line=ref.line,
|
||||
start_character=ref.character,
|
||||
end_line=ref.line,
|
||||
end_character=ref.character,
|
||||
),
|
||||
detail="Inferred from reference",
|
||||
))
|
||||
|
||||
self._cache(cache_key, symbol.file_path, items)
|
||||
return items
|
||||
|
||||
async def get_document_symbols(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
"""Get all symbols in a document (batch operation).
|
||||
|
||||
This is more efficient than individual hover queries when processing
|
||||
multiple locations in the same file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file
|
||||
|
||||
Returns:
|
||||
List of symbol dictionaries with name, kind, range, etc.
|
||||
Returns empty list on error or timeout.
|
||||
"""
|
||||
cache_key = f"symbols:{file_path}"
|
||||
|
||||
if self._is_cached(cache_key, file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
symbols: List[Dict[str, Any]] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_document_symbols", {
|
||||
"file_path": file_path,
|
||||
})
|
||||
|
||||
if isinstance(result, list):
|
||||
symbols = self._flatten_document_symbols(result)
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_document_symbols(file_path)
|
||||
|
||||
if result:
|
||||
symbols = self._flatten_document_symbols(result)
|
||||
|
||||
self._cache(cache_key, file_path, symbols)
|
||||
return symbols
|
||||
|
||||
def _flatten_document_symbols(
|
||||
self, symbols: List[Dict[str, Any]], parent_name: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Flatten nested document symbols into a flat list.
|
||||
|
||||
Document symbols can be nested (e.g., methods inside classes).
|
||||
This flattens them for easier lookup by line number.
|
||||
|
||||
Args:
|
||||
symbols: List of symbol dictionaries (may be nested)
|
||||
parent_name: Name of parent symbol for qualification
|
||||
|
||||
Returns:
|
||||
Flat list of all symbols with their ranges
|
||||
"""
|
||||
flat: List[Dict[str, Any]] = []
|
||||
|
||||
for sym in symbols:
|
||||
# Add the symbol itself
|
||||
symbol_entry = {
|
||||
"name": sym.get("name", "unknown"),
|
||||
"kind": self._symbol_kind_to_string(sym.get("kind", 0)),
|
||||
"range": sym.get("range", sym.get("location", {}).get("range", {})),
|
||||
"selection_range": sym.get("selectionRange", {}),
|
||||
"detail": sym.get("detail", ""),
|
||||
"parent": parent_name,
|
||||
}
|
||||
flat.append(symbol_entry)
|
||||
|
||||
# Recursively process children
|
||||
children = sym.get("children", [])
|
||||
if children:
|
||||
qualified_name = sym.get("name", "")
|
||||
if parent_name:
|
||||
qualified_name = f"{parent_name}.{qualified_name}"
|
||||
flat.extend(self._flatten_document_symbols(children, qualified_name))
|
||||
|
||||
return flat
|
||||
|
||||
def _symbol_kind_to_string(self, kind: int) -> str:
|
||||
"""Convert LSP SymbolKind integer to string.
|
||||
|
||||
Args:
|
||||
kind: LSP SymbolKind enum value
|
||||
|
||||
Returns:
|
||||
Human-readable string representation
|
||||
"""
|
||||
# LSP SymbolKind enum (1-indexed)
|
||||
kinds = {
|
||||
1: "file",
|
||||
2: "module",
|
||||
3: "namespace",
|
||||
4: "package",
|
||||
5: "class",
|
||||
6: "method",
|
||||
7: "property",
|
||||
8: "field",
|
||||
9: "constructor",
|
||||
10: "enum",
|
||||
11: "interface",
|
||||
12: "function",
|
||||
13: "variable",
|
||||
14: "constant",
|
||||
15: "string",
|
||||
16: "number",
|
||||
17: "boolean",
|
||||
18: "array",
|
||||
19: "object",
|
||||
20: "key",
|
||||
21: "null",
|
||||
22: "enum_member",
|
||||
23: "struct",
|
||||
24: "event",
|
||||
25: "operator",
|
||||
26: "type_parameter",
|
||||
}
|
||||
return kinds.get(kind, "unknown")
|
||||
|
||||
async def get_hover(self, symbol: CodeSymbolNode) -> Optional[str]:
|
||||
"""Get hover documentation for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to get hover info for
|
||||
|
||||
Returns:
|
||||
Hover documentation as string, or None if not available
|
||||
"""
|
||||
cache_key = f"hover:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
hover_text: Optional[str] = None
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_hover", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result:
|
||||
hover_text = self._parse_hover_result(result)
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
hover_text = await manager.get_hover(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
self._cache(cache_key, symbol.file_path, hover_text)
|
||||
return hover_text
|
||||
|
||||
def _parse_hover_result(self, result: Any) -> Optional[str]:
|
||||
"""Parse hover result into string."""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
elif isinstance(result, list):
|
||||
parts = []
|
||||
for item in result:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
value = item.get("value", item.get("contents", ""))
|
||||
if value:
|
||||
parts.append(str(value))
|
||||
return "\n\n".join(parts) if parts else None
|
||||
elif isinstance(result, dict):
|
||||
contents = result.get("contents", result.get("value", ""))
|
||||
if isinstance(contents, str):
|
||||
return contents
|
||||
elif isinstance(contents, list):
|
||||
parts = []
|
||||
for c in contents:
|
||||
if isinstance(c, str):
|
||||
parts.append(c)
|
||||
elif isinstance(c, dict):
|
||||
parts.append(str(c.get("value", "")))
|
||||
return "\n\n".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> "LspBridge":
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit - close connections."""
|
||||
await self.close()
|
||||
|
||||
|
||||
# Simple test
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
async def test_lsp_bridge():
|
||||
"""Simple test of LspBridge functionality."""
|
||||
print("Testing LspBridge (Standalone Mode)...")
|
||||
print(f"Timeout: {LspBridge.DEFAULT_TIMEOUT}s")
|
||||
print(f"Cache TTL: {LspBridge.DEFAULT_CACHE_TTL}s")
|
||||
print()
|
||||
|
||||
# Create a test symbol pointing to this file
|
||||
test_file = os.path.abspath(__file__)
|
||||
test_symbol = CodeSymbolNode(
|
||||
id=f"{test_file}:LspBridge:96",
|
||||
name="LspBridge",
|
||||
kind="class",
|
||||
file_path=test_file,
|
||||
range=Range(
|
||||
start_line=96,
|
||||
start_character=1,
|
||||
end_line=200,
|
||||
end_character=1,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"Test symbol: {test_symbol.name} in {os.path.basename(test_symbol.file_path)}")
|
||||
print()
|
||||
|
||||
# Use standalone mode (default)
|
||||
async with LspBridge(
|
||||
workspace_root=str(Path(__file__).parent.parent.parent.parent),
|
||||
) as bridge:
|
||||
print("1. Testing get_document_symbols...")
|
||||
try:
|
||||
symbols = await bridge.get_document_symbols(test_file)
|
||||
print(f" Found {len(symbols)} symbols")
|
||||
for sym in symbols[:5]:
|
||||
print(f" - {sym.get('name')} ({sym.get('kind')})")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("2. Testing get_definition...")
|
||||
try:
|
||||
definition = await bridge.get_definition(test_symbol)
|
||||
if definition:
|
||||
print(f" Definition: {os.path.basename(definition.file_path)}:{definition.line}")
|
||||
else:
|
||||
print(" No definition found")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("3. Testing get_references...")
|
||||
try:
|
||||
refs = await bridge.get_references(test_symbol)
|
||||
print(f" Found {len(refs)} references")
|
||||
for ref in refs[:3]:
|
||||
print(f" - {os.path.basename(ref.file_path)}:{ref.line}")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("4. Testing get_hover...")
|
||||
try:
|
||||
hover = await bridge.get_hover(test_symbol)
|
||||
if hover:
|
||||
print(f" Hover: {hover[:100]}...")
|
||||
else:
|
||||
print(" No hover info found")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("5. Testing get_call_hierarchy...")
|
||||
try:
|
||||
calls = await bridge.get_call_hierarchy(test_symbol)
|
||||
print(f" Found {len(calls)} call hierarchy items")
|
||||
for call in calls[:3]:
|
||||
print(f" - {call.name} in {os.path.basename(call.file_path)}")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("6. Testing cache...")
|
||||
print(f" Cache entries: {len(bridge.cache)}")
|
||||
for key in list(bridge.cache.keys())[:5]:
|
||||
print(f" - {key}")
|
||||
|
||||
print()
|
||||
print("Test complete!")
|
||||
|
||||
# Run the test
|
||||
# Note: On Windows, use default ProactorEventLoop (supports subprocess creation)
|
||||
|
||||
asyncio.run(test_lsp_bridge())
|
||||
375
codex-lens/build/lib/codexlens/lsp/lsp_graph_builder.py
Normal file
375
codex-lens/build/lib/codexlens/lsp/lsp_graph_builder.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Graph builder for code association graphs via LSP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeAssociationGraph,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
)
|
||||
from codexlens.lsp.lsp_bridge import (
|
||||
Location,
|
||||
LspBridge,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LspGraphBuilder:
|
||||
"""Builds code association graph by expanding from seed symbols using LSP."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: int = 2,
|
||||
max_nodes: int = 100,
|
||||
max_concurrent: int = 10,
|
||||
):
|
||||
"""Initialize GraphBuilder.
|
||||
|
||||
Args:
|
||||
max_depth: Maximum depth for BFS expansion from seeds.
|
||||
max_nodes: Maximum number of nodes in the graph.
|
||||
max_concurrent: Maximum concurrent LSP requests.
|
||||
"""
|
||||
self.max_depth = max_depth
|
||||
self.max_nodes = max_nodes
|
||||
self.max_concurrent = max_concurrent
|
||||
# Cache for document symbols per file (avoids per-location hover queries)
|
||||
self._document_symbols_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
async def build_from_seeds(
|
||||
self,
|
||||
seeds: List[CodeSymbolNode],
|
||||
lsp_bridge: LspBridge,
|
||||
) -> CodeAssociationGraph:
|
||||
"""Build association graph by BFS expansion from seeds.
|
||||
|
||||
For each seed:
|
||||
1. Get references via LSP
|
||||
2. Get call hierarchy via LSP
|
||||
3. Add nodes and edges to graph
|
||||
4. Continue expanding until max_depth or max_nodes reached
|
||||
|
||||
Args:
|
||||
seeds: Initial seed symbols to expand from.
|
||||
lsp_bridge: LSP bridge for querying language servers.
|
||||
|
||||
Returns:
|
||||
CodeAssociationGraph with expanded nodes and relationships.
|
||||
"""
|
||||
graph = CodeAssociationGraph()
|
||||
visited: Set[str] = set()
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
# Initialize queue with seeds at depth 0
|
||||
queue: List[Tuple[CodeSymbolNode, int]] = [(s, 0) for s in seeds]
|
||||
|
||||
# Add seed nodes to graph
|
||||
for seed in seeds:
|
||||
graph.add_node(seed)
|
||||
|
||||
# BFS expansion
|
||||
while queue and len(graph.nodes) < self.max_nodes:
|
||||
# Take a batch of nodes from queue
|
||||
batch_size = min(self.max_concurrent, len(queue))
|
||||
batch = queue[:batch_size]
|
||||
queue = queue[batch_size:]
|
||||
|
||||
# Expand nodes in parallel
|
||||
tasks = [
|
||||
self._expand_node(
|
||||
node, depth, graph, lsp_bridge, visited, semaphore
|
||||
)
|
||||
for node, depth in batch
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results and add new nodes to queue
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning("Error expanding node: %s", result)
|
||||
continue
|
||||
if result:
|
||||
# Add new nodes to queue if not at max depth
|
||||
for new_node, new_depth in result:
|
||||
if (
|
||||
new_depth <= self.max_depth
|
||||
and len(graph.nodes) < self.max_nodes
|
||||
):
|
||||
queue.append((new_node, new_depth))
|
||||
|
||||
return graph
|
||||
|
||||
async def _expand_node(
|
||||
self,
|
||||
node: CodeSymbolNode,
|
||||
depth: int,
|
||||
graph: CodeAssociationGraph,
|
||||
lsp_bridge: LspBridge,
|
||||
visited: Set[str],
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> List[Tuple[CodeSymbolNode, int]]:
|
||||
"""Expand a single node, return new nodes to process.
|
||||
|
||||
Args:
|
||||
node: Node to expand.
|
||||
depth: Current depth in BFS.
|
||||
graph: Graph to add nodes and edges to.
|
||||
lsp_bridge: LSP bridge for queries.
|
||||
visited: Set of visited node IDs.
|
||||
semaphore: Semaphore for concurrency control.
|
||||
|
||||
Returns:
|
||||
List of (new_node, new_depth) tuples to add to queue.
|
||||
"""
|
||||
# Skip if already visited or at max depth
|
||||
if node.id in visited:
|
||||
return []
|
||||
if depth > self.max_depth:
|
||||
return []
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
return []
|
||||
|
||||
visited.add(node.id)
|
||||
new_nodes: List[Tuple[CodeSymbolNode, int]] = []
|
||||
|
||||
async with semaphore:
|
||||
# Get relationships in parallel
|
||||
try:
|
||||
refs_task = lsp_bridge.get_references(node)
|
||||
calls_task = lsp_bridge.get_call_hierarchy(node)
|
||||
|
||||
refs, calls = await asyncio.gather(
|
||||
refs_task, calls_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle reference results
|
||||
if isinstance(refs, Exception):
|
||||
logger.debug(
|
||||
"Failed to get references for %s: %s", node.id, refs
|
||||
)
|
||||
refs = []
|
||||
|
||||
# Handle call hierarchy results
|
||||
if isinstance(calls, Exception):
|
||||
logger.debug(
|
||||
"Failed to get call hierarchy for %s: %s",
|
||||
node.id,
|
||||
calls,
|
||||
)
|
||||
calls = []
|
||||
|
||||
# Process references
|
||||
for ref in refs:
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
break
|
||||
|
||||
ref_node = await self._location_to_node(ref, lsp_bridge)
|
||||
if ref_node and ref_node.id != node.id:
|
||||
if ref_node.id not in graph.nodes:
|
||||
graph.add_node(ref_node)
|
||||
new_nodes.append((ref_node, depth + 1))
|
||||
# Use add_edge since both nodes should exist now
|
||||
graph.add_edge(node.id, ref_node.id, "references")
|
||||
|
||||
# Process call hierarchy (incoming calls)
|
||||
for call in calls:
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
break
|
||||
|
||||
call_node = await self._call_hierarchy_to_node(
|
||||
call, lsp_bridge
|
||||
)
|
||||
if call_node and call_node.id != node.id:
|
||||
if call_node.id not in graph.nodes:
|
||||
graph.add_node(call_node)
|
||||
new_nodes.append((call_node, depth + 1))
|
||||
# Incoming call: call_node calls node
|
||||
graph.add_edge(call_node.id, node.id, "calls")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error during node expansion for %s: %s", node.id, e
|
||||
)
|
||||
|
||||
return new_nodes
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the document symbols cache.
|
||||
|
||||
Call this between searches to free memory and ensure fresh data.
|
||||
"""
|
||||
self._document_symbols_cache.clear()
|
||||
|
||||
async def _get_symbol_at_location(
|
||||
self,
|
||||
file_path: str,
|
||||
line: int,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Find symbol at location using cached document symbols.
|
||||
|
||||
This is much more efficient than individual hover queries because
|
||||
document symbols are fetched once per file and cached.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file.
|
||||
line: Line number (1-based).
|
||||
lsp_bridge: LSP bridge for fetching document symbols.
|
||||
|
||||
Returns:
|
||||
Symbol dictionary with name, kind, range, etc., or None if not found.
|
||||
"""
|
||||
# Get or fetch document symbols for this file
|
||||
if file_path not in self._document_symbols_cache:
|
||||
symbols = await lsp_bridge.get_document_symbols(file_path)
|
||||
self._document_symbols_cache[file_path] = symbols
|
||||
|
||||
symbols = self._document_symbols_cache[file_path]
|
||||
|
||||
# Find symbol containing this line (best match = smallest range)
|
||||
best_match: Optional[Dict[str, Any]] = None
|
||||
best_range_size = float("inf")
|
||||
|
||||
for symbol in symbols:
|
||||
sym_range = symbol.get("range", {})
|
||||
start = sym_range.get("start", {})
|
||||
end = sym_range.get("end", {})
|
||||
|
||||
# LSP ranges are 0-based, our line is 1-based
|
||||
start_line = start.get("line", 0) + 1
|
||||
end_line = end.get("line", 0) + 1
|
||||
|
||||
if start_line <= line <= end_line:
|
||||
range_size = end_line - start_line
|
||||
if range_size < best_range_size:
|
||||
best_match = symbol
|
||||
best_range_size = range_size
|
||||
|
||||
return best_match
|
||||
|
||||
async def _location_to_node(
|
||||
self,
|
||||
location: Location,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[CodeSymbolNode]:
|
||||
"""Convert LSP location to CodeSymbolNode.
|
||||
|
||||
Uses cached document symbols instead of individual hover queries
|
||||
for better performance.
|
||||
|
||||
Args:
|
||||
location: LSP location to convert.
|
||||
lsp_bridge: LSP bridge for additional queries.
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode or None if conversion fails.
|
||||
"""
|
||||
try:
|
||||
file_path = location.file_path
|
||||
start_line = location.line
|
||||
|
||||
# Try to find symbol info from cached document symbols (fast)
|
||||
symbol_info = await self._get_symbol_at_location(
|
||||
file_path, start_line, lsp_bridge
|
||||
)
|
||||
|
||||
if symbol_info:
|
||||
name = symbol_info.get("name", f"symbol_L{start_line}")
|
||||
kind = symbol_info.get("kind", "unknown")
|
||||
|
||||
# Extract range from symbol if available
|
||||
sym_range = symbol_info.get("range", {})
|
||||
start = sym_range.get("start", {})
|
||||
end = sym_range.get("end", {})
|
||||
|
||||
location_range = Range(
|
||||
start_line=start.get("line", start_line - 1) + 1,
|
||||
start_character=start.get("character", location.character - 1) + 1,
|
||||
end_line=end.get("line", start_line - 1) + 1,
|
||||
end_character=end.get("character", location.character - 1) + 1,
|
||||
)
|
||||
else:
|
||||
# Fallback to basic node without symbol info
|
||||
name = f"symbol_L{start_line}"
|
||||
kind = "unknown"
|
||||
location_range = Range(
|
||||
start_line=location.line,
|
||||
start_character=location.character,
|
||||
end_line=location.line,
|
||||
end_character=location.character,
|
||||
)
|
||||
|
||||
node_id = self._create_node_id(file_path, name, start_line)
|
||||
|
||||
return CodeSymbolNode(
|
||||
id=node_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=location_range,
|
||||
docstring="", # Skip hover for performance
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to convert location to node: %s", e)
|
||||
return None
|
||||
|
||||
async def _call_hierarchy_to_node(
|
||||
self,
|
||||
call_item: CallHierarchyItem,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[CodeSymbolNode]:
|
||||
"""Convert CallHierarchyItem to CodeSymbolNode.
|
||||
|
||||
Args:
|
||||
call_item: Call hierarchy item to convert.
|
||||
lsp_bridge: LSP bridge (unused, kept for API consistency).
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode or None if conversion fails.
|
||||
"""
|
||||
try:
|
||||
file_path = call_item.file_path
|
||||
name = call_item.name
|
||||
start_line = call_item.range.start_line
|
||||
# CallHierarchyItem.kind is already a string
|
||||
kind = call_item.kind
|
||||
|
||||
node_id = self._create_node_id(file_path, name, start_line)
|
||||
|
||||
return CodeSymbolNode(
|
||||
id=node_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=call_item.range,
|
||||
docstring=call_item.detail or "",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to convert call hierarchy item to node: %s", e
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_node_id(
|
||||
self, file_path: str, name: str, line: int
|
||||
) -> str:
|
||||
"""Create unique node ID.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
name: Symbol name.
|
||||
line: Line number (0-based).
|
||||
|
||||
Returns:
|
||||
Unique node ID string.
|
||||
"""
|
||||
return f"{file_path}:{name}:{line}"
|
||||
177
codex-lens/build/lib/codexlens/lsp/providers.py
Normal file
177
codex-lens/build/lib/codexlens/lsp/providers.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""LSP feature providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol."""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: tuple # (start_line, end_line)
|
||||
|
||||
|
||||
class HoverProvider:
|
||||
"""Provides hover information for symbols."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
registry: Optional["RegistryStore"] = None,
|
||||
) -> None:
|
||||
"""Initialize hover provider.
|
||||
|
||||
Args:
|
||||
global_index: Global symbol index for lookups
|
||||
registry: Optional registry store for index path resolution
|
||||
"""
|
||||
self.global_index = global_index
|
||||
self.registry = registry
|
||||
|
||||
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
|
||||
"""Get hover information for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to look up
|
||||
|
||||
Returns:
|
||||
HoverInfo or None if symbol not found
|
||||
"""
|
||||
# Look up symbol in global index using exact match
|
||||
symbols = self.global_index.search(
|
||||
name=symbol_name,
|
||||
limit=1,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == symbol_name]
|
||||
|
||||
if not exact_matches:
|
||||
return None
|
||||
|
||||
symbol = exact_matches[0]
|
||||
|
||||
# Extract signature from source file
|
||||
signature = self._extract_signature(symbol)
|
||||
|
||||
# Symbol uses 'file' attribute and 'range' tuple
|
||||
file_path = symbol.file or ""
|
||||
start_line, end_line = symbol.range
|
||||
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=signature,
|
||||
documentation=None, # Symbol doesn't have docstring field
|
||||
file_path=file_path,
|
||||
line_range=(start_line, end_line),
|
||||
)
|
||||
|
||||
def _extract_signature(self, symbol) -> str:
|
||||
"""Extract function/class signature from source file.
|
||||
|
||||
Args:
|
||||
symbol: Symbol object with file and range information
|
||||
|
||||
Returns:
|
||||
Extracted signature string or fallback kind + name
|
||||
"""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Extract signature lines (first line of definition + continuation)
|
||||
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
|
||||
if start_line >= len(lines) or start_line < 0:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
signature_lines = []
|
||||
first_line = lines[start_line]
|
||||
signature_lines.append(first_line)
|
||||
|
||||
# Continue if multiline signature (no closing paren + colon yet)
|
||||
# Look for patterns like "def func(", "class Foo(", etc.
|
||||
i = start_line + 1
|
||||
max_lines = min(start_line + 5, len(lines))
|
||||
while i < max_lines:
|
||||
line = signature_lines[-1]
|
||||
# Stop if we see closing pattern
|
||||
if "):" in line or line.rstrip().endswith(":"):
|
||||
break
|
||||
signature_lines.append(lines[i])
|
||||
i += 1
|
||||
|
||||
return "\n".join(signature_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
def format_hover_markdown(self, info: HoverInfo) -> str:
|
||||
"""Format hover info as Markdown.
|
||||
|
||||
Args:
|
||||
info: HoverInfo object to format
|
||||
|
||||
Returns:
|
||||
Markdown-formatted hover content
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Detect language for code fence based on file extension
|
||||
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
|
||||
lang_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".jsx": "javascript",
|
||||
".java": "java",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".cs": "csharp",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
}
|
||||
lang = lang_map.get(ext, "")
|
||||
|
||||
# Code block with signature
|
||||
parts.append(f"```{lang}\n{info.signature}\n```")
|
||||
|
||||
# Documentation if available
|
||||
if info.documentation:
|
||||
parts.append(f"\n---\n\n{info.documentation}")
|
||||
|
||||
# Location info
|
||||
file_name = Path(info.file_path).name if info.file_path else "unknown"
|
||||
parts.append(
|
||||
f"\n---\n\n*{info.kind}* defined in "
|
||||
f"`{file_name}` "
|
||||
f"(line {info.line_range[0]})"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
263
codex-lens/build/lib/codexlens/lsp/server.py
Normal file
263
codex-lens/build/lib/codexlens/lsp/server.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""codex-lens LSP Server implementation using pygls.
|
||||
|
||||
This module provides the main Language Server class and entry point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
from pygls.lsp.server import LanguageServer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodexLensLanguageServer(LanguageServer):
|
||||
"""Language Server for codex-lens code indexing.
|
||||
|
||||
Provides IDE features using codex-lens symbol index:
|
||||
- Go to Definition
|
||||
- Find References
|
||||
- Code Completion
|
||||
- Hover Information
|
||||
- Workspace Symbol Search
|
||||
|
||||
Attributes:
|
||||
registry: Global project registry for path lookups
|
||||
mapper: Path mapper for source/index conversions
|
||||
global_index: Project-wide symbol index
|
||||
search_engine: Chain search engine for symbol search
|
||||
workspace_root: Current workspace root path
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="codexlens-lsp", version="0.1.0")
|
||||
|
||||
self.registry: Optional[RegistryStore] = None
|
||||
self.mapper: Optional[PathMapper] = None
|
||||
self.global_index: Optional[GlobalSymbolIndex] = None
|
||||
self.search_engine: Optional[ChainSearchEngine] = None
|
||||
self.workspace_root: Optional[Path] = None
|
||||
self._config: Optional[Config] = None
|
||||
|
||||
def initialize_components(self, workspace_root: Path) -> bool:
|
||||
"""Initialize codex-lens components for the workspace.
|
||||
|
||||
Args:
|
||||
workspace_root: Root path of the workspace
|
||||
|
||||
Returns:
|
||||
True if initialization succeeded, False otherwise
|
||||
"""
|
||||
self.workspace_root = workspace_root.resolve()
|
||||
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
|
||||
|
||||
try:
|
||||
# Initialize registry
|
||||
self.registry = RegistryStore()
|
||||
self.registry.initialize()
|
||||
|
||||
# Initialize path mapper
|
||||
self.mapper = PathMapper()
|
||||
|
||||
# Try to find project in registry
|
||||
project_info = self.registry.find_by_source_path(str(self.workspace_root))
|
||||
|
||||
if project_info:
|
||||
project_id = int(project_info["id"])
|
||||
index_root = Path(project_info["index_root"])
|
||||
|
||||
# Initialize global symbol index
|
||||
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||
self.global_index = GlobalSymbolIndex(global_db, project_id)
|
||||
self.global_index.initialize()
|
||||
|
||||
# Initialize search engine
|
||||
self._config = Config()
|
||||
self.search_engine = ChainSearchEngine(
|
||||
registry=self.registry,
|
||||
mapper=self.mapper,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Workspace not indexed by codex-lens: %s. "
|
||||
"Run 'codexlens index %s' to index first.",
|
||||
self.workspace_root,
|
||||
self.workspace_root,
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize codex-lens: %s", exc)
|
||||
return False
|
||||
|
||||
def shutdown_components(self) -> None:
|
||||
"""Clean up codex-lens components."""
|
||||
if self.global_index:
|
||||
try:
|
||||
self.global_index.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing global index: %s", exc)
|
||||
self.global_index = None
|
||||
|
||||
if self.search_engine:
|
||||
try:
|
||||
self.search_engine.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing search engine: %s", exc)
|
||||
self.search_engine = None
|
||||
|
||||
if self.registry:
|
||||
try:
|
||||
self.registry.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing registry: %s", exc)
|
||||
self.registry = None
|
||||
|
||||
|
||||
# Create server instance
|
||||
server = CodexLensLanguageServer()
|
||||
|
||||
|
||||
@server.feature(lsp.INITIALIZE)
|
||||
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
|
||||
"""Handle LSP initialize request."""
|
||||
logger.info("LSP initialize request received")
|
||||
|
||||
# Get workspace root
|
||||
workspace_root: Optional[Path] = None
|
||||
if params.root_uri:
|
||||
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
|
||||
elif params.root_path:
|
||||
workspace_root = Path(params.root_path)
|
||||
|
||||
if workspace_root:
|
||||
server.initialize_components(workspace_root)
|
||||
|
||||
# Declare server capabilities
|
||||
return lsp.InitializeResult(
|
||||
capabilities=lsp.ServerCapabilities(
|
||||
text_document_sync=lsp.TextDocumentSyncOptions(
|
||||
open_close=True,
|
||||
change=lsp.TextDocumentSyncKind.Incremental,
|
||||
save=lsp.SaveOptions(include_text=False),
|
||||
),
|
||||
definition_provider=True,
|
||||
references_provider=True,
|
||||
completion_provider=lsp.CompletionOptions(
|
||||
trigger_characters=[".", ":"],
|
||||
resolve_provider=False,
|
||||
),
|
||||
hover_provider=True,
|
||||
workspace_symbol_provider=True,
|
||||
),
|
||||
server_info=lsp.ServerInfo(
|
||||
name="codexlens-lsp",
|
||||
version="0.1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@server.feature(lsp.SHUTDOWN)
|
||||
def lsp_shutdown(params: None) -> None:
|
||||
"""Handle LSP shutdown request."""
|
||||
logger.info("LSP shutdown request received")
|
||||
server.shutdown_components()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Entry point for codexlens-lsp command.
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success)
|
||||
"""
|
||||
# Import handlers to register them with the server
|
||||
# This must be done before starting the server
|
||||
import codexlens.lsp.handlers # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="codex-lens Language Server",
|
||||
prog="codexlens-lsp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stdio",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use stdio for communication (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tcp",
|
||||
action="store_true",
|
||||
help="Use TCP for communication",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default="127.0.0.1",
|
||||
help="TCP host (default: 127.0.0.1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=2087,
|
||||
help="TCP port (default: 2087)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Log level (default: INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
help="Log file path (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
log_handlers = []
|
||||
if args.log_file:
|
||||
log_handlers.append(logging.FileHandler(args.log_file))
|
||||
else:
|
||||
log_handlers.append(logging.StreamHandler(sys.stderr))
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
logger.info("Starting codexlens-lsp server")
|
||||
|
||||
if args.tcp:
|
||||
logger.info("Starting TCP server on %s:%d", args.host, args.port)
|
||||
server.start_tcp(args.host, args.port)
|
||||
else:
|
||||
logger.info("Starting stdio server")
|
||||
server.start_io()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
1159
codex-lens/build/lib/codexlens/lsp/standalone_manager.py
Normal file
1159
codex-lens/build/lib/codexlens/lsp/standalone_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
20
codex-lens/build/lib/codexlens/mcp/__init__.py
Normal file
20
codex-lens/build/lib/codexlens/mcp/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Model Context Protocol implementation for Claude Code integration."""
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"MCPContext",
|
||||
"SymbolInfo",
|
||||
"ReferenceInfo",
|
||||
"RelatedSymbol",
|
||||
"MCPProvider",
|
||||
"HookManager",
|
||||
"create_context_for_prompt",
|
||||
]
|
||||
170
codex-lens/build/lib/codexlens/mcp/hooks.py
Normal file
170
codex-lens/build/lib/codexlens/mcp/hooks.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Hook interfaces for Claude Code integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import MCPContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Manages hook registration and execution."""
|
||||
|
||||
def __init__(self, mcp_provider: "MCPProvider") -> None:
|
||||
self.mcp_provider = mcp_provider
|
||||
self._pre_hooks: Dict[str, Callable] = {}
|
||||
self._post_hooks: Dict[str, Callable] = {}
|
||||
|
||||
# Register default hooks
|
||||
self._register_default_hooks()
|
||||
|
||||
def _register_default_hooks(self) -> None:
|
||||
"""Register built-in hooks."""
|
||||
self._pre_hooks["explain"] = self._pre_explain_hook
|
||||
self._pre_hooks["refactor"] = self._pre_refactor_hook
|
||||
self._pre_hooks["document"] = self._pre_document_hook
|
||||
|
||||
def execute_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> Optional[MCPContext]:
|
||||
"""Execute pre-tool hook to gather context.
|
||||
|
||||
Args:
|
||||
action: The action being performed (e.g., "explain", "refactor")
|
||||
params: Parameters for the action
|
||||
|
||||
Returns:
|
||||
MCPContext to inject into prompt, or None
|
||||
"""
|
||||
hook = self._pre_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
logger.debug(f"No pre-hook for action: {action}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return hook(params)
|
||||
except Exception as e:
|
||||
logger.error(f"Pre-hook failed for {action}: {e}")
|
||||
return None
|
||||
|
||||
def execute_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
result: Any,
|
||||
) -> None:
|
||||
"""Execute post-tool hook for proactive caching.
|
||||
|
||||
Args:
|
||||
action: The action that was performed
|
||||
result: Result of the action
|
||||
"""
|
||||
hook = self._post_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
return
|
||||
|
||||
try:
|
||||
hook(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Post-hook failed for {action}: {e}")
|
||||
|
||||
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'explain' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="symbol_explanation",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'refactor' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="refactor_context",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
max_references=20,
|
||||
)
|
||||
|
||||
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'document' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
file_path = params.get("file_path")
|
||||
|
||||
if symbol_name:
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="documentation_context",
|
||||
include_references=False,
|
||||
include_related=True,
|
||||
)
|
||||
elif file_path:
|
||||
return self.mcp_provider.build_context_for_file(
|
||||
Path(file_path),
|
||||
context_type="file_documentation",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def register_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
|
||||
) -> None:
|
||||
"""Register a custom pre-tool hook."""
|
||||
self._pre_hooks[action] = hook
|
||||
|
||||
def register_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Any], None],
|
||||
) -> None:
|
||||
"""Register a custom post-tool hook."""
|
||||
self._post_hooks[action] = hook
|
||||
|
||||
|
||||
def create_context_for_prompt(
|
||||
mcp_provider: "MCPProvider",
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Create context string for prompt injection.
|
||||
|
||||
This is the main entry point for Claude Code hook integration.
|
||||
|
||||
Args:
|
||||
mcp_provider: The MCP provider instance
|
||||
action: Action being performed
|
||||
params: Action parameters
|
||||
|
||||
Returns:
|
||||
Formatted context string for prompt injection
|
||||
"""
|
||||
manager = HookManager(mcp_provider)
|
||||
context = manager.execute_pre_hook(action, params)
|
||||
|
||||
if context:
|
||||
return context.to_prompt_injection()
|
||||
|
||||
return ""
|
||||
202
codex-lens/build/lib/codexlens/mcp/provider.py
Normal file
202
codex-lens/build/lib/codexlens/mcp/provider.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""MCP context provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPProvider:
|
||||
"""Builds MCP context objects from codex-lens data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
search_engine: "ChainSearchEngine",
|
||||
registry: "RegistryStore",
|
||||
) -> None:
|
||||
self.global_index = global_index
|
||||
self.search_engine = search_engine
|
||||
self.registry = registry
|
||||
|
||||
def build_context(
|
||||
self,
|
||||
symbol_name: str,
|
||||
context_type: str = "symbol_explanation",
|
||||
include_references: bool = True,
|
||||
include_related: bool = True,
|
||||
max_references: int = 10,
|
||||
) -> Optional[MCPContext]:
|
||||
"""Build comprehensive context for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to contextualize
|
||||
context_type: Type of context being requested
|
||||
include_references: Whether to include reference locations
|
||||
include_related: Whether to include related symbols
|
||||
max_references: Maximum number of references to include
|
||||
|
||||
Returns:
|
||||
MCPContext object or None if symbol not found
|
||||
"""
|
||||
# Look up symbol
|
||||
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
|
||||
|
||||
if not symbols:
|
||||
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
|
||||
return None
|
||||
|
||||
symbol = symbols[0]
|
||||
|
||||
# Build SymbolInfo
|
||||
symbol_info = SymbolInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
file_path=symbol.file or "",
|
||||
line_start=symbol.range[0],
|
||||
line_end=symbol.range[1],
|
||||
signature=None, # Symbol entity doesn't have signature
|
||||
documentation=None, # Symbol entity doesn't have docstring
|
||||
)
|
||||
|
||||
# Extract definition source code
|
||||
definition = self._extract_definition(symbol)
|
||||
|
||||
# Get references
|
||||
references = []
|
||||
if include_references:
|
||||
refs = self.search_engine.search_references(
|
||||
symbol_name,
|
||||
limit=max_references,
|
||||
)
|
||||
references = [
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
context=r.context,
|
||||
relationship_type=r.relationship_type,
|
||||
)
|
||||
for r in refs
|
||||
]
|
||||
|
||||
# Get related symbols
|
||||
related_symbols = []
|
||||
if include_related:
|
||||
related_symbols = self._get_related_symbols(symbol)
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
symbol=symbol_info,
|
||||
definition=definition,
|
||||
references=references,
|
||||
related_symbols=related_symbols,
|
||||
metadata={
|
||||
"source": "codex-lens",
|
||||
},
|
||||
)
|
||||
|
||||
def _extract_definition(self, symbol) -> Optional[str]:
|
||||
"""Extract source code for symbol definition."""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return None
|
||||
|
||||
content = file_path.read_text(encoding='utf-8', errors='ignore')
|
||||
lines = content.split("\n")
|
||||
|
||||
start = symbol.range[0] - 1
|
||||
end = symbol.range[1]
|
||||
|
||||
if start >= len(lines):
|
||||
return None
|
||||
|
||||
return "\n".join(lines[start:end])
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract definition: {e}")
|
||||
return None
|
||||
|
||||
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
|
||||
"""Get symbols related to the given symbol."""
|
||||
related = []
|
||||
|
||||
try:
|
||||
# Search for symbols that might be related by name patterns
|
||||
# This is a simplified implementation - could be enhanced with relationship data
|
||||
|
||||
# Look for imports/callers via reference search
|
||||
refs = self.search_engine.search_references(symbol.name, limit=20)
|
||||
|
||||
seen_names = set()
|
||||
for ref in refs:
|
||||
# Extract potential symbol name from context
|
||||
if ref.relationship_type and ref.relationship_type not in seen_names:
|
||||
related.append(RelatedSymbol(
|
||||
name=f"{Path(ref.file_path).stem}",
|
||||
kind="module",
|
||||
relationship=ref.relationship_type,
|
||||
file_path=ref.file_path,
|
||||
))
|
||||
seen_names.add(ref.relationship_type)
|
||||
if len(related) >= 10:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get related symbols: {e}")
|
||||
|
||||
return related
|
||||
|
||||
def build_context_for_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
context_type: str = "file_overview",
|
||||
) -> MCPContext:
|
||||
"""Build context for an entire file."""
|
||||
# Try to get symbols by searching with file path
|
||||
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
|
||||
symbols = []
|
||||
|
||||
# Search for common symbols that might be in this file
|
||||
# This is a simplified approach - a full implementation would query by file path
|
||||
try:
|
||||
# Use the global index to search for symbols from this file
|
||||
file_str = str(file_path.resolve())
|
||||
# Get all symbols and filter by file path (not efficient but works)
|
||||
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
|
||||
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get file symbols: {e}")
|
||||
|
||||
related = [
|
||||
RelatedSymbol(
|
||||
name=s.name,
|
||||
kind=s.kind,
|
||||
relationship="defines",
|
||||
)
|
||||
for s in symbols
|
||||
]
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
related_symbols=related,
|
||||
metadata={
|
||||
"file_path": str(file_path),
|
||||
"symbol_count": len(symbols),
|
||||
},
|
||||
)
|
||||
113
codex-lens/build/lib/codexlens/mcp/schema.py
Normal file
113
codex-lens/build/lib/codexlens/mcp/schema.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""MCP data models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Information about a code symbol."""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line_start: int
|
||||
line_end: int
|
||||
signature: Optional[str] = None
|
||||
documentation: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceInfo:
|
||||
"""Information about a symbol reference."""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelatedSymbol:
|
||||
"""Related symbol (import, call target, etc.)."""
|
||||
name: str
|
||||
kind: str
|
||||
relationship: str # "imports", "calls", "inherits", "uses"
|
||||
file_path: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPContext:
|
||||
"""Model Context Protocol context object.
|
||||
|
||||
This is the structured context that gets injected into
|
||||
LLM prompts to provide code understanding.
|
||||
"""
|
||||
version: str = "1.0"
|
||||
context_type: str = "code_context"
|
||||
symbol: Optional[SymbolInfo] = None
|
||||
definition: Optional[str] = None
|
||||
references: List[ReferenceInfo] = field(default_factory=list)
|
||||
related_symbols: List[RelatedSymbol] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result = {
|
||||
"version": self.version,
|
||||
"context_type": self.context_type,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
if self.symbol:
|
||||
result["symbol"] = self.symbol.to_dict()
|
||||
if self.definition:
|
||||
result["definition"] = self.definition
|
||||
if self.references:
|
||||
result["references"] = [r.to_dict() for r in self.references]
|
||||
if self.related_symbols:
|
||||
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
|
||||
|
||||
return result
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""Serialize to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
def to_prompt_injection(self) -> str:
|
||||
"""Format for injection into LLM prompt."""
|
||||
parts = ["<code_context>"]
|
||||
|
||||
if self.symbol:
|
||||
parts.append(f"## Symbol: {self.symbol.name}")
|
||||
parts.append(f"Type: {self.symbol.kind}")
|
||||
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
|
||||
|
||||
if self.definition:
|
||||
parts.append("\n## Definition")
|
||||
parts.append(f"```\n{self.definition}\n```")
|
||||
|
||||
if self.references:
|
||||
parts.append(f"\n## References ({len(self.references)} found)")
|
||||
for ref in self.references[:5]: # Limit to 5
|
||||
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
|
||||
parts.append(f" ```\n {ref.context}\n ```")
|
||||
|
||||
if self.related_symbols:
|
||||
parts.append("\n## Related Symbols")
|
||||
for sym in self.related_symbols[:10]: # Limit to 10
|
||||
parts.append(f"- {sym.name} ({sym.relationship})")
|
||||
|
||||
parts.append("</code_context>")
|
||||
return "\n".join(parts)
|
||||
8
codex-lens/build/lib/codexlens/parsers/__init__.py
Normal file
8
codex-lens/build/lib/codexlens/parsers/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Parsers for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .factory import ParserFactory
|
||||
|
||||
__all__ = ["ParserFactory"]
|
||||
|
||||
202
codex-lens/build/lib/codexlens/parsers/encoding.py
Normal file
202
codex-lens/build/lib/codexlens/parsers/encoding.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Optional encoding detection module for CodexLens.
|
||||
|
||||
Provides automatic encoding detection with graceful fallback to UTF-8.
|
||||
Install with: pip install codexlens[encoding]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Feature flag for encoding detection availability
|
||||
ENCODING_DETECTION_AVAILABLE = False
|
||||
_import_error: Optional[str] = None
|
||||
|
||||
|
||||
def _detect_chardet_backend() -> Tuple[bool, Optional[str]]:
|
||||
"""Detect if chardet or charset-normalizer is available."""
|
||||
try:
|
||||
import chardet
|
||||
return True, None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from charset_normalizer import from_bytes
|
||||
return True, None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return False, "chardet not available. Install with: pip install codexlens[encoding]"
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
ENCODING_DETECTION_AVAILABLE, _import_error = _detect_chardet_backend()
|
||||
|
||||
|
||||
def check_encoding_available() -> Tuple[bool, Optional[str]]:
|
||||
"""Check if encoding detection dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (available, error_message)
|
||||
"""
|
||||
return ENCODING_DETECTION_AVAILABLE, _import_error
|
||||
|
||||
|
||||
def detect_encoding(content_bytes: bytes, confidence_threshold: float = 0.7) -> str:
|
||||
"""Detect encoding from file content bytes.
|
||||
|
||||
Uses chardet or charset-normalizer with configurable confidence threshold.
|
||||
Falls back to UTF-8 if confidence is too low or detection unavailable.
|
||||
|
||||
Args:
|
||||
content_bytes: Raw file content as bytes
|
||||
confidence_threshold: Minimum confidence (0.0-1.0) to accept detection
|
||||
|
||||
Returns:
|
||||
Detected encoding name (e.g., 'utf-8', 'iso-8859-1', 'gbk')
|
||||
Returns 'utf-8' as fallback if detection fails or confidence too low
|
||||
"""
|
||||
if not ENCODING_DETECTION_AVAILABLE:
|
||||
log.debug("Encoding detection not available, using UTF-8 fallback")
|
||||
return "utf-8"
|
||||
|
||||
if not content_bytes:
|
||||
return "utf-8"
|
||||
|
||||
try:
|
||||
# Try chardet first
|
||||
try:
|
||||
import chardet
|
||||
result = chardet.detect(content_bytes)
|
||||
encoding = result.get("encoding")
|
||||
confidence = result.get("confidence", 0.0)
|
||||
|
||||
if encoding and confidence >= confidence_threshold:
|
||||
log.debug(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
|
||||
# Normalize encoding name: replace underscores with hyphens
|
||||
return encoding.lower().replace('_', '-')
|
||||
else:
|
||||
log.debug(
|
||||
f"Low confidence encoding detection: {encoding} "
|
||||
f"(confidence: {confidence:.2f}), using UTF-8 fallback"
|
||||
)
|
||||
return "utf-8"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback to charset-normalizer
|
||||
try:
|
||||
from charset_normalizer import from_bytes
|
||||
results = from_bytes(content_bytes)
|
||||
if results:
|
||||
best = results.best()
|
||||
if best and best.encoding:
|
||||
log.debug(f"Detected encoding via charset-normalizer: {best.encoding}")
|
||||
# Normalize encoding name: replace underscores with hyphens
|
||||
return best.encoding.lower().replace('_', '-')
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
log.warning(f"Encoding detection failed: {e}, using UTF-8 fallback")
|
||||
|
||||
return "utf-8"
|
||||
|
||||
|
||||
def read_file_safe(
|
||||
path: Path | str,
|
||||
confidence_threshold: float = 0.7,
|
||||
max_detection_bytes: int = 100_000
|
||||
) -> Tuple[str, str]:
|
||||
"""Read file with automatic encoding detection and safe decoding.
|
||||
|
||||
Reads file bytes, detects encoding, and decodes with error replacement
|
||||
to preserve file structure even with encoding issues.
|
||||
|
||||
Args:
|
||||
path: Path to file to read
|
||||
confidence_threshold: Minimum confidence for encoding detection
|
||||
max_detection_bytes: Maximum bytes to use for encoding detection (default 100KB)
|
||||
|
||||
Returns:
|
||||
Tuple of (content, detected_encoding)
|
||||
- content: Decoded file content (with <20> for unmappable bytes)
|
||||
- detected_encoding: Detected encoding name
|
||||
|
||||
Raises:
|
||||
OSError: If file cannot be read
|
||||
IsADirectoryError: If path is a directory
|
||||
"""
|
||||
file_path = Path(path) if isinstance(path, str) else path
|
||||
|
||||
# Read file bytes
|
||||
try:
|
||||
content_bytes = file_path.read_bytes()
|
||||
except Exception as e:
|
||||
log.error(f"Failed to read file {file_path}: {e}")
|
||||
raise
|
||||
|
||||
# Detect encoding from first N bytes for performance
|
||||
detection_sample = content_bytes[:max_detection_bytes] if len(content_bytes) > max_detection_bytes else content_bytes
|
||||
encoding = detect_encoding(detection_sample, confidence_threshold)
|
||||
|
||||
# Decode with error replacement to preserve structure
|
||||
try:
|
||||
content = content_bytes.decode(encoding, errors='replace')
|
||||
log.debug(f"Successfully decoded {file_path} using {encoding}")
|
||||
return content, encoding
|
||||
except Exception as e:
|
||||
# Final fallback to UTF-8 with replacement
|
||||
log.warning(f"Failed to decode {file_path} with {encoding}, using UTF-8: {e}")
|
||||
content = content_bytes.decode('utf-8', errors='replace')
|
||||
return content, 'utf-8'
|
||||
|
||||
|
||||
def is_binary_file(path: Path | str, sample_size: int = 8192) -> bool:
|
||||
"""Check if file is likely binary by sampling first bytes.
|
||||
|
||||
Uses heuristic: if >30% of sample bytes are null or non-text, consider binary.
|
||||
|
||||
Args:
|
||||
path: Path to file to check
|
||||
sample_size: Number of bytes to sample (default 8KB)
|
||||
|
||||
Returns:
|
||||
True if file appears to be binary, False otherwise
|
||||
"""
|
||||
file_path = Path(path) if isinstance(path, str) else path
|
||||
|
||||
try:
|
||||
with file_path.open('rb') as f:
|
||||
sample = f.read(sample_size)
|
||||
|
||||
if not sample:
|
||||
return False
|
||||
|
||||
# Count null bytes and non-printable characters
|
||||
null_count = sample.count(b'\x00')
|
||||
non_text_count = sum(1 for byte in sample if byte < 0x20 and byte not in (0x09, 0x0a, 0x0d))
|
||||
|
||||
# If >30% null bytes or >50% non-text, consider binary
|
||||
null_ratio = null_count / len(sample)
|
||||
non_text_ratio = non_text_count / len(sample)
|
||||
|
||||
return null_ratio > 0.3 or non_text_ratio > 0.5
|
||||
|
||||
except Exception as e:
|
||||
log.debug(f"Binary check failed for {file_path}: {e}, assuming text")
|
||||
return False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ENCODING_DETECTION_AVAILABLE",
|
||||
"check_encoding_available",
|
||||
"detect_encoding",
|
||||
"read_file_safe",
|
||||
"is_binary_file",
|
||||
]
|
||||
385
codex-lens/build/lib/codexlens/parsers/factory.py
Normal file
385
codex-lens/build/lib/codexlens/parsers/factory.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Parser factory for CodexLens.
|
||||
|
||||
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
|
||||
available. Regex fallbacks are retained to preserve the existing parser
|
||||
interface and behavior in minimal environments.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
class Parser(Protocol):
|
||||
def parse(self, text: str, path: Path) -> IndexedFile: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleRegexParser:
|
||||
language_id: str
|
||||
|
||||
def parse(self, text: str, path: Path) -> IndexedFile:
|
||||
# Try tree-sitter first for supported languages
|
||||
if self.language_id in {"python", "javascript", "typescript"}:
|
||||
ts_parser = TreeSitterSymbolParser(self.language_id, path)
|
||||
if ts_parser.is_available():
|
||||
indexed = ts_parser.parse(text, path)
|
||||
if indexed is not None:
|
||||
return indexed
|
||||
|
||||
# Fallback to regex parsing
|
||||
if self.language_id == "python":
|
||||
symbols = _parse_python_symbols_regex(text)
|
||||
relationships = _parse_python_relationships_regex(text, path)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
symbols = _parse_js_ts_symbols_regex(text)
|
||||
relationships = _parse_js_ts_relationships_regex(text, path)
|
||||
elif self.language_id == "java":
|
||||
symbols = _parse_java_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "go":
|
||||
symbols = _parse_go_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "markdown":
|
||||
symbols = _parse_markdown_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "text":
|
||||
symbols = _parse_text_symbols(text)
|
||||
relationships = []
|
||||
else:
|
||||
symbols = _parse_generic_symbols(text)
|
||||
relationships = []
|
||||
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
relationships=relationships,
|
||||
)
|
||||
|
||||
|
||||
class ParserFactory:
|
||||
def __init__(self, config: Config) -> None:
|
||||
self.config = config
|
||||
self._parsers: Dict[str, Parser] = {}
|
||||
|
||||
def get_parser(self, language_id: str) -> Parser:
|
||||
if language_id not in self._parsers:
|
||||
self._parsers[language_id] = SimpleRegexParser(language_id)
|
||||
return self._parsers[language_id]
|
||||
|
||||
|
||||
# Regex-based fallback parsers
|
||||
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
|
||||
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
|
||||
|
||||
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
|
||||
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||
|
||||
|
||||
|
||||
|
||||
def _parse_python_symbols(text: str) -> List[Symbol]:
|
||||
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
|
||||
ts_parser = TreeSitterSymbolParser("python")
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_python_symbols_regex(text)
|
||||
|
||||
|
||||
def _parse_js_ts_symbols(
|
||||
text: str,
|
||||
language_id: str = "javascript",
|
||||
path: Optional[Path] = None,
|
||||
) -> List[Symbol]:
|
||||
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
|
||||
ts_parser = TreeSitterSymbolParser(language_id, path)
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_js_ts_symbols_regex(text)
|
||||
|
||||
|
||||
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
current_class_indent: Optional[int] = None
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _PY_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_class_indent = len(line) - len(line.lstrip(" "))
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
def_match = _PY_DEF_RE.match(line)
|
||||
if def_match:
|
||||
indent = len(line) - len(line.lstrip(" "))
|
||||
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
|
||||
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
|
||||
continue
|
||||
if current_class_indent is not None:
|
||||
indent = len(line) - len(line.lstrip(" "))
|
||||
if line.strip() and indent <= current_class_indent:
|
||||
current_class_indent = None
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||
relationships: List[CodeRelationship] = []
|
||||
current_scope: str | None = None
|
||||
source_file = str(path.resolve())
|
||||
|
||||
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _PY_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_scope = class_match.group(1)
|
||||
continue
|
||||
|
||||
def_match = _PY_DEF_RE.match(line)
|
||||
if def_match:
|
||||
current_scope = def_match.group(1)
|
||||
continue
|
||||
|
||||
if current_scope is None:
|
||||
continue
|
||||
|
||||
import_match = _PY_IMPORT_RE.search(line)
|
||||
if import_match:
|
||||
import_target = import_match.group(1) or import_match.group(2)
|
||||
if import_target:
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=import_target.strip(),
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
for call_match in _PY_CALL_RE.finditer(line):
|
||||
call_name = call_match.group(1)
|
||||
if call_name in {
|
||||
"if",
|
||||
"for",
|
||||
"while",
|
||||
"return",
|
||||
"print",
|
||||
"len",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"list",
|
||||
"dict",
|
||||
"set",
|
||||
"tuple",
|
||||
current_scope,
|
||||
}:
|
||||
continue
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=call_name,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
|
||||
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
|
||||
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
|
||||
_JS_ARROW_RE = re.compile(
|
||||
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
|
||||
)
|
||||
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
|
||||
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
|
||||
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||
|
||||
|
||||
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
in_class = False
|
||||
class_brace_depth = 0
|
||||
brace_depth = 0
|
||||
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
brace_depth += line.count("{") - line.count("}")
|
||||
|
||||
class_match = _JS_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
in_class = True
|
||||
class_brace_depth = brace_depth
|
||||
continue
|
||||
|
||||
if in_class and brace_depth < class_brace_depth:
|
||||
in_class = False
|
||||
|
||||
func_match = _JS_FUNC_RE.match(line)
|
||||
if func_match:
|
||||
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
|
||||
continue
|
||||
|
||||
arrow_match = _JS_ARROW_RE.match(line)
|
||||
if arrow_match:
|
||||
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
|
||||
continue
|
||||
|
||||
if in_class:
|
||||
method_match = _JS_METHOD_RE.match(line)
|
||||
if method_match:
|
||||
name = method_match.group(1)
|
||||
if name != "constructor":
|
||||
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
|
||||
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||
relationships: List[CodeRelationship] = []
|
||||
current_scope: str | None = None
|
||||
source_file = str(path.resolve())
|
||||
|
||||
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _JS_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_scope = class_match.group(1)
|
||||
continue
|
||||
|
||||
func_match = _JS_FUNC_RE.match(line)
|
||||
if func_match:
|
||||
current_scope = func_match.group(1)
|
||||
continue
|
||||
|
||||
arrow_match = _JS_ARROW_RE.match(line)
|
||||
if arrow_match:
|
||||
current_scope = arrow_match.group(1)
|
||||
continue
|
||||
|
||||
if current_scope is None:
|
||||
continue
|
||||
|
||||
import_match = _JS_IMPORT_RE.search(line)
|
||||
if import_match:
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=import_match.group(1),
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
for call_match in _JS_CALL_RE.finditer(line):
|
||||
call_name = call_match.group(1)
|
||||
if call_name in {current_scope}:
|
||||
continue
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=call_name,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
|
||||
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
|
||||
_JAVA_METHOD_RE = re.compile(
|
||||
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
|
||||
)
|
||||
|
||||
|
||||
def _parse_java_symbols(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _JAVA_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
method_match = _JAVA_METHOD_RE.match(line)
|
||||
if method_match:
|
||||
symbols.append(Symbol(name=method_match.group(1), kind="method", range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
_GO_FUNC_RE = re.compile(r"^\s*func\s+(?:\([^)]+\)\s+)?([A-Za-z_]\w*)\s*\(")
|
||||
_GO_TYPE_RE = re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+(?:struct|interface)\b")
|
||||
|
||||
|
||||
def _parse_go_symbols(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
type_match = _GO_TYPE_RE.match(line)
|
||||
if type_match:
|
||||
symbols.append(Symbol(name=type_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
func_match = _GO_FUNC_RE.match(line)
|
||||
if func_match:
|
||||
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
_GENERIC_DEF_RE = re.compile(r"^\s*(?:def|function|func)\s+([A-Za-z_]\w*)\b")
|
||||
_GENERIC_CLASS_RE = re.compile(r"^\s*(?:class|struct|interface)\s+([A-Za-z_]\w*)\b")
|
||||
|
||||
|
||||
def _parse_generic_symbols(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _GENERIC_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
|
||||
continue
|
||||
def_match = _GENERIC_DEF_RE.match(line)
|
||||
if def_match:
|
||||
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
# Markdown heading regex: # Heading, ## Heading, etc.
|
||||
_MD_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$")
|
||||
|
||||
|
||||
def _parse_markdown_symbols(text: str) -> List[Symbol]:
|
||||
"""Parse Markdown headings as symbols.
|
||||
|
||||
Extracts # headings as 'section' symbols with heading level as kind suffix.
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
for i, line in enumerate(text.splitlines(), start=1):
|
||||
heading_match = _MD_HEADING_RE.match(line)
|
||||
if heading_match:
|
||||
level = len(heading_match.group(1))
|
||||
title = heading_match.group(2).strip()
|
||||
# Use 'section' kind with level indicator
|
||||
kind = f"h{level}"
|
||||
symbols.append(Symbol(name=title, kind=kind, range=(i, i)))
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_text_symbols(text: str) -> List[Symbol]:
|
||||
"""Parse plain text files - no symbols, just index content."""
|
||||
# Text files don't have structured symbols, return empty list
|
||||
# The file content will still be indexed for FTS search
|
||||
return []
|
||||
98
codex-lens/build/lib/codexlens/parsers/tokenizer.py
Normal file
98
codex-lens/build/lib/codexlens/parsers/tokenizer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Token counting utilities for CodexLens.
|
||||
|
||||
Provides accurate token counting using tiktoken with character count fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
TIKTOKEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
TIKTOKEN_AVAILABLE = False
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Token counter with tiktoken primary and character count fallback."""
|
||||
|
||||
def __init__(self, encoding_name: str = "cl100k_base") -> None:
|
||||
"""Initialize tokenizer.
|
||||
|
||||
Args:
|
||||
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
|
||||
"""
|
||||
self._encoding: Optional[object] = None
|
||||
self._encoding_name = encoding_name
|
||||
|
||||
if TIKTOKEN_AVAILABLE:
|
||||
try:
|
||||
self._encoding = tiktoken.get_encoding(encoding_name)
|
||||
except Exception:
|
||||
# Fallback to character counting if encoding fails
|
||||
self._encoding = None
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in text.
|
||||
|
||||
Uses tiktoken if available, otherwise falls back to character count / 4.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if self._encoding is not None:
|
||||
try:
|
||||
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
# Fall through to character count fallback
|
||||
pass
|
||||
|
||||
# Fallback: rough estimate using character count
|
||||
# Average of ~4 characters per token for English text
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
def is_using_tiktoken(self) -> bool:
|
||||
"""Check if tiktoken is being used.
|
||||
|
||||
Returns:
|
||||
True if tiktoken is available and initialized
|
||||
"""
|
||||
return self._encoding is not None
|
||||
|
||||
|
||||
# Global default tokenizer instance
|
||||
_default_tokenizer: Optional[Tokenizer] = None
|
||||
|
||||
|
||||
def get_default_tokenizer() -> Tokenizer:
|
||||
"""Get the global default tokenizer instance.
|
||||
|
||||
Returns:
|
||||
Shared Tokenizer instance
|
||||
"""
|
||||
global _default_tokenizer
|
||||
if _default_tokenizer is None:
|
||||
_default_tokenizer = Tokenizer()
|
||||
return _default_tokenizer
|
||||
|
||||
|
||||
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
|
||||
"""Count tokens in text using default or provided tokenizer.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
tokenizer: Optional tokenizer instance (uses default if None)
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_default_tokenizer()
|
||||
return tokenizer.count_tokens(text)
|
||||
809
codex-lens/build/lib/codexlens/parsers/treesitter_parser.py
Normal file
809
codex-lens/build/lib/codexlens/parsers/treesitter_parser.py
Normal file
@@ -0,0 +1,809 @@
|
||||
"""Tree-sitter based parser for CodexLens.
|
||||
|
||||
Provides precise AST-level parsing via tree-sitter.
|
||||
|
||||
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
|
||||
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
|
||||
return `None`; callers should use a regex-based fallback such as
|
||||
`codexlens.parsers.factory.SimpleRegexParser`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Language as TreeSitterLanguage
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
from tree_sitter import Parser as TreeSitterParser
|
||||
TREE_SITTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TreeSitterLanguage = None # type: ignore[assignment]
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TreeSitterParser = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
class TreeSitterSymbolParser:
|
||||
"""Parser using tree-sitter for AST-level symbol extraction."""
|
||||
|
||||
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
|
||||
"""Initialize tree-sitter parser for a language.
|
||||
|
||||
Args:
|
||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||
path: Optional file path for language variant detection (e.g., .tsx)
|
||||
"""
|
||||
self.language_id = language_id
|
||||
self.path = path
|
||||
self._parser: Optional[object] = None
|
||||
self._language: Optional[TreeSitterLanguage] = None
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
|
||||
if TREE_SITTER_AVAILABLE:
|
||||
self._initialize_parser()
|
||||
|
||||
def _initialize_parser(self) -> None:
|
||||
"""Initialize tree-sitter parser and language."""
|
||||
if TreeSitterParser is None or TreeSitterLanguage is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Load language grammar
|
||||
if self.language_id == "python":
|
||||
import tree_sitter_python
|
||||
self._language = TreeSitterLanguage(tree_sitter_python.language())
|
||||
elif self.language_id == "javascript":
|
||||
import tree_sitter_javascript
|
||||
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
|
||||
elif self.language_id == "typescript":
|
||||
import tree_sitter_typescript
|
||||
# Detect TSX files by extension
|
||||
if self.path is not None and self.path.suffix.lower() == ".tsx":
|
||||
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
|
||||
else:
|
||||
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
|
||||
else:
|
||||
return
|
||||
|
||||
# Create parser
|
||||
self._parser = TreeSitterParser()
|
||||
if hasattr(self._parser, "set_language"):
|
||||
self._parser.set_language(self._language) # type: ignore[attr-defined]
|
||||
else:
|
||||
self._parser.language = self._language # type: ignore[assignment]
|
||||
|
||||
except Exception:
|
||||
# Gracefully handle missing language bindings
|
||||
self._parser = None
|
||||
self._language = None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if tree-sitter parser is available.
|
||||
|
||||
Returns:
|
||||
True if parser is initialized and ready
|
||||
"""
|
||||
return self._parser is not None and self._language is not None
|
||||
|
||||
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
|
||||
if not self.is_available() or self._parser is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
return source_bytes, tree.root_node
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
|
||||
"""Parse source code and extract symbols without creating IndexedFile.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
|
||||
Returns:
|
||||
List of symbols if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
parsed = self._parse_tree(text)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
source_bytes, root = parsed
|
||||
try:
|
||||
return self._extract_symbols(source_bytes, root)
|
||||
except Exception:
|
||||
# Gracefully handle extraction errors
|
||||
return None
|
||||
|
||||
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
|
||||
"""Parse source code and extract symbols.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
path: File path
|
||||
|
||||
Returns:
|
||||
IndexedFile if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
parsed = self._parse_tree(text)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
source_bytes, root = parsed
|
||||
try:
|
||||
symbols = self._extract_symbols(source_bytes, root)
|
||||
relationships = self._extract_relationships(source_bytes, root, path)
|
||||
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
relationships=relationships,
|
||||
)
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return None
|
||||
|
||||
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of extracted symbols
|
||||
"""
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_symbols(source_bytes, root)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_symbols(source_bytes, root)
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_relationships(source_bytes, root, path)
|
||||
if self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_relationships(source_bytes, root, path)
|
||||
return []
|
||||
|
||||
def _extract_python_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
source_file = str(path.resolve())
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
scope_stack: List[str] = []
|
||||
alias_stack: List[Dict[str, str]] = [{}]
|
||||
|
||||
def record_import(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_call(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
base = target_symbol.split(".", 1)[0]
|
||||
if base in {"self", "cls"}:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.INHERITS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def visit(node: TreeSitterNode) -> None:
|
||||
pushed_scope = False
|
||||
pushed_aliases = False
|
||||
|
||||
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type == "class_definition" and pushed_scope:
|
||||
superclasses = node.child_by_field_name("superclasses")
|
||||
if superclasses is not None:
|
||||
for child in superclasses.children:
|
||||
dotted = self._python_expression_to_dotted(source_bytes, child)
|
||||
if not dotted:
|
||||
continue
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_inherits(resolved, self._node_start_line(node))
|
||||
|
||||
if node.type in {"import_statement", "import_from_statement"}:
|
||||
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
|
||||
if updates:
|
||||
alias_stack[-1].update(updates)
|
||||
for target_symbol in imported_targets:
|
||||
record_import(target_symbol, self._node_start_line(node))
|
||||
|
||||
if node.type == "call":
|
||||
fn_node = node.child_by_field_name("function")
|
||||
if fn_node is not None:
|
||||
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_call(resolved, self._node_start_line(node))
|
||||
|
||||
for child in node.children:
|
||||
visit(child)
|
||||
|
||||
if pushed_aliases:
|
||||
alias_stack.pop()
|
||||
if pushed_scope:
|
||||
scope_stack.pop()
|
||||
|
||||
visit(root)
|
||||
return relationships
|
||||
|
||||
def _extract_js_ts_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
source_file = str(path.resolve())
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
scope_stack: List[str] = []
|
||||
alias_stack: List[Dict[str, str]] = [{}]
|
||||
|
||||
def record_import(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_call(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
base = target_symbol.split(".", 1)[0]
|
||||
if base in {"this", "super"}:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.INHERITS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def visit(node: TreeSitterNode) -> None:
|
||||
pushed_scope = False
|
||||
pushed_aliases = False
|
||||
|
||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if pushed_scope:
|
||||
superclass = node.child_by_field_name("superclass")
|
||||
if superclass is not None:
|
||||
dotted = self._js_expression_to_dotted(source_bytes, superclass)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_inherits(resolved, self._node_start_line(node))
|
||||
|
||||
if node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is not None
|
||||
and value_node is not None
|
||||
and name_node.type in {"identifier", "property_identifier"}
|
||||
and value_node.type == "arrow_function"
|
||||
):
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type == "method_definition" and self._has_class_ancestor(node):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name and scope_name != "constructor":
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type in {"import_declaration", "import_statement"}:
|
||||
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
|
||||
if updates:
|
||||
alias_stack[-1].update(updates)
|
||||
for target_symbol in imported_targets:
|
||||
record_import(target_symbol, self._node_start_line(node))
|
||||
|
||||
# Best-effort support for CommonJS require() imports:
|
||||
# const fs = require("fs")
|
||||
if node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is not None
|
||||
and value_node is not None
|
||||
and name_node.type == "identifier"
|
||||
and value_node.type == "call_expression"
|
||||
):
|
||||
callee = value_node.child_by_field_name("function")
|
||||
args = value_node.child_by_field_name("arguments")
|
||||
if (
|
||||
callee is not None
|
||||
and self._node_text(source_bytes, callee).strip() == "require"
|
||||
and args is not None
|
||||
):
|
||||
module_name = self._js_first_string_argument(source_bytes, args)
|
||||
if module_name:
|
||||
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
|
||||
record_import(module_name, self._node_start_line(node))
|
||||
|
||||
if node.type == "call_expression":
|
||||
fn_node = node.child_by_field_name("function")
|
||||
if fn_node is not None:
|
||||
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_call(resolved, self._node_start_line(node))
|
||||
|
||||
for child in node.children:
|
||||
visit(child)
|
||||
|
||||
if pushed_aliases:
|
||||
alias_stack.pop()
|
||||
if pushed_scope:
|
||||
scope_stack.pop()
|
||||
|
||||
visit(root)
|
||||
return relationships
|
||||
|
||||
def _node_start_line(self, node: TreeSitterNode) -> int:
|
||||
return node.start_point[0] + 1
|
||||
|
||||
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
|
||||
dotted = (dotted or "").strip()
|
||||
if not dotted:
|
||||
return ""
|
||||
|
||||
base, sep, rest = dotted.partition(".")
|
||||
resolved_base = aliases.get(base, base)
|
||||
if not rest:
|
||||
return resolved_base
|
||||
if resolved_base and rest:
|
||||
return f"{resolved_base}.{rest}"
|
||||
return resolved_base
|
||||
|
||||
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
if node.type in {"identifier", "dotted_name"}:
|
||||
return self._node_text(source_bytes, node).strip()
|
||||
if node.type == "attribute":
|
||||
obj = node.child_by_field_name("object")
|
||||
attr = node.child_by_field_name("attribute")
|
||||
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
|
||||
if obj_text and attr_text:
|
||||
return f"{obj_text}.{attr_text}"
|
||||
return obj_text or attr_text
|
||||
return ""
|
||||
|
||||
def _python_import_aliases_and_targets(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
node: TreeSitterNode,
|
||||
) -> tuple[Dict[str, str], List[str]]:
|
||||
aliases: Dict[str, str] = {}
|
||||
targets: List[str] = []
|
||||
|
||||
if node.type == "import_statement":
|
||||
for child in node.children:
|
||||
if child.type == "aliased_import":
|
||||
name_node = child.child_by_field_name("name")
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
module_name = self._node_text(source_bytes, name_node).strip()
|
||||
if not module_name:
|
||||
continue
|
||||
bound_name = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else module_name.split(".", 1)[0]
|
||||
)
|
||||
if bound_name:
|
||||
aliases[bound_name] = module_name
|
||||
targets.append(module_name)
|
||||
elif child.type == "dotted_name":
|
||||
module_name = self._node_text(source_bytes, child).strip()
|
||||
if not module_name:
|
||||
continue
|
||||
bound_name = module_name.split(".", 1)[0]
|
||||
if bound_name:
|
||||
aliases[bound_name] = bound_name
|
||||
targets.append(module_name)
|
||||
|
||||
if node.type == "import_from_statement":
|
||||
module_name = ""
|
||||
module_node = node.child_by_field_name("module_name")
|
||||
if module_node is None:
|
||||
for child in node.children:
|
||||
if child.type == "dotted_name":
|
||||
module_node = child
|
||||
break
|
||||
if module_node is not None:
|
||||
module_name = self._node_text(source_bytes, module_node).strip()
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "aliased_import":
|
||||
name_node = child.child_by_field_name("name")
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
imported_name = self._node_text(source_bytes, name_node).strip()
|
||||
if not imported_name or imported_name == "*":
|
||||
continue
|
||||
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||
bound_name = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else imported_name
|
||||
)
|
||||
if bound_name:
|
||||
aliases[bound_name] = target
|
||||
targets.append(target)
|
||||
elif child.type == "identifier":
|
||||
imported_name = self._node_text(source_bytes, child).strip()
|
||||
if not imported_name or imported_name in {"from", "import", "*"}:
|
||||
continue
|
||||
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||
aliases[imported_name] = target
|
||||
targets.append(target)
|
||||
|
||||
return aliases, targets
|
||||
|
||||
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
if node.type in {"this", "super"}:
|
||||
return node.type
|
||||
if node.type in {"identifier", "property_identifier"}:
|
||||
return self._node_text(source_bytes, node).strip()
|
||||
if node.type == "member_expression":
|
||||
obj = node.child_by_field_name("object")
|
||||
prop = node.child_by_field_name("property")
|
||||
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
|
||||
if obj_text and prop_text:
|
||||
return f"{obj_text}.{prop_text}"
|
||||
return obj_text or prop_text
|
||||
return ""
|
||||
|
||||
def _js_import_aliases_and_targets(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
node: TreeSitterNode,
|
||||
) -> tuple[Dict[str, str], List[str]]:
|
||||
aliases: Dict[str, str] = {}
|
||||
targets: List[str] = []
|
||||
|
||||
module_name = ""
|
||||
source_node = node.child_by_field_name("source")
|
||||
if source_node is not None:
|
||||
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
|
||||
if module_name:
|
||||
targets.append(module_name)
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "import_clause":
|
||||
for clause_child in child.children:
|
||||
if clause_child.type == "identifier":
|
||||
# Default import: import React from "react"
|
||||
local = self._node_text(source_bytes, clause_child).strip()
|
||||
if local and module_name:
|
||||
aliases[local] = module_name
|
||||
if clause_child.type == "namespace_import":
|
||||
# Namespace import: import * as fs from "fs"
|
||||
name_node = clause_child.child_by_field_name("name")
|
||||
if name_node is not None and module_name:
|
||||
local = self._node_text(source_bytes, name_node).strip()
|
||||
if local:
|
||||
aliases[local] = module_name
|
||||
if clause_child.type == "named_imports":
|
||||
for spec in clause_child.children:
|
||||
if spec.type != "import_specifier":
|
||||
continue
|
||||
name_node = spec.child_by_field_name("name")
|
||||
alias_node = spec.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
imported = self._node_text(source_bytes, name_node).strip()
|
||||
if not imported:
|
||||
continue
|
||||
local = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else imported
|
||||
)
|
||||
if local and module_name:
|
||||
aliases[local] = f"{module_name}.{imported}"
|
||||
targets.append(f"{module_name}.{imported}")
|
||||
|
||||
return aliases, targets
|
||||
|
||||
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
|
||||
for child in args_node.children:
|
||||
if child.type == "string":
|
||||
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
|
||||
return ""
|
||||
|
||||
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract Python symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of Python symbols (classes, functions, methods)
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind=self._python_function_kind(node),
|
||||
range=self._node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract JavaScript/TypeScript symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of JS/TS symbols (classes, functions, methods)
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is None
|
||||
or value_node is None
|
||||
or name_node.type not in {"identifier", "property_identifier"}
|
||||
or value_node.type != "arrow_function"
|
||||
):
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type == "method_definition" and self._has_class_ancestor(node):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name == "constructor":
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=name,
|
||||
kind="method",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
def _python_function_kind(self, node: TreeSitterNode) -> str:
|
||||
"""Determine if Python function is a method or standalone function.
|
||||
|
||||
Args:
|
||||
node: Function definition node
|
||||
|
||||
Returns:
|
||||
'method' if inside a class, 'function' otherwise
|
||||
"""
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"function_definition", "async_function_definition"}:
|
||||
return "function"
|
||||
if parent.type == "class_definition":
|
||||
return "method"
|
||||
parent = parent.parent
|
||||
return "function"
|
||||
|
||||
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
|
||||
"""Check if node has a class ancestor.
|
||||
|
||||
Args:
|
||||
node: AST node to check
|
||||
|
||||
Returns:
|
||||
True if node is inside a class
|
||||
"""
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"class_declaration", "class"}:
|
||||
return True
|
||||
parent = parent.parent
|
||||
return False
|
||||
|
||||
def _iter_nodes(self, root: TreeSitterNode):
|
||||
"""Iterate over all nodes in AST.
|
||||
|
||||
Args:
|
||||
root: Root node to start iteration
|
||||
|
||||
Yields:
|
||||
AST nodes in depth-first order
|
||||
"""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
"""Extract text for a node.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
Text content of node
|
||||
"""
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
|
||||
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
|
||||
"""Get line range for a node.
|
||||
|
||||
Args:
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
(start_line, end_line) tuple, 1-based inclusive
|
||||
"""
|
||||
start_line = node.start_point[0] + 1
|
||||
end_line = node.end_point[0] + 1
|
||||
return (start_line, max(start_line, end_line))
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in text.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
return self._tokenizer.count_tokens(text)
|
||||
53
codex-lens/build/lib/codexlens/search/__init__.py
Normal file
53
codex-lens/build/lib/codexlens/search/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from .chain_search import (
|
||||
ChainSearchEngine,
|
||||
SearchOptions,
|
||||
SearchStats,
|
||||
ChainSearchResult,
|
||||
quick_search,
|
||||
)
|
||||
|
||||
# Clustering availability flag (lazy import pattern)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
_clustering_import_error: str | None = None
|
||||
|
||||
try:
|
||||
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
|
||||
from .clustering import check_clustering_available
|
||||
CLUSTERING_AVAILABLE = _clustering_flag
|
||||
except ImportError as e:
|
||||
_clustering_import_error = str(e)
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Fallback when clustering module not loadable."""
|
||||
return False, _clustering_import_error
|
||||
|
||||
|
||||
# Clustering module exports (conditional)
|
||||
try:
|
||||
from .clustering import (
|
||||
BaseClusteringStrategy,
|
||||
ClusteringConfig,
|
||||
ClusteringStrategyFactory,
|
||||
get_strategy,
|
||||
)
|
||||
_clustering_exports = [
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
]
|
||||
except ImportError:
|
||||
_clustering_exports = []
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChainSearchEngine",
|
||||
"SearchOptions",
|
||||
"SearchStats",
|
||||
"ChainSearchResult",
|
||||
"quick_search",
|
||||
# Clustering
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
*_clustering_exports,
|
||||
]
|
||||
@@ -0,0 +1,21 @@
|
||||
"""Association tree module for LSP-based code relationship discovery.
|
||||
|
||||
This module provides components for building and processing call association trees
|
||||
using Language Server Protocol (LSP) call hierarchy capabilities.
|
||||
"""
|
||||
|
||||
from .builder import AssociationTreeBuilder
|
||||
from .data_structures import (
|
||||
CallTree,
|
||||
TreeNode,
|
||||
UniqueNode,
|
||||
)
|
||||
from .deduplicator import ResultDeduplicator
|
||||
|
||||
__all__ = [
|
||||
"AssociationTreeBuilder",
|
||||
"CallTree",
|
||||
"TreeNode",
|
||||
"UniqueNode",
|
||||
"ResultDeduplicator",
|
||||
]
|
||||
@@ -0,0 +1,450 @@
|
||||
"""Association tree builder using LSP call hierarchy.
|
||||
|
||||
Builds call relationship trees by recursively expanding from seed locations
|
||||
using Language Server Protocol (LSP) call hierarchy capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
from .data_structures import CallTree, TreeNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssociationTreeBuilder:
|
||||
"""Builds association trees from seed locations using LSP call hierarchy.
|
||||
|
||||
Uses depth-first recursive expansion to build a tree of code relationships
|
||||
starting from seed locations (typically from vector search results).
|
||||
|
||||
Strategy:
|
||||
- Start from seed locations (vector search results)
|
||||
- For each seed, get call hierarchy items via LSP
|
||||
- Recursively expand incoming calls (callers) if expand_callers=True
|
||||
- Recursively expand outgoing calls (callees) if expand_callees=True
|
||||
- Track visited nodes to prevent cycles
|
||||
- Stop at max_depth or when no more relations found
|
||||
|
||||
Attributes:
|
||||
lsp_manager: StandaloneLspManager for LSP communication
|
||||
visited: Set of visited node IDs to prevent cycles
|
||||
timeout: Timeout for individual LSP requests (seconds)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lsp_manager: StandaloneLspManager,
|
||||
timeout: float = 5.0,
|
||||
analysis_wait: float = 2.0,
|
||||
):
|
||||
"""Initialize AssociationTreeBuilder.
|
||||
|
||||
Args:
|
||||
lsp_manager: StandaloneLspManager instance for LSP communication
|
||||
timeout: Timeout for individual LSP requests in seconds
|
||||
analysis_wait: Time to wait for LSP analysis on first file (seconds)
|
||||
"""
|
||||
self.lsp_manager = lsp_manager
|
||||
self.timeout = timeout
|
||||
self.analysis_wait = analysis_wait
|
||||
self.visited: Set[str] = set()
|
||||
self._analyzed_files: Set[str] = set() # Track files already analyzed
|
||||
|
||||
async def build_tree(
|
||||
self,
|
||||
seed_file_path: str,
|
||||
seed_line: int,
|
||||
seed_character: int = 1,
|
||||
max_depth: int = 5,
|
||||
expand_callers: bool = True,
|
||||
expand_callees: bool = True,
|
||||
) -> CallTree:
|
||||
"""Build call tree from a single seed location.
|
||||
|
||||
Args:
|
||||
seed_file_path: Path to the seed file
|
||||
seed_line: Line number of the seed symbol (1-based)
|
||||
seed_character: Character position (1-based, default 1)
|
||||
max_depth: Maximum recursion depth (default 5)
|
||||
expand_callers: Whether to expand incoming calls (callers)
|
||||
expand_callees: Whether to expand outgoing calls (callees)
|
||||
|
||||
Returns:
|
||||
CallTree containing all discovered nodes and relationships
|
||||
"""
|
||||
tree = CallTree()
|
||||
self.visited.clear()
|
||||
|
||||
# Determine wait time - only wait for analysis on first encounter of file
|
||||
wait_time = 0.0
|
||||
if seed_file_path not in self._analyzed_files:
|
||||
wait_time = self.analysis_wait
|
||||
self._analyzed_files.add(seed_file_path)
|
||||
|
||||
# Get call hierarchy items for the seed position
|
||||
try:
|
||||
hierarchy_items = await asyncio.wait_for(
|
||||
self.lsp_manager.get_call_hierarchy_items(
|
||||
file_path=seed_file_path,
|
||||
line=seed_line,
|
||||
character=seed_character,
|
||||
wait_for_analysis=wait_time,
|
||||
),
|
||||
timeout=self.timeout + wait_time,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Timeout getting call hierarchy items for %s:%d",
|
||||
seed_file_path,
|
||||
seed_line,
|
||||
)
|
||||
return tree
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting call hierarchy items for %s:%d: %s",
|
||||
seed_file_path,
|
||||
seed_line,
|
||||
e,
|
||||
)
|
||||
return tree
|
||||
|
||||
if not hierarchy_items:
|
||||
logger.debug(
|
||||
"No call hierarchy items found for %s:%d",
|
||||
seed_file_path,
|
||||
seed_line,
|
||||
)
|
||||
return tree
|
||||
|
||||
# Create root nodes from hierarchy items
|
||||
for item_dict in hierarchy_items:
|
||||
# Convert LSP dict to CallHierarchyItem
|
||||
item = self._dict_to_call_hierarchy_item(item_dict)
|
||||
if not item:
|
||||
continue
|
||||
|
||||
root_node = TreeNode(
|
||||
item=item,
|
||||
depth=0,
|
||||
path_from_root=[self._create_node_id(item)],
|
||||
)
|
||||
tree.roots.append(root_node)
|
||||
tree.add_node(root_node)
|
||||
|
||||
# Mark as visited
|
||||
self.visited.add(root_node.node_id)
|
||||
|
||||
# Recursively expand the tree
|
||||
await self._expand_node(
|
||||
node=root_node,
|
||||
node_dict=item_dict,
|
||||
tree=tree,
|
||||
current_depth=0,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
|
||||
tree.depth_reached = max_depth
|
||||
return tree
|
||||
|
||||
async def _expand_node(
|
||||
self,
|
||||
node: TreeNode,
|
||||
node_dict: Dict,
|
||||
tree: CallTree,
|
||||
current_depth: int,
|
||||
max_depth: int,
|
||||
expand_callers: bool,
|
||||
expand_callees: bool,
|
||||
) -> None:
|
||||
"""Recursively expand a node by fetching its callers and callees.
|
||||
|
||||
Args:
|
||||
node: TreeNode to expand
|
||||
node_dict: LSP CallHierarchyItem dict (for LSP requests)
|
||||
tree: CallTree to add discovered nodes to
|
||||
current_depth: Current recursion depth
|
||||
max_depth: Maximum allowed depth
|
||||
expand_callers: Whether to expand incoming calls
|
||||
expand_callees: Whether to expand outgoing calls
|
||||
"""
|
||||
# Stop if max depth reached
|
||||
if current_depth >= max_depth:
|
||||
return
|
||||
|
||||
# Prepare tasks for parallel expansion
|
||||
tasks = []
|
||||
|
||||
if expand_callers:
|
||||
tasks.append(
|
||||
self._expand_incoming_calls(
|
||||
node=node,
|
||||
node_dict=node_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
)
|
||||
|
||||
if expand_callees:
|
||||
tasks.append(
|
||||
self._expand_outgoing_calls(
|
||||
node=node,
|
||||
node_dict=node_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute expansions in parallel
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _expand_incoming_calls(
|
||||
self,
|
||||
node: TreeNode,
|
||||
node_dict: Dict,
|
||||
tree: CallTree,
|
||||
current_depth: int,
|
||||
max_depth: int,
|
||||
expand_callers: bool,
|
||||
expand_callees: bool,
|
||||
) -> None:
|
||||
"""Expand incoming calls (callers) for a node.
|
||||
|
||||
Args:
|
||||
node: TreeNode being expanded
|
||||
node_dict: LSP dict for the node
|
||||
tree: CallTree to add nodes to
|
||||
current_depth: Current depth
|
||||
max_depth: Maximum depth
|
||||
expand_callers: Whether to continue expanding callers
|
||||
expand_callees: Whether to expand callees
|
||||
"""
|
||||
try:
|
||||
incoming_calls = await asyncio.wait_for(
|
||||
self.lsp_manager.get_incoming_calls(item=node_dict),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Timeout getting incoming calls for %s", node.node_id)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Error getting incoming calls for %s: %s", node.node_id, e)
|
||||
return
|
||||
|
||||
if not incoming_calls:
|
||||
return
|
||||
|
||||
# Process each incoming call
|
||||
for call_dict in incoming_calls:
|
||||
caller_dict = call_dict.get("from")
|
||||
if not caller_dict:
|
||||
continue
|
||||
|
||||
# Convert to CallHierarchyItem
|
||||
caller_item = self._dict_to_call_hierarchy_item(caller_dict)
|
||||
if not caller_item:
|
||||
continue
|
||||
|
||||
caller_id = self._create_node_id(caller_item)
|
||||
|
||||
# Check for cycles
|
||||
if caller_id in self.visited:
|
||||
# Create cycle marker node
|
||||
cycle_node = TreeNode(
|
||||
item=caller_item,
|
||||
depth=current_depth + 1,
|
||||
is_cycle=True,
|
||||
path_from_root=node.path_from_root + [caller_id],
|
||||
)
|
||||
node.parents.append(cycle_node)
|
||||
continue
|
||||
|
||||
# Create new caller node
|
||||
caller_node = TreeNode(
|
||||
item=caller_item,
|
||||
depth=current_depth + 1,
|
||||
path_from_root=node.path_from_root + [caller_id],
|
||||
)
|
||||
|
||||
# Add to tree
|
||||
tree.add_node(caller_node)
|
||||
tree.add_edge(caller_node, node)
|
||||
|
||||
# Update relationships
|
||||
node.parents.append(caller_node)
|
||||
caller_node.children.append(node)
|
||||
|
||||
# Mark as visited
|
||||
self.visited.add(caller_id)
|
||||
|
||||
# Recursively expand the caller
|
||||
await self._expand_node(
|
||||
node=caller_node,
|
||||
node_dict=caller_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth + 1,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
|
||||
async def _expand_outgoing_calls(
|
||||
self,
|
||||
node: TreeNode,
|
||||
node_dict: Dict,
|
||||
tree: CallTree,
|
||||
current_depth: int,
|
||||
max_depth: int,
|
||||
expand_callers: bool,
|
||||
expand_callees: bool,
|
||||
) -> None:
|
||||
"""Expand outgoing calls (callees) for a node.
|
||||
|
||||
Args:
|
||||
node: TreeNode being expanded
|
||||
node_dict: LSP dict for the node
|
||||
tree: CallTree to add nodes to
|
||||
current_depth: Current depth
|
||||
max_depth: Maximum depth
|
||||
expand_callers: Whether to expand callers
|
||||
expand_callees: Whether to continue expanding callees
|
||||
"""
|
||||
try:
|
||||
outgoing_calls = await asyncio.wait_for(
|
||||
self.lsp_manager.get_outgoing_calls(item=node_dict),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Timeout getting outgoing calls for %s", node.node_id)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Error getting outgoing calls for %s: %s", node.node_id, e)
|
||||
return
|
||||
|
||||
if not outgoing_calls:
|
||||
return
|
||||
|
||||
# Process each outgoing call
|
||||
for call_dict in outgoing_calls:
|
||||
callee_dict = call_dict.get("to")
|
||||
if not callee_dict:
|
||||
continue
|
||||
|
||||
# Convert to CallHierarchyItem
|
||||
callee_item = self._dict_to_call_hierarchy_item(callee_dict)
|
||||
if not callee_item:
|
||||
continue
|
||||
|
||||
callee_id = self._create_node_id(callee_item)
|
||||
|
||||
# Check for cycles
|
||||
if callee_id in self.visited:
|
||||
# Create cycle marker node
|
||||
cycle_node = TreeNode(
|
||||
item=callee_item,
|
||||
depth=current_depth + 1,
|
||||
is_cycle=True,
|
||||
path_from_root=node.path_from_root + [callee_id],
|
||||
)
|
||||
node.children.append(cycle_node)
|
||||
continue
|
||||
|
||||
# Create new callee node
|
||||
callee_node = TreeNode(
|
||||
item=callee_item,
|
||||
depth=current_depth + 1,
|
||||
path_from_root=node.path_from_root + [callee_id],
|
||||
)
|
||||
|
||||
# Add to tree
|
||||
tree.add_node(callee_node)
|
||||
tree.add_edge(node, callee_node)
|
||||
|
||||
# Update relationships
|
||||
node.children.append(callee_node)
|
||||
callee_node.parents.append(node)
|
||||
|
||||
# Mark as visited
|
||||
self.visited.add(callee_id)
|
||||
|
||||
# Recursively expand the callee
|
||||
await self._expand_node(
|
||||
node=callee_node,
|
||||
node_dict=callee_dict,
|
||||
tree=tree,
|
||||
current_depth=current_depth + 1,
|
||||
max_depth=max_depth,
|
||||
expand_callers=expand_callers,
|
||||
expand_callees=expand_callees,
|
||||
)
|
||||
|
||||
def _dict_to_call_hierarchy_item(
|
||||
self, item_dict: Dict
|
||||
) -> Optional[CallHierarchyItem]:
|
||||
"""Convert LSP dict to CallHierarchyItem.
|
||||
|
||||
Args:
|
||||
item_dict: LSP CallHierarchyItem dictionary
|
||||
|
||||
Returns:
|
||||
CallHierarchyItem or None if conversion fails
|
||||
"""
|
||||
try:
|
||||
# Extract URI and convert to file path
|
||||
uri = item_dict.get("uri", "")
|
||||
file_path = uri.replace("file:///", "").replace("file://", "")
|
||||
|
||||
# Handle Windows paths (file:///C:/...)
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
# Extract range
|
||||
range_dict = item_dict.get("range", {})
|
||||
start = range_dict.get("start", {})
|
||||
end = range_dict.get("end", {})
|
||||
|
||||
# Create Range (convert from 0-based to 1-based)
|
||||
item_range = Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
)
|
||||
|
||||
return CallHierarchyItem(
|
||||
name=item_dict.get("name", "unknown"),
|
||||
kind=str(item_dict.get("kind", "unknown")),
|
||||
file_path=file_path,
|
||||
range=item_range,
|
||||
detail=item_dict.get("detail"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to convert dict to CallHierarchyItem: %s", e)
|
||||
return None
|
||||
|
||||
def _create_node_id(self, item: CallHierarchyItem) -> str:
|
||||
"""Create unique node ID from CallHierarchyItem.
|
||||
|
||||
Args:
|
||||
item: CallHierarchyItem
|
||||
|
||||
Returns:
|
||||
Unique node ID string
|
||||
"""
|
||||
return f"{item.file_path}:{item.name}:{item.range.start_line}"
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Data structures for association tree building.
|
||||
|
||||
Defines the core data classes for representing call hierarchy trees and
|
||||
deduplicated results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
|
||||
|
||||
|
||||
@dataclass
|
||||
class TreeNode:
|
||||
"""Node in the call association tree.
|
||||
|
||||
Represents a single function/method in the tree, including its position
|
||||
in the hierarchy and relationships.
|
||||
|
||||
Attributes:
|
||||
item: LSP CallHierarchyItem containing symbol information
|
||||
depth: Distance from the root node (seed) - 0 for roots
|
||||
children: List of child nodes (functions called by this node)
|
||||
parents: List of parent nodes (functions that call this node)
|
||||
is_cycle: Whether this node creates a circular reference
|
||||
path_from_root: Path (list of node IDs) from root to this node
|
||||
"""
|
||||
|
||||
item: CallHierarchyItem
|
||||
depth: int = 0
|
||||
children: List[TreeNode] = field(default_factory=list)
|
||||
parents: List[TreeNode] = field(default_factory=list)
|
||||
is_cycle: bool = False
|
||||
path_from_root: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
"""Unique identifier for this node."""
|
||||
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on node ID."""
|
||||
return hash(self.node_id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on node ID."""
|
||||
if not isinstance(other, TreeNode):
|
||||
return False
|
||||
return self.node_id == other.node_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the node."""
|
||||
cycle_marker = " [CYCLE]" if self.is_cycle else ""
|
||||
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallTree:
|
||||
"""Complete call tree structure built from seeds.
|
||||
|
||||
Contains all nodes discovered through recursive expansion and
|
||||
the relationships between them.
|
||||
|
||||
Attributes:
|
||||
roots: List of root nodes (seed symbols)
|
||||
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
|
||||
node_list: Flat list of all nodes in tree order
|
||||
edges: List of (from_node_id, to_node_id) tuples representing calls
|
||||
depth_reached: Maximum depth achieved in expansion
|
||||
"""
|
||||
|
||||
roots: List[TreeNode] = field(default_factory=list)
|
||||
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
|
||||
node_list: List[TreeNode] = field(default_factory=list)
|
||||
edges: List[tuple[str, str]] = field(default_factory=list)
|
||||
depth_reached: int = 0
|
||||
|
||||
def add_node(self, node: TreeNode) -> None:
|
||||
"""Add a node to the tree.
|
||||
|
||||
Args:
|
||||
node: TreeNode to add
|
||||
"""
|
||||
if node.node_id not in self.all_nodes:
|
||||
self.all_nodes[node.node_id] = node
|
||||
self.node_list.append(node)
|
||||
|
||||
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
|
||||
"""Add an edge between two nodes.
|
||||
|
||||
Args:
|
||||
from_node: Source node
|
||||
to_node: Target node
|
||||
"""
|
||||
edge = (from_node.node_id, to_node.node_id)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[TreeNode]:
|
||||
"""Get a node by ID.
|
||||
|
||||
Args:
|
||||
node_id: Node identifier
|
||||
|
||||
Returns:
|
||||
TreeNode if found, None otherwise
|
||||
"""
|
||||
return self.all_nodes.get(node_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return total number of nodes in tree."""
|
||||
return len(self.all_nodes)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the tree."""
|
||||
return (
|
||||
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
|
||||
f"depth={self.depth_reached})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniqueNode:
|
||||
"""Deduplicated unique code symbol from the tree.
|
||||
|
||||
Represents a single unique code location that may appear multiple times
|
||||
in the tree under different contexts. Contains aggregated information
|
||||
about all occurrences.
|
||||
|
||||
Attributes:
|
||||
file_path: Absolute path to the file
|
||||
name: Symbol name (function, method, class, etc.)
|
||||
kind: Symbol kind (function, method, class, etc.)
|
||||
range: Code range in the file
|
||||
min_depth: Minimum depth at which this node appears in the tree
|
||||
occurrences: Number of times this node appears in the tree
|
||||
paths: List of paths from roots to this node
|
||||
context_nodes: Related nodes from the tree
|
||||
score: Composite relevance score (higher is better)
|
||||
"""
|
||||
|
||||
file_path: str
|
||||
name: str
|
||||
kind: str
|
||||
range: Range
|
||||
min_depth: int = 0
|
||||
occurrences: int = 1
|
||||
paths: List[List[str]] = field(default_factory=list)
|
||||
context_nodes: List[str] = field(default_factory=list)
|
||||
score: float = 0.0
|
||||
|
||||
@property
|
||||
def node_key(self) -> tuple[str, int, int]:
|
||||
"""Unique key for deduplication.
|
||||
|
||||
Uses (file_path, start_line, end_line) as the unique identifier
|
||||
for this symbol across all occurrences.
|
||||
"""
|
||||
return (
|
||||
self.file_path,
|
||||
self.range.start_line,
|
||||
self.range.end_line,
|
||||
)
|
||||
|
||||
def add_path(self, path: List[str]) -> None:
|
||||
"""Add a path from root to this node.
|
||||
|
||||
Args:
|
||||
path: List of node IDs from root to this node
|
||||
"""
|
||||
if path not in self.paths:
|
||||
self.paths.append(path)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on node key."""
|
||||
return hash(self.node_key)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on node key."""
|
||||
if not isinstance(other, UniqueNode):
|
||||
return False
|
||||
return self.node_key == other.node_key
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the unique node."""
|
||||
return (
|
||||
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
|
||||
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
|
||||
)
|
||||
@@ -0,0 +1,301 @@
|
||||
"""Result deduplication for association tree nodes.
|
||||
|
||||
Provides functionality to extract unique nodes from a call tree and assign
|
||||
relevance scores based on various factors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .data_structures import (
|
||||
CallTree,
|
||||
TreeNode,
|
||||
UniqueNode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Symbol kind weights for scoring (higher = more relevant)
|
||||
KIND_WEIGHTS: Dict[str, float] = {
|
||||
# Functions and methods are primary targets
|
||||
"function": 1.0,
|
||||
"method": 1.0,
|
||||
"12": 1.0, # LSP SymbolKind.Function
|
||||
"6": 1.0, # LSP SymbolKind.Method
|
||||
# Classes are important but secondary
|
||||
"class": 0.8,
|
||||
"5": 0.8, # LSP SymbolKind.Class
|
||||
# Interfaces and types
|
||||
"interface": 0.7,
|
||||
"11": 0.7, # LSP SymbolKind.Interface
|
||||
"type": 0.6,
|
||||
# Constructors
|
||||
"constructor": 0.9,
|
||||
"9": 0.9, # LSP SymbolKind.Constructor
|
||||
# Variables and constants
|
||||
"variable": 0.4,
|
||||
"13": 0.4, # LSP SymbolKind.Variable
|
||||
"constant": 0.5,
|
||||
"14": 0.5, # LSP SymbolKind.Constant
|
||||
# Default for unknown kinds
|
||||
"unknown": 0.3,
|
||||
}
|
||||
|
||||
|
||||
class ResultDeduplicator:
|
||||
"""Extracts and scores unique nodes from call trees.
|
||||
|
||||
Processes a CallTree to extract unique code locations, merging duplicates
|
||||
and assigning relevance scores based on:
|
||||
- Depth: Shallower nodes (closer to seeds) score higher
|
||||
- Frequency: Nodes appearing multiple times score higher
|
||||
- Kind: Function/method > class > variable
|
||||
|
||||
Attributes:
|
||||
depth_weight: Weight for depth factor in scoring (default 0.4)
|
||||
frequency_weight: Weight for frequency factor (default 0.3)
|
||||
kind_weight: Weight for symbol kind factor (default 0.3)
|
||||
max_depth_penalty: Maximum depth before full penalty applied
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth_weight: float = 0.4,
|
||||
frequency_weight: float = 0.3,
|
||||
kind_weight: float = 0.3,
|
||||
max_depth_penalty: int = 10,
|
||||
):
|
||||
"""Initialize ResultDeduplicator.
|
||||
|
||||
Args:
|
||||
depth_weight: Weight for depth factor (0.0-1.0)
|
||||
frequency_weight: Weight for frequency factor (0.0-1.0)
|
||||
kind_weight: Weight for symbol kind factor (0.0-1.0)
|
||||
max_depth_penalty: Depth at which score becomes 0 for depth factor
|
||||
"""
|
||||
self.depth_weight = depth_weight
|
||||
self.frequency_weight = frequency_weight
|
||||
self.kind_weight = kind_weight
|
||||
self.max_depth_penalty = max_depth_penalty
|
||||
|
||||
def deduplicate(
|
||||
self,
|
||||
tree: CallTree,
|
||||
max_results: Optional[int] = None,
|
||||
) -> List[UniqueNode]:
|
||||
"""Extract unique nodes from the call tree.
|
||||
|
||||
Traverses the tree, groups nodes by their unique key (file_path,
|
||||
start_line, end_line), and merges duplicate occurrences.
|
||||
|
||||
Args:
|
||||
tree: CallTree to process
|
||||
max_results: Maximum number of results to return (None = all)
|
||||
|
||||
Returns:
|
||||
List of UniqueNode objects, sorted by score descending
|
||||
"""
|
||||
if not tree.node_list:
|
||||
return []
|
||||
|
||||
# Group nodes by unique key
|
||||
unique_map: Dict[tuple, UniqueNode] = {}
|
||||
|
||||
for node in tree.node_list:
|
||||
if node.is_cycle:
|
||||
# Skip cycle markers - they point to already-counted nodes
|
||||
continue
|
||||
|
||||
key = self._get_node_key(node)
|
||||
|
||||
if key in unique_map:
|
||||
# Update existing unique node
|
||||
unique_node = unique_map[key]
|
||||
unique_node.occurrences += 1
|
||||
unique_node.min_depth = min(unique_node.min_depth, node.depth)
|
||||
unique_node.add_path(node.path_from_root)
|
||||
|
||||
# Collect context from relationships
|
||||
for parent in node.parents:
|
||||
if not parent.is_cycle:
|
||||
unique_node.context_nodes.append(parent.node_id)
|
||||
for child in node.children:
|
||||
if not child.is_cycle:
|
||||
unique_node.context_nodes.append(child.node_id)
|
||||
else:
|
||||
# Create new unique node
|
||||
unique_node = UniqueNode(
|
||||
file_path=node.item.file_path,
|
||||
name=node.item.name,
|
||||
kind=node.item.kind,
|
||||
range=node.item.range,
|
||||
min_depth=node.depth,
|
||||
occurrences=1,
|
||||
paths=[node.path_from_root.copy()],
|
||||
context_nodes=[],
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
# Collect initial context
|
||||
for parent in node.parents:
|
||||
if not parent.is_cycle:
|
||||
unique_node.context_nodes.append(parent.node_id)
|
||||
for child in node.children:
|
||||
if not child.is_cycle:
|
||||
unique_node.context_nodes.append(child.node_id)
|
||||
|
||||
unique_map[key] = unique_node
|
||||
|
||||
# Calculate scores for all unique nodes
|
||||
unique_nodes = list(unique_map.values())
|
||||
|
||||
# Find max frequency for normalization
|
||||
max_frequency = max((n.occurrences for n in unique_nodes), default=1)
|
||||
|
||||
for node in unique_nodes:
|
||||
node.score = self._score_node(node, max_frequency)
|
||||
|
||||
# Sort by score descending
|
||||
unique_nodes.sort(key=lambda n: n.score, reverse=True)
|
||||
|
||||
# Apply max_results limit
|
||||
if max_results is not None and max_results > 0:
|
||||
unique_nodes = unique_nodes[:max_results]
|
||||
|
||||
logger.debug(
|
||||
"Deduplicated %d tree nodes to %d unique nodes",
|
||||
len(tree.node_list),
|
||||
len(unique_nodes),
|
||||
)
|
||||
|
||||
return unique_nodes
|
||||
|
||||
def _score_node(
|
||||
self,
|
||||
node: UniqueNode,
|
||||
max_frequency: int,
|
||||
) -> float:
|
||||
"""Calculate composite score for a unique node.
|
||||
|
||||
Score = depth_weight * depth_score +
|
||||
frequency_weight * frequency_score +
|
||||
kind_weight * kind_score
|
||||
|
||||
Args:
|
||||
node: UniqueNode to score
|
||||
max_frequency: Maximum occurrence count for normalization
|
||||
|
||||
Returns:
|
||||
Composite score between 0.0 and 1.0
|
||||
"""
|
||||
# Depth score: closer to root = higher score
|
||||
# Score of 1.0 at depth 0, decreasing to 0.0 at max_depth_penalty
|
||||
depth_score = max(
|
||||
0.0,
|
||||
1.0 - (node.min_depth / self.max_depth_penalty),
|
||||
)
|
||||
|
||||
# Frequency score: more occurrences = higher score
|
||||
frequency_score = node.occurrences / max_frequency if max_frequency > 0 else 0.0
|
||||
|
||||
# Kind score: function/method > class > variable
|
||||
kind_str = str(node.kind).lower()
|
||||
kind_score = KIND_WEIGHTS.get(kind_str, KIND_WEIGHTS["unknown"])
|
||||
|
||||
# Composite score
|
||||
score = (
|
||||
self.depth_weight * depth_score
|
||||
+ self.frequency_weight * frequency_score
|
||||
+ self.kind_weight * kind_score
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
def _get_node_key(self, node: TreeNode) -> tuple:
|
||||
"""Get unique key for a tree node.
|
||||
|
||||
Uses (file_path, start_line, end_line) as the unique identifier.
|
||||
|
||||
Args:
|
||||
node: TreeNode
|
||||
|
||||
Returns:
|
||||
Tuple key for deduplication
|
||||
"""
|
||||
return (
|
||||
node.item.file_path,
|
||||
node.item.range.start_line,
|
||||
node.item.range.end_line,
|
||||
)
|
||||
|
||||
def filter_by_kind(
|
||||
self,
|
||||
nodes: List[UniqueNode],
|
||||
kinds: List[str],
|
||||
) -> List[UniqueNode]:
|
||||
"""Filter unique nodes by symbol kind.
|
||||
|
||||
Args:
|
||||
nodes: List of UniqueNode to filter
|
||||
kinds: List of allowed kinds (e.g., ["function", "method"])
|
||||
|
||||
Returns:
|
||||
Filtered list of UniqueNode
|
||||
"""
|
||||
kinds_lower = [k.lower() for k in kinds]
|
||||
return [
|
||||
node
|
||||
for node in nodes
|
||||
if str(node.kind).lower() in kinds_lower
|
||||
]
|
||||
|
||||
def filter_by_file(
|
||||
self,
|
||||
nodes: List[UniqueNode],
|
||||
file_patterns: List[str],
|
||||
) -> List[UniqueNode]:
|
||||
"""Filter unique nodes by file path patterns.
|
||||
|
||||
Args:
|
||||
nodes: List of UniqueNode to filter
|
||||
file_patterns: List of path substrings to match
|
||||
|
||||
Returns:
|
||||
Filtered list of UniqueNode
|
||||
"""
|
||||
return [
|
||||
node
|
||||
for node in nodes
|
||||
if any(pattern in node.file_path for pattern in file_patterns)
|
||||
]
|
||||
|
||||
def to_dict_list(self, nodes: List[UniqueNode]) -> List[Dict]:
|
||||
"""Convert list of UniqueNode to JSON-serializable dicts.
|
||||
|
||||
Args:
|
||||
nodes: List of UniqueNode
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"file_path": node.file_path,
|
||||
"name": node.name,
|
||||
"kind": node.kind,
|
||||
"range": {
|
||||
"start_line": node.range.start_line,
|
||||
"start_character": node.range.start_character,
|
||||
"end_line": node.range.end_line,
|
||||
"end_character": node.range.end_character,
|
||||
},
|
||||
"min_depth": node.min_depth,
|
||||
"occurrences": node.occurrences,
|
||||
"path_count": len(node.paths),
|
||||
"score": round(node.score, 4),
|
||||
}
|
||||
for node in nodes
|
||||
]
|
||||
277
codex-lens/build/lib/codexlens/search/binary_searcher.py
Normal file
277
codex-lens/build/lib/codexlens/search/binary_searcher.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Binary vector searcher for cascade search.
|
||||
|
||||
This module provides fast binary vector search using Hamming distance
|
||||
for the first stage of cascade search (coarse filtering).
|
||||
|
||||
Supports two loading modes:
|
||||
1. Memory-mapped file (preferred): Low memory footprint, OS-managed paging
|
||||
2. Database loading (fallback): Loads all vectors into RAM
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pre-computed popcount lookup table for vectorized Hamming distance
|
||||
# Each byte value (0-255) maps to its bit count
|
||||
_POPCOUNT_TABLE = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
|
||||
|
||||
|
||||
class BinarySearcher:
|
||||
"""Fast binary vector search using Hamming distance.
|
||||
|
||||
This class implements the first stage of cascade search:
|
||||
fast, approximate retrieval using binary vectors and Hamming distance.
|
||||
|
||||
The binary vectors are derived from dense embeddings by thresholding:
|
||||
binary[i] = 1 if dense[i] > 0 else 0
|
||||
|
||||
Hamming distance between two binary vectors counts the number of
|
||||
differing bits, which can be computed very efficiently using XOR
|
||||
and population count.
|
||||
|
||||
Supports two loading modes:
|
||||
- Memory-mapped file (preferred): Uses np.memmap for minimal RAM usage
|
||||
- Database (fallback): Loads all vectors into memory from SQLite
|
||||
"""
|
||||
|
||||
def __init__(self, index_root_or_meta_path: Path) -> None:
|
||||
"""Initialize BinarySearcher.
|
||||
|
||||
Args:
|
||||
index_root_or_meta_path: Either:
|
||||
- Path to index root directory (containing _binary_vectors.mmap)
|
||||
- Path to _vectors_meta.db (legacy mode, loads from DB)
|
||||
"""
|
||||
path = Path(index_root_or_meta_path)
|
||||
|
||||
# Determine if this is an index root or a specific DB path
|
||||
if path.suffix == '.db':
|
||||
# Legacy mode: specific DB path
|
||||
self.index_root = path.parent
|
||||
self.meta_store_path = path
|
||||
else:
|
||||
# New mode: index root directory
|
||||
self.index_root = path
|
||||
self.meta_store_path = path / "_vectors_meta.db"
|
||||
|
||||
self._chunk_ids: Optional[np.ndarray] = None
|
||||
self._binary_matrix: Optional[np.ndarray] = None
|
||||
self._is_memmap = False
|
||||
self._loaded = False
|
||||
|
||||
def load(self) -> bool:
|
||||
"""Load binary vectors using memory-mapped file or database fallback.
|
||||
|
||||
Tries to load from memory-mapped file first (preferred for large indexes),
|
||||
falls back to database loading if mmap file doesn't exist.
|
||||
|
||||
Returns:
|
||||
True if vectors were loaded successfully.
|
||||
"""
|
||||
if self._loaded:
|
||||
return True
|
||||
|
||||
# Try memory-mapped file first (preferred)
|
||||
mmap_path = self.index_root / "_binary_vectors.mmap"
|
||||
meta_path = mmap_path.with_suffix('.meta.json')
|
||||
|
||||
if mmap_path.exists() and meta_path.exists():
|
||||
try:
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
shape = tuple(meta['shape'])
|
||||
self._chunk_ids = np.array(meta['chunk_ids'], dtype=np.int64)
|
||||
|
||||
# Memory-map the binary matrix (read-only)
|
||||
self._binary_matrix = np.memmap(
|
||||
str(mmap_path),
|
||||
dtype=np.uint8,
|
||||
mode='r',
|
||||
shape=shape
|
||||
)
|
||||
self._is_memmap = True
|
||||
self._loaded = True
|
||||
|
||||
logger.info(
|
||||
"Memory-mapped %d binary vectors (%d bytes each)",
|
||||
len(self._chunk_ids), shape[1]
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load mmap binary vectors, falling back to DB: %s", e)
|
||||
|
||||
# Fallback: load from database
|
||||
return self._load_from_db()
|
||||
|
||||
def _load_from_db(self) -> bool:
|
||||
"""Load binary vectors from database (legacy/fallback mode).
|
||||
|
||||
Returns:
|
||||
True if vectors were loaded successfully.
|
||||
"""
|
||||
try:
|
||||
from codexlens.storage.vector_meta_store import VectorMetadataStore
|
||||
|
||||
with VectorMetadataStore(self.meta_store_path) as store:
|
||||
rows = store.get_all_binary_vectors()
|
||||
|
||||
if not rows:
|
||||
logger.warning("No binary vectors found in %s", self.meta_store_path)
|
||||
return False
|
||||
|
||||
# Convert to numpy arrays for fast computation
|
||||
self._chunk_ids = np.array([r[0] for r in rows], dtype=np.int64)
|
||||
|
||||
# Unpack bytes to numpy array
|
||||
binary_arrays = []
|
||||
for _, vec_bytes in rows:
|
||||
arr = np.frombuffer(vec_bytes, dtype=np.uint8)
|
||||
binary_arrays.append(arr)
|
||||
|
||||
self._binary_matrix = np.vstack(binary_arrays)
|
||||
self._is_memmap = False
|
||||
self._loaded = True
|
||||
|
||||
logger.info(
|
||||
"Loaded %d binary vectors from DB (%d bytes each)",
|
||||
len(self._chunk_ids), self._binary_matrix.shape[1]
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load binary vectors: %s", e)
|
||||
return False
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_vector: np.ndarray,
|
||||
top_k: int = 100
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Search for similar vectors using Hamming distance.
|
||||
|
||||
Args:
|
||||
query_vector: Dense query vector (will be binarized).
|
||||
top_k: Number of top results to return.
|
||||
|
||||
Returns:
|
||||
List of (chunk_id, hamming_distance) tuples sorted by distance.
|
||||
"""
|
||||
if not self._loaded and not self.load():
|
||||
return []
|
||||
|
||||
# Binarize query vector
|
||||
query_binary = (query_vector > 0).astype(np.uint8)
|
||||
query_packed = np.packbits(query_binary)
|
||||
|
||||
# Compute Hamming distances using XOR and popcount
|
||||
# XOR gives 1 for differing bits
|
||||
xor_result = np.bitwise_xor(self._binary_matrix, query_packed)
|
||||
|
||||
# Vectorized popcount using lookup table (orders of magnitude faster)
|
||||
# Sum the bit counts for each byte across all columns
|
||||
distances = np.sum(_POPCOUNT_TABLE[xor_result], axis=1, dtype=np.int32)
|
||||
|
||||
# Get top-k with smallest distances
|
||||
if top_k >= len(distances):
|
||||
top_indices = np.argsort(distances)
|
||||
else:
|
||||
# Partial sort for efficiency
|
||||
top_indices = np.argpartition(distances, top_k)[:top_k]
|
||||
top_indices = top_indices[np.argsort(distances[top_indices])]
|
||||
|
||||
results = [
|
||||
(int(self._chunk_ids[i]), int(distances[i]))
|
||||
for i in top_indices
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
def search_with_rerank(
|
||||
self,
|
||||
query_dense: np.ndarray,
|
||||
dense_vectors: np.ndarray,
|
||||
dense_chunk_ids: np.ndarray,
|
||||
top_k: int = 10,
|
||||
candidates: int = 100
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Two-stage cascade search: binary filter + dense rerank.
|
||||
|
||||
Args:
|
||||
query_dense: Dense query vector.
|
||||
dense_vectors: Dense vectors for reranking (from HNSW or stored).
|
||||
dense_chunk_ids: Chunk IDs corresponding to dense_vectors.
|
||||
top_k: Final number of results.
|
||||
candidates: Number of candidates from binary search.
|
||||
|
||||
Returns:
|
||||
List of (chunk_id, cosine_similarity) tuples.
|
||||
"""
|
||||
# Stage 1: Binary filtering
|
||||
binary_results = self.search(query_dense, top_k=candidates)
|
||||
if not binary_results:
|
||||
return []
|
||||
|
||||
candidate_ids = {r[0] for r in binary_results}
|
||||
|
||||
# Stage 2: Dense reranking
|
||||
# Find indices of candidates in dense_vectors
|
||||
candidate_mask = np.isin(dense_chunk_ids, list(candidate_ids))
|
||||
candidate_indices = np.where(candidate_mask)[0]
|
||||
|
||||
if len(candidate_indices) == 0:
|
||||
# Fallback: return binary results with normalized distance
|
||||
max_dist = max(r[1] for r in binary_results) if binary_results else 1
|
||||
return [(r[0], 1.0 - r[1] / max_dist) for r in binary_results[:top_k]]
|
||||
|
||||
# Compute cosine similarities for candidates
|
||||
candidate_vectors = dense_vectors[candidate_indices]
|
||||
candidate_ids_array = dense_chunk_ids[candidate_indices]
|
||||
|
||||
# Normalize vectors
|
||||
query_norm = query_dense / (np.linalg.norm(query_dense) + 1e-8)
|
||||
cand_norms = candidate_vectors / (
|
||||
np.linalg.norm(candidate_vectors, axis=1, keepdims=True) + 1e-8
|
||||
)
|
||||
|
||||
# Cosine similarities
|
||||
similarities = np.dot(cand_norms, query_norm)
|
||||
|
||||
# Sort by similarity (descending)
|
||||
sorted_indices = np.argsort(-similarities)[:top_k]
|
||||
|
||||
results = [
|
||||
(int(candidate_ids_array[i]), float(similarities[i]))
|
||||
for i in sorted_indices
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def vector_count(self) -> int:
|
||||
"""Get number of loaded binary vectors."""
|
||||
return len(self._chunk_ids) if self._chunk_ids is not None else 0
|
||||
|
||||
@property
|
||||
def is_memmap(self) -> bool:
|
||||
"""Check if using memory-mapped file (vs in-memory array)."""
|
||||
return self._is_memmap
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear loaded vectors from memory."""
|
||||
# For memmap, just delete the reference (OS will handle cleanup)
|
||||
if self._is_memmap and self._binary_matrix is not None:
|
||||
del self._binary_matrix
|
||||
self._chunk_ids = None
|
||||
self._binary_matrix = None
|
||||
self._is_memmap = False
|
||||
self._loaded = False
|
||||
3268
codex-lens/build/lib/codexlens/search/chain_search.py
Normal file
3268
codex-lens/build/lib/codexlens/search/chain_search.py
Normal file
File diff suppressed because it is too large
Load Diff
124
codex-lens/build/lib/codexlens/search/clustering/__init__.py
Normal file
124
codex-lens/build/lib/codexlens/search/clustering/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Clustering strategies for the staged hybrid search pipeline.
|
||||
|
||||
This module provides extensible clustering infrastructure for grouping
|
||||
similar search results and selecting representative results.
|
||||
|
||||
Install with: pip install codexlens[clustering]
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import (
|
||||
... CLUSTERING_AVAILABLE,
|
||||
... ClusteringConfig,
|
||||
... get_strategy,
|
||||
... )
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy with fallback
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
>>>
|
||||
>>> # Or explicitly use a specific strategy
|
||||
>>> if CLUSTERING_AVAILABLE:
|
||||
... from codexlens.search.clustering import HDBSCANStrategy
|
||||
... strategy = HDBSCANStrategy(config)
|
||||
... representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Always export base classes and factory (no heavy dependencies)
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .factory import (
|
||||
ClusteringStrategyFactory,
|
||||
check_clustering_strategy_available,
|
||||
get_strategy,
|
||||
)
|
||||
from .noop_strategy import NoOpStrategy
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
|
||||
# Feature flag for clustering availability (hdbscan + sklearn)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
HDBSCAN_AVAILABLE = False
|
||||
DBSCAN_AVAILABLE = False
|
||||
_import_error: str | None = None
|
||||
|
||||
|
||||
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
|
||||
"""Detect if clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
|
||||
"""
|
||||
hdbscan_ok = False
|
||||
dbscan_ok = False
|
||||
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
hdbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
dbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
all_ok = hdbscan_ok and dbscan_ok
|
||||
error = None
|
||||
if not all_ok:
|
||||
missing = []
|
||||
if not hdbscan_ok:
|
||||
missing.append("hdbscan")
|
||||
if not dbscan_ok:
|
||||
missing.append("scikit-learn")
|
||||
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
|
||||
|
||||
return all_ok, hdbscan_ok, dbscan_ok, error
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
|
||||
_detect_clustering_available()
|
||||
)
|
||||
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Check if all clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
return CLUSTERING_AVAILABLE, _import_error
|
||||
|
||||
|
||||
# Conditionally export strategy implementations
|
||||
__all__ = [
|
||||
# Feature flags
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"HDBSCAN_AVAILABLE",
|
||||
"DBSCAN_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
# Base classes
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
# Factory
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
"check_clustering_strategy_available",
|
||||
# Always-available strategies
|
||||
"NoOpStrategy",
|
||||
"FrequencyStrategy",
|
||||
"FrequencyConfig",
|
||||
]
|
||||
|
||||
# Conditionally add strategy classes to __all__ and module namespace
|
||||
if HDBSCAN_AVAILABLE:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
|
||||
__all__.append("HDBSCANStrategy")
|
||||
|
||||
if DBSCAN_AVAILABLE:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
|
||||
__all__.append("DBSCANStrategy")
|
||||
153
codex-lens/build/lib/codexlens/search/clustering/base.py
Normal file
153
codex-lens/build/lib/codexlens/search/clustering/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Base classes for clustering strategies in the hybrid search pipeline.
|
||||
|
||||
This module defines the abstract base class for clustering strategies used
|
||||
in the staged hybrid search pipeline. Strategies cluster search results
|
||||
based on their embeddings and select representative results from each cluster.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusteringConfig:
|
||||
"""Configuration parameters for clustering strategies.
|
||||
|
||||
Attributes:
|
||||
min_cluster_size: Minimum number of results to form a cluster.
|
||||
HDBSCAN default is 5, but for search results 2-3 is often better.
|
||||
min_samples: Number of samples in a neighborhood for a point to be
|
||||
considered a core point. Lower values allow more clusters.
|
||||
metric: Distance metric for clustering. Common options:
|
||||
- 'euclidean': Standard L2 distance
|
||||
- 'cosine': Cosine distance (1 - cosine_similarity)
|
||||
- 'manhattan': L1 distance
|
||||
cluster_selection_epsilon: Distance threshold for cluster selection.
|
||||
Results within this distance may be merged into the same cluster.
|
||||
allow_single_cluster: If True, allow all results to form one cluster.
|
||||
Useful when results are very similar.
|
||||
prediction_data: If True, generate prediction data for new points.
|
||||
"""
|
||||
|
||||
min_cluster_size: int = 3
|
||||
min_samples: int = 2
|
||||
metric: str = "cosine"
|
||||
cluster_selection_epsilon: float = 0.0
|
||||
allow_single_cluster: bool = True
|
||||
prediction_data: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
if self.min_cluster_size < 2:
|
||||
raise ValueError("min_cluster_size must be >= 2")
|
||||
if self.min_samples < 1:
|
||||
raise ValueError("min_samples must be >= 1")
|
||||
if self.metric not in ("euclidean", "cosine", "manhattan"):
|
||||
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
|
||||
if self.cluster_selection_epsilon < 0:
|
||||
raise ValueError("cluster_selection_epsilon must be >= 0")
|
||||
|
||||
|
||||
class BaseClusteringStrategy(ABC):
|
||||
"""Abstract base class for clustering strategies.
|
||||
|
||||
Clustering strategies are used in the staged hybrid search pipeline to
|
||||
group similar search results and select representative results from each
|
||||
cluster, reducing redundancy while maintaining diversity.
|
||||
|
||||
Subclasses must implement:
|
||||
- cluster(): Group results into clusters based on embeddings
|
||||
- select_representatives(): Choose best result(s) from each cluster
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize the clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config = config or ClusteringConfig()
|
||||
|
||||
@abstractmethod
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results based on their embeddings.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
Used for additional metadata during clustering.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Results not assigned to any cluster
|
||||
(noise points) should be returned as single-element clusters.
|
||||
|
||||
Example:
|
||||
>>> strategy = HDBSCANStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
|
||||
>>> # Result indices 0, 2, 5 are in cluster 0
|
||||
>>> # Result indices 1, 3 are in cluster 1
|
||||
>>> # Result index 4 is a noise point (singleton cluster)
|
||||
>>> # Result indices 6, 7, 8 are in cluster 2
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
This method chooses the best result(s) from each cluster to include
|
||||
in the final search results. The selection can be based on:
|
||||
- Highest score within cluster
|
||||
- Closest to cluster centroid
|
||||
- Custom selection logic
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings array for centroid-based selection.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one or more per cluster,
|
||||
ordered by relevance (highest score first).
|
||||
|
||||
Example:
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns best result from each cluster
|
||||
"""
|
||||
...
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
@@ -0,0 +1,197 @@
|
||||
"""DBSCAN-based clustering strategy for search results.
|
||||
|
||||
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the fallback clustering strategy when HDBSCAN is not available.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class DBSCANStrategy(BaseClusteringStrategy):
|
||||
"""DBSCAN-based clustering strategy.
|
||||
|
||||
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
|
||||
DBSCAN requires an explicit eps parameter, which is auto-computed from the
|
||||
distance distribution if not provided.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = DBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
# Default eps percentile for auto-computation
|
||||
DEFAULT_EPS_PERCENTILE: float = 15.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
eps: Optional[float] = None,
|
||||
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
|
||||
) -> None:
|
||||
"""Initialize DBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
|
||||
from the distance distribution.
|
||||
eps_percentile: Percentile of pairwise distances to use for
|
||||
auto-computing eps. Default is 15th percentile.
|
||||
|
||||
Raises:
|
||||
ImportError: If sklearn is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.eps = eps
|
||||
self.eps_percentile = eps_percentile
|
||||
|
||||
# Validate sklearn is available
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"scikit-learn package is required for DBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def _compute_eps(self, embeddings: "np.ndarray") -> float:
|
||||
"""Auto-compute eps from pairwise distance distribution.
|
||||
|
||||
Uses the specified percentile of pairwise distances as eps,
|
||||
which typically captures local density well.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
|
||||
Returns:
|
||||
Computed eps value.
|
||||
"""
|
||||
import numpy as np
|
||||
from sklearn.metrics import pairwise_distances
|
||||
|
||||
# Compute pairwise distances
|
||||
distances = pairwise_distances(embeddings, metric=self.config.metric)
|
||||
|
||||
# Get upper triangle (excluding diagonal)
|
||||
upper_tri = distances[np.triu_indices_from(distances, k=1)]
|
||||
|
||||
if len(upper_tri) == 0:
|
||||
# Only one point, return a default small eps
|
||||
return 0.1
|
||||
|
||||
# Use percentile of distances as eps
|
||||
eps = float(np.percentile(upper_tri, self.eps_percentile))
|
||||
|
||||
# Ensure eps is positive
|
||||
return max(eps, 1e-6)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using DBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
from sklearn.cluster import DBSCAN
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: single result
|
||||
if n_results == 1:
|
||||
return [[0]]
|
||||
|
||||
# Determine eps value
|
||||
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
|
||||
|
||||
# Configure DBSCAN clusterer
|
||||
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
|
||||
clusterer = DBSCAN(
|
||||
eps=eps,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
202
codex-lens/build/lib/codexlens/search/clustering/factory.py
Normal file
202
codex-lens/build/lib/codexlens/search/clustering/factory.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Factory for creating clustering strategies.
|
||||
|
||||
Provides a unified interface for instantiating different clustering backends
|
||||
with automatic fallback chain: hdbscan -> dbscan -> noop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .noop_strategy import NoOpStrategy
|
||||
|
||||
|
||||
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific clustering strategy can be used.
|
||||
|
||||
Args:
|
||||
strategy: Strategy name to check. Options:
|
||||
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
|
||||
- "dbscan": DBSCAN clustering (requires sklearn)
|
||||
- "frequency": Frequency-based clustering (always available)
|
||||
- "noop": No-op strategy (always available)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
if strategy == "hdbscan":
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"hdbscan package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "dbscan":
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"scikit-learn package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "frequency":
|
||||
# Frequency strategy is always available (no external deps)
|
||||
return True, None
|
||||
|
||||
if strategy == "noop":
|
||||
return True, None
|
||||
|
||||
return False, (
|
||||
f"Invalid clustering strategy: {strategy}. "
|
||||
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
|
||||
)
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy: str = "hdbscan",
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
*,
|
||||
fallback: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Factory function to create clustering strategy with fallback chain.
|
||||
|
||||
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
|
||||
|
||||
Args:
|
||||
strategy: Clustering strategy to use. Options:
|
||||
- "hdbscan": HDBSCAN clustering (default, recommended)
|
||||
- "dbscan": DBSCAN clustering (fallback)
|
||||
- "frequency": Frequency-based clustering (groups by symbol occurrence)
|
||||
- "noop": No-op strategy (returns all results ungrouped)
|
||||
- "auto": Try hdbscan, then dbscan, then noop
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
For frequency strategy, pass FrequencyConfig for full control.
|
||||
fallback: If True (default), automatically fall back to next strategy
|
||||
in the chain when primary is unavailable. If False, raise ImportError
|
||||
when requested strategy is unavailable.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
For DBSCANStrategy: eps, eps_percentile
|
||||
For FrequencyStrategy: group_by, min_frequency, etc.
|
||||
|
||||
Returns:
|
||||
BaseClusteringStrategy: Configured clustering strategy instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If strategy is not recognized.
|
||||
ImportError: If required dependencies are not installed and fallback=False.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
|
||||
>>> strategy = get_strategy("hdbscan", config)
|
||||
>>> # Use frequency-based strategy
|
||||
>>> from codexlens.search.clustering import FrequencyConfig
|
||||
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = get_strategy("frequency", freq_config)
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
# Handle "auto" - try strategies in order
|
||||
if strategy == "auto":
|
||||
return _get_best_available_strategy(config, **kwargs)
|
||||
|
||||
if strategy == "hdbscan":
|
||||
ok, err = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
if fallback:
|
||||
# Try dbscan fallback
|
||||
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok_dbscan:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
# Final fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "dbscan":
|
||||
ok, err = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
if fallback:
|
||||
# Fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "frequency":
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
|
||||
if config is None or not isinstance(config, FrequencyConfig):
|
||||
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
|
||||
else:
|
||||
freq_config = config
|
||||
return FrequencyStrategy(freq_config)
|
||||
|
||||
if strategy == "noop":
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown clustering strategy: {strategy}. "
|
||||
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
|
||||
)
|
||||
|
||||
|
||||
def _get_best_available_strategy(
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Get the best available clustering strategy.
|
||||
|
||||
Tries strategies in order: hdbscan -> dbscan -> noop
|
||||
|
||||
Args:
|
||||
config: Clustering configuration.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
|
||||
Returns:
|
||||
Best available clustering strategy instance.
|
||||
"""
|
||||
# Try HDBSCAN first
|
||||
ok, _ = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
# Try DBSCAN second
|
||||
ok, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
# Fallback to NoOp
|
||||
return NoOpStrategy(config)
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
ClusteringStrategyFactory = type(
|
||||
"ClusteringStrategyFactory",
|
||||
(),
|
||||
{
|
||||
"get_strategy": staticmethod(get_strategy),
|
||||
"check_available": staticmethod(check_clustering_strategy_available),
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,263 @@
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol/method name and prunes based on
|
||||
occurrence frequency. High-frequency symbols (frequently referenced methods)
|
||||
are considered more important and retained, while low-frequency results
|
||||
(potentially noise) can be filtered out.
|
||||
|
||||
Use cases:
|
||||
- Prioritize commonly called methods/functions
|
||||
- Filter out one-off results that may be less relevant
|
||||
- Deduplicate results pointing to the same symbol from different locations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrequencyConfig(ClusteringConfig):
|
||||
"""Configuration for frequency-based clustering strategy.
|
||||
|
||||
Attributes:
|
||||
group_by: Field to group results by for frequency counting.
|
||||
- 'symbol': Group by symbol_name (default, for method/function dedup)
|
||||
- 'file': Group by file path
|
||||
- 'symbol_kind': Group by symbol type (function, class, etc.)
|
||||
min_frequency: Minimum occurrence count to keep a result.
|
||||
Results appearing less than this are considered noise and pruned.
|
||||
max_representatives_per_group: Maximum results to keep per symbol group.
|
||||
frequency_weight: How much to boost score based on frequency.
|
||||
Final score = original_score * (1 + frequency_weight * log(frequency))
|
||||
keep_mode: How to handle low-frequency results.
|
||||
- 'filter': Remove results below min_frequency
|
||||
- 'demote': Keep but lower their score ranking
|
||||
"""
|
||||
|
||||
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
|
||||
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
|
||||
max_representatives_per_group: int = 3
|
||||
frequency_weight: float = 0.1 # Boost factor for frequency
|
||||
keep_mode: Literal["filter", "demote"] = "demote"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
# Skip parent validation since we don't use HDBSCAN params
|
||||
if self.min_frequency < 1:
|
||||
raise ValueError("min_frequency must be >= 1")
|
||||
if self.max_representatives_per_group < 1:
|
||||
raise ValueError("max_representatives_per_group must be >= 1")
|
||||
if self.frequency_weight < 0:
|
||||
raise ValueError("frequency_weight must be >= 0")
|
||||
if self.group_by not in ("symbol", "file", "symbol_kind"):
|
||||
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
|
||||
if self.keep_mode not in ("filter", "demote"):
|
||||
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
|
||||
|
||||
|
||||
class FrequencyStrategy(BaseClusteringStrategy):
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol name (or file/kind) and:
|
||||
1. Counts how many times each symbol appears in results
|
||||
2. Higher frequency = more important (frequently referenced method)
|
||||
3. Filters or demotes low-frequency results
|
||||
4. Selects top representatives from each frequency group
|
||||
|
||||
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
|
||||
- Does NOT require embeddings (works with metadata only)
|
||||
- Is very fast (O(n) complexity)
|
||||
- Is deterministic (no random initialization)
|
||||
- Works well for symbol-level deduplication
|
||||
|
||||
Example:
|
||||
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = FrequencyStrategy(config)
|
||||
>>> # Results with symbol "authenticate" appearing 5 times
|
||||
>>> # will be prioritized over "helper_func" appearing once
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
|
||||
"""Initialize the frequency strategy.
|
||||
|
||||
Args:
|
||||
config: Frequency configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config: FrequencyConfig = config or FrequencyConfig()
|
||||
|
||||
def _get_group_key(self, result: "SearchResult") -> str:
|
||||
"""Extract grouping key from a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult to extract key from.
|
||||
|
||||
Returns:
|
||||
String key for grouping (symbol name, file path, or kind).
|
||||
"""
|
||||
if self.config.group_by == "symbol":
|
||||
# Use symbol_name if available, otherwise fall back to file:line
|
||||
symbol = getattr(result, "symbol_name", None)
|
||||
if symbol:
|
||||
return str(symbol)
|
||||
# Fallback: use file path + start_line as pseudo-symbol
|
||||
start_line = getattr(result, "start_line", 0) or 0
|
||||
return f"{result.path}:{start_line}"
|
||||
|
||||
elif self.config.group_by == "file":
|
||||
return str(result.path)
|
||||
|
||||
elif self.config.group_by == "symbol_kind":
|
||||
kind = getattr(result, "symbol_kind", None)
|
||||
return str(kind) if kind else "unknown"
|
||||
|
||||
return str(result.path) # Default fallback
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Group search results by frequency of occurrence.
|
||||
|
||||
Note: This method ignores embeddings and groups by metadata only.
|
||||
The embeddings parameter is kept for interface compatibility.
|
||||
|
||||
Args:
|
||||
embeddings: Ignored (kept for interface compatibility).
|
||||
results: List of SearchResult objects to cluster.
|
||||
|
||||
Returns:
|
||||
List of clusters (groups), where each cluster contains indices
|
||||
of results with the same grouping key. Clusters are ordered by
|
||||
frequency (highest frequency first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Group results by key
|
||||
groups: Dict[str, List[int]] = defaultdict(list)
|
||||
for idx, result in enumerate(results):
|
||||
key = self._get_group_key(result)
|
||||
groups[key].append(idx)
|
||||
|
||||
# Sort groups by frequency (descending) then by key (for stability)
|
||||
sorted_groups = sorted(
|
||||
groups.items(),
|
||||
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
|
||||
)
|
||||
|
||||
# Convert to list of clusters
|
||||
clusters = [indices for _, indices in sorted_groups]
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results based on frequency and score.
|
||||
|
||||
For each frequency group:
|
||||
1. If frequency < min_frequency: filter or demote based on keep_mode
|
||||
2. Sort by score within group
|
||||
3. Apply frequency boost to scores
|
||||
4. Select top N representatives
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (used for tie-breaking if provided).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, ordered by
|
||||
frequency-adjusted score (highest first).
|
||||
"""
|
||||
import math
|
||||
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
demoted: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
frequency = len(cluster_indices)
|
||||
|
||||
# Get results in this cluster, sorted by score
|
||||
cluster_results = [results[i] for i in cluster_indices]
|
||||
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
|
||||
# Check frequency threshold
|
||||
if frequency < self.config.min_frequency:
|
||||
if self.config.keep_mode == "filter":
|
||||
# Skip low-frequency results entirely
|
||||
continue
|
||||
else: # demote mode
|
||||
# Keep but add to demoted list (lower priority)
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
demoted.append(result)
|
||||
continue
|
||||
|
||||
# Apply frequency boost and select top representatives
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
# Calculate frequency-boosted score
|
||||
original_score = getattr(result, "score", 0.0)
|
||||
# log(frequency + 1) to handle frequency=1 case smoothly
|
||||
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
|
||||
boosted_score = original_score * frequency_boost
|
||||
|
||||
# Create new result with boosted score and frequency metadata
|
||||
# Note: SearchResult might be immutable, so we preserve original
|
||||
# and track boosted score in metadata
|
||||
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
|
||||
result.metadata["frequency"] = frequency
|
||||
result.metadata["frequency_boosted_score"] = boosted_score
|
||||
|
||||
representatives.append(result)
|
||||
|
||||
# Sort representatives by boosted score (or original score as fallback)
|
||||
def get_sort_score(r: "SearchResult") -> float:
|
||||
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
|
||||
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
|
||||
return getattr(r, "score", 0.0)
|
||||
|
||||
representatives.sort(key=get_sort_score, reverse=True)
|
||||
|
||||
# Add demoted results at the end
|
||||
if demoted:
|
||||
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
representatives.extend(demoted)
|
||||
|
||||
return representatives
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array (may be ignored for frequency-based clustering).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
@@ -0,0 +1,153 @@
|
||||
"""HDBSCAN-based clustering strategy for search results.
|
||||
|
||||
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the primary clustering strategy for grouping similar search results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class HDBSCANStrategy(BaseClusteringStrategy):
|
||||
"""HDBSCAN-based clustering strategy.
|
||||
|
||||
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
|
||||
HDBSCAN is preferred over DBSCAN because it:
|
||||
- Automatically determines the number of clusters
|
||||
- Handles varying density clusters well
|
||||
- Identifies noise points (outliers) effectively
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = HDBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize HDBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
|
||||
Raises:
|
||||
ImportError: If hdbscan package is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
# Validate hdbscan is available
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"hdbscan package is required for HDBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using HDBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
import hdbscan
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: fewer results than min_cluster_size
|
||||
if n_results < self.config.min_cluster_size:
|
||||
# Return each result as its own singleton cluster
|
||||
return [[i] for i in range(n_results)]
|
||||
|
||||
# Configure HDBSCAN clusterer
|
||||
clusterer = hdbscan.HDBSCAN(
|
||||
min_cluster_size=self.config.min_cluster_size,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
|
||||
allow_single_cluster=self.config.allow_single_cluster,
|
||||
prediction_data=self.config.prediction_data,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
@@ -0,0 +1,83 @@
|
||||
"""No-op clustering strategy for search results.
|
||||
|
||||
NoOpStrategy returns all results ungrouped when clustering dependencies
|
||||
are not available or clustering is disabled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class NoOpStrategy(BaseClusteringStrategy):
|
||||
"""No-op clustering strategy that returns all results ungrouped.
|
||||
|
||||
This strategy is used as a final fallback when no clustering dependencies
|
||||
are available, or when clustering is explicitly disabled. Each result
|
||||
is treated as its own singleton cluster.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import NoOpStrategy
|
||||
>>> strategy = NoOpStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns all results sorted by score
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize NoOp clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Ignored for NoOpStrategy
|
||||
but accepted for interface compatibility.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Return each result as its own singleton cluster.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
Not used but accepted for interface compatibility.
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of singleton clusters, one per result.
|
||||
"""
|
||||
return [[i] for i in range(len(results))]
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Return all results sorted by score.
|
||||
|
||||
Since each cluster is a singleton, this effectively returns all
|
||||
results sorted by score descending.
|
||||
|
||||
Args:
|
||||
clusters: List of singleton clusters.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used).
|
||||
|
||||
Returns:
|
||||
All SearchResult objects sorted by score (highest first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Return all results sorted by score
|
||||
return sorted(results, key=lambda r: r.score, reverse=True)
|
||||
171
codex-lens/build/lib/codexlens/search/enrichment.py
Normal file
171
codex-lens/build/lib/codexlens/search/enrichment.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# codex-lens/src/codexlens/search/enrichment.py
|
||||
"""Relationship enrichment for search results."""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.graph_expander import GraphExpander
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
|
||||
class RelationshipEnricher:
|
||||
"""Enriches search results with code graph relationships."""
|
||||
|
||||
def __init__(self, index_path: Path):
|
||||
"""Initialize with path to index database.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db SQLite database
|
||||
"""
|
||||
self.index_path = index_path
|
||||
self.db_conn: Optional[sqlite3.Connection] = None
|
||||
self._connect()
|
||||
|
||||
def _connect(self) -> None:
|
||||
"""Establish read-only database connection."""
|
||||
if self.index_path.exists():
|
||||
self.db_conn = sqlite3.connect(
|
||||
f"file:{self.index_path}?mode=ro",
|
||||
uri=True,
|
||||
check_same_thread=False
|
||||
)
|
||||
self.db_conn.row_factory = sqlite3.Row
|
||||
|
||||
def enrich(self, results: List[Dict[str, Any]], limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Add relationship data to search results.
|
||||
|
||||
Args:
|
||||
results: List of search result dictionaries
|
||||
limit: Maximum number of results to enrich
|
||||
|
||||
Returns:
|
||||
Results with relationships field added
|
||||
"""
|
||||
if not self.db_conn:
|
||||
return results
|
||||
|
||||
for result in results[:limit]:
|
||||
file_path = result.get('file') or result.get('path')
|
||||
symbol_name = result.get('symbol')
|
||||
result['relationships'] = self._find_relationships(file_path, symbol_name)
|
||||
return results
|
||||
|
||||
def _find_relationships(self, file_path: Optional[str], symbol_name: Optional[str]) -> List[Dict[str, Any]]:
|
||||
"""Query relationships for a symbol.
|
||||
|
||||
Args:
|
||||
file_path: Path to file containing the symbol
|
||||
symbol_name: Name of the symbol
|
||||
|
||||
Returns:
|
||||
List of relationship dictionaries with type, direction, target/source, file, line
|
||||
"""
|
||||
if not self.db_conn or not symbol_name:
|
||||
return []
|
||||
|
||||
relationships = []
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
try:
|
||||
# Find symbol ID(s) by name and optionally file
|
||||
if file_path:
|
||||
cursor.execute(
|
||||
'SELECT id FROM symbols WHERE name = ? AND file_path = ?',
|
||||
(symbol_name, file_path)
|
||||
)
|
||||
else:
|
||||
cursor.execute('SELECT id FROM symbols WHERE name = ?', (symbol_name,))
|
||||
|
||||
symbol_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not symbol_ids:
|
||||
return []
|
||||
|
||||
# Query outgoing relationships (symbol is source)
|
||||
placeholders = ','.join('?' * len(symbol_ids))
|
||||
cursor.execute(f'''
|
||||
SELECT sr.relationship_type, sr.target_symbol_fqn, sr.file_path, sr.line
|
||||
FROM symbol_relationships sr
|
||||
WHERE sr.source_symbol_id IN ({placeholders})
|
||||
''', symbol_ids)
|
||||
|
||||
for row in cursor.fetchall():
|
||||
relationships.append({
|
||||
'type': row[0],
|
||||
'direction': 'outgoing',
|
||||
'target': row[1],
|
||||
'file': row[2],
|
||||
'line': row[3],
|
||||
})
|
||||
|
||||
# Query incoming relationships (symbol is target)
|
||||
# Match against symbol name or qualified name patterns
|
||||
cursor.execute('''
|
||||
SELECT sr.relationship_type, s.name AS source_name, sr.file_path, sr.line
|
||||
FROM symbol_relationships sr
|
||||
JOIN symbols s ON sr.source_symbol_id = s.id
|
||||
WHERE sr.target_symbol_fqn = ? OR sr.target_symbol_fqn LIKE ?
|
||||
''', (symbol_name, f'%.{symbol_name}'))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
rel_type = row[0]
|
||||
# Convert to incoming type
|
||||
incoming_type = self._to_incoming_type(rel_type)
|
||||
relationships.append({
|
||||
'type': incoming_type,
|
||||
'direction': 'incoming',
|
||||
'source': row[1],
|
||||
'file': row[2],
|
||||
'line': row[3],
|
||||
})
|
||||
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
return relationships
|
||||
|
||||
def _to_incoming_type(self, outgoing_type: str) -> str:
|
||||
"""Convert outgoing relationship type to incoming type.
|
||||
|
||||
Args:
|
||||
outgoing_type: The outgoing relationship type (e.g., 'calls', 'imports')
|
||||
|
||||
Returns:
|
||||
Corresponding incoming type (e.g., 'called_by', 'imported_by')
|
||||
"""
|
||||
type_map = {
|
||||
'calls': 'called_by',
|
||||
'imports': 'imported_by',
|
||||
'extends': 'extended_by',
|
||||
}
|
||||
return type_map.get(outgoing_type, f'{outgoing_type}_by')
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self.db_conn:
|
||||
self.db_conn.close()
|
||||
self.db_conn = None
|
||||
|
||||
def __enter__(self) -> 'RelationshipEnricher':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class SearchEnrichmentPipeline:
|
||||
"""Search post-processing pipeline (optional enrichments)."""
|
||||
|
||||
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||
self._config = config
|
||||
self._graph_expander = GraphExpander(mapper, config=config)
|
||||
|
||||
def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]:
|
||||
"""Expand base results with related symbols when enabled in config."""
|
||||
if self._config is None or not getattr(self._config, "enable_graph_expansion", False):
|
||||
return []
|
||||
|
||||
depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2)
|
||||
return self._graph_expander.expand(results, depth=depth)
|
||||
264
codex-lens/build/lib/codexlens/search/graph_expander.py
Normal file
264
codex-lens/build/lib/codexlens/search/graph_expander.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Graph expansion for search results using precomputed neighbors.
|
||||
|
||||
Expands top search results with related symbol definitions by traversing
|
||||
precomputed N-hop neighbors stored in the per-directory index databases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]:
|
||||
return (result.path, result.symbol_name, result.start_line, result.end_line)
|
||||
|
||||
|
||||
def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]:
|
||||
if content is None:
|
||||
return None
|
||||
if start_line is None or end_line is None:
|
||||
return None
|
||||
if start_line < 1 or end_line < start_line:
|
||||
return None
|
||||
|
||||
lines = content.splitlines()
|
||||
start_idx = max(0, start_line - 1)
|
||||
end_idx = min(len(lines), end_line)
|
||||
if start_idx >= len(lines):
|
||||
return None
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
|
||||
class GraphExpander:
|
||||
"""Expands SearchResult lists with related symbols from the code graph."""
|
||||
|
||||
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||
self._mapper = mapper
|
||||
self._config = config
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
def expand(
|
||||
self,
|
||||
results: Sequence[SearchResult],
|
||||
*,
|
||||
depth: Optional[int] = None,
|
||||
max_expand: int = 10,
|
||||
max_related: int = 50,
|
||||
) -> List[SearchResult]:
|
||||
"""Expand top results with related symbols.
|
||||
|
||||
Args:
|
||||
results: Base ranked results.
|
||||
depth: Maximum relationship depth to include (defaults to Config or 2).
|
||||
max_expand: Only expand the top-N base results to bound cost.
|
||||
max_related: Maximum related results to return.
|
||||
|
||||
Returns:
|
||||
A list of related SearchResult objects with relationship_depth metadata.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2
|
||||
max_depth = int(depth if depth is not None else configured_depth)
|
||||
if max_depth <= 0:
|
||||
return []
|
||||
max_depth = min(max_depth, 2)
|
||||
|
||||
expand_count = max(0, int(max_expand))
|
||||
related_limit = max(0, int(max_related))
|
||||
if expand_count == 0 or related_limit == 0:
|
||||
return []
|
||||
|
||||
seen = {_result_key(r) for r in results}
|
||||
related_results: List[SearchResult] = []
|
||||
conn_cache: Dict[Path, sqlite3.Connection] = {}
|
||||
|
||||
try:
|
||||
for base in list(results)[:expand_count]:
|
||||
if len(related_results) >= related_limit:
|
||||
break
|
||||
|
||||
if not base.symbol_name or not base.path:
|
||||
continue
|
||||
|
||||
index_path = self._mapper.source_to_index_db(Path(base.path).parent)
|
||||
conn = conn_cache.get(index_path)
|
||||
if conn is None:
|
||||
conn = self._connect_readonly(index_path)
|
||||
if conn is None:
|
||||
continue
|
||||
conn_cache[index_path] = conn
|
||||
|
||||
source_ids = self._resolve_source_symbol_ids(
|
||||
conn,
|
||||
file_path=base.path,
|
||||
symbol_name=base.symbol_name,
|
||||
symbol_kind=base.symbol_kind,
|
||||
)
|
||||
if not source_ids:
|
||||
continue
|
||||
|
||||
for source_id in source_ids:
|
||||
neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit)
|
||||
for neighbor_id, rel_depth in neighbors:
|
||||
if len(related_results) >= related_limit:
|
||||
break
|
||||
row = self._get_symbol_details(conn, neighbor_id)
|
||||
if row is None:
|
||||
continue
|
||||
|
||||
path = str(row["full_path"])
|
||||
symbol_name = str(row["name"])
|
||||
symbol_kind = str(row["kind"])
|
||||
start_line = int(row["start_line"]) if row["start_line"] is not None else None
|
||||
end_line = int(row["end_line"]) if row["end_line"] is not None else None
|
||||
content_block = _slice_content_block(
|
||||
str(row["content"]) if row["content"] is not None else "",
|
||||
start_line,
|
||||
end_line,
|
||||
)
|
||||
|
||||
score = float(base.score) * (0.5 ** int(rel_depth))
|
||||
candidate = SearchResult(
|
||||
path=path,
|
||||
score=max(0.0, score),
|
||||
excerpt=None,
|
||||
content=content_block,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
metadata={"relationship_depth": int(rel_depth)},
|
||||
)
|
||||
|
||||
key = _result_key(candidate)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
related_results.append(candidate)
|
||||
|
||||
finally:
|
||||
for conn in conn_cache.values():
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return related_results
|
||||
|
||||
def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]:
|
||||
try:
|
||||
if not index_path.exists() or index_path.stat().st_size == 0:
|
||||
return None
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
except Exception as exc:
|
||||
self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc)
|
||||
return None
|
||||
|
||||
def _resolve_source_symbol_ids(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
*,
|
||||
file_path: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str],
|
||||
) -> List[int]:
|
||||
try:
|
||||
if symbol_kind:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.id
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE f.full_path = ? AND s.name = ? AND s.kind = ?
|
||||
""",
|
||||
(file_path, symbol_name, symbol_kind),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.id
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE f.full_path = ? AND s.name = ?
|
||||
""",
|
||||
(file_path, symbol_name),
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
ids: List[int] = []
|
||||
for row in rows:
|
||||
try:
|
||||
ids.append(int(row["id"]))
|
||||
except Exception:
|
||||
continue
|
||||
return ids
|
||||
|
||||
def _get_neighbors(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
source_symbol_id: int,
|
||||
*,
|
||||
max_depth: int,
|
||||
limit: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT neighbor_symbol_id, relationship_depth
|
||||
FROM graph_neighbors
|
||||
WHERE source_symbol_id = ? AND relationship_depth <= ?
|
||||
ORDER BY relationship_depth ASC, neighbor_symbol_id ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(int(source_symbol_id), int(max_depth), int(limit)),
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
neighbors: List[Tuple[int, int]] = []
|
||||
for row in rows:
|
||||
try:
|
||||
neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"])))
|
||||
except Exception:
|
||||
continue
|
||||
return neighbors
|
||||
|
||||
def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]:
|
||||
try:
|
||||
return conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.id,
|
||||
s.name,
|
||||
s.kind,
|
||||
s.start_line,
|
||||
s.end_line,
|
||||
f.full_path,
|
||||
f.content
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE s.id = ?
|
||||
""",
|
||||
(int(symbol_id),),
|
||||
).fetchone()
|
||||
except sqlite3.Error:
|
||||
return None
|
||||
|
||||
1409
codex-lens/build/lib/codexlens/search/hybrid_search.py
Normal file
1409
codex-lens/build/lib/codexlens/search/hybrid_search.py
Normal file
File diff suppressed because it is too large
Load Diff
242
codex-lens/build/lib/codexlens/search/query_parser.py
Normal file
242
codex-lens/build/lib/codexlens/search/query_parser.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Query preprocessing for CodexLens search.
|
||||
|
||||
Provides query expansion for better identifier matching:
|
||||
- CamelCase splitting: UserAuth → User OR Auth
|
||||
- snake_case splitting: user_auth → user OR auth
|
||||
- Preserves original query for exact matching
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Set, List
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryParser:
|
||||
"""Parser for preprocessing search queries before FTS5 execution.
|
||||
|
||||
Expands identifier-style queries (CamelCase, snake_case) into OR queries
|
||||
to improve recall when searching for code symbols.
|
||||
|
||||
Example transformations:
|
||||
- 'UserAuth' → 'UserAuth OR User OR Auth'
|
||||
- 'user_auth' → 'user_auth OR user OR auth'
|
||||
- 'getUserData' → 'getUserData OR get OR User OR Data'
|
||||
"""
|
||||
|
||||
# Patterns for identifier splitting
|
||||
CAMEL_CASE_PATTERN = re.compile(r'([a-z])([A-Z])')
|
||||
SNAKE_CASE_PATTERN = re.compile(r'_+')
|
||||
KEBAB_CASE_PATTERN = re.compile(r'-+')
|
||||
|
||||
# Minimum token length to include in expansion (avoid noise from single chars)
|
||||
MIN_TOKEN_LENGTH = 2
|
||||
|
||||
# All-caps acronyms pattern (e.g., HTTP, SQL, API)
|
||||
ALL_CAPS_PATTERN = re.compile(r'^[A-Z]{2,}$')
|
||||
|
||||
def __init__(self, enable: bool = True, min_token_length: int = 2):
|
||||
"""Initialize query parser.
|
||||
|
||||
Args:
|
||||
enable: Whether to enable query preprocessing
|
||||
min_token_length: Minimum token length to include in expansion
|
||||
"""
|
||||
self.enable = enable
|
||||
self.min_token_length = min_token_length
|
||||
|
||||
def preprocess_query(self, query: str) -> str:
|
||||
"""Preprocess query with identifier expansion.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
|
||||
Returns:
|
||||
Expanded query with OR operator connecting original and split tokens
|
||||
|
||||
Example:
|
||||
>>> parser = QueryParser()
|
||||
>>> parser.preprocess_query('UserAuth')
|
||||
'UserAuth OR User OR Auth'
|
||||
>>> parser.preprocess_query('get_user_data')
|
||||
'get_user_data OR get OR user OR data'
|
||||
"""
|
||||
if not self.enable:
|
||||
return query
|
||||
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return query
|
||||
|
||||
# Extract tokens from query (handle multiple words/terms)
|
||||
# For simple queries, just process the whole thing
|
||||
# For complex FTS5 queries with operators, preserve structure
|
||||
if self._is_simple_query(query):
|
||||
return self._expand_simple_query(query)
|
||||
else:
|
||||
# Complex query with FTS5 operators, don't expand
|
||||
log.debug(f"Skipping expansion for complex FTS5 query: {query}")
|
||||
return query
|
||||
|
||||
def _is_simple_query(self, query: str) -> bool:
|
||||
"""Check if query is simple (no FTS5 operators).
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
True if query is simple (safe to expand), False otherwise
|
||||
"""
|
||||
# Check for FTS5 operators that indicate complex query
|
||||
fts5_operators = ['OR', 'AND', 'NOT', 'NEAR', '*', '^', '"']
|
||||
return not any(op in query for op in fts5_operators)
|
||||
|
||||
def _expand_simple_query(self, query: str) -> str:
|
||||
"""Expand a simple query with identifier splitting.
|
||||
|
||||
Args:
|
||||
query: Simple search query
|
||||
|
||||
Returns:
|
||||
Expanded query with OR operators
|
||||
"""
|
||||
tokens: Set[str] = set()
|
||||
|
||||
# Always include original query
|
||||
tokens.add(query)
|
||||
|
||||
# Split on whitespace first
|
||||
words = query.split()
|
||||
|
||||
for word in words:
|
||||
# Extract tokens from this word
|
||||
word_tokens = self._extract_tokens(word)
|
||||
tokens.update(word_tokens)
|
||||
|
||||
# Filter out short tokens and duplicates
|
||||
filtered_tokens = [
|
||||
t for t in tokens
|
||||
if len(t) >= self.min_token_length
|
||||
]
|
||||
|
||||
# Remove duplicates while preserving original query first
|
||||
unique_tokens: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
|
||||
# Always put original query first
|
||||
if query not in seen and len(query) >= self.min_token_length:
|
||||
unique_tokens.append(query)
|
||||
seen.add(query)
|
||||
|
||||
# Add other tokens
|
||||
for token in filtered_tokens:
|
||||
if token not in seen:
|
||||
unique_tokens.append(token)
|
||||
seen.add(token)
|
||||
|
||||
# Join with OR operator (only if we have multiple tokens)
|
||||
if len(unique_tokens) > 1:
|
||||
expanded = ' OR '.join(unique_tokens)
|
||||
log.debug(f"Expanded query: '{query}' → '{expanded}'")
|
||||
return expanded
|
||||
else:
|
||||
return query
|
||||
|
||||
def _extract_tokens(self, word: str) -> Set[str]:
|
||||
"""Extract tokens from a single word using various splitting strategies.
|
||||
|
||||
Args:
|
||||
word: Single word/identifier to split
|
||||
|
||||
Returns:
|
||||
Set of extracted tokens
|
||||
"""
|
||||
tokens: Set[str] = set()
|
||||
|
||||
# Add original word
|
||||
tokens.add(word)
|
||||
|
||||
# Handle all-caps acronyms (don't split)
|
||||
if self.ALL_CAPS_PATTERN.match(word):
|
||||
return tokens
|
||||
|
||||
# CamelCase splitting
|
||||
camel_tokens = self._split_camel_case(word)
|
||||
tokens.update(camel_tokens)
|
||||
|
||||
# snake_case splitting
|
||||
snake_tokens = self._split_snake_case(word)
|
||||
tokens.update(snake_tokens)
|
||||
|
||||
# kebab-case splitting
|
||||
kebab_tokens = self._split_kebab_case(word)
|
||||
tokens.update(kebab_tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _split_camel_case(self, word: str) -> List[str]:
|
||||
"""Split CamelCase identifier into tokens.
|
||||
|
||||
Args:
|
||||
word: CamelCase identifier (e.g., 'getUserData')
|
||||
|
||||
Returns:
|
||||
List of tokens (e.g., ['get', 'User', 'Data'])
|
||||
"""
|
||||
# Insert space before uppercase letters preceded by lowercase
|
||||
spaced = self.CAMEL_CASE_PATTERN.sub(r'\1 \2', word)
|
||||
# Split on spaces and filter empty
|
||||
return [t for t in spaced.split() if t]
|
||||
|
||||
def _split_snake_case(self, word: str) -> List[str]:
|
||||
"""Split snake_case identifier into tokens.
|
||||
|
||||
Args:
|
||||
word: snake_case identifier (e.g., 'get_user_data')
|
||||
|
||||
Returns:
|
||||
List of tokens (e.g., ['get', 'user', 'data'])
|
||||
"""
|
||||
# Split on underscores
|
||||
return [t for t in self.SNAKE_CASE_PATTERN.split(word) if t]
|
||||
|
||||
def _split_kebab_case(self, word: str) -> List[str]:
|
||||
"""Split kebab-case identifier into tokens.
|
||||
|
||||
Args:
|
||||
word: kebab-case identifier (e.g., 'get-user-data')
|
||||
|
||||
Returns:
|
||||
List of tokens (e.g., ['get', 'user', 'data'])
|
||||
"""
|
||||
# Split on hyphens
|
||||
return [t for t in self.KEBAB_CASE_PATTERN.split(word) if t]
|
||||
|
||||
|
||||
# Global default parser instance
|
||||
_default_parser = QueryParser(enable=True)
|
||||
|
||||
|
||||
def preprocess_query(query: str, enable: bool = True) -> str:
|
||||
"""Convenience function for query preprocessing.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
enable: Whether to enable preprocessing
|
||||
|
||||
Returns:
|
||||
Preprocessed query with identifier expansion
|
||||
"""
|
||||
if not enable:
|
||||
return query
|
||||
|
||||
return _default_parser.preprocess_query(query)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QueryParser",
|
||||
"preprocess_query",
|
||||
]
|
||||
942
codex-lens/build/lib/codexlens/search/ranking.py
Normal file
942
codex-lens/build/lib/codexlens/search/ranking.py
Normal file
@@ -0,0 +1,942 @@
|
||||
"""Ranking algorithms for hybrid search result fusion.
|
||||
|
||||
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
|
||||
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import math
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
|
||||
# Default RRF weights for SPLADE-based hybrid search
|
||||
DEFAULT_WEIGHTS = {
|
||||
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
}
|
||||
|
||||
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
||||
FTS_FALLBACK_WEIGHTS = {
|
||||
"exact": 0.25,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
}
|
||||
|
||||
|
||||
class QueryIntent(str, Enum):
|
||||
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
def normalize_weights(weights: Dict[str, float | None]) -> Dict[str, float | None]:
|
||||
"""Normalize weights to sum to 1.0 (best-effort)."""
|
||||
total = sum(float(v) for v in weights.values() if v is not None)
|
||||
|
||||
# NaN total: do not attempt to normalize (division would propagate NaNs).
|
||||
if math.isnan(total):
|
||||
return dict(weights)
|
||||
|
||||
# Infinite total: do not attempt to normalize (division yields 0 or NaN).
|
||||
if not math.isfinite(total):
|
||||
return dict(weights)
|
||||
|
||||
# Zero/negative total: do not attempt to normalize (invalid denominator).
|
||||
if total <= 0:
|
||||
return dict(weights)
|
||||
|
||||
return {k: (float(v) / total if v is not None else None) for k, v in weights.items()}
|
||||
|
||||
|
||||
def detect_query_intent(query: str) -> QueryIntent:
|
||||
"""Detect whether a query is code-like, natural-language, or mixed.
|
||||
|
||||
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
|
||||
"""
|
||||
trimmed = (query or "").strip()
|
||||
if not trimmed:
|
||||
return QueryIntent.MIXED
|
||||
|
||||
lower = trimmed.lower()
|
||||
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
|
||||
|
||||
has_code_signals = bool(
|
||||
re.search(r"(::|->|\.)", trimmed)
|
||||
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
|
||||
or re.search(r"\b\w+_\w+\b", trimmed)
|
||||
or re.search(
|
||||
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
|
||||
lower,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
has_natural_signals = bool(
|
||||
word_count > 5
|
||||
or "?" in trimmed
|
||||
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
|
||||
or re.search(
|
||||
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
|
||||
trimmed,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
|
||||
if has_code_signals and has_natural_signals:
|
||||
return QueryIntent.MIXED
|
||||
if has_code_signals:
|
||||
return QueryIntent.KEYWORD
|
||||
if has_natural_signals:
|
||||
return QueryIntent.SEMANTIC
|
||||
return QueryIntent.MIXED
|
||||
|
||||
|
||||
def adjust_weights_by_intent(
|
||||
intent: QueryIntent,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Adjust RRF weights based on query intent."""
|
||||
# Check if using SPLADE or FTS mode
|
||||
use_splade = "splade" in base_weights
|
||||
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
if use_splade:
|
||||
target = {"splade": 0.6, "vector": 0.4}
|
||||
else:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
if use_splade:
|
||||
target = {"splade": 0.3, "vector": 0.7}
|
||||
else:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
else:
|
||||
target = dict(base_weights)
|
||||
|
||||
# Filter to active backends
|
||||
keys = list(base_weights.keys())
|
||||
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||
return normalize_weights(filtered)
|
||||
|
||||
|
||||
def get_rrf_weights(
|
||||
query: str,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Compute adaptive RRF weights from query intent."""
|
||||
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
|
||||
|
||||
|
||||
# File extensions to category mapping for fast lookup
|
||||
_EXT_TO_CATEGORY: Dict[str, str] = {
|
||||
# Code extensions
|
||||
".py": "code", ".js": "code", ".jsx": "code", ".ts": "code", ".tsx": "code",
|
||||
".java": "code", ".go": "code", ".zig": "code", ".m": "code", ".mm": "code",
|
||||
".c": "code", ".h": "code", ".cc": "code", ".cpp": "code", ".hpp": "code", ".cxx": "code",
|
||||
".rs": "code",
|
||||
# Doc extensions
|
||||
".md": "doc", ".mdx": "doc", ".txt": "doc", ".rst": "doc",
|
||||
}
|
||||
|
||||
|
||||
def get_file_category(path: str) -> Optional[str]:
|
||||
"""Get file category ('code' or 'doc') from path extension.
|
||||
|
||||
Args:
|
||||
path: File path string
|
||||
|
||||
Returns:
|
||||
'code', 'doc', or None if unknown
|
||||
"""
|
||||
ext = Path(path).suffix.lower()
|
||||
return _EXT_TO_CATEGORY.get(ext)
|
||||
|
||||
|
||||
def filter_results_by_category(
|
||||
results: List[SearchResult],
|
||||
intent: QueryIntent,
|
||||
allow_mixed: bool = True,
|
||||
) -> List[SearchResult]:
|
||||
"""Filter results by category based on query intent.
|
||||
|
||||
Strategy:
|
||||
- KEYWORD (code intent): Only return code files
|
||||
- SEMANTIC (doc intent): Prefer docs, but allow code if allow_mixed=True
|
||||
- MIXED: Return all results
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
intent: Query intent from detect_query_intent()
|
||||
allow_mixed: If True, SEMANTIC intent includes code files with lower priority
|
||||
|
||||
Returns:
|
||||
Filtered and re-ranked list of SearchResult objects
|
||||
"""
|
||||
if not results or intent == QueryIntent.MIXED:
|
||||
return results
|
||||
|
||||
code_results = []
|
||||
doc_results = []
|
||||
unknown_results = []
|
||||
|
||||
for r in results:
|
||||
category = get_file_category(r.path)
|
||||
if category == "code":
|
||||
code_results.append(r)
|
||||
elif category == "doc":
|
||||
doc_results.append(r)
|
||||
else:
|
||||
unknown_results.append(r)
|
||||
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
# Code intent: return only code files + unknown (might be code)
|
||||
filtered = code_results + unknown_results
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
if allow_mixed:
|
||||
# Semantic intent with mixed: docs first, then code
|
||||
filtered = doc_results + code_results + unknown_results
|
||||
else:
|
||||
# Semantic intent strict: only docs
|
||||
filtered = doc_results + unknown_results
|
||||
else:
|
||||
filtered = results
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def simple_weighted_fusion(
|
||||
results_map: Dict[str, List[SearchResult]],
|
||||
weights: Dict[str, float] = None,
|
||||
) -> List[SearchResult]:
|
||||
"""Combine search results using simple weighted sum of normalized scores.
|
||||
|
||||
This is an alternative to RRF that preserves score magnitude information.
|
||||
Scores are min-max normalized per source before weighted combination.
|
||||
|
||||
Formula: score(d) = Σ weight_source * normalized_score_source(d)
|
||||
|
||||
Args:
|
||||
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by fused score (descending)
|
||||
|
||||
Examples:
|
||||
>>> fts_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
||||
>>> vector_results = [SearchResult(path="b.py", score=0.85, excerpt="...")]
|
||||
>>> results_map = {'exact': fts_results, 'vector': vector_results}
|
||||
>>> fused = simple_weighted_fusion(results_map)
|
||||
"""
|
||||
if not results_map:
|
||||
return []
|
||||
|
||||
# Default equal weights if not provided
|
||||
if weights is None:
|
||||
num_sources = len(results_map)
|
||||
weights = {source: 1.0 / num_sources for source in results_map}
|
||||
|
||||
# Normalize weights to sum to 1.0
|
||||
weight_sum = sum(weights.values())
|
||||
if not math.isclose(weight_sum, 1.0, abs_tol=0.01) and weight_sum > 0:
|
||||
weights = {source: w / weight_sum for source, w in weights.items()}
|
||||
|
||||
# Compute min-max normalization parameters per source
|
||||
source_stats: Dict[str, tuple] = {}
|
||||
for source_name, results in results_map.items():
|
||||
if not results:
|
||||
continue
|
||||
scores = [r.score for r in results]
|
||||
min_s, max_s = min(scores), max(scores)
|
||||
source_stats[source_name] = (min_s, max_s)
|
||||
|
||||
def normalize_score(score: float, source: str) -> float:
|
||||
"""Normalize score to [0, 1] range using min-max scaling."""
|
||||
if source not in source_stats:
|
||||
return 0.0
|
||||
min_s, max_s = source_stats[source]
|
||||
if max_s == min_s:
|
||||
return 1.0 if score >= min_s else 0.0
|
||||
return (score - min_s) / (max_s - min_s)
|
||||
|
||||
# Build unified result set with weighted scores
|
||||
path_to_result: Dict[str, SearchResult] = {}
|
||||
path_to_fusion_score: Dict[str, float] = {}
|
||||
path_to_source_scores: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
for source_name, results in results_map.items():
|
||||
weight = weights.get(source_name, 0.0)
|
||||
if weight == 0:
|
||||
continue
|
||||
|
||||
for result in results:
|
||||
path = result.path
|
||||
normalized = normalize_score(result.score, source_name)
|
||||
contribution = weight * normalized
|
||||
|
||||
if path not in path_to_fusion_score:
|
||||
path_to_fusion_score[path] = 0.0
|
||||
path_to_result[path] = result
|
||||
path_to_source_scores[path] = {}
|
||||
|
||||
path_to_fusion_score[path] += contribution
|
||||
path_to_source_scores[path][source_name] = normalized
|
||||
|
||||
# Create final results with fusion scores
|
||||
fused_results = []
|
||||
for path, base_result in path_to_result.items():
|
||||
fusion_score = path_to_fusion_score[path]
|
||||
|
||||
fused_result = SearchResult(
|
||||
path=base_result.path,
|
||||
score=fusion_score,
|
||||
excerpt=base_result.excerpt,
|
||||
content=base_result.content,
|
||||
symbol=base_result.symbol,
|
||||
chunk=base_result.chunk,
|
||||
metadata={
|
||||
**base_result.metadata,
|
||||
"fusion_method": "simple_weighted",
|
||||
"fusion_score": fusion_score,
|
||||
"original_score": base_result.score,
|
||||
"source_scores": path_to_source_scores[path],
|
||||
},
|
||||
start_line=base_result.start_line,
|
||||
end_line=base_result.end_line,
|
||||
symbol_name=base_result.symbol_name,
|
||||
symbol_kind=base_result.symbol_kind,
|
||||
)
|
||||
fused_results.append(fused_result)
|
||||
|
||||
fused_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return fused_results
|
||||
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
results_map: Dict[str, List[SearchResult]],
|
||||
weights: Dict[str, float] = None,
|
||||
k: int = 60,
|
||||
) -> List[SearchResult]:
|
||||
"""Combine search results from multiple sources using Reciprocal Rank Fusion.
|
||||
|
||||
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
|
||||
|
||||
Supports three-way fusion with FTS, Vector, and SPLADE sources.
|
||||
|
||||
Args:
|
||||
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||
Or: {'splade': 0.4, 'vector': 0.6}
|
||||
k: Constant to avoid division by zero and control rank influence (default 60)
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by fused score (descending)
|
||||
|
||||
Examples:
|
||||
>>> exact_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
||||
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
|
||||
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
|
||||
>>> fused = reciprocal_rank_fusion(results_map)
|
||||
|
||||
# Three-way fusion with SPLADE
|
||||
>>> results_map = {
|
||||
... 'exact': exact_results,
|
||||
... 'vector': vector_results,
|
||||
... 'splade': splade_results
|
||||
... }
|
||||
>>> fused = reciprocal_rank_fusion(results_map, k=60)
|
||||
"""
|
||||
if not results_map:
|
||||
return []
|
||||
|
||||
# Default equal weights if not provided
|
||||
if weights is None:
|
||||
num_sources = len(results_map)
|
||||
weights = {source: 1.0 / num_sources for source in results_map}
|
||||
|
||||
# Validate weights sum to 1.0
|
||||
weight_sum = sum(weights.values())
|
||||
if not math.isclose(weight_sum, 1.0, abs_tol=0.01):
|
||||
# Normalize weights to sum to 1.0
|
||||
weights = {source: w / weight_sum for source, w in weights.items()}
|
||||
|
||||
# Build unified result set with RRF scores
|
||||
path_to_result: Dict[str, SearchResult] = {}
|
||||
path_to_fusion_score: Dict[str, float] = {}
|
||||
path_to_source_ranks: Dict[str, Dict[str, int]] = {}
|
||||
|
||||
for source_name, results in results_map.items():
|
||||
weight = weights.get(source_name, 0.0)
|
||||
if weight == 0:
|
||||
continue
|
||||
|
||||
for rank, result in enumerate(results, start=1):
|
||||
path = result.path
|
||||
rrf_contribution = weight / (k + rank)
|
||||
|
||||
# Initialize or accumulate fusion score
|
||||
if path not in path_to_fusion_score:
|
||||
path_to_fusion_score[path] = 0.0
|
||||
path_to_result[path] = result
|
||||
path_to_source_ranks[path] = {}
|
||||
|
||||
path_to_fusion_score[path] += rrf_contribution
|
||||
path_to_source_ranks[path][source_name] = rank
|
||||
|
||||
# Create final results with fusion scores
|
||||
fused_results = []
|
||||
for path, base_result in path_to_result.items():
|
||||
fusion_score = path_to_fusion_score[path]
|
||||
|
||||
# Create new SearchResult with fusion_score in metadata
|
||||
fused_result = SearchResult(
|
||||
path=base_result.path,
|
||||
score=fusion_score,
|
||||
excerpt=base_result.excerpt,
|
||||
content=base_result.content,
|
||||
symbol=base_result.symbol,
|
||||
chunk=base_result.chunk,
|
||||
metadata={
|
||||
**base_result.metadata,
|
||||
"fusion_method": "rrf",
|
||||
"fusion_score": fusion_score,
|
||||
"original_score": base_result.score,
|
||||
"rrf_k": k,
|
||||
"source_ranks": path_to_source_ranks[path],
|
||||
},
|
||||
start_line=base_result.start_line,
|
||||
end_line=base_result.end_line,
|
||||
symbol_name=base_result.symbol_name,
|
||||
symbol_kind=base_result.symbol_kind,
|
||||
)
|
||||
fused_results.append(fused_result)
|
||||
|
||||
# Sort by fusion score descending
|
||||
fused_results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return fused_results
|
||||
|
||||
|
||||
def apply_symbol_boost(
|
||||
results: List[SearchResult],
|
||||
boost_factor: float = 1.5,
|
||||
) -> List[SearchResult]:
|
||||
"""Boost fused scores for results that include an explicit symbol match.
|
||||
|
||||
The boost is multiplicative on the current result.score (typically the RRF fusion score).
|
||||
When boosted, the original score is preserved in metadata["original_fusion_score"] and
|
||||
metadata["boosted"] is set to True.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if boost_factor <= 1.0:
|
||||
# Still return new objects to follow immutable transformation pattern.
|
||||
return [
|
||||
SearchResult(
|
||||
path=r.path,
|
||||
score=r.score,
|
||||
excerpt=r.excerpt,
|
||||
content=r.content,
|
||||
symbol=r.symbol,
|
||||
chunk=r.chunk,
|
||||
metadata={**r.metadata},
|
||||
start_line=r.start_line,
|
||||
end_line=r.end_line,
|
||||
symbol_name=r.symbol_name,
|
||||
symbol_kind=r.symbol_kind,
|
||||
additional_locations=list(r.additional_locations),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
boosted_results: List[SearchResult] = []
|
||||
for result in results:
|
||||
has_symbol = bool(result.symbol_name)
|
||||
original_score = float(result.score)
|
||||
boosted_score = original_score * boost_factor if has_symbol else original_score
|
||||
|
||||
metadata = {**result.metadata}
|
||||
if has_symbol:
|
||||
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
|
||||
metadata["boosted"] = True
|
||||
metadata["symbol_boost_factor"] = boost_factor
|
||||
|
||||
boosted_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=boosted_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata=metadata,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
boosted_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return boosted_results
|
||||
|
||||
|
||||
def rerank_results(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
embedder: Any,
|
||||
top_k: int = 50,
|
||||
) -> List[SearchResult]:
|
||||
"""Re-rank results with embedding cosine similarity, combined with current score.
|
||||
|
||||
Combined score formula:
|
||||
0.5 * rrf_score + 0.5 * cosine_similarity
|
||||
|
||||
If embedder is None or embedding fails, returns results as-is.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if embedder is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
||||
# Defensive: handle mismatched lengths and zero vectors.
|
||||
n = min(len(vec_a), len(vec_b))
|
||||
if n == 0:
|
||||
return 0.0
|
||||
dot = 0.0
|
||||
norm_a = 0.0
|
||||
norm_b = 0.0
|
||||
for i in range(n):
|
||||
a = float(vec_a[i])
|
||||
b = float(vec_b[i])
|
||||
dot += a * b
|
||||
norm_a += a * a
|
||||
norm_b += b * b
|
||||
if norm_a <= 0.0 or norm_b <= 0.0:
|
||||
return 0.0
|
||||
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
|
||||
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
|
||||
return max(0.0, min(1.0, sim))
|
||||
|
||||
def text_for_embedding(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
# Fallback: stable, non-empty text.
|
||||
return r.symbol_name or r.path
|
||||
|
||||
try:
|
||||
if hasattr(embedder, "embed_single"):
|
||||
query_vec = embedder.embed_single(query)
|
||||
else:
|
||||
query_vec = embedder.embed(query)[0]
|
||||
|
||||
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
|
||||
doc_vecs = embedder.embed(doc_texts)
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
rrf_score = float(result.score)
|
||||
sim = cosine_similarity(query_vec, doc_vecs[idx])
|
||||
combined_score = 0.5 * rrf_score + 0.5 * sim
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"rrf_score": rrf_score,
|
||||
"cosine_similarity": sim,
|
||||
"reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Preserve remaining results without re-ranking, but keep immutability.
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def cross_encoder_rerank(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
reranker: Any,
|
||||
top_k: int = 50,
|
||||
batch_size: int = 32,
|
||||
chunk_type_weights: Optional[Dict[str, float]] = None,
|
||||
test_file_penalty: float = 0.0,
|
||||
) -> List[SearchResult]:
|
||||
"""Second-stage reranking using a cross-encoder model.
|
||||
|
||||
This function is dependency-agnostic: callers can pass any object that exposes
|
||||
a compatible `score_pairs(pairs, batch_size=...)` method.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
results: List of search results to rerank
|
||||
reranker: Cross-encoder model with score_pairs or predict method
|
||||
top_k: Number of top results to rerank
|
||||
batch_size: Batch size for reranking
|
||||
chunk_type_weights: Optional weights for different chunk types.
|
||||
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
|
||||
test_file_penalty: Penalty applied to test files (0.0-1.0).
|
||||
Example: 0.2 means test files get 20% score reduction
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if reranker is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def text_for_pair(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
return r.symbol_name or r.path
|
||||
|
||||
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
|
||||
|
||||
try:
|
||||
if hasattr(reranker, "score_pairs"):
|
||||
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
|
||||
elif hasattr(reranker, "predict"):
|
||||
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
|
||||
else:
|
||||
return results
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
if not raw_scores or len(raw_scores) != rerank_count:
|
||||
return results
|
||||
|
||||
scores = [float(s) for s in raw_scores]
|
||||
min_s = min(scores)
|
||||
max_s = max(scores)
|
||||
|
||||
def sigmoid(x: float) -> float:
|
||||
# Clamp to keep exp() stable.
|
||||
x = max(-50.0, min(50.0, x))
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
if 0.0 <= min_s and max_s <= 1.0:
|
||||
probs = scores
|
||||
else:
|
||||
probs = [sigmoid(s) for s in scores]
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
# Helper to detect test files
|
||||
def is_test_file(path: str) -> bool:
|
||||
if not path:
|
||||
return False
|
||||
basename = path.split("/")[-1].split("\\")[-1]
|
||||
return (
|
||||
basename.startswith("test_") or
|
||||
basename.endswith("_test.py") or
|
||||
basename.endswith(".test.ts") or
|
||||
basename.endswith(".test.js") or
|
||||
basename.endswith(".spec.ts") or
|
||||
basename.endswith(".spec.js") or
|
||||
"/tests/" in path or
|
||||
"\\tests\\" in path or
|
||||
"/test/" in path or
|
||||
"\\test\\" in path
|
||||
)
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
prev_score = float(result.score)
|
||||
ce_score = scores[idx]
|
||||
ce_prob = probs[idx]
|
||||
|
||||
# Base combined score
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob
|
||||
|
||||
# Apply chunk_type weight adjustment
|
||||
if chunk_type_weights:
|
||||
chunk_type = None
|
||||
if result.chunk and hasattr(result.chunk, "metadata"):
|
||||
chunk_type = result.chunk.metadata.get("chunk_type")
|
||||
elif result.metadata:
|
||||
chunk_type = result.metadata.get("chunk_type")
|
||||
|
||||
if chunk_type and chunk_type in chunk_type_weights:
|
||||
weight = chunk_type_weights[chunk_type]
|
||||
# Apply weight to CE contribution only
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
|
||||
|
||||
# Apply test file penalty
|
||||
if test_file_penalty > 0 and is_test_file(result.path):
|
||||
combined_score = combined_score * (1.0 - test_file_penalty)
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"pre_cross_encoder_score": prev_score,
|
||||
"cross_encoder_score": ce_score,
|
||||
"cross_encoder_prob": ce_prob,
|
||||
"cross_encoder_reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def normalize_bm25_score(score: float) -> float:
|
||||
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
||||
|
||||
SQLite FTS5 returns negative BM25 scores (more negative = better match).
|
||||
Uses sigmoid transformation for normalization.
|
||||
|
||||
Args:
|
||||
score: Raw BM25 score from SQLite (typically negative)
|
||||
|
||||
Returns:
|
||||
Normalized score in range [0, 1]
|
||||
|
||||
Examples:
|
||||
>>> normalize_bm25_score(-10.5) # Good match
|
||||
0.85
|
||||
>>> normalize_bm25_score(-1.2) # Weak match
|
||||
0.62
|
||||
"""
|
||||
# Take absolute value (BM25 is negative in SQLite)
|
||||
abs_score = abs(score)
|
||||
|
||||
# Sigmoid transformation: 1 / (1 + e^(-x))
|
||||
# Scale factor of 0.1 maps typical BM25 range (-20 to 0) to (0, 1)
|
||||
normalized = 1.0 / (1.0 + math.exp(-abs_score * 0.1))
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def tag_search_source(results: List[SearchResult], source: str) -> List[SearchResult]:
|
||||
"""Tag search results with their source for RRF tracking.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
source: Source identifier ('exact', 'fuzzy', 'vector')
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects with 'search_source' in metadata
|
||||
"""
|
||||
tagged_results = []
|
||||
for result in results:
|
||||
tagged_result = SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata, "search_source": source},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
)
|
||||
tagged_results.append(tagged_result)
|
||||
|
||||
return tagged_results
|
||||
|
||||
|
||||
def group_similar_results(
|
||||
results: List[SearchResult],
|
||||
score_threshold_abs: float = 0.01,
|
||||
content_field: str = "excerpt"
|
||||
) -> List[SearchResult]:
|
||||
"""Group search results by content and score similarity.
|
||||
|
||||
Groups results that have similar content and similar scores into a single
|
||||
representative result, with other locations stored in additional_locations.
|
||||
|
||||
Algorithm:
|
||||
1. Group results by content (using excerpt or content field)
|
||||
2. Within each content group, create subgroups based on score similarity
|
||||
3. Select highest-scoring result as representative for each subgroup
|
||||
4. Store other results in subgroup as additional_locations
|
||||
|
||||
Args:
|
||||
results: A list of SearchResult objects (typically sorted by score)
|
||||
score_threshold_abs: Absolute score difference to consider results similar.
|
||||
Results with |score_a - score_b| <= threshold are grouped.
|
||||
Default 0.01 is suitable for RRF fusion scores.
|
||||
content_field: The field to use for content grouping ('excerpt' or 'content')
|
||||
|
||||
Returns:
|
||||
A new list of SearchResult objects where similar items are grouped.
|
||||
The list is sorted by score descending.
|
||||
|
||||
Examples:
|
||||
>>> results = [SearchResult(path="a.py", score=0.5, excerpt="def foo()"),
|
||||
... SearchResult(path="b.py", score=0.5, excerpt="def foo()")]
|
||||
>>> grouped = group_similar_results(results)
|
||||
>>> len(grouped) # Two results merged into one
|
||||
1
|
||||
>>> len(grouped[0].additional_locations) # One additional location
|
||||
1
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Group results by content
|
||||
content_map: Dict[str, List[SearchResult]] = {}
|
||||
unidentifiable_results: List[SearchResult] = []
|
||||
|
||||
for r in results:
|
||||
key = getattr(r, content_field, None)
|
||||
if key and key.strip():
|
||||
content_map.setdefault(key, []).append(r)
|
||||
else:
|
||||
# Results without content can't be grouped by content
|
||||
unidentifiable_results.append(r)
|
||||
|
||||
final_results: List[SearchResult] = []
|
||||
|
||||
# Process each content group
|
||||
for content_group in content_map.values():
|
||||
# Sort by score descending within group
|
||||
content_group.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
while content_group:
|
||||
# Take highest scoring as representative
|
||||
representative = content_group.pop(0)
|
||||
others_in_group = []
|
||||
remaining_for_next_pass = []
|
||||
|
||||
# Find results with similar scores
|
||||
for item in content_group:
|
||||
if abs(representative.score - item.score) <= score_threshold_abs:
|
||||
others_in_group.append(item)
|
||||
else:
|
||||
remaining_for_next_pass.append(item)
|
||||
|
||||
# Create grouped result with additional locations
|
||||
if others_in_group:
|
||||
# Build new result with additional_locations populated
|
||||
grouped_result = SearchResult(
|
||||
path=representative.path,
|
||||
score=representative.score,
|
||||
excerpt=representative.excerpt,
|
||||
content=representative.content,
|
||||
symbol=representative.symbol,
|
||||
chunk=representative.chunk,
|
||||
metadata={
|
||||
**representative.metadata,
|
||||
"grouped_count": len(others_in_group) + 1,
|
||||
},
|
||||
start_line=representative.start_line,
|
||||
end_line=representative.end_line,
|
||||
symbol_name=representative.symbol_name,
|
||||
symbol_kind=representative.symbol_kind,
|
||||
additional_locations=[
|
||||
AdditionalLocation(
|
||||
path=other.path,
|
||||
score=other.score,
|
||||
start_line=other.start_line,
|
||||
end_line=other.end_line,
|
||||
symbol_name=other.symbol_name,
|
||||
) for other in others_in_group
|
||||
],
|
||||
)
|
||||
final_results.append(grouped_result)
|
||||
else:
|
||||
final_results.append(representative)
|
||||
|
||||
content_group = remaining_for_next_pass
|
||||
|
||||
# Add ungroupable results
|
||||
final_results.extend(unidentifiable_results)
|
||||
|
||||
# Sort final results by score descending
|
||||
final_results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return final_results
|
||||
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Optional semantic search module for CodexLens.
|
||||
|
||||
Install with: pip install codexlens[semantic]
|
||||
Uses fastembed (ONNX-based, lightweight ~200MB)
|
||||
|
||||
GPU Acceleration:
|
||||
- Automatic GPU detection and usage when available
|
||||
- Supports CUDA (NVIDIA), TensorRT, DirectML (Windows), ROCm (AMD), CoreML (Apple)
|
||||
- Install GPU support: pip install onnxruntime-gpu (NVIDIA) or onnxruntime-directml (Windows)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
SEMANTIC_AVAILABLE = False
|
||||
SEMANTIC_BACKEND: str | None = None
|
||||
GPU_AVAILABLE = False
|
||||
LITELLM_AVAILABLE = False
|
||||
_import_error: str | None = None
|
||||
|
||||
|
||||
def _detect_backend() -> tuple[bool, str | None, bool, str | None]:
|
||||
"""Detect if fastembed and GPU are available."""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as e:
|
||||
return False, None, False, f"numpy not available: {e}"
|
||||
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
return False, None, False, "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||
|
||||
# Check GPU availability
|
||||
gpu_available = False
|
||||
try:
|
||||
from .gpu_support import is_gpu_available
|
||||
gpu_available = is_gpu_available()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return True, "fastembed", gpu_available, None
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, GPU_AVAILABLE, _import_error = _detect_backend()
|
||||
|
||||
|
||||
def check_semantic_available() -> tuple[bool, str | None]:
|
||||
"""Check if semantic search dependencies are available."""
|
||||
return SEMANTIC_AVAILABLE, _import_error
|
||||
|
||||
|
||||
def check_gpu_available() -> tuple[bool, str]:
|
||||
"""Check if GPU acceleration is available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, status_message)
|
||||
"""
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
return False, "Semantic search not available"
|
||||
|
||||
try:
|
||||
from .gpu_support import is_gpu_available, get_gpu_summary
|
||||
if is_gpu_available():
|
||||
return True, get_gpu_summary()
|
||||
return False, "No GPU detected (using CPU)"
|
||||
except ImportError:
|
||||
return False, "GPU support module not available"
|
||||
|
||||
|
||||
# Export embedder components
|
||||
# BaseEmbedder is always available (abstract base class)
|
||||
from .base import BaseEmbedder
|
||||
|
||||
# Factory function for creating embedders
|
||||
from .factory import get_embedder as get_embedder_factory
|
||||
|
||||
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
|
||||
try:
|
||||
import ccw_litellm # noqa: F401
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
LITELLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
LiteLLMEmbedderWrapper = None
|
||||
LITELLM_AVAILABLE = False
|
||||
|
||||
|
||||
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific embedding backend can be used.
|
||||
|
||||
Notes:
|
||||
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
|
||||
- "litellm" requires ccw-litellm to be installed in the same environment.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
if backend == "fastembed":
|
||||
if SEMANTIC_AVAILABLE:
|
||||
return True, None
|
||||
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||
if backend == "litellm":
|
||||
if LITELLM_AVAILABLE:
|
||||
return True, None
|
||||
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
|
||||
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SEMANTIC_AVAILABLE",
|
||||
"SEMANTIC_BACKEND",
|
||||
"GPU_AVAILABLE",
|
||||
"LITELLM_AVAILABLE",
|
||||
"check_semantic_available",
|
||||
"is_embedding_backend_available",
|
||||
"check_gpu_available",
|
||||
"BaseEmbedder",
|
||||
"get_embedder_factory",
|
||||
"LiteLLMEmbedderWrapper",
|
||||
]
|
||||
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
File diff suppressed because it is too large
Load Diff
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Base class for embedders.
|
||||
|
||||
Defines the interface that all embedders must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""Base class for all embedders.
|
||||
|
||||
All embedder implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions.
|
||||
|
||||
Returns:
|
||||
int: Dimension of the embedding vectors.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self) -> str:
|
||||
"""Return model name.
|
||||
|
||||
Returns:
|
||||
str: Name or identifier of the underlying model.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for embeddings.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Embed texts to numpy array.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
...
|
||||
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""Code chunking strategies for semantic search.
|
||||
|
||||
This module provides various chunking strategies for breaking down source code
|
||||
into semantic chunks suitable for embedding and search.
|
||||
|
||||
Lightweight Mode:
|
||||
The ChunkConfig supports a `skip_token_count` option for performance optimization.
|
||||
When enabled, token counting uses a fast character-based estimation (char/4)
|
||||
instead of expensive tiktoken encoding.
|
||||
|
||||
Use cases for lightweight mode:
|
||||
- Large-scale indexing where speed is critical
|
||||
- Scenarios where approximate token counts are acceptable
|
||||
- Memory-constrained environments
|
||||
- Initial prototyping and development
|
||||
|
||||
Example:
|
||||
# Default mode (accurate tiktoken encoding)
|
||||
config = ChunkConfig()
|
||||
chunker = Chunker(config)
|
||||
|
||||
# Lightweight mode (fast char/4 estimation)
|
||||
config = ChunkConfig(skip_token_count=True)
|
||||
chunker = Chunker(config)
|
||||
chunks = chunker.chunk_file(content, symbols, path, language)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SemanticChunk, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkConfig:
|
||||
"""Configuration for chunking strategies."""
|
||||
max_chunk_size: int = 1000 # Max characters per chunk
|
||||
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
|
||||
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||
min_chunk_size: int = 50 # Minimum chunk size
|
||||
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
|
||||
strip_comments: bool = True # Remove comments from chunk content for embedding
|
||||
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
|
||||
preserve_original: bool = True # Store original content in metadata when stripping
|
||||
|
||||
|
||||
class CommentStripper:
|
||||
"""Remove comments from source code while preserving structure."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_comments(content: str) -> str:
|
||||
"""Strip Python comments (# style) but preserve docstrings.
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
for line in lines:
|
||||
new_line = []
|
||||
i = 0
|
||||
while i < len(line):
|
||||
char = line[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'") and not in_string:
|
||||
# Check for triple quotes
|
||||
if line[i:i+3] in ('"""', "'''"):
|
||||
in_string = True
|
||||
string_char = line[i:i+3]
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
continue
|
||||
else:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif in_string:
|
||||
if string_char and len(string_char) == 3:
|
||||
if line[i:i+3] == string_char:
|
||||
in_string = False
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
string_char = None
|
||||
continue
|
||||
elif char == string_char:
|
||||
# Check for escape
|
||||
if i > 0 and line[i-1] != '\\':
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
# Handle comments (only outside strings)
|
||||
if char == '#' and not in_string:
|
||||
# Rest of line is comment, skip it
|
||||
new_line.append('\n' if line.endswith('\n') else '')
|
||||
break
|
||||
|
||||
new_line.append(char)
|
||||
i += 1
|
||||
|
||||
result_lines.append(''.join(new_line))
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_c_style_comments(content: str) -> str:
|
||||
"""Strip C-style comments (// and /* */) from code.
|
||||
|
||||
Args:
|
||||
content: Source code with C-style comments
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
string_char = None
|
||||
in_multiline_comment = False
|
||||
|
||||
while i < len(content):
|
||||
# Handle multi-line comment end
|
||||
if in_multiline_comment:
|
||||
if content[i:i+2] == '*/':
|
||||
in_multiline_comment = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
char = content[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'", '`') and not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
elif in_string:
|
||||
result.append(char)
|
||||
if char == string_char and (i == 0 or content[i-1] != '\\'):
|
||||
in_string = False
|
||||
string_char = None
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle comments
|
||||
if content[i:i+2] == '//':
|
||||
# Single line comment - skip to end of line
|
||||
while i < len(content) and content[i] != '\n':
|
||||
i += 1
|
||||
if i < len(content):
|
||||
result.append('\n')
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if content[i:i+2] == '/*':
|
||||
in_multiline_comment = True
|
||||
i += 2
|
||||
continue
|
||||
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_comments(cls, content: str, language: str) -> str:
|
||||
"""Strip comments based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_comments(content)
|
||||
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
|
||||
return cls.strip_c_style_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class DocstringStripper:
|
||||
"""Remove docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_docstrings(content: str) -> str:
|
||||
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
# Check for docstring start
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
|
||||
# Single line docstring
|
||||
if stripped.count(quote_type) >= 2:
|
||||
# Skip this line (docstring)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Multi-line docstring - skip until closing
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
if quote_type in lines[i]:
|
||||
i += 1
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
i += 1
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_jsdoc_comments(content: str) -> str:
|
||||
"""Strip JSDoc comments (/** ... */) from code.
|
||||
|
||||
Args:
|
||||
content: JavaScript/TypeScript source code
|
||||
|
||||
Returns:
|
||||
Code with JSDoc comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_jsdoc = False
|
||||
|
||||
while i < len(content):
|
||||
if in_jsdoc:
|
||||
if content[i:i+2] == '*/':
|
||||
in_jsdoc = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for JSDoc start (/** but not /*)
|
||||
if content[i:i+3] == '/**':
|
||||
in_jsdoc = True
|
||||
i += 3
|
||||
continue
|
||||
|
||||
result.append(content[i])
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_docstrings(cls, content: str, language: str) -> str:
|
||||
"""Strip docstrings based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.strip_jsdoc_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class Chunker:
|
||||
"""Chunk code files for semantic embedding."""
|
||||
|
||||
def __init__(self, config: ChunkConfig | None = None) -> None:
|
||||
self.config = config or ChunkConfig()
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
self._comment_stripper = CommentStripper()
|
||||
self._docstring_stripper = DocstringStripper()
|
||||
|
||||
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
|
||||
"""Process chunk content by stripping comments/docstrings if configured.
|
||||
|
||||
Args:
|
||||
content: Original chunk content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_content, original_content_if_preserved)
|
||||
"""
|
||||
original = content if self.config.preserve_original else None
|
||||
processed = content
|
||||
|
||||
if self.config.strip_comments:
|
||||
processed = self._comment_stripper.strip_comments(processed, language)
|
||||
|
||||
if self.config.strip_docstrings:
|
||||
processed = self._docstring_stripper.strip_docstrings(processed, language)
|
||||
|
||||
# If nothing changed, don't store original
|
||||
if processed == content:
|
||||
original = None
|
||||
|
||||
return processed, original
|
||||
|
||||
def _estimate_token_count(self, text: str) -> int:
|
||||
"""Estimate token count based on config.
|
||||
|
||||
If skip_token_count is True, uses character-based estimation (char/4).
|
||||
Otherwise, uses accurate tiktoken encoding.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if self.config.skip_token_count:
|
||||
# Fast character-based estimation: ~4 chars per token
|
||||
return max(1, len(text) // 4)
|
||||
return self._tokenizer.count_tokens(text)
|
||||
|
||||
def chunk_by_symbol(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code by extracted symbols (functions, classes).
|
||||
|
||||
Each symbol becomes one chunk with its full content.
|
||||
Large symbols exceeding max_chunk_size are recursively split using sliding window.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start_line - 1)
|
||||
end_idx = min(len(lines), end_line)
|
||||
|
||||
chunk_content = "".join(lines[start_idx:end_idx])
|
||||
if len(chunk_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Check if symbol content exceeds max_chunk_size
|
||||
if len(chunk_content) > self.config.max_chunk_size:
|
||||
# Create line mapping for correct line number tracking
|
||||
line_mapping = list(range(start_line, end_line + 1))
|
||||
|
||||
# Use sliding window to split large symbol
|
||||
sub_chunks = self.chunk_sliding_window(
|
||||
chunk_content,
|
||||
file_path=file_path,
|
||||
language=language,
|
||||
line_mapping=line_mapping
|
||||
)
|
||||
|
||||
# Update sub_chunks with parent symbol metadata
|
||||
for sub_chunk in sub_chunks:
|
||||
sub_chunk.metadata["symbol_name"] = symbol.name
|
||||
sub_chunk.metadata["symbol_kind"] = symbol.kind
|
||||
sub_chunk.metadata["strategy"] = "symbol_split"
|
||||
sub_chunk.metadata["chunk_type"] = "code"
|
||||
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
|
||||
|
||||
chunks.extend(sub_chunks)
|
||||
else:
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"symbol_name": symbol.name,
|
||||
"symbol_kind": symbol.kind,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_sliding_window(
|
||||
self,
|
||||
content: str,
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
line_mapping: Optional[List[int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code using sliding window approach.
|
||||
|
||||
Used for files without clear symbol boundaries or very long functions.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
line_mapping: Optional list mapping content line indices to original line numbers
|
||||
(1-indexed). If provided, line_mapping[i] is the original line number
|
||||
for the i-th line in content.
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
if not lines:
|
||||
return chunks
|
||||
|
||||
# Calculate lines per chunk based on average line length
|
||||
avg_line_len = len(content) / max(len(lines), 1)
|
||||
lines_per_chunk = max(10, int(self.config.max_chunk_size / max(avg_line_len, 1)))
|
||||
overlap_lines = max(2, int(self.config.overlap / max(avg_line_len, 1)))
|
||||
# Ensure overlap is less than chunk size to prevent infinite loop
|
||||
overlap_lines = min(overlap_lines, lines_per_chunk - 1)
|
||||
|
||||
start = 0
|
||||
chunk_idx = 0
|
||||
|
||||
while start < len(lines):
|
||||
end = min(start + lines_per_chunk, len(lines))
|
||||
chunk_content = "".join(lines[start:end])
|
||||
|
||||
if len(chunk_content.strip()) >= self.config.min_chunk_size:
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
# Move window forward
|
||||
step = lines_per_chunk - overlap_lines
|
||||
if step <= 0:
|
||||
step = 1
|
||||
start += step
|
||||
continue
|
||||
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
# Calculate correct line numbers
|
||||
if line_mapping:
|
||||
# Use line mapping to get original line numbers
|
||||
start_line = line_mapping[start]
|
||||
end_line = line_mapping[end - 1]
|
||||
else:
|
||||
# Default behavior: treat content as starting at line 1
|
||||
start_line = start + 1
|
||||
end_line = end
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_index": chunk_idx,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "sliding_window",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
chunk_idx += 1
|
||||
|
||||
# Move window, accounting for overlap
|
||||
step = lines_per_chunk - overlap_lines
|
||||
if step <= 0:
|
||||
step = 1 # Failsafe to prevent infinite loop
|
||||
start += step
|
||||
|
||||
# Break if we've reached the end
|
||||
if end >= len(lines):
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk a file using the best strategy.
|
||||
|
||||
Uses symbol-based chunking if symbols available,
|
||||
falls back to sliding window for files without symbols.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
if symbols:
|
||||
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
|
||||
return self.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
class DocstringExtractor:
|
||||
"""Extract docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract Python docstrings with their line ranges.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
start_line = i + 1
|
||||
|
||||
if stripped.count(quote_type) >= 2:
|
||||
docstring_content = line
|
||||
end_line = i + 1
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
docstring_lines = [line]
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
docstring_lines.append(lines[i])
|
||||
if quote_type in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
docstring_content = "".join(docstring_lines)
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return docstrings
|
||||
|
||||
@staticmethod
|
||||
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract JSDoc comments with their line ranges.
|
||||
|
||||
Returns: List of (comment_content, start_line, end_line) tuples
|
||||
"""
|
||||
comments: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith('/**'):
|
||||
start_line = i + 1
|
||||
comment_lines = [line]
|
||||
i += 1
|
||||
|
||||
while i < len(lines):
|
||||
comment_lines.append(lines[i])
|
||||
if '*/' in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
comment_content = "".join(comment_lines)
|
||||
comments.append((comment_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return comments
|
||||
|
||||
@classmethod
|
||||
def extract_docstrings(
|
||||
cls,
|
||||
content: str,
|
||||
language: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""Extract docstrings based on language.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.extract_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.extract_jsdoc_comments(content)
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
|
||||
|
||||
Composition-based strategy that:
|
||||
1. Extracts docstrings as dedicated chunks
|
||||
2. For remaining code, uses base chunker (symbol or sliding window)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_chunker: Chunker | None = None,
|
||||
config: ChunkConfig | None = None
|
||||
) -> None:
|
||||
"""Initialize hybrid chunker.
|
||||
|
||||
Args:
|
||||
base_chunker: Chunker to use for non-docstring content
|
||||
config: Configuration for chunking
|
||||
"""
|
||||
self.config = config or ChunkConfig()
|
||||
self.base_chunker = base_chunker or Chunker(self.config)
|
||||
self.docstring_extractor = DocstringExtractor()
|
||||
|
||||
def _get_excluded_line_ranges(
|
||||
self,
|
||||
docstrings: List[Tuple[str, int, int]]
|
||||
) -> set[int]:
|
||||
"""Get set of line numbers that are part of docstrings."""
|
||||
excluded_lines: set[int] = set()
|
||||
for _, start_line, end_line in docstrings:
|
||||
for line_num in range(start_line, end_line + 1):
|
||||
excluded_lines.add(line_num)
|
||||
return excluded_lines
|
||||
|
||||
def _filter_symbols_outside_docstrings(
|
||||
self,
|
||||
symbols: List[Symbol],
|
||||
excluded_lines: set[int]
|
||||
) -> List[Symbol]:
|
||||
"""Filter symbols to exclude those completely within docstrings."""
|
||||
filtered: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
symbol_lines = set(range(start_line, end_line + 1))
|
||||
if not symbol_lines.issubset(excluded_lines):
|
||||
filtered.append(symbol)
|
||||
return filtered
|
||||
|
||||
def _find_parent_symbol(
|
||||
self,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
symbols: List[Symbol],
|
||||
) -> Optional[Symbol]:
|
||||
"""Find the smallest symbol range that fully contains a docstring span."""
|
||||
candidates: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
sym_start, sym_end = symbol.range
|
||||
if sym_start <= start_line and end_line <= sym_end:
|
||||
candidates.append(symbol)
|
||||
if not candidates:
|
||||
return None
|
||||
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk file using hybrid strategy.
|
||||
|
||||
Extracts docstrings first, then chunks remaining code.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
|
||||
# Step 1: Extract docstrings as dedicated chunks
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
if language == "python":
|
||||
# Fast path: avoid expensive docstring extraction if delimiters are absent.
|
||||
if '"""' in content or "'''" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
if "/**" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
else:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
|
||||
# Fast path: no docstrings -> delegate to base chunker directly.
|
||||
if not docstrings:
|
||||
if symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
else:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
return base_chunks
|
||||
|
||||
for docstring_content, start_line, end_line in docstrings:
|
||||
if len(docstring_content.strip()) >= self.config.min_chunk_size:
|
||||
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
|
||||
# Use base chunker's token estimation method
|
||||
token_count = self.base_chunker._estimate_token_count(docstring_content)
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_type": "docstring",
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "hybrid",
|
||||
"token_count": token_count,
|
||||
}
|
||||
if parent_symbol is not None:
|
||||
metadata["parent_symbol"] = parent_symbol.name
|
||||
metadata["parent_symbol_kind"] = parent_symbol.kind
|
||||
metadata["parent_symbol_range"] = parent_symbol.range
|
||||
chunks.append(SemanticChunk(
|
||||
content=docstring_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Step 2: Get line ranges occupied by docstrings
|
||||
excluded_lines = self._get_excluded_line_ranges(docstrings)
|
||||
|
||||
# Step 3: Filter symbols to exclude docstring-only ranges
|
||||
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
|
||||
|
||||
# Step 4: Chunk remaining content using base chunker
|
||||
if filtered_symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, filtered_symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
lines = content.splitlines(keepends=True)
|
||||
remaining_lines: List[str] = []
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
if i not in excluded_lines:
|
||||
remaining_lines.append(line)
|
||||
|
||||
if remaining_lines:
|
||||
remaining_content = "".join(remaining_lines)
|
||||
if len(remaining_content.strip()) >= self.config.min_chunk_size:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(
|
||||
remaining_content, file_path, language
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Smart code extraction for complete code blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
|
||||
|
||||
def extract_complete_code_block(
|
||||
result: SearchResult,
|
||||
source_file_path: Optional[str] = None,
|
||||
context_lines: int = 0,
|
||||
) -> str:
|
||||
"""Extract complete code block from a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult from semantic search.
|
||||
source_file_path: Optional path to source file for re-reading.
|
||||
context_lines: Additional lines of context to include above/below.
|
||||
|
||||
Returns:
|
||||
Complete code block as string.
|
||||
"""
|
||||
# If we have full content stored, use it
|
||||
if result.content:
|
||||
if context_lines == 0:
|
||||
return result.content
|
||||
# Need to add context, read from file
|
||||
|
||||
# Try to read from source file
|
||||
file_path = source_file_path or result.path
|
||||
if not file_path or not Path(file_path).exists():
|
||||
# Fall back to excerpt
|
||||
return result.excerpt or ""
|
||||
|
||||
try:
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.splitlines()
|
||||
|
||||
# Get line range
|
||||
start_line = result.start_line or 1
|
||||
end_line = result.end_line or len(lines)
|
||||
|
||||
# Add context
|
||||
start_idx = max(0, start_line - 1 - context_lines)
|
||||
end_idx = min(len(lines), end_line + context_lines)
|
||||
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
except Exception:
|
||||
return result.excerpt or result.content or ""
|
||||
|
||||
|
||||
def extract_symbol_with_context(
|
||||
file_path: str,
|
||||
symbol: Symbol,
|
||||
include_docstring: bool = True,
|
||||
include_decorators: bool = True,
|
||||
) -> str:
|
||||
"""Extract a symbol (function/class) with its docstring and decorators.
|
||||
|
||||
Args:
|
||||
file_path: Path to source file.
|
||||
symbol: Symbol to extract.
|
||||
include_docstring: Include docstring if present.
|
||||
include_decorators: Include decorators/annotations above symbol.
|
||||
|
||||
Returns:
|
||||
Complete symbol code with context.
|
||||
"""
|
||||
try:
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.splitlines()
|
||||
|
||||
start_line, end_line = symbol.range
|
||||
start_idx = start_line - 1
|
||||
end_idx = end_line
|
||||
|
||||
# Look for decorators above the symbol
|
||||
if include_decorators and start_idx > 0:
|
||||
decorator_start = start_idx
|
||||
# Search backwards for decorators
|
||||
i = start_idx - 1
|
||||
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
|
||||
line = lines[i].strip()
|
||||
if line.startswith("@"):
|
||||
decorator_start = i
|
||||
i -= 1
|
||||
elif line == "" or line.startswith("#"):
|
||||
# Skip empty lines and comments, continue looking
|
||||
i -= 1
|
||||
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
|
||||
# JavaScript/Java style comments
|
||||
decorator_start = i
|
||||
i -= 1
|
||||
else:
|
||||
# Found non-decorator, non-comment line, stop
|
||||
break
|
||||
start_idx = decorator_start
|
||||
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def format_search_result_code(
|
||||
result: SearchResult,
|
||||
max_lines: Optional[int] = None,
|
||||
show_line_numbers: bool = True,
|
||||
highlight_match: bool = False,
|
||||
) -> str:
|
||||
"""Format search result code for display.
|
||||
|
||||
Args:
|
||||
result: SearchResult to format.
|
||||
max_lines: Maximum lines to show (None for all).
|
||||
show_line_numbers: Include line numbers in output.
|
||||
highlight_match: Add markers for matched region.
|
||||
|
||||
Returns:
|
||||
Formatted code string.
|
||||
"""
|
||||
content = result.content or result.excerpt or ""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
lines = content.splitlines()
|
||||
|
||||
# Truncate if needed
|
||||
truncated = False
|
||||
if max_lines and len(lines) > max_lines:
|
||||
lines = lines[:max_lines]
|
||||
truncated = True
|
||||
|
||||
# Format with line numbers
|
||||
if show_line_numbers:
|
||||
start = result.start_line or 1
|
||||
formatted_lines = []
|
||||
for i, line in enumerate(lines):
|
||||
line_num = start + i
|
||||
formatted_lines.append(f"{line_num:4d} | {line}")
|
||||
output = "\n".join(formatted_lines)
|
||||
else:
|
||||
output = "\n".join(lines)
|
||||
|
||||
if truncated:
|
||||
output += "\n... (truncated)"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_code_block_summary(result: SearchResult) -> str:
|
||||
"""Get a concise summary of a code block.
|
||||
|
||||
Args:
|
||||
result: SearchResult to summarize.
|
||||
|
||||
Returns:
|
||||
Summary string like "function hello_world (lines 10-25)"
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if result.symbol_kind:
|
||||
parts.append(result.symbol_kind)
|
||||
|
||||
if result.symbol_name:
|
||||
parts.append(f"`{result.symbol_name}`")
|
||||
elif result.excerpt:
|
||||
# Extract first meaningful identifier
|
||||
first_line = result.excerpt.split("\n")[0][:50]
|
||||
parts.append(f'"{first_line}..."')
|
||||
|
||||
if result.start_line and result.end_line:
|
||||
if result.start_line == result.end_line:
|
||||
parts.append(f"(line {result.start_line})")
|
||||
else:
|
||||
parts.append(f"(lines {result.start_line}-{result.end_line})")
|
||||
|
||||
if result.path:
|
||||
file_name = Path(result.path).name
|
||||
parts.append(f"in {file_name}")
|
||||
|
||||
return " ".join(parts) if parts else "unknown code block"
|
||||
|
||||
|
||||
class CodeBlockResult:
|
||||
"""Enhanced search result with complete code block."""
|
||||
|
||||
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
|
||||
self.result = result
|
||||
self.source_path = source_path or result.path
|
||||
self._full_code: Optional[str] = None
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
return self.result.score
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.result.path
|
||||
|
||||
@property
|
||||
def file_name(self) -> str:
|
||||
return Path(self.result.path).name
|
||||
|
||||
@property
|
||||
def symbol_name(self) -> Optional[str]:
|
||||
return self.result.symbol_name
|
||||
|
||||
@property
|
||||
def symbol_kind(self) -> Optional[str]:
|
||||
return self.result.symbol_kind
|
||||
|
||||
@property
|
||||
def line_range(self) -> Tuple[int, int]:
|
||||
return (
|
||||
self.result.start_line or 1,
|
||||
self.result.end_line or 1
|
||||
)
|
||||
|
||||
@property
|
||||
def full_code(self) -> str:
|
||||
"""Get full code block content."""
|
||||
if self._full_code is None:
|
||||
self._full_code = extract_complete_code_block(self.result, self.source_path)
|
||||
return self._full_code
|
||||
|
||||
@property
|
||||
def excerpt(self) -> str:
|
||||
"""Get short excerpt."""
|
||||
return self.result.excerpt or ""
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
"""Get code block summary."""
|
||||
return get_code_block_summary(self.result)
|
||||
|
||||
def format(
|
||||
self,
|
||||
max_lines: Optional[int] = None,
|
||||
show_line_numbers: bool = True,
|
||||
) -> str:
|
||||
"""Format code for display."""
|
||||
# Use full code if available
|
||||
display_result = SearchResult(
|
||||
path=self.result.path,
|
||||
score=self.result.score,
|
||||
content=self.full_code,
|
||||
start_line=self.result.start_line,
|
||||
end_line=self.result.end_line,
|
||||
)
|
||||
return format_search_result_code(
|
||||
display_result,
|
||||
max_lines=max_lines,
|
||||
show_line_numbers=show_line_numbers
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
|
||||
|
||||
|
||||
def enhance_search_results(
|
||||
results: List[SearchResult],
|
||||
) -> List[CodeBlockResult]:
|
||||
"""Enhance search results with complete code block access.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult from semantic search.
|
||||
|
||||
Returns:
|
||||
List of CodeBlockResult with full code access.
|
||||
"""
|
||||
return [CodeBlockResult(r) for r in results]
|
||||
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Embedder for semantic code search using fastembed.
|
||||
|
||||
Supports GPU acceleration via ONNX execution providers (CUDA, TensorRT, DirectML, ROCm, CoreML).
|
||||
GPU acceleration is automatic when available, with transparent CPU fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import SEMANTIC_AVAILABLE
|
||||
from .base import BaseEmbedder
|
||||
from .gpu_support import get_optimal_providers, is_gpu_available, get_gpu_summary, get_selected_device_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global embedder cache for singleton pattern
|
||||
_embedder_cache: Dict[str, "Embedder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
|
||||
"""Get or create a cached Embedder instance (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
Embedder instances across multiple searches, avoiding repeated model
|
||||
loading overhead (~0.8s per load).
|
||||
|
||||
Args:
|
||||
profile: Model profile ("fast", "code", "multilingual", "balanced")
|
||||
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||
|
||||
Returns:
|
||||
Cached Embedder instance for the given profile
|
||||
"""
|
||||
global _embedder_cache
|
||||
|
||||
# Cache key includes GPU preference to support mixed configurations
|
||||
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
|
||||
|
||||
# All cache access is protected by _cache_lock to avoid races with
|
||||
# clear_embedder_cache() during concurrent access.
|
||||
with _cache_lock:
|
||||
embedder = _embedder_cache.get(cache_key)
|
||||
if embedder is not None:
|
||||
return embedder
|
||||
|
||||
# Create new embedder and cache it
|
||||
embedder = Embedder(profile=profile, use_gpu=use_gpu)
|
||||
# Pre-load model to ensure it's ready
|
||||
embedder._load_model()
|
||||
_embedder_cache[cache_key] = embedder
|
||||
|
||||
# Log GPU status on first embedder creation
|
||||
if use_gpu and is_gpu_available():
|
||||
logger.info(f"Embedder initialized with GPU: {get_gpu_summary()}")
|
||||
elif use_gpu:
|
||||
logger.debug("GPU not available, using CPU for embeddings")
|
||||
|
||||
return embedder
|
||||
|
||||
|
||||
def clear_embedder_cache() -> None:
|
||||
"""Clear the embedder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when embedders are no longer needed.
|
||||
"""
|
||||
global _embedder_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for embedder in _embedder_cache.values():
|
||||
if embedder._model is not None:
|
||||
del embedder._model
|
||||
embedder._model = None
|
||||
_embedder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class Embedder(BaseEmbedder):
|
||||
"""Generate embeddings for code chunks using fastembed (ONNX-based).
|
||||
|
||||
Supported Model Profiles:
|
||||
- fast: BAAI/bge-small-en-v1.5 (384 dim) - Fast, lightweight, English-optimized
|
||||
- code: jinaai/jina-embeddings-v2-base-code (768 dim) - Code-optimized, best for programming languages
|
||||
- multilingual: intfloat/multilingual-e5-large (1024 dim) - Multilingual + code support
|
||||
- balanced: mixedbread-ai/mxbai-embed-large-v1 (1024 dim) - High accuracy, general purpose
|
||||
"""
|
||||
|
||||
# Model profiles for different use cases
|
||||
MODELS = {
|
||||
"fast": "BAAI/bge-small-en-v1.5", # 384 dim - Fast, lightweight
|
||||
"code": "jinaai/jina-embeddings-v2-base-code", # 768 dim - Code-optimized
|
||||
"multilingual": "intfloat/multilingual-e5-large", # 1024 dim - Multilingual
|
||||
"balanced": "mixedbread-ai/mxbai-embed-large-v1", # 1024 dim - High accuracy
|
||||
}
|
||||
|
||||
# Dimension mapping for each model
|
||||
MODEL_DIMS = {
|
||||
"BAAI/bge-small-en-v1.5": 384,
|
||||
"jinaai/jina-embeddings-v2-base-code": 768,
|
||||
"intfloat/multilingual-e5-large": 1024,
|
||||
"mixedbread-ai/mxbai-embed-large-v1": 1024,
|
||||
}
|
||||
|
||||
# Default model (fast profile)
|
||||
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
|
||||
DEFAULT_PROFILE = "fast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
profile: str | None = None,
|
||||
use_gpu: bool = True,
|
||||
providers: List[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize embedder with model or profile.
|
||||
|
||||
Args:
|
||||
model_name: Explicit model name (e.g., "jinaai/jina-embeddings-v2-base-code")
|
||||
profile: Model profile shortcut ("fast", "code", "multilingual", "balanced")
|
||||
If both provided, model_name takes precedence.
|
||||
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||
providers: Explicit ONNX providers list (overrides use_gpu if provided)
|
||||
"""
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Semantic search dependencies not available. "
|
||||
"Install with: pip install codexlens[semantic]"
|
||||
)
|
||||
|
||||
# Resolve model name from profile or use explicit name
|
||||
if model_name:
|
||||
self._model_name = model_name
|
||||
elif profile and profile in self.MODELS:
|
||||
self._model_name = self.MODELS[profile]
|
||||
else:
|
||||
self._model_name = self.DEFAULT_MODEL
|
||||
|
||||
# Configure ONNX execution providers with device_id options for GPU selection
|
||||
# Using with_device_options=True ensures DirectML/CUDA device_id is passed correctly
|
||||
if providers is not None:
|
||||
self._providers = providers
|
||||
else:
|
||||
self._providers = get_optimal_providers(use_gpu=use_gpu, with_device_options=True)
|
||||
|
||||
self._use_gpu = use_gpu
|
||||
self._model = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Get model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Get embedding dimension for current model."""
|
||||
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Get maximum token limit for current model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens based on model profile.
|
||||
- fast: 512 (lightweight, optimized for speed)
|
||||
- code: 8192 (code-optimized, larger context)
|
||||
- multilingual: 512 (standard multilingual model)
|
||||
- balanced: 512 (general purpose)
|
||||
"""
|
||||
# Determine profile from model name
|
||||
profile = None
|
||||
for prof, model in self.MODELS.items():
|
||||
if model == self._model_name:
|
||||
profile = prof
|
||||
break
|
||||
|
||||
# Return token limit based on profile
|
||||
if profile == "code":
|
||||
return 8192
|
||||
elif profile in ("fast", "multilingual", "balanced"):
|
||||
return 512
|
||||
else:
|
||||
# Default for unknown models
|
||||
return 512
|
||||
|
||||
@property
|
||||
def providers(self) -> List[str]:
|
||||
"""Get configured ONNX execution providers."""
|
||||
return self._providers
|
||||
|
||||
@property
|
||||
def is_gpu_enabled(self) -> bool:
|
||||
"""Check if GPU acceleration is enabled for this embedder."""
|
||||
gpu_providers = {"CUDAExecutionProvider", "TensorrtExecutionProvider",
|
||||
"DmlExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"}
|
||||
# Handle both string providers and tuple providers (name, options)
|
||||
for p in self._providers:
|
||||
provider_name = p[0] if isinstance(p, tuple) else p
|
||||
if provider_name in gpu_providers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load the embedding model with configured providers."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
# providers already include device_id options via get_optimal_providers(with_device_options=True)
|
||||
# DO NOT pass device_ids separately - fastembed ignores it when providers is specified
|
||||
# See: fastembed/text/onnx_embedding.py - device_ids is only used with cuda=True
|
||||
try:
|
||||
self._model = TextEmbedding(
|
||||
model_name=self.model_name,
|
||||
providers=self._providers,
|
||||
)
|
||||
logger.debug(f"Model loaded with providers: {self._providers}")
|
||||
except TypeError:
|
||||
# Fallback for older fastembed versions without providers parameter
|
||||
logger.warning(
|
||||
"fastembed version doesn't support 'providers' parameter. "
|
||||
"Upgrade fastembed for GPU acceleration: pip install --upgrade fastembed"
|
||||
)
|
||||
self._model = TextEmbedding(model_name=self.model_name)
|
||||
|
||||
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for one or more texts.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each is a list of floats).
|
||||
|
||||
Note:
|
||||
This method converts numpy arrays to Python lists for backward compatibility.
|
||||
For memory-efficient processing, use embed_to_numpy() instead.
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
embeddings = list(self._model.embed(texts))
|
||||
return [emb.tolist() for emb in embeddings]
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], batch_size: Optional[int] = None) -> np.ndarray:
|
||||
"""Generate embeddings for one or more texts (returns numpy arrays).
|
||||
|
||||
This method is more memory-efficient than embed() as it avoids converting
|
||||
numpy arrays to Python lists, which can significantly reduce memory usage
|
||||
during batch processing.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
batch_size: Optional batch size for fastembed processing.
|
||||
Larger values improve GPU utilization but use more memory.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Pass batch_size to fastembed for optimal GPU utilization
|
||||
# Default batch_size in fastembed is 256, but larger values can improve throughput
|
||||
if batch_size is not None:
|
||||
embeddings = list(self._model.embed(texts, batch_size=batch_size))
|
||||
else:
|
||||
embeddings = list(self._model.embed(texts))
|
||||
return np.array(embeddings)
|
||||
|
||||
def embed_single(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
return self.embed(text)[0]
|
||||
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Factory for creating embedders.
|
||||
|
||||
Provides a unified interface for instantiating different embedder backends.
|
||||
Includes caching to avoid repeated model loading overhead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
# Module-level cache for embedder instances
|
||||
# Key: (backend, profile, model, use_gpu) -> embedder instance
|
||||
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_embedder(
|
||||
backend: str = "fastembed",
|
||||
profile: str = "code",
|
||||
model: str = "default",
|
||||
use_gpu: bool = True,
|
||||
endpoints: Optional[List[Dict[str, Any]]] = None,
|
||||
strategy: str = "latency_aware",
|
||||
cooldown: float = 60.0,
|
||||
**kwargs: Any,
|
||||
) -> BaseEmbedder:
|
||||
"""Factory function to create embedder based on backend.
|
||||
|
||||
Args:
|
||||
backend: Embedder backend to use. Options:
|
||||
- "fastembed": Use fastembed (ONNX-based) embedder (default)
|
||||
- "litellm": Use ccw-litellm embedder
|
||||
profile: Model profile for fastembed backend ("fast", "code", "multilingual", "balanced")
|
||||
Used only when backend="fastembed". Default: "code"
|
||||
model: Model identifier for litellm backend.
|
||||
Used only when backend="litellm". Default: "default"
|
||||
use_gpu: Whether to use GPU acceleration when available (default: True).
|
||||
Used only when backend="fastembed".
|
||||
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
|
||||
Each endpoint is a dict with keys: model, api_key, api_base, weight.
|
||||
Used only when backend="litellm" and multiple endpoints provided.
|
||||
strategy: Selection strategy for multi-endpoint mode:
|
||||
"round_robin", "latency_aware", "weighted_random".
|
||||
Default: "latency_aware"
|
||||
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
|
||||
**kwargs: Additional backend-specific arguments
|
||||
|
||||
Returns:
|
||||
BaseEmbedder: Configured embedder instance
|
||||
|
||||
Raises:
|
||||
ValueError: If backend is not recognized
|
||||
ImportError: If required backend dependencies are not installed
|
||||
|
||||
Examples:
|
||||
Create fastembed embedder with code profile:
|
||||
>>> embedder = get_embedder(backend="fastembed", profile="code")
|
||||
|
||||
Create fastembed embedder with fast profile and CPU only:
|
||||
>>> embedder = get_embedder(backend="fastembed", profile="fast", use_gpu=False)
|
||||
|
||||
Create litellm embedder:
|
||||
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
|
||||
|
||||
Create rotational embedder with multiple endpoints:
|
||||
>>> endpoints = [
|
||||
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||
... ]
|
||||
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
||||
"""
|
||||
# Build cache key from immutable configuration
|
||||
if backend == "fastembed":
|
||||
cache_key = ("fastembed", profile, None, use_gpu)
|
||||
elif backend == "litellm":
|
||||
# For litellm, use model as part of cache key
|
||||
# Multi-endpoint mode is not cached as it's more complex
|
||||
if endpoints and len(endpoints) > 1:
|
||||
cache_key = None # Skip cache for multi-endpoint
|
||||
else:
|
||||
effective_model = endpoints[0]["model"] if endpoints else model
|
||||
cache_key = ("litellm", None, effective_model, None)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
# Check cache first (thread-safe)
|
||||
if cache_key is not None:
|
||||
with _cache_lock:
|
||||
if cache_key in _embedder_cache:
|
||||
_logger.debug("Returning cached embedder for %s", cache_key)
|
||||
return _embedder_cache[cache_key]
|
||||
|
||||
# Create new embedder instance
|
||||
embedder: Optional[BaseEmbedder] = None
|
||||
|
||||
if backend == "fastembed":
|
||||
from .embedder import Embedder
|
||||
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||
elif backend == "litellm":
|
||||
# Check if multi-endpoint mode is requested
|
||||
if endpoints and len(endpoints) > 1:
|
||||
from .rotational_embedder import create_rotational_embedder
|
||||
# Multi-endpoint is not cached
|
||||
return create_rotational_embedder(
|
||||
endpoints_config=endpoints,
|
||||
strategy=strategy,
|
||||
default_cooldown=cooldown,
|
||||
)
|
||||
elif endpoints and len(endpoints) == 1:
|
||||
# Single endpoint in list - use it directly
|
||||
ep = endpoints[0]
|
||||
ep_kwargs = {**kwargs}
|
||||
if "api_key" in ep:
|
||||
ep_kwargs["api_key"] = ep["api_key"]
|
||||
if "api_base" in ep:
|
||||
ep_kwargs["api_base"] = ep["api_base"]
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||
else:
|
||||
# No endpoints list - use model parameter
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Supported backends: 'fastembed', 'litellm'"
|
||||
)
|
||||
|
||||
# Cache the embedder for future use (thread-safe)
|
||||
if cache_key is not None and embedder is not None:
|
||||
with _cache_lock:
|
||||
# Double-check to avoid race condition
|
||||
if cache_key not in _embedder_cache:
|
||||
_embedder_cache[cache_key] = embedder
|
||||
_logger.debug("Cached new embedder for %s", cache_key)
|
||||
else:
|
||||
# Another thread created it already, use that one
|
||||
embedder = _embedder_cache[cache_key]
|
||||
|
||||
return embedder # type: ignore
|
||||
|
||||
|
||||
def clear_embedder_cache() -> int:
|
||||
"""Clear the embedder cache.
|
||||
|
||||
Returns:
|
||||
Number of embedders cleared from cache
|
||||
"""
|
||||
with _cache_lock:
|
||||
count = len(_embedder_cache)
|
||||
_embedder_cache.clear()
|
||||
_logger.debug("Cleared %d embedders from cache", count)
|
||||
return count
|
||||
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""GPU acceleration support for semantic embeddings.
|
||||
|
||||
This module provides GPU detection, initialization, and fallback handling
|
||||
for ONNX-based embedding generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUDevice:
|
||||
"""Individual GPU device info."""
|
||||
device_id: int
|
||||
name: str
|
||||
is_discrete: bool # True for discrete GPU (NVIDIA, AMD), False for integrated (Intel UHD)
|
||||
vendor: str # "nvidia", "amd", "intel", "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUInfo:
|
||||
"""GPU availability and configuration info."""
|
||||
|
||||
gpu_available: bool = False
|
||||
cuda_available: bool = False
|
||||
gpu_count: int = 0
|
||||
gpu_name: Optional[str] = None
|
||||
onnx_providers: List[str] = None
|
||||
devices: List[GPUDevice] = None # List of detected GPU devices
|
||||
preferred_device_id: Optional[int] = None # Preferred GPU for embedding
|
||||
|
||||
def __post_init__(self):
|
||||
if self.onnx_providers is None:
|
||||
self.onnx_providers = ["CPUExecutionProvider"]
|
||||
if self.devices is None:
|
||||
self.devices = []
|
||||
|
||||
|
||||
_gpu_info_cache: Optional[GPUInfo] = None
|
||||
|
||||
|
||||
def _enumerate_gpus() -> List[GPUDevice]:
|
||||
"""Enumerate available GPU devices using WMI on Windows.
|
||||
|
||||
Returns:
|
||||
List of GPUDevice with device info, ordered by device_id.
|
||||
"""
|
||||
devices = []
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
if sys.platform == "win32":
|
||||
# Use PowerShell to query GPU information via WMI
|
||||
cmd = [
|
||||
"powershell", "-NoProfile", "-Command",
|
||||
"Get-WmiObject Win32_VideoController | Select-Object DeviceID, Name, AdapterCompatibility | ConvertTo-Json"
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
import json
|
||||
gpu_data = json.loads(result.stdout)
|
||||
|
||||
# Handle single GPU case (returns dict instead of list)
|
||||
if isinstance(gpu_data, dict):
|
||||
gpu_data = [gpu_data]
|
||||
|
||||
for idx, gpu in enumerate(gpu_data):
|
||||
name = gpu.get("Name", "Unknown GPU")
|
||||
compat = gpu.get("AdapterCompatibility", "").lower()
|
||||
|
||||
# Determine vendor
|
||||
name_lower = name.lower()
|
||||
if "nvidia" in name_lower or "nvidia" in compat:
|
||||
vendor = "nvidia"
|
||||
is_discrete = True
|
||||
elif "amd" in name_lower or "radeon" in name_lower or "amd" in compat:
|
||||
vendor = "amd"
|
||||
is_discrete = True
|
||||
elif "intel" in name_lower or "intel" in compat:
|
||||
vendor = "intel"
|
||||
# Intel UHD/Iris are integrated, Intel Arc is discrete
|
||||
is_discrete = "arc" in name_lower
|
||||
else:
|
||||
vendor = "unknown"
|
||||
is_discrete = False
|
||||
|
||||
devices.append(GPUDevice(
|
||||
device_id=idx,
|
||||
name=name,
|
||||
is_discrete=is_discrete,
|
||||
vendor=vendor
|
||||
))
|
||||
logger.debug(f"Detected GPU {idx}: {name} (vendor={vendor}, discrete={is_discrete})")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"GPU enumeration failed: {e}")
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def _get_preferred_device_id(devices: List[GPUDevice]) -> Optional[int]:
|
||||
"""Determine the preferred GPU device_id for embedding.
|
||||
|
||||
Preference order:
|
||||
1. NVIDIA discrete GPU (best DirectML/CUDA support)
|
||||
2. AMD discrete GPU
|
||||
3. Intel Arc (discrete)
|
||||
4. Intel integrated (fallback)
|
||||
|
||||
Returns:
|
||||
device_id of preferred GPU, or None to use default.
|
||||
"""
|
||||
if not devices:
|
||||
return None
|
||||
|
||||
# Priority: NVIDIA > AMD > Intel Arc > Intel integrated
|
||||
priority_order = [
|
||||
("nvidia", True), # NVIDIA discrete
|
||||
("amd", True), # AMD discrete
|
||||
("intel", True), # Intel Arc (discrete)
|
||||
("intel", False), # Intel integrated (fallback)
|
||||
]
|
||||
|
||||
for target_vendor, target_discrete in priority_order:
|
||||
for device in devices:
|
||||
if device.vendor == target_vendor and device.is_discrete == target_discrete:
|
||||
logger.info(f"Preferred GPU: {device.name} (device_id={device.device_id})")
|
||||
return device.device_id
|
||||
|
||||
# If no match, use first device
|
||||
if devices:
|
||||
return devices[0].device_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def detect_gpu(force_refresh: bool = False) -> GPUInfo:
|
||||
"""Detect available GPU resources for embedding acceleration.
|
||||
|
||||
Args:
|
||||
force_refresh: If True, re-detect GPU even if cached.
|
||||
|
||||
Returns:
|
||||
GPUInfo with detection results.
|
||||
"""
|
||||
global _gpu_info_cache
|
||||
|
||||
if _gpu_info_cache is not None and not force_refresh:
|
||||
return _gpu_info_cache
|
||||
|
||||
info = GPUInfo()
|
||||
|
||||
# Enumerate GPU devices first
|
||||
info.devices = _enumerate_gpus()
|
||||
info.gpu_count = len(info.devices)
|
||||
if info.devices:
|
||||
# Set preferred device (discrete GPU preferred over integrated)
|
||||
info.preferred_device_id = _get_preferred_device_id(info.devices)
|
||||
# Set gpu_name to preferred device name
|
||||
for dev in info.devices:
|
||||
if dev.device_id == info.preferred_device_id:
|
||||
info.gpu_name = dev.name
|
||||
break
|
||||
|
||||
# Check PyTorch CUDA availability (most reliable detection)
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
info.cuda_available = True
|
||||
info.gpu_available = True
|
||||
info.gpu_count = torch.cuda.device_count()
|
||||
if info.gpu_count > 0:
|
||||
info.gpu_name = torch.cuda.get_device_name(0)
|
||||
logger.debug(f"PyTorch CUDA detected: {info.gpu_count} GPU(s)")
|
||||
except ImportError:
|
||||
logger.debug("PyTorch not available for GPU detection")
|
||||
|
||||
# Check ONNX Runtime providers with validation
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
|
||||
# Build provider list with priority order
|
||||
providers = []
|
||||
|
||||
# Test each provider to ensure it actually works
|
||||
def test_provider(provider_name: str) -> bool:
|
||||
"""Test if a provider actually works by creating a dummy session."""
|
||||
try:
|
||||
# Create a minimal ONNX model to test provider
|
||||
import numpy as np
|
||||
# Simple test: just check if provider can be instantiated
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.log_severity_level = 4 # Suppress warnings
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# CUDA provider (NVIDIA GPU) - check if CUDA runtime is available
|
||||
if "CUDAExecutionProvider" in available_providers:
|
||||
# Verify CUDA is actually usable by checking for cuBLAS
|
||||
cuda_works = False
|
||||
try:
|
||||
import ctypes
|
||||
# Try to load cuBLAS to verify CUDA installation
|
||||
try:
|
||||
ctypes.CDLL("cublas64_12.dll")
|
||||
cuda_works = True
|
||||
except OSError:
|
||||
try:
|
||||
ctypes.CDLL("cublas64_11.dll")
|
||||
cuda_works = True
|
||||
except OSError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if cuda_works:
|
||||
providers.append("CUDAExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX CUDAExecutionProvider available and working")
|
||||
else:
|
||||
logger.debug("ONNX CUDAExecutionProvider listed but CUDA runtime not found")
|
||||
|
||||
# TensorRT provider (optimized NVIDIA inference)
|
||||
if "TensorrtExecutionProvider" in available_providers:
|
||||
# TensorRT requires additional libraries, skip for now
|
||||
logger.debug("ONNX TensorrtExecutionProvider available (requires TensorRT SDK)")
|
||||
|
||||
# DirectML provider (Windows GPU - AMD/Intel/NVIDIA)
|
||||
if "DmlExecutionProvider" in available_providers:
|
||||
providers.append("DmlExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX DmlExecutionProvider available (DirectML)")
|
||||
|
||||
# ROCm provider (AMD GPU on Linux)
|
||||
if "ROCMExecutionProvider" in available_providers:
|
||||
providers.append("ROCMExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX ROCMExecutionProvider available (AMD)")
|
||||
|
||||
# CoreML provider (Apple Silicon)
|
||||
if "CoreMLExecutionProvider" in available_providers:
|
||||
providers.append("CoreMLExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX CoreMLExecutionProvider available (Apple)")
|
||||
|
||||
# Always include CPU as fallback
|
||||
providers.append("CPUExecutionProvider")
|
||||
|
||||
info.onnx_providers = providers
|
||||
|
||||
except ImportError:
|
||||
logger.debug("ONNX Runtime not available")
|
||||
info.onnx_providers = ["CPUExecutionProvider"]
|
||||
|
||||
_gpu_info_cache = info
|
||||
return info
|
||||
|
||||
|
||||
def get_optimal_providers(use_gpu: bool = True, with_device_options: bool = False) -> list:
|
||||
"""Get optimal ONNX execution providers based on availability.
|
||||
|
||||
Args:
|
||||
use_gpu: If True, include GPU providers when available.
|
||||
If False, force CPU-only execution.
|
||||
with_device_options: If True, return providers as tuples with device_id options
|
||||
for proper GPU device selection (required for DirectML).
|
||||
|
||||
Returns:
|
||||
List of provider names or tuples (provider_name, options_dict) in priority order.
|
||||
"""
|
||||
if not use_gpu:
|
||||
return ["CPUExecutionProvider"]
|
||||
|
||||
gpu_info = detect_gpu()
|
||||
|
||||
# Check if GPU was requested but not available - log warning
|
||||
if not gpu_info.gpu_available:
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
except ImportError:
|
||||
available_providers = []
|
||||
logger.warning(
|
||||
"GPU acceleration was requested, but no supported GPU provider (CUDA, DirectML) "
|
||||
f"was found. Available providers: {available_providers}. Falling back to CPU."
|
||||
)
|
||||
else:
|
||||
# Log which GPU provider is being used
|
||||
gpu_providers = [p for p in gpu_info.onnx_providers if p != "CPUExecutionProvider"]
|
||||
if gpu_providers:
|
||||
logger.info(f"Using {gpu_providers[0]} for ONNX GPU acceleration")
|
||||
|
||||
if not with_device_options:
|
||||
return gpu_info.onnx_providers
|
||||
|
||||
# Build providers with device_id options for GPU providers
|
||||
device_id = get_selected_device_id()
|
||||
providers = []
|
||||
|
||||
for provider in gpu_info.onnx_providers:
|
||||
if provider == "DmlExecutionProvider" and device_id is not None:
|
||||
# DirectML requires device_id in provider_options tuple
|
||||
providers.append(("DmlExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"DmlExecutionProvider configured with device_id={device_id}")
|
||||
elif provider == "CUDAExecutionProvider" and device_id is not None:
|
||||
# CUDA also supports device_id in provider_options
|
||||
providers.append(("CUDAExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"CUDAExecutionProvider configured with device_id={device_id}")
|
||||
elif provider == "ROCMExecutionProvider" and device_id is not None:
|
||||
# ROCm supports device_id
|
||||
providers.append(("ROCMExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"ROCMExecutionProvider configured with device_id={device_id}")
|
||||
else:
|
||||
# CPU and other providers don't need device_id
|
||||
providers.append(provider)
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
def is_gpu_available() -> bool:
|
||||
"""Check if any GPU acceleration is available."""
|
||||
return detect_gpu().gpu_available
|
||||
|
||||
|
||||
def get_gpu_summary() -> str:
|
||||
"""Get human-readable GPU status summary."""
|
||||
info = detect_gpu()
|
||||
|
||||
if not info.gpu_available:
|
||||
return "GPU: Not available (using CPU)"
|
||||
|
||||
parts = []
|
||||
if info.gpu_name:
|
||||
parts.append(f"GPU: {info.gpu_name}")
|
||||
if info.gpu_count > 1:
|
||||
parts.append(f"({info.gpu_count} devices)")
|
||||
|
||||
# Show active providers (excluding CPU fallback)
|
||||
gpu_providers = [p for p in info.onnx_providers if p != "CPUExecutionProvider"]
|
||||
if gpu_providers:
|
||||
parts.append(f"Providers: {', '.join(gpu_providers)}")
|
||||
|
||||
return " | ".join(parts) if parts else "GPU: Available"
|
||||
|
||||
|
||||
def clear_gpu_cache() -> None:
|
||||
"""Clear cached GPU detection info."""
|
||||
global _gpu_info_cache
|
||||
_gpu_info_cache = None
|
||||
|
||||
|
||||
# User-selected device ID (overrides auto-detection)
|
||||
_selected_device_id: Optional[int] = None
|
||||
|
||||
|
||||
def get_gpu_devices() -> List[dict]:
|
||||
"""Get list of available GPU devices for frontend selection.
|
||||
|
||||
Returns:
|
||||
List of dicts with device info for each GPU.
|
||||
"""
|
||||
info = detect_gpu()
|
||||
devices = []
|
||||
|
||||
for dev in info.devices:
|
||||
devices.append({
|
||||
"device_id": dev.device_id,
|
||||
"name": dev.name,
|
||||
"vendor": dev.vendor,
|
||||
"is_discrete": dev.is_discrete,
|
||||
"is_preferred": dev.device_id == info.preferred_device_id,
|
||||
"is_selected": dev.device_id == get_selected_device_id(),
|
||||
})
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def get_selected_device_id() -> Optional[int]:
|
||||
"""Get the user-selected GPU device_id.
|
||||
|
||||
Returns:
|
||||
User-selected device_id, or auto-detected preferred device_id if not set.
|
||||
"""
|
||||
global _selected_device_id
|
||||
|
||||
if _selected_device_id is not None:
|
||||
return _selected_device_id
|
||||
|
||||
# Fall back to auto-detected preferred device
|
||||
info = detect_gpu()
|
||||
return info.preferred_device_id
|
||||
|
||||
|
||||
def set_selected_device_id(device_id: Optional[int]) -> bool:
|
||||
"""Set the GPU device_id to use for embeddings.
|
||||
|
||||
Args:
|
||||
device_id: GPU device_id to use, or None to use auto-detection.
|
||||
|
||||
Returns:
|
||||
True if device_id is valid, False otherwise.
|
||||
"""
|
||||
global _selected_device_id
|
||||
|
||||
if device_id is None:
|
||||
_selected_device_id = None
|
||||
logger.info("GPU selection reset to auto-detection")
|
||||
return True
|
||||
|
||||
# Validate device_id exists
|
||||
info = detect_gpu()
|
||||
valid_ids = [dev.device_id for dev in info.devices]
|
||||
|
||||
if device_id in valid_ids:
|
||||
_selected_device_id = device_id
|
||||
device_name = next((dev.name for dev in info.devices if dev.device_id == device_id), "Unknown")
|
||||
logger.info(f"GPU selection set to device {device_id}: {device_name}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Invalid device_id {device_id}. Valid IDs: {valid_ids}")
|
||||
return False
|
||||
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""LiteLLM embedder wrapper for CodexLens.
|
||||
|
||||
Provides integration with ccw-litellm's LiteLLMEmbedder for embedding generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
|
||||
class LiteLLMEmbedderWrapper(BaseEmbedder):
|
||||
"""Wrapper for ccw-litellm LiteLLMEmbedder.
|
||||
|
||||
This wrapper adapts the ccw-litellm LiteLLMEmbedder to the CodexLens
|
||||
BaseEmbedder interface, enabling seamless integration with CodexLens
|
||||
semantic search functionality.
|
||||
|
||||
Args:
|
||||
model: Model identifier for LiteLLM (default: "default")
|
||||
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm package is not installed
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "default", **kwargs) -> None:
|
||||
"""Initialize LiteLLM embedder wrapper.
|
||||
|
||||
Args:
|
||||
model: Model identifier for LiteLLM (default: "default")
|
||||
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm package is not installed
|
||||
"""
|
||||
try:
|
||||
from ccw_litellm import LiteLLMEmbedder
|
||||
self._embedder = LiteLLMEmbedder(model=model, **kwargs)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions from LiteLLMEmbedder.
|
||||
|
||||
Returns:
|
||||
int: Dimension of the embedding vectors.
|
||||
"""
|
||||
return self._embedder.dimensions
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name from LiteLLMEmbedder.
|
||||
|
||||
Returns:
|
||||
str: Name or identifier of the underlying model.
|
||||
"""
|
||||
return self._embedder.model_name
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for the embedding model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Reads from LiteLLM config's max_input_tokens property.
|
||||
"""
|
||||
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
|
||||
if hasattr(self._embedder, 'max_input_tokens'):
|
||||
return self._embedder.max_input_tokens
|
||||
|
||||
# Fallback: infer from model name
|
||||
model_name_lower = self.model_name.lower()
|
||||
|
||||
# Large models (8B or "large" in name)
|
||||
if '8b' in model_name_lower or 'large' in model_name_lower:
|
||||
return 32768
|
||||
|
||||
# OpenAI text-embedding-3-* models
|
||||
if 'text-embedding-3' in model_name_lower:
|
||||
return 8191
|
||||
|
||||
# Default fallback
|
||||
return 8192
|
||||
|
||||
def _sanitize_text(self, text: str) -> str:
|
||||
"""Sanitize text to work around ModelScope API routing bug.
|
||||
|
||||
ModelScope incorrectly routes text starting with lowercase 'import'
|
||||
to an Ollama endpoint, causing failures. This adds a leading space
|
||||
to work around the issue without affecting embedding quality.
|
||||
|
||||
Args:
|
||||
text: Text to sanitize.
|
||||
|
||||
Returns:
|
||||
Sanitized text safe for embedding API.
|
||||
"""
|
||||
if text.startswith('import'):
|
||||
return ' ' + text
|
||||
return text
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||
"""Embed texts to numpy array using LiteLLMEmbedder.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
**kwargs: Additional arguments (ignored for LiteLLM backend).
|
||||
Accepts batch_size for API compatibility with fastembed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Sanitize texts to avoid ModelScope routing bug
|
||||
texts = [self._sanitize_text(t) for t in texts]
|
||||
|
||||
# LiteLLM handles batching internally, ignore batch_size parameter
|
||||
return self._embedder.embed(texts)
|
||||
|
||||
def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
list[float]: Embedding vector as a list of floats.
|
||||
"""
|
||||
# Sanitize text before embedding
|
||||
sanitized = self._sanitize_text(text)
|
||||
embedding = self._embedder.embed([sanitized])
|
||||
return embedding[0].tolist()
|
||||
|
||||
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Reranker backends for second-stage search ranking.
|
||||
|
||||
This subpackage provides a unified interface and factory for different reranking
|
||||
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import BaseReranker
|
||||
from .factory import check_reranker_available, get_reranker
|
||||
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
|
||||
from .legacy import CrossEncoderReranker, check_cross_encoder_available
|
||||
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
|
||||
|
||||
__all__ = [
|
||||
"BaseReranker",
|
||||
"check_reranker_available",
|
||||
"get_reranker",
|
||||
"CrossEncoderReranker",
|
||||
"check_cross_encoder_available",
|
||||
"FastEmbedReranker",
|
||||
"check_fastembed_reranker_available",
|
||||
"ONNXReranker",
|
||||
"check_onnx_reranker_available",
|
||||
]
|
||||
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""API-based reranker using a remote HTTP provider.
|
||||
|
||||
Supported providers:
|
||||
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
|
||||
- Cohere: https://api.cohere.ai/v1/rerank
|
||||
- Jina: https://api.jina.ai/v1/rerank
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
|
||||
|
||||
|
||||
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback."""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Try loading from .env files
|
||||
try:
|
||||
from codexlens.env_config import get_env
|
||||
return get_env(key, workspace_root=workspace_root)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def check_httpx_available() -> tuple[bool, str | None]:
|
||||
try:
|
||||
import httpx # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"httpx not available: {exc}. Install with: pip install httpx"
|
||||
return True, None
|
||||
|
||||
|
||||
class APIReranker(BaseReranker):
|
||||
"""Reranker backed by a remote reranking HTTP API."""
|
||||
|
||||
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
|
||||
"siliconflow": {
|
||||
"api_base": "https://api.siliconflow.cn",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "BAAI/bge-reranker-v2-m3",
|
||||
},
|
||||
"cohere": {
|
||||
"api_base": "https://api.cohere.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "rerank-english-v3.0",
|
||||
},
|
||||
"jina": {
|
||||
"api_base": "https://api.jina.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "jina-reranker-v2-base-multilingual",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: str = "siliconflow",
|
||||
model_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
backoff_base_s: float = 0.5,
|
||||
backoff_max_s: float = 8.0,
|
||||
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||
workspace_root: Path | str | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> None:
|
||||
ok, err = check_httpx_available()
|
||||
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||
raise ImportError(err)
|
||||
|
||||
import httpx
|
||||
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
|
||||
self.provider = (provider or "").strip().lower()
|
||||
if self.provider not in self._PROVIDER_DEFAULTS:
|
||||
raise ValueError(
|
||||
f"Unknown reranker provider: {provider}. "
|
||||
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
|
||||
)
|
||||
|
||||
defaults = self._PROVIDER_DEFAULTS[self.provider]
|
||||
|
||||
# Load api_base from env with .env fallback
|
||||
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
|
||||
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
|
||||
self.endpoint = defaults["endpoint"]
|
||||
|
||||
# Load model from env with .env fallback
|
||||
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
|
||||
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
# Load API key from env with .env fallback
|
||||
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
|
||||
resolved_key = resolved_key.strip()
|
||||
if not resolved_key:
|
||||
raise ValueError(
|
||||
f"Missing API key for reranker provider '{self.provider}'. "
|
||||
f"Pass api_key=... or set ${env_api_key}."
|
||||
)
|
||||
self._api_key = resolved_key
|
||||
|
||||
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
|
||||
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
|
||||
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
|
||||
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.provider == "cohere":
|
||||
headers.setdefault("Cohere-Version", "2022-12-06")
|
||||
|
||||
self._client = httpx.Client(
|
||||
base_url=self.api_base,
|
||||
headers=headers,
|
||||
timeout=self.timeout_s,
|
||||
)
|
||||
|
||||
# Store max_input_tokens with model-aware defaults
|
||||
if max_input_tokens is not None:
|
||||
self._max_input_tokens = max_input_tokens
|
||||
else:
|
||||
# Infer from model name
|
||||
model_lower = self.model_name.lower()
|
||||
if '8b' in model_lower or 'large' in model_lower:
|
||||
self._max_input_tokens = 32768
|
||||
else:
|
||||
self._max_input_tokens = 8192
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking."""
|
||||
return self._max_input_tokens
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return
|
||||
|
||||
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
|
||||
if retry_after_s is not None and retry_after_s > 0:
|
||||
time.sleep(min(float(retry_after_s), self.backoff_max_s))
|
||||
return
|
||||
|
||||
exp = self.backoff_base_s * (2**attempt)
|
||||
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
|
||||
time.sleep(min(self.backoff_max_s, exp + jitter))
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
|
||||
value = (headers.get("Retry-After") or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_retry_status(status_code: int) -> bool:
|
||||
return status_code == 429 or 500 <= status_code <= 599
|
||||
|
||||
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
response = self._client.post(self.endpoint, json=dict(payload))
|
||||
except Exception as exc: # httpx is optional at import-time
|
||||
last_exc = exc
|
||||
if attempt < self.max_retries:
|
||||
self._sleep_backoff(attempt)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' after "
|
||||
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
status = int(getattr(response, "status_code", 0) or 0)
|
||||
if status >= 400:
|
||||
body_preview = ""
|
||||
try:
|
||||
body_preview = (response.text or "").strip()
|
||||
except Exception:
|
||||
body_preview = ""
|
||||
if len(body_preview) > 300:
|
||||
body_preview = body_preview[:300] + "…"
|
||||
|
||||
if self._should_retry_status(status) and attempt < self.max_retries:
|
||||
retry_after = self._parse_retry_after_seconds(response.headers)
|
||||
logger.warning(
|
||||
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
|
||||
self.api_base,
|
||||
self.endpoint,
|
||||
status,
|
||||
attempt + 1,
|
||||
self.max_retries + 1,
|
||||
)
|
||||
self._sleep_backoff(attempt, retry_after_s=retry_after)
|
||||
continue
|
||||
|
||||
if status in {401, 403}:
|
||||
raise RuntimeError(
|
||||
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
|
||||
"Check your API key."
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
|
||||
f"Response: {body_preview or '<empty>'}"
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' is not valid JSON: "
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' must be a JSON object; "
|
||||
f"got {type(data).__name__}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
|
||||
if not isinstance(results, list):
|
||||
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(expected)]
|
||||
filled = 0
|
||||
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score"))
|
||||
if idx is None or score is None:
|
||||
continue
|
||||
try:
|
||||
idx_int = int(idx)
|
||||
score_f = float(score)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if 0 <= idx_int < expected:
|
||||
scores[idx_int] = score_f
|
||||
filled += 1
|
||||
|
||||
if filled != expected:
|
||||
raise RuntimeError(
|
||||
f"Rerank response contained {filled}/{expected} scored documents; "
|
||||
"ensure top_n matches the number of documents."
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": list(documents),
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
return payload
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count using fast heuristic.
|
||||
|
||||
Uses len(text) // 4 as approximation (~4 chars per token for English).
|
||||
Not perfectly accurate for all models/languages but sufficient for
|
||||
batch sizing decisions where exact counts aren't critical.
|
||||
"""
|
||||
return len(text) // 4
|
||||
|
||||
def _create_token_aware_batches(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> list[list[tuple[int, str]]]:
|
||||
"""Split documents into batches that fit within token limits.
|
||||
|
||||
Uses 90% of max_input_tokens as safety margin.
|
||||
Each batch includes the query tokens overhead.
|
||||
"""
|
||||
max_tokens = int(self._max_input_tokens * 0.9)
|
||||
query_tokens = self._estimate_tokens(query)
|
||||
|
||||
batches: list[list[tuple[int, str]]] = []
|
||||
current_batch: list[tuple[int, str]] = []
|
||||
current_tokens = query_tokens # Start with query overhead
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_tokens = self._estimate_tokens(doc)
|
||||
|
||||
# Warn if single document exceeds token limit (will be truncated by API)
|
||||
if doc_tokens > max_tokens - query_tokens:
|
||||
logger.warning(
|
||||
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
|
||||
f"(limit: {max_tokens - query_tokens} after query overhead). "
|
||||
"Document will likely be truncated by the API."
|
||||
)
|
||||
|
||||
# If batch would exceed limit, start new batch
|
||||
if current_tokens + doc_tokens > max_tokens and current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
current_tokens = query_tokens
|
||||
|
||||
current_batch.append((idx, doc))
|
||||
current_tokens += doc_tokens
|
||||
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# Create token-aware batches
|
||||
batches = self._create_token_aware_batches(query, documents)
|
||||
|
||||
if len(batches) == 1:
|
||||
# Single batch - original behavior
|
||||
payload = self._build_payload(query=query, documents=documents)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
return self._extract_scores_from_results(results, expected=len(documents))
|
||||
|
||||
# Multiple batches - process each and merge results
|
||||
logger.info(
|
||||
f"Splitting {len(documents)} documents into {len(batches)} batches "
|
||||
f"(max_input_tokens: {self._max_input_tokens})"
|
||||
)
|
||||
|
||||
all_scores: list[float] = [0.0] * len(documents)
|
||||
|
||||
for batch in batches:
|
||||
batch_docs = [doc for _, doc in batch]
|
||||
payload = self._build_payload(query=query, documents=batch_docs)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
|
||||
|
||||
# Map scores back to original indices
|
||||
for (orig_idx, _), score in zip(batch, batch_scores):
|
||||
all_scores[orig_idx] = score
|
||||
|
||||
return all_scores
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
|
||||
) -> list[float]:
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
grouped: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
grouped.setdefault(str(query), []).append((idx, str(doc)))
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(len(pairs))]
|
||||
|
||||
for query, items in grouped.items():
|
||||
documents = [doc for _, doc in items]
|
||||
query_scores = self._rerank_one_query(query=query, documents=documents)
|
||||
for (orig_idx, _), score in zip(items, query_scores):
|
||||
scores[orig_idx] = float(score)
|
||||
|
||||
return scores
|
||||
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Base class for rerankers.
|
||||
|
||||
Defines the interface that all rerankers must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
"""Base class for all rerankers.
|
||||
|
||||
All reranker implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be processed at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair).
|
||||
"""
|
||||
...
|
||||
|
||||
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Factory for creating rerankers.
|
||||
|
||||
Provides a unified interface for instantiating different reranker backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
|
||||
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific reranker backend can be used.
|
||||
|
||||
Notes:
|
||||
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
|
||||
- "onnx" redirects to "fastembed" for backward compatibility.
|
||||
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
|
||||
- "api" uses a remote reranking HTTP API (requires httpx).
|
||||
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "legacy":
|
||||
from .legacy import check_cross_encoder_available
|
||||
|
||||
return check_cross_encoder_available()
|
||||
|
||||
if backend == "fastembed":
|
||||
from .fastembed_reranker import check_fastembed_reranker_available
|
||||
|
||||
return check_fastembed_reranker_available()
|
||||
|
||||
if backend == "onnx":
|
||||
# Redirect to fastembed for backward compatibility
|
||||
from .fastembed_reranker import check_fastembed_reranker_available
|
||||
|
||||
return check_fastembed_reranker_available()
|
||||
|
||||
if backend == "litellm":
|
||||
try:
|
||||
import ccw_litellm # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
|
||||
)
|
||||
|
||||
try:
|
||||
from .litellm_reranker import LiteLLMReranker # noqa: F401
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return False, f"LiteLLM reranker backend not available: {exc}"
|
||||
|
||||
return True, None
|
||||
|
||||
if backend == "api":
|
||||
from .api_reranker import check_httpx_available
|
||||
|
||||
return check_httpx_available()
|
||||
|
||||
return False, (
|
||||
f"Invalid reranker backend: {backend}. "
|
||||
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
|
||||
)
|
||||
|
||||
|
||||
def get_reranker(
|
||||
backend: str = "fastembed",
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
device: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseReranker:
|
||||
"""Factory function to create reranker based on backend.
|
||||
|
||||
Args:
|
||||
backend: Reranker backend to use. Options:
|
||||
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
|
||||
- "onnx": Redirects to fastembed for backward compatibility
|
||||
- "api": HTTP API backend (remote providers)
|
||||
- "litellm": LiteLLM backend (LLM-based, for API mode)
|
||||
- "legacy": sentence-transformers CrossEncoder backend (optional)
|
||||
model_name: Model identifier for model-based backends. Defaults depend on backend:
|
||||
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
|
||||
- onnx: (redirects to fastembed)
|
||||
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
|
||||
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
- litellm: default
|
||||
device: Optional device string for backends that support it (legacy only).
|
||||
**kwargs: Additional backend-specific arguments.
|
||||
|
||||
Returns:
|
||||
BaseReranker: Configured reranker instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If backend is not recognized.
|
||||
ImportError: If required backend dependencies are not installed or backend is unavailable.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "fastembed":
|
||||
ok, err = check_reranker_available("fastembed")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .fastembed_reranker import FastEmbedReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||
_ = device # Device selection is managed via fastembed providers.
|
||||
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "onnx":
|
||||
# Redirect to fastembed for backward compatibility
|
||||
ok, err = check_reranker_available("fastembed")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .fastembed_reranker import FastEmbedReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||
_ = device # Device selection is managed via fastembed providers.
|
||||
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "legacy":
|
||||
ok, err = check_reranker_available("legacy")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .legacy import CrossEncoderReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
|
||||
|
||||
if backend == "litellm":
|
||||
ok, err = check_reranker_available("litellm")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .litellm_reranker import LiteLLMReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote LLM backends.
|
||||
resolved_model_name = (model_name or "").strip() or "default"
|
||||
return LiteLLMReranker(model=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "api":
|
||||
ok, err = check_reranker_available("api")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .api_reranker import APIReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote HTTP backends.
|
||||
resolved_model_name = (model_name or "").strip() or None
|
||||
return APIReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
"""FastEmbed-based reranker backend.
|
||||
|
||||
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
|
||||
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
|
||||
|
||||
Install:
|
||||
pip install fastembed>=0.4.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether fastembed reranker dependencies are available."""
|
||||
try:
|
||||
import fastembed # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
try:
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed TextCrossEncoder not available: {exc}. "
|
||||
"Upgrade with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
class FastEmbedReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
# Alternative models supported by fastembed:
|
||||
# - "BAAI/bge-reranker-base"
|
||||
# - "BAAI/bge-reranker-large"
|
||||
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize FastEmbed reranker.
|
||||
|
||||
Args:
|
||||
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
|
||||
use_gpu: Whether to use GPU acceleration when available.
|
||||
cache_dir: Optional directory for caching downloaded models.
|
||||
threads: Optional number of threads for ONNX Runtime.
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.cache_dir = cache_dir
|
||||
self.threads = threads
|
||||
|
||||
self._encoder: Any | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy-load the TextCrossEncoder model."""
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
ok, err = check_fastembed_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder
|
||||
|
||||
# Determine providers based on GPU preference
|
||||
providers: list[str] | None = None
|
||||
if self.use_gpu:
|
||||
try:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
|
||||
except Exception:
|
||||
# Fallback: let fastembed decide
|
||||
providers = None
|
||||
|
||||
# Build initialization kwargs
|
||||
init_kwargs: dict[str, Any] = {}
|
||||
if self.cache_dir:
|
||||
init_kwargs["cache_dir"] = self.cache_dir
|
||||
if self.threads is not None:
|
||||
init_kwargs["threads"] = self.threads
|
||||
if providers:
|
||||
init_kwargs["providers"] = providers
|
||||
|
||||
logger.debug(
|
||||
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
|
||||
self.model_name,
|
||||
self.use_gpu,
|
||||
)
|
||||
|
||||
self._encoder = TextCrossEncoder(
|
||||
model_name=self.model_name,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
logger.debug("FastEmbed reranker model loaded successfully")
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: float) -> float:
|
||||
"""Numerically stable sigmoid function."""
|
||||
if x < -709:
|
||||
return 0.0
|
||||
if x > 709:
|
||||
return 1.0
|
||||
import math
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair), normalized to [0, 1] range.
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
|
||||
# For batch scoring of multiple query-doc pairs, we need to process them.
|
||||
# Group by query for efficiency when same query appears multiple times.
|
||||
query_to_docs: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
if query not in query_to_docs:
|
||||
query_to_docs[query] = []
|
||||
query_to_docs[query].append((idx, doc))
|
||||
|
||||
# Score each query group
|
||||
scores: list[float] = [0.0] * len(pairs)
|
||||
|
||||
for query, indexed_docs in query_to_docs.items():
|
||||
docs = [doc for _, doc in indexed_docs]
|
||||
indices = [idx for idx, _ in indexed_docs]
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Map scores back to original positions and normalize with sigmoid
|
||||
for i, raw_score in enumerate(raw_scores):
|
||||
if i < len(indices):
|
||||
original_idx = indices[i]
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
scores[original_idx] = self._sigmoid(float(raw_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
|
||||
# Leave scores as 0.0 for failed queries
|
||||
|
||||
return scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
*,
|
||||
top_k: int | None = None,
|
||||
batch_size: int = 32,
|
||||
) -> list[tuple[float, str, int]]:
|
||||
"""Rerank documents for a single query.
|
||||
|
||||
This is a convenience method that provides results in ranked order.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
documents: List of documents to rerank.
|
||||
top_k: Return only top K results. None returns all.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of (score, document, original_index) tuples, sorted by score descending.
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=list(documents),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to our format: (normalized_score, document, original_index)
|
||||
ranked = []
|
||||
for idx, raw_score in enumerate(raw_scores):
|
||||
if idx < len(documents):
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
normalized = self._sigmoid(float(raw_score))
|
||||
ranked.append((normalized, documents[idx], idx))
|
||||
|
||||
# Sort by score descending
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
ranked = ranked[:top_k]
|
||||
|
||||
return ranked
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
|
||||
return []
|
||||
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Legacy sentence-transformers cross-encoder reranker.
|
||||
|
||||
Install with: pip install codexlens[reranker-legacy]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder as _CrossEncoder
|
||||
|
||||
CROSS_ENCODER_AVAILABLE = True
|
||||
_import_error: str | None = None
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
_CrossEncoder = None # type: ignore[assignment]
|
||||
CROSS_ENCODER_AVAILABLE = False
|
||||
_import_error = str(exc)
|
||||
|
||||
|
||||
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
||||
if CROSS_ENCODER_AVAILABLE:
|
||||
return True, None
|
||||
return (
|
||||
False,
|
||||
_import_error
|
||||
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
|
||||
)
|
||||
|
||||
|
||||
class CrossEncoderReranker(BaseReranker):
|
||||
"""Cross-encoder reranker with lazy model loading."""
|
||||
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
self.model_name = (model_name or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.device = (device or "").strip() or None
|
||||
self._model = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
ok, err = check_cross_encoder_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.device:
|
||||
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
|
||||
else:
|
||||
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
|
||||
raise
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[Tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> List[float]:
|
||||
"""Score (query, doc) pairs using the cross-encoder.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair) in the model's native scale (usually logits).
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
||||
return [float(s) for s in scores]
|
||||
@@ -0,0 +1,214 @@
|
||||
"""Experimental LiteLLM reranker backend.
|
||||
|
||||
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
|
||||
relevance of a single (query, document) pair per request.
|
||||
|
||||
Notes:
|
||||
- This backend is experimental and may be slow/expensive compared to local
|
||||
rerankers.
|
||||
- It relies on `ccw-litellm` for a unified LLM API across providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
|
||||
|
||||
|
||||
def _coerce_score_to_unit_interval(score: float) -> float:
|
||||
"""Coerce a numeric score into [0, 1].
|
||||
|
||||
The prompt asks for a float in [0, 1], but some models may respond with 0-10
|
||||
or 0-100 scales. This function attempts a conservative normalization.
|
||||
"""
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
if 0.0 <= score <= 10.0:
|
||||
return score / 10.0
|
||||
if 0.0 <= score <= 100.0:
|
||||
return score / 100.0
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
def _extract_score(text: str) -> float | None:
|
||||
"""Extract a numeric relevance score from an LLM response."""
|
||||
content = (text or "").strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# Prefer JSON if present.
|
||||
if "{" in content and "}" in content:
|
||||
try:
|
||||
start = content.index("{")
|
||||
end = content.rindex("}") + 1
|
||||
payload = json.loads(content[start:end])
|
||||
if isinstance(payload, dict) and "score" in payload:
|
||||
return float(payload["score"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
match = _NUMBER_RE.search(content)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
return float(match.group(0))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class LiteLLMReranker(BaseReranker):
|
||||
"""Experimental reranker that uses a LiteLLM-compatible model.
|
||||
|
||||
This reranker scores each (query, doc) pair in isolation (single-pair mode)
|
||||
to improve prompt reliability across providers.
|
||||
"""
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a relevance scoring assistant.\n"
|
||||
"Given a search query and a document snippet, output a single numeric "
|
||||
"relevance score between 0 and 1.\n\n"
|
||||
"Scoring guidance:\n"
|
||||
"- 1.0: The document directly answers the query.\n"
|
||||
"- 0.5: The document is partially relevant.\n"
|
||||
"- 0.0: The document is unrelated.\n\n"
|
||||
"Output requirements:\n"
|
||||
"- Output ONLY the number (e.g., 0.73).\n"
|
||||
"- Do not include any other text."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "default",
|
||||
*,
|
||||
requests_per_minute: float | None = None,
|
||||
min_interval_seconds: float | None = None,
|
||||
default_score: float = 0.0,
|
||||
max_doc_chars: int = 8000,
|
||||
**litellm_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the reranker.
|
||||
|
||||
Args:
|
||||
model: Model name from ccw-litellm configuration (default: "default").
|
||||
requests_per_minute: Optional rate limit in requests per minute.
|
||||
min_interval_seconds: Optional minimum interval between requests. If set,
|
||||
it takes precedence over requests_per_minute.
|
||||
default_score: Score to use when an API call fails or parsing fails.
|
||||
max_doc_chars: Maximum number of document characters to include in the prompt.
|
||||
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm is not installed.
|
||||
ValueError: If model is blank.
|
||||
"""
|
||||
self.model_name = (model or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model cannot be blank")
|
||||
|
||||
self.default_score = float(default_score)
|
||||
|
||||
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
|
||||
|
||||
if min_interval_seconds is not None:
|
||||
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
|
||||
elif requests_per_minute is not None and float(requests_per_minute) > 0:
|
||||
self._min_interval_seconds = 60.0 / float(requests_per_minute)
|
||||
else:
|
||||
self._min_interval_seconds = 0.0
|
||||
|
||||
# Prefer deterministic output by default; allow overrides via kwargs.
|
||||
litellm_kwargs = dict(litellm_kwargs)
|
||||
litellm_kwargs.setdefault("temperature", 0.0)
|
||||
litellm_kwargs.setdefault("max_tokens", 16)
|
||||
|
||||
try:
|
||||
from ccw_litellm import ChatMessage, LiteLLMClient
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
raise ImportError(
|
||||
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||
) from exc
|
||||
|
||||
self._ChatMessage = ChatMessage
|
||||
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
|
||||
|
||||
self._lock = threading.RLock()
|
||||
self._last_request_at = 0.0
|
||||
|
||||
def _sanitize_text(self, text: str) -> str:
|
||||
# Keep consistent with LiteLLMEmbedderWrapper workaround.
|
||||
if text.startswith("import"):
|
||||
return " " + text
|
||||
return text
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
if self._min_interval_seconds <= 0:
|
||||
return
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_at
|
||||
if elapsed < self._min_interval_seconds:
|
||||
time.sleep(self._min_interval_seconds - elapsed)
|
||||
self._last_request_at = time.monotonic()
|
||||
|
||||
def _build_user_prompt(self, query: str, doc: str) -> str:
|
||||
sanitized_query = self._sanitize_text(query or "")
|
||||
sanitized_doc = self._sanitize_text(doc or "")
|
||||
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
|
||||
sanitized_doc = sanitized_doc[: self.max_doc_chars]
|
||||
|
||||
return (
|
||||
"Query:\n"
|
||||
f"{sanitized_query}\n\n"
|
||||
"Document:\n"
|
||||
f"{sanitized_doc}\n\n"
|
||||
"Return the relevance score (0 to 1) as a single number:"
|
||||
)
|
||||
|
||||
def _score_single_pair(self, query: str, doc: str) -> float:
|
||||
messages = [
|
||||
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
|
||||
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
|
||||
]
|
||||
|
||||
try:
|
||||
self._rate_limit()
|
||||
response = self._client.chat(messages)
|
||||
except Exception as exc:
|
||||
logger.debug("LiteLLM reranker request failed: %s", exc)
|
||||
return self.default_score
|
||||
|
||||
raw = getattr(response, "content", "") or ""
|
||||
score = _extract_score(raw)
|
||||
if score is None:
|
||||
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
|
||||
return self.default_score
|
||||
return _coerce_score_to_unit_interval(float(score))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with per-pair LLM calls."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
|
||||
scores: list[float] = []
|
||||
for i in range(0, len(pairs), bs):
|
||||
batch = pairs[i : i + bs]
|
||||
for query, doc in batch:
|
||||
scores.append(self._score_single_pair(query, doc))
|
||||
return scores
|
||||
@@ -0,0 +1,268 @@
|
||||
"""Optimum + ONNX Runtime reranker backend.
|
||||
|
||||
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
|
||||
classification models. It is designed to run without requiring PyTorch at
|
||||
runtime by using numpy tensors and ONNX Runtime execution providers.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_onnx_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
|
||||
class ONNXReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
providers: list[Any] | None = None,
|
||||
max_length: int | None = None,
|
||||
) -> None:
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.providers = providers
|
||||
|
||||
self.max_length = int(max_length) if max_length is not None else None
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._model_input_names: set[str] | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_onnx_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
# Include device_id options for DirectML/CUDA selection when available.
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=True
|
||||
)
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`.
|
||||
# Prefer passing the full providers list, with a conservative fallback.
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception:
|
||||
model_kwargs = {}
|
||||
|
||||
try:
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(
|
||||
self.model_name,
|
||||
**model_kwargs,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments.
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache model input names to filter tokenizer outputs defensively.
|
||||
input_names: set[str] | None = None
|
||||
for attr in ("input_names", "model_input_names"):
|
||||
names = getattr(self._model, attr, None)
|
||||
if isinstance(names, (list, tuple)) and names:
|
||||
input_names = {str(n) for n in names}
|
||||
break
|
||||
if input_names is None:
|
||||
try:
|
||||
session = getattr(self._model, "model", None)
|
||||
if session is not None and hasattr(session, "get_inputs"):
|
||||
input_names = {i.name for i in session.get_inputs()}
|
||||
except Exception:
|
||||
input_names = None
|
||||
self._model_input_names = input_names
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
x = np.clip(x, -50.0, 50.0)
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
@staticmethod
|
||||
def _select_relevance_logit(logits: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
arr = np.asarray(logits)
|
||||
if arr.ndim == 0:
|
||||
return arr.reshape(1)
|
||||
if arr.ndim == 1:
|
||||
return arr
|
||||
if arr.ndim >= 2:
|
||||
# Common cases:
|
||||
# - Regression: (batch, 1)
|
||||
# - Binary classification: (batch, 2)
|
||||
if arr.shape[-1] == 1:
|
||||
return arr[..., 0]
|
||||
if arr.shape[-1] == 2:
|
||||
# Convert 2-logit softmax into a single logit via difference.
|
||||
return arr[..., 1] - arr[..., 0]
|
||||
return arr.max(axis=-1)
|
||||
return arr.reshape(-1)
|
||||
|
||||
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
|
||||
|
||||
queries = [q for q, _ in batch]
|
||||
docs = [d for _, d in batch]
|
||||
|
||||
tokenizer_kwargs: dict[str, Any] = {
|
||||
"text": queries,
|
||||
"text_pair": docs,
|
||||
"padding": True,
|
||||
"truncation": True,
|
||||
"return_tensors": "np",
|
||||
}
|
||||
|
||||
max_len = self.max_length
|
||||
if max_len is None:
|
||||
try:
|
||||
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
|
||||
if 0 < model_max < 10_000:
|
||||
max_len = model_max
|
||||
else:
|
||||
max_len = 512
|
||||
except Exception:
|
||||
max_len = 512
|
||||
if max_len is not None and max_len > 0:
|
||||
tokenizer_kwargs["max_length"] = int(max_len)
|
||||
|
||||
encoded = self._tokenizer(**tokenizer_kwargs)
|
||||
inputs = dict(encoded)
|
||||
|
||||
# Some models do not accept token_type_ids; filter to known input names if available.
|
||||
if self._model_input_names:
|
||||
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
|
||||
if self._model is None:
|
||||
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
|
||||
|
||||
outputs = self._model(**inputs)
|
||||
if hasattr(outputs, "logits"):
|
||||
return outputs.logits
|
||||
if isinstance(outputs, dict) and "logits" in outputs:
|
||||
return outputs["logits"]
|
||||
if isinstance(outputs, (list, tuple)) and outputs:
|
||||
return outputs[0]
|
||||
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
import numpy as np
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores: list[float] = []
|
||||
|
||||
for batch in _iter_batches(list(pairs), bs):
|
||||
inputs = self._tokenize_batch(batch)
|
||||
logits = self._forward_logits(inputs)
|
||||
rel_logits = self._select_relevance_logit(logits)
|
||||
probs = self._sigmoid(rel_logits)
|
||||
probs = np.clip(probs, 0.0, 1.0)
|
||||
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
|
||||
|
||||
if len(scores) != len(pairs):
|
||||
logger.debug(
|
||||
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
|
||||
)
|
||||
return scores[: len(pairs)]
|
||||
|
||||
return scores
|
||||
434
codex-lens/build/lib/codexlens/semantic/rotational_embedder.py
Normal file
434
codex-lens/build/lib/codexlens/semantic/rotational_embedder.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""Rotational embedder for multi-endpoint API load balancing.
|
||||
|
||||
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
|
||||
to maximize throughput while respecting rate limits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndpointStatus(Enum):
|
||||
"""Status of an API endpoint."""
|
||||
AVAILABLE = "available"
|
||||
COOLING = "cooling" # Rate limited, temporarily unavailable
|
||||
FAILED = "failed" # Permanent failure (auth error, etc.)
|
||||
|
||||
|
||||
class SelectionStrategy(Enum):
|
||||
"""Strategy for selecting endpoints."""
|
||||
ROUND_ROBIN = "round_robin"
|
||||
LATENCY_AWARE = "latency_aware"
|
||||
WEIGHTED_RANDOM = "weighted_random"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointConfig:
|
||||
"""Configuration for a single API endpoint."""
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
weight: float = 1.0 # Higher weight = more requests
|
||||
max_concurrent: int = 4 # Max concurrent requests to this endpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointState:
|
||||
"""Runtime state for an endpoint."""
|
||||
config: EndpointConfig
|
||||
embedder: Any = None # LiteLLMEmbedderWrapper instance
|
||||
|
||||
# Health metrics
|
||||
status: EndpointStatus = EndpointStatus.AVAILABLE
|
||||
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
|
||||
|
||||
# Performance metrics
|
||||
total_requests: int = 0
|
||||
total_failures: int = 0
|
||||
avg_latency_ms: float = 0.0
|
||||
last_latency_ms: float = 0.0
|
||||
|
||||
# Concurrency tracking
|
||||
active_requests: int = 0
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if endpoint is available for requests."""
|
||||
if self.status == EndpointStatus.FAILED:
|
||||
return False
|
||||
if self.status == EndpointStatus.COOLING:
|
||||
if time.time() >= self.cooldown_until:
|
||||
self.status = EndpointStatus.AVAILABLE
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_cooldown(self, seconds: float) -> None:
|
||||
"""Put endpoint in cooldown state."""
|
||||
self.status = EndpointStatus.COOLING
|
||||
self.cooldown_until = time.time() + seconds
|
||||
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
|
||||
|
||||
def mark_failed(self) -> None:
|
||||
"""Mark endpoint as permanently failed."""
|
||||
self.status = EndpointStatus.FAILED
|
||||
logger.error(f"Endpoint {self.config.model} marked as failed")
|
||||
|
||||
def record_success(self, latency_ms: float) -> None:
|
||||
"""Record successful request."""
|
||||
self.total_requests += 1
|
||||
self.last_latency_ms = latency_ms
|
||||
# Exponential moving average for latency
|
||||
alpha = 0.3
|
||||
if self.avg_latency_ms == 0:
|
||||
self.avg_latency_ms = latency_ms
|
||||
else:
|
||||
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
|
||||
|
||||
def record_failure(self) -> None:
|
||||
"""Record failed request."""
|
||||
self.total_requests += 1
|
||||
self.total_failures += 1
|
||||
|
||||
@property
|
||||
def health_score(self) -> float:
|
||||
"""Calculate health score (0-1) based on metrics."""
|
||||
if not self.is_available():
|
||||
return 0.0
|
||||
|
||||
# Base score from success rate
|
||||
if self.total_requests > 0:
|
||||
success_rate = 1 - (self.total_failures / self.total_requests)
|
||||
else:
|
||||
success_rate = 1.0
|
||||
|
||||
# Latency factor (faster = higher score)
|
||||
# Normalize: 100ms = 1.0, 1000ms = 0.1
|
||||
if self.avg_latency_ms > 0:
|
||||
latency_factor = min(1.0, 100 / self.avg_latency_ms)
|
||||
else:
|
||||
latency_factor = 1.0
|
||||
|
||||
# Availability factor (less concurrent = more available)
|
||||
if self.config.max_concurrent > 0:
|
||||
availability = 1 - (self.active_requests / self.config.max_concurrent)
|
||||
else:
|
||||
availability = 1.0
|
||||
|
||||
# Combined score with weights
|
||||
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
|
||||
|
||||
|
||||
class RotationalEmbedder(BaseEmbedder):
|
||||
"""Embedder that load balances across multiple API endpoints.
|
||||
|
||||
Features:
|
||||
- Intelligent endpoint selection based on latency and health
|
||||
- Automatic failover on rate limits (429) and server errors
|
||||
- Cooldown management to respect rate limits
|
||||
- Thread-safe concurrent request handling
|
||||
|
||||
Args:
|
||||
endpoints: List of endpoint configurations
|
||||
strategy: Selection strategy (default: latency_aware)
|
||||
default_cooldown: Default cooldown seconds for rate limits (default: 60)
|
||||
max_retries: Maximum retry attempts across all endpoints (default: 3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoints: List[EndpointConfig],
|
||||
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
|
||||
default_cooldown: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
if not endpoints:
|
||||
raise ValueError("At least one endpoint must be provided")
|
||||
|
||||
self.strategy = strategy
|
||||
self.default_cooldown = default_cooldown
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Initialize endpoint states
|
||||
self._endpoints: List[EndpointState] = []
|
||||
self._lock = threading.Lock()
|
||||
self._round_robin_index = 0
|
||||
|
||||
# Create embedder instances for each endpoint
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
|
||||
for config in endpoints:
|
||||
# Build kwargs for LiteLLMEmbedderWrapper
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if config.api_key:
|
||||
kwargs["api_key"] = config.api_key
|
||||
if config.api_base:
|
||||
kwargs["api_base"] = config.api_base
|
||||
|
||||
try:
|
||||
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
|
||||
state = EndpointState(config=config, embedder=embedder)
|
||||
self._endpoints.append(state)
|
||||
logger.info(f"Initialized endpoint: {config.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
|
||||
|
||||
if not self._endpoints:
|
||||
raise ValueError("Failed to initialize any endpoints")
|
||||
|
||||
# Cache embedding properties from first endpoint
|
||||
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
|
||||
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
|
||||
self._max_tokens = self._endpoints[0].embedder.max_tokens
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions."""
|
||||
return self._embedding_dim
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit."""
|
||||
return self._max_tokens
|
||||
|
||||
@property
|
||||
def endpoint_count(self) -> int:
|
||||
"""Return number of configured endpoints."""
|
||||
return len(self._endpoints)
|
||||
|
||||
@property
|
||||
def available_endpoint_count(self) -> int:
|
||||
"""Return number of available endpoints."""
|
||||
return sum(1 for ep in self._endpoints if ep.is_available())
|
||||
|
||||
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
|
||||
"""Get statistics for all endpoints."""
|
||||
stats = []
|
||||
for ep in self._endpoints:
|
||||
stats.append({
|
||||
"model": ep.config.model,
|
||||
"status": ep.status.value,
|
||||
"total_requests": ep.total_requests,
|
||||
"total_failures": ep.total_failures,
|
||||
"avg_latency_ms": round(ep.avg_latency_ms, 2),
|
||||
"health_score": round(ep.health_score, 3),
|
||||
"active_requests": ep.active_requests,
|
||||
})
|
||||
return stats
|
||||
|
||||
def _select_endpoint(self) -> Optional[EndpointState]:
|
||||
"""Select best available endpoint based on strategy."""
|
||||
available = [ep for ep in self._endpoints if ep.is_available()]
|
||||
|
||||
if not available:
|
||||
return None
|
||||
|
||||
if self.strategy == SelectionStrategy.ROUND_ROBIN:
|
||||
with self._lock:
|
||||
self._round_robin_index = (self._round_robin_index + 1) % len(available)
|
||||
return available[self._round_robin_index]
|
||||
|
||||
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
|
||||
# Sort by health score (descending) and pick top candidate
|
||||
# Add small random factor to prevent thundering herd
|
||||
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored[0][0]
|
||||
|
||||
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
|
||||
# Weighted random selection based on health scores
|
||||
scores = [ep.health_score for ep in available]
|
||||
total = sum(scores)
|
||||
if total == 0:
|
||||
return random.choice(available)
|
||||
|
||||
weights = [s / total for s in scores]
|
||||
return random.choices(available, weights=weights, k=1)[0]
|
||||
|
||||
return available[0]
|
||||
|
||||
def _parse_retry_after(self, error: Exception) -> Optional[float]:
|
||||
"""Extract Retry-After value from error if available."""
|
||||
error_str = str(error)
|
||||
|
||||
# Try to find Retry-After in error message
|
||||
import re
|
||||
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _is_rate_limit_error(self, error: Exception) -> bool:
|
||||
"""Check if error is a rate limit error."""
|
||||
error_str = str(error).lower()
|
||||
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""Check if error is retryable (not auth/config error)."""
|
||||
error_str = str(error).lower()
|
||||
# Retryable errors
|
||||
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
|
||||
"timeout", "connection", "service unavailable"]):
|
||||
return True
|
||||
# Non-retryable errors (auth, config)
|
||||
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
|
||||
"unauthorized", "api key"]):
|
||||
return False
|
||||
# Default to retryable for unknown errors
|
||||
return True
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||
"""Embed texts using load-balanced endpoint selection.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
**kwargs: Additional arguments passed to underlying embedder.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all endpoints fail after retries.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
tried_endpoints: set = set()
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
endpoint = self._select_endpoint()
|
||||
|
||||
if endpoint is None:
|
||||
# All endpoints unavailable, wait for shortest cooldown
|
||||
min_cooldown = min(
|
||||
(ep.cooldown_until - time.time() for ep in self._endpoints
|
||||
if ep.status == EndpointStatus.COOLING),
|
||||
default=self.default_cooldown
|
||||
)
|
||||
if min_cooldown > 0 and attempt < self.max_retries:
|
||||
wait_time = min(min_cooldown, 30) # Cap wait at 30s
|
||||
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
break
|
||||
|
||||
# Track tried endpoints to avoid infinite loops
|
||||
endpoint_id = id(endpoint)
|
||||
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
|
||||
# Already tried all endpoints
|
||||
break
|
||||
tried_endpoints.add(endpoint_id)
|
||||
|
||||
# Acquire slot
|
||||
with endpoint.lock:
|
||||
endpoint.active_requests += 1
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Record success
|
||||
endpoint.record_success(latency_ms)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
endpoint.record_failure()
|
||||
|
||||
if self._is_rate_limit_error(e):
|
||||
# Rate limited - set cooldown
|
||||
retry_after = self._parse_retry_after(e) or self.default_cooldown
|
||||
endpoint.set_cooldown(retry_after)
|
||||
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
|
||||
f"cooling for {retry_after}s")
|
||||
|
||||
elif not self._is_retryable_error(e):
|
||||
# Permanent failure (auth error, etc.)
|
||||
endpoint.mark_failed()
|
||||
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
|
||||
|
||||
else:
|
||||
# Temporary error - short cooldown
|
||||
endpoint.set_cooldown(5.0)
|
||||
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
|
||||
|
||||
finally:
|
||||
with endpoint.lock:
|
||||
endpoint.active_requests -= 1
|
||||
|
||||
# All retries exhausted
|
||||
available = self.available_endpoint_count
|
||||
raise RuntimeError(
|
||||
f"All embedding attempts failed after {self.max_retries + 1} tries. "
|
||||
f"Available endpoints: {available}/{len(self._endpoints)}. "
|
||||
f"Last error: {last_error}"
|
||||
)
|
||||
|
||||
|
||||
def create_rotational_embedder(
|
||||
endpoints_config: List[Dict[str, Any]],
|
||||
strategy: str = "latency_aware",
|
||||
default_cooldown: float = 60.0,
|
||||
) -> RotationalEmbedder:
|
||||
"""Factory function to create RotationalEmbedder from config dicts.
|
||||
|
||||
Args:
|
||||
endpoints_config: List of endpoint configuration dicts with keys:
|
||||
- model: Model identifier (required)
|
||||
- api_key: API key (optional)
|
||||
- api_base: API base URL (optional)
|
||||
- weight: Request weight (optional, default 1.0)
|
||||
- max_concurrent: Max concurrent requests (optional, default 4)
|
||||
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
|
||||
default_cooldown: Default cooldown seconds for rate limits
|
||||
|
||||
Returns:
|
||||
Configured RotationalEmbedder instance
|
||||
|
||||
Example config:
|
||||
endpoints_config = [
|
||||
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||
]
|
||||
"""
|
||||
endpoints = []
|
||||
for cfg in endpoints_config:
|
||||
endpoints.append(EndpointConfig(
|
||||
model=cfg["model"],
|
||||
api_key=cfg.get("api_key"),
|
||||
api_base=cfg.get("api_base"),
|
||||
weight=cfg.get("weight", 1.0),
|
||||
max_concurrent=cfg.get("max_concurrent", 4),
|
||||
))
|
||||
|
||||
strategy_enum = SelectionStrategy[strategy.upper()]
|
||||
|
||||
return RotationalEmbedder(
|
||||
endpoints=endpoints,
|
||||
strategy=strategy_enum,
|
||||
default_cooldown=default_cooldown,
|
||||
)
|
||||
567
codex-lens/build/lib/codexlens/semantic/splade_encoder.py
Normal file
567
codex-lens/build/lib/codexlens/semantic/splade_encoder.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
|
||||
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
|
||||
that combine the interpretability of BM25 with neural relevance modeling.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
Install (GPU):
|
||||
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_splade_available() -> Tuple[bool, Optional[str]]:
|
||||
"""Check whether SPLADE dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (available: bool, error_message: Optional[str])
|
||||
"""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc:
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# Global cache for SPLADE encoders (singleton pattern)
|
||||
_splade_cache: Dict[str, "SpladeEncoder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_splade_encoder(
|
||||
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> "SpladeEncoder":
|
||||
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
SpladeEncoder instances across multiple searches, avoiding repeated model
|
||||
loading overhead.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
|
||||
Returns:
|
||||
Cached SpladeEncoder instance for the given configuration
|
||||
"""
|
||||
global _splade_cache
|
||||
|
||||
# Cache key includes all configuration parameters
|
||||
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
|
||||
|
||||
with _cache_lock:
|
||||
encoder = _splade_cache.get(cache_key)
|
||||
if encoder is not None:
|
||||
return encoder
|
||||
|
||||
# Create new encoder and cache it
|
||||
encoder = SpladeEncoder(
|
||||
model_name=model_name,
|
||||
use_gpu=use_gpu,
|
||||
max_length=max_length,
|
||||
sparsity_threshold=sparsity_threshold,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
# Pre-load model to ensure it's ready
|
||||
encoder._load_model()
|
||||
_splade_cache[cache_key] = encoder
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def clear_splade_cache() -> None:
|
||||
"""Clear the SPLADE encoder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when encoders are no longer needed.
|
||||
"""
|
||||
global _splade_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for encoder in _splade_cache.values():
|
||||
if encoder._model is not None:
|
||||
del encoder._model
|
||||
encoder._model = None
|
||||
if encoder._tokenizer is not None:
|
||||
del encoder._tokenizer
|
||||
encoder._tokenizer = None
|
||||
_splade_cache.clear()
|
||||
|
||||
|
||||
class SpladeEncoder:
|
||||
"""ONNX-optimized SPLADE sparse encoder.
|
||||
|
||||
Produces sparse vectors with vocabulary-aligned dimensions.
|
||||
Output: Dict[int, float] mapping token_id to weight.
|
||||
|
||||
SPLADE activation formula:
|
||||
splade_repr = log(1 + ReLU(logits)) * attention_mask
|
||||
splade_vec = max_pooling(splade_repr, axis=sequence_length)
|
||||
|
||||
References:
|
||||
- SPLADE: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
providers: Optional[List[Any]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize SPLADE encoder.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.max_length = int(max_length) if max_length > 0 else 512
|
||||
self.sparsity_threshold = float(sparsity_threshold)
|
||||
self.providers = providers
|
||||
|
||||
# Setup ONNX cache directory
|
||||
if cache_dir:
|
||||
self._cache_dir = Path(cache_dir)
|
||||
else:
|
||||
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._vocab_size: int | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _get_local_cache_path(self) -> Path:
|
||||
"""Get local cache path for this model's ONNX files.
|
||||
|
||||
Returns:
|
||||
Path to the local ONNX cache directory for this model
|
||||
"""
|
||||
# Replace / with -- for filesystem-safe naming
|
||||
safe_name = self.model_name.replace("/", "--")
|
||||
return self._cache_dir / safe_name
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load ONNX model and tokenizer.
|
||||
|
||||
First checks local cache for ONNX model, falling back to
|
||||
HuggingFace download and conversion if not cached.
|
||||
"""
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from .gpu_support import get_optimal_providers, get_selected_device_id
|
||||
|
||||
# Get providers as pure string list (cache-friendly)
|
||||
# NOTE: with_device_options=False to avoid tuple-based providers
|
||||
# which break optimum's caching mechanism
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=False
|
||||
)
|
||||
# Get device_id separately for provider_options
|
||||
self._device_id = get_selected_device_id() if self.use_gpu else None
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`
|
||||
# Prefer passing the full providers list, with a conservative fallback
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
# Pass device_id via provider_options for GPU selection
|
||||
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
|
||||
# Build provider_options dict for each GPU provider
|
||||
provider_options = {}
|
||||
for p in self.providers:
|
||||
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
|
||||
provider_options[p] = {"device_id": self._device_id}
|
||||
if provider_options:
|
||||
model_kwargs["provider_options"] = provider_options
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to inspect ORTModel signature: {e}")
|
||||
model_kwargs = {}
|
||||
|
||||
# Check for local ONNX cache first
|
||||
local_cache = self._get_local_cache_path()
|
||||
onnx_model_path = local_cache / "model.onnx"
|
||||
|
||||
if onnx_model_path.exists():
|
||||
# Load from local cache
|
||||
logger.info(f"Loading SPLADE from local cache: {local_cache}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
str(local_cache),
|
||||
**model_kwargs,
|
||||
)
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
str(local_cache), use_fast=True
|
||||
)
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.info(
|
||||
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load from cache, redownloading: {e}")
|
||||
|
||||
# Download and convert from HuggingFace
|
||||
logger.info(f"Downloading SPLADE model: {self.model_name}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True, # Export to ONNX
|
||||
**model_kwargs,
|
||||
)
|
||||
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True,
|
||||
)
|
||||
logger.warning(
|
||||
"Optimum version doesn't support provider parameters. "
|
||||
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||
)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache vocabulary size
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||
|
||||
# Save to local cache for future use
|
||||
try:
|
||||
local_cache.mkdir(parents=True, exist_ok=True)
|
||||
self._model.save_pretrained(str(local_cache))
|
||||
self._tokenizer.save_pretrained(str(local_cache))
|
||||
logger.info(f"SPLADE model cached to: {local_cache}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache SPLADE model: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||
"""Apply SPLADE activation function to model outputs.
|
||||
|
||||
Formula: log(1 + ReLU(logits)) * attention_mask
|
||||
|
||||
Args:
|
||||
logits: Model output logits (batch, seq_len, vocab_size)
|
||||
attention_mask: Attention mask (batch, seq_len)
|
||||
|
||||
Returns:
|
||||
SPLADE representations (batch, seq_len, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# ReLU activation
|
||||
relu_logits = np.maximum(0, logits)
|
||||
|
||||
# Log(1 + x) transformation
|
||||
log_relu = np.log1p(relu_logits)
|
||||
|
||||
# Apply attention mask (expand to match vocab dimension)
|
||||
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
|
||||
mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
||||
|
||||
# Element-wise multiplication
|
||||
splade_repr = log_relu * mask_expanded
|
||||
|
||||
return splade_repr
|
||||
|
||||
@staticmethod
|
||||
def _max_pooling(splade_repr: Any) -> Any:
|
||||
"""Max pooling over sequence length dimension.
|
||||
|
||||
Args:
|
||||
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
|
||||
|
||||
Returns:
|
||||
Pooled sparse vectors (batch, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Max pooling over sequence dimension (axis=1)
|
||||
return np.max(splade_repr, axis=1)
|
||||
|
||||
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
|
||||
"""Convert dense vector to sparse dictionary.
|
||||
|
||||
Args:
|
||||
dense_vec: Dense vector (vocab_size,)
|
||||
|
||||
Returns:
|
||||
Sparse dictionary {token_id: weight} with weights above threshold
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Find non-zero indices above threshold
|
||||
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
|
||||
|
||||
# Create sparse dictionary
|
||||
sparse_dict = {
|
||||
int(idx): float(dense_vec[idx])
|
||||
for idx in nonzero_indices
|
||||
}
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def warmup(self, text: str = "warmup query") -> None:
|
||||
"""Warmup the encoder by running a dummy inference.
|
||||
|
||||
First-time model inference includes initialization overhead.
|
||||
Call this method once before the first real search to avoid
|
||||
latency spikes.
|
||||
|
||||
Args:
|
||||
text: Dummy text for warmup (default: "warmup query")
|
||||
"""
|
||||
logger.info("Warming up SPLADE encoder...")
|
||||
# Trigger model loading and first inference
|
||||
_ = self.encode_text(text)
|
||||
logger.info("SPLADE encoder warmup complete")
|
||||
|
||||
def encode_text(self, text: str) -> Dict[int, float]:
|
||||
"""Encode text to sparse vector {token_id: weight}.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Sparse vector as dictionary mapping token_id to weight
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Tokenize input
|
||||
encoded = self._tokenizer(
|
||||
text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vec = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert to sparse dictionary (single item batch)
|
||||
sparse_dict = self._to_sparse_dict(splade_vec[0])
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
|
||||
"""Batch encode texts to sparse vectors.
|
||||
|
||||
Args:
|
||||
texts: List of input texts to encode
|
||||
batch_size: Batch size for encoding (default: 32)
|
||||
|
||||
Returns:
|
||||
List of sparse vectors as dictionaries
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
results: List[Dict[int, float]] = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Tokenize batch
|
||||
encoded = self._tokenizer(
|
||||
batch_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vecs = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert each vector to sparse dictionary
|
||||
for vec in splade_vecs:
|
||||
sparse_dict = self._to_sparse_dict(vec)
|
||||
results.append(sparse_dict)
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Return vocabulary size (~30k for BERT-based models).
|
||||
|
||||
Returns:
|
||||
Vocabulary size (number of tokens in tokenizer)
|
||||
"""
|
||||
if self._vocab_size is not None:
|
||||
return self._vocab_size
|
||||
|
||||
self._load_model()
|
||||
return self._vocab_size or 0
|
||||
|
||||
def get_token(self, token_id: int) -> str:
|
||||
"""Convert token_id to string (for debugging).
|
||||
|
||||
Args:
|
||||
token_id: Token ID to convert
|
||||
|
||||
Returns:
|
||||
Token string
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded")
|
||||
|
||||
return self._tokenizer.decode([token_id])
|
||||
|
||||
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
|
||||
"""Get top-k tokens with highest weights from sparse vector.
|
||||
|
||||
Useful for debugging and understanding what the model is focusing on.
|
||||
|
||||
Args:
|
||||
sparse_vec: Sparse vector as {token_id: weight}
|
||||
top_k: Number of top tokens to return
|
||||
|
||||
Returns:
|
||||
List of (token_string, weight) tuples, sorted by weight descending
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if not sparse_vec:
|
||||
return []
|
||||
|
||||
# Sort by weight descending
|
||||
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Take top-k and convert token_ids to strings
|
||||
top_items = sorted_items[:top_k]
|
||||
|
||||
return [
|
||||
(self.get_token(token_id), weight)
|
||||
for token_id, weight in top_items
|
||||
]
|
||||
1278
codex-lens/build/lib/codexlens/semantic/vector_store.py
Normal file
1278
codex-lens/build/lib/codexlens/semantic/vector_store.py
Normal file
File diff suppressed because it is too large
Load Diff
32
codex-lens/build/lib/codexlens/storage/__init__.py
Normal file
32
codex-lens/build/lib/codexlens/storage/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Storage backends for CodexLens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .sqlite_store import SQLiteStore
|
||||
from .path_mapper import PathMapper
|
||||
from .registry import RegistryStore, ProjectInfo, DirMapping
|
||||
from .dir_index import DirIndexStore, SubdirLink, FileEntry
|
||||
from .index_tree import IndexTreeBuilder, BuildResult, DirBuildResult
|
||||
from .vector_meta_store import VectorMetadataStore
|
||||
|
||||
__all__ = [
|
||||
# Legacy (workspace-local)
|
||||
"SQLiteStore",
|
||||
# Path mapping
|
||||
"PathMapper",
|
||||
# Global registry
|
||||
"RegistryStore",
|
||||
"ProjectInfo",
|
||||
"DirMapping",
|
||||
# Directory index
|
||||
"DirIndexStore",
|
||||
"SubdirLink",
|
||||
"FileEntry",
|
||||
# Tree builder
|
||||
"IndexTreeBuilder",
|
||||
"BuildResult",
|
||||
"DirBuildResult",
|
||||
# Vector metadata
|
||||
"VectorMetadataStore",
|
||||
]
|
||||
|
||||
2358
codex-lens/build/lib/codexlens/storage/dir_index.py
Normal file
2358
codex-lens/build/lib/codexlens/storage/dir_index.py
Normal file
File diff suppressed because it is too large
Load Diff
32
codex-lens/build/lib/codexlens/storage/file_cache.py
Normal file
32
codex-lens/build/lib/codexlens/storage/file_cache.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Simple filesystem cache helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileCache:
|
||||
"""Caches file mtimes for incremental indexing."""
|
||||
|
||||
cache_path: Path
|
||||
|
||||
def load_mtime(self, path: Path) -> Optional[float]:
|
||||
try:
|
||||
key = self._key_for(path)
|
||||
record = (self.cache_path / key).read_text(encoding="utf-8")
|
||||
return float(record)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def store_mtime(self, path: Path, mtime: float) -> None:
|
||||
self.cache_path.mkdir(parents=True, exist_ok=True)
|
||||
key = self._key_for(path)
|
||||
(self.cache_path / key).write_text(str(mtime), encoding="utf-8")
|
||||
|
||||
def _key_for(self, path: Path) -> str:
|
||||
safe = str(path).replace(":", "_").replace("\\", "_").replace("/", "_")
|
||||
return f"{safe}.mtime"
|
||||
|
||||
398
codex-lens/build/lib/codexlens/storage/global_index.py
Normal file
398
codex-lens/build/lib/codexlens/storage/global_index.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""Global cross-directory symbol index for fast lookups.
|
||||
|
||||
Stores symbols for an entire project in a single SQLite database so symbol search
|
||||
does not require traversing every directory _index.db.
|
||||
|
||||
This index is updated incrementally during file indexing (delete+insert per file)
|
||||
to avoid expensive batch rebuilds.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
|
||||
class GlobalSymbolIndex:
|
||||
"""Project-wide symbol index with incremental updates."""
|
||||
|
||||
SCHEMA_VERSION = 1
|
||||
DEFAULT_DB_NAME = "_global_symbols.db"
|
||||
|
||||
def __init__(self, db_path: str | Path, project_id: int) -> None:
|
||||
self.db_path = Path(db_path).resolve()
|
||||
self.project_id = int(project_id)
|
||||
self._lock = threading.RLock()
|
||||
self._conn: Optional[sqlite3.Connection] = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Create database and schema if not exists."""
|
||||
with self._lock:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = self._get_connection()
|
||||
|
||||
current_version = self._get_schema_version(conn)
|
||||
if current_version > self.SCHEMA_VERSION:
|
||||
raise StorageError(
|
||||
f"Database schema version {current_version} is newer than "
|
||||
f"supported version {self.SCHEMA_VERSION}. "
|
||||
f"Please update the application or use a compatible database.",
|
||||
db_path=str(self.db_path),
|
||||
operation="initialize",
|
||||
details={
|
||||
"current_version": current_version,
|
||||
"supported_version": self.SCHEMA_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if current_version == 0:
|
||||
self._create_schema(conn)
|
||||
self._set_schema_version(conn, self.SCHEMA_VERSION)
|
||||
elif current_version < self.SCHEMA_VERSION:
|
||||
self._apply_migrations(conn, current_version)
|
||||
self._set_schema_version(conn, self.SCHEMA_VERSION)
|
||||
|
||||
conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
with self._lock:
|
||||
if self._conn is not None:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._conn = None
|
||||
|
||||
def __enter__(self) -> "GlobalSymbolIndex":
|
||||
self.initialize()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
||||
self.close()
|
||||
|
||||
def add_symbol(self, symbol: Symbol, file_path: str | Path, index_path: str | Path) -> None:
|
||||
"""Insert a single symbol (idempotent) for incremental updates."""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
index_path_str = str(Path(index_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO global_symbols(
|
||||
project_id, symbol_name, symbol_kind,
|
||||
file_path, start_line, end_line, index_path
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(
|
||||
project_id, symbol_name, symbol_kind,
|
||||
file_path, start_line, end_line
|
||||
)
|
||||
DO UPDATE SET
|
||||
index_path=excluded.index_path
|
||||
""",
|
||||
(
|
||||
self.project_id,
|
||||
symbol.name,
|
||||
symbol.kind,
|
||||
file_path_str,
|
||||
symbol.range[0],
|
||||
symbol.range[1],
|
||||
index_path_str,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.DatabaseError as exc:
|
||||
conn.rollback()
|
||||
raise StorageError(
|
||||
f"Failed to add symbol {symbol.name}: {exc}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_symbol",
|
||||
) from exc
|
||||
|
||||
def update_file_symbols(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
symbols: List[Symbol],
|
||||
index_path: str | Path | None = None,
|
||||
) -> None:
|
||||
"""Replace all symbols for a file atomically (delete + insert)."""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
index_path_str: Optional[str]
|
||||
if index_path is not None:
|
||||
index_path_str = str(Path(index_path).resolve())
|
||||
else:
|
||||
index_path_str = self._get_existing_index_path(file_path_str)
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute("BEGIN")
|
||||
conn.execute(
|
||||
"DELETE FROM global_symbols WHERE project_id=? AND file_path=?",
|
||||
(self.project_id, file_path_str),
|
||||
)
|
||||
|
||||
if symbols:
|
||||
if not index_path_str:
|
||||
raise StorageError(
|
||||
"index_path is required when inserting symbols for a new file",
|
||||
db_path=str(self.db_path),
|
||||
operation="update_file_symbols",
|
||||
details={"file_path": file_path_str},
|
||||
)
|
||||
|
||||
rows = [
|
||||
(
|
||||
self.project_id,
|
||||
s.name,
|
||||
s.kind,
|
||||
file_path_str,
|
||||
s.range[0],
|
||||
s.range[1],
|
||||
index_path_str,
|
||||
)
|
||||
for s in symbols
|
||||
]
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO global_symbols(
|
||||
project_id, symbol_name, symbol_kind,
|
||||
file_path, start_line, end_line, index_path
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(
|
||||
project_id, symbol_name, symbol_kind,
|
||||
file_path, start_line, end_line
|
||||
)
|
||||
DO UPDATE SET
|
||||
index_path=excluded.index_path
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
except sqlite3.DatabaseError as exc:
|
||||
conn.rollback()
|
||||
raise StorageError(
|
||||
f"Failed to update symbols for {file_path_str}: {exc}",
|
||||
db_path=str(self.db_path),
|
||||
operation="update_file_symbols",
|
||||
) from exc
|
||||
|
||||
def delete_file_symbols(self, file_path: str | Path) -> int:
|
||||
"""Remove all symbols for a file. Returns number of rows deleted."""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cur = conn.execute(
|
||||
"DELETE FROM global_symbols WHERE project_id=? AND file_path=?",
|
||||
(self.project_id, file_path_str),
|
||||
)
|
||||
conn.commit()
|
||||
return int(cur.rowcount or 0)
|
||||
except sqlite3.DatabaseError as exc:
|
||||
conn.rollback()
|
||||
raise StorageError(
|
||||
f"Failed to delete symbols for {file_path_str}: {exc}",
|
||||
db_path=str(self.db_path),
|
||||
operation="delete_file_symbols",
|
||||
) from exc
|
||||
|
||||
def search(
|
||||
self,
|
||||
name: str,
|
||||
kind: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
prefix_mode: bool = True,
|
||||
) -> List[Symbol]:
|
||||
"""Search symbols and return full Symbol objects."""
|
||||
if prefix_mode:
|
||||
pattern = f"{name}%"
|
||||
else:
|
||||
pattern = f"%{name}%"
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
if kind:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND symbol_name LIKE ? AND symbol_kind=?
|
||||
ORDER BY symbol_name
|
||||
LIMIT ?
|
||||
""",
|
||||
(self.project_id, pattern, kind, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND symbol_name LIKE ?
|
||||
ORDER BY symbol_name
|
||||
LIMIT ?
|
||||
""",
|
||||
(self.project_id, pattern, limit),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
Symbol(
|
||||
name=row["symbol_name"],
|
||||
kind=row["symbol_kind"],
|
||||
range=(row["start_line"], row["end_line"]),
|
||||
file=row["file_path"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def search_symbols(
|
||||
self,
|
||||
name: str,
|
||||
kind: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
prefix_mode: bool = True,
|
||||
) -> List[Tuple[str, Tuple[int, int]]]:
|
||||
"""Search symbols and return only (file_path, (start_line, end_line))."""
|
||||
symbols = self.search(name=name, kind=kind, limit=limit, prefix_mode=prefix_mode)
|
||||
return [(s.file or "", s.range) for s in symbols]
|
||||
|
||||
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
|
||||
"""Get all symbols in a specific file, sorted by start_line.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the file
|
||||
|
||||
Returns:
|
||||
List of Symbol objects sorted by start_line
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND file_path=?
|
||||
ORDER BY start_line
|
||||
""",
|
||||
(self.project_id, file_path_str),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
Symbol(
|
||||
name=row["symbol_name"],
|
||||
kind=row["symbol_kind"],
|
||||
range=(row["start_line"], row["end_line"]),
|
||||
file=row["file_path"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _get_existing_index_path(self, file_path_str: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT index_path
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND file_path=?
|
||||
LIMIT 1
|
||||
""",
|
||||
(self.project_id, file_path_str),
|
||||
).fetchone()
|
||||
return str(row["index_path"]) if row else None
|
||||
|
||||
def _get_schema_version(self, conn: sqlite3.Connection) -> int:
|
||||
try:
|
||||
row = conn.execute("PRAGMA user_version").fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _set_schema_version(self, conn: sqlite3.Connection, version: int) -> None:
|
||||
conn.execute(f"PRAGMA user_version = {int(version)}")
|
||||
|
||||
def _apply_migrations(self, conn: sqlite3.Connection, from_version: int) -> None:
|
||||
# No migrations yet (v1).
|
||||
_ = (conn, from_version)
|
||||
return
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
if self._conn is None:
|
||||
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA synchronous=NORMAL")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||
self._conn.execute("PRAGMA mmap_size=30000000000")
|
||||
return self._conn
|
||||
|
||||
def _create_schema(self, conn: sqlite3.Connection) -> None:
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS global_symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
project_id INTEGER NOT NULL,
|
||||
symbol_name TEXT NOT NULL,
|
||||
symbol_kind TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
start_line INTEGER,
|
||||
end_line INTEGER,
|
||||
index_path TEXT NOT NULL,
|
||||
UNIQUE(
|
||||
project_id, symbol_name, symbol_kind,
|
||||
file_path, start_line, end_line
|
||||
)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Required by optimization spec.
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_global_symbols_name_kind
|
||||
ON global_symbols(symbol_name, symbol_kind)
|
||||
"""
|
||||
)
|
||||
# Used by common queries (project-scoped name lookups).
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_name_kind
|
||||
ON global_symbols(project_id, symbol_name, symbol_kind)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_file
|
||||
ON global_symbols(project_id, file_path)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_index_path
|
||||
ON global_symbols(project_id, index_path)
|
||||
"""
|
||||
)
|
||||
except sqlite3.DatabaseError as exc:
|
||||
raise StorageError(
|
||||
f"Failed to initialize global symbol schema: {exc}",
|
||||
db_path=str(self.db_path),
|
||||
operation="_create_schema",
|
||||
) from exc
|
||||
|
||||
1064
codex-lens/build/lib/codexlens/storage/index_tree.py
Normal file
1064
codex-lens/build/lib/codexlens/storage/index_tree.py
Normal file
File diff suppressed because it is too large
Load Diff
136
codex-lens/build/lib/codexlens/storage/merkle_tree.py
Normal file
136
codex-lens/build/lib/codexlens/storage/merkle_tree.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Merkle tree utilities for change detection.
|
||||
|
||||
This module provides a generic, file-system based Merkle tree implementation
|
||||
that can be used to efficiently diff directory states.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
|
||||
def sha256_bytes(data: bytes) -> str:
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def sha256_text(text: str) -> str:
|
||||
return sha256_bytes(text.encode("utf-8", errors="ignore"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MerkleNode:
|
||||
"""A Merkle node representing either a file (leaf) or directory (internal)."""
|
||||
|
||||
name: str
|
||||
rel_path: str
|
||||
hash: str
|
||||
is_dir: bool
|
||||
children: Dict[str, "MerkleNode"] = field(default_factory=dict)
|
||||
|
||||
def iter_files(self) -> Iterable["MerkleNode"]:
|
||||
if not self.is_dir:
|
||||
yield self
|
||||
return
|
||||
for child in self.children.values():
|
||||
yield from child.iter_files()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MerkleTree:
|
||||
"""Merkle tree for a directory snapshot."""
|
||||
|
||||
root: MerkleNode
|
||||
|
||||
@classmethod
|
||||
def build_from_directory(cls, root_dir: Path) -> "MerkleTree":
|
||||
root_dir = Path(root_dir).resolve()
|
||||
node = cls._build_node(root_dir, base=root_dir)
|
||||
return cls(root=node)
|
||||
|
||||
@classmethod
|
||||
def _build_node(cls, path: Path, *, base: Path) -> MerkleNode:
|
||||
if path.is_file():
|
||||
rel = str(path.relative_to(base)).replace("\\", "/")
|
||||
return MerkleNode(
|
||||
name=path.name,
|
||||
rel_path=rel,
|
||||
hash=sha256_bytes(path.read_bytes()),
|
||||
is_dir=False,
|
||||
)
|
||||
|
||||
if not path.is_dir():
|
||||
rel = str(path.relative_to(base)).replace("\\", "/")
|
||||
return MerkleNode(name=path.name, rel_path=rel, hash="", is_dir=False)
|
||||
|
||||
children: Dict[str, MerkleNode] = {}
|
||||
for child in sorted(path.iterdir(), key=lambda p: p.name):
|
||||
child_node = cls._build_node(child, base=base)
|
||||
children[child_node.name] = child_node
|
||||
|
||||
items = [
|
||||
f"{'d' if n.is_dir else 'f'}:{name}:{n.hash}"
|
||||
for name, n in sorted(children.items(), key=lambda kv: kv[0])
|
||||
]
|
||||
dir_hash = sha256_text("\n".join(items))
|
||||
|
||||
rel_path = "." if path == base else str(path.relative_to(base)).replace("\\", "/")
|
||||
return MerkleNode(
|
||||
name="." if path == base else path.name,
|
||||
rel_path=rel_path,
|
||||
hash=dir_hash,
|
||||
is_dir=True,
|
||||
children=children,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def find_changed_files(old: Optional["MerkleTree"], new: Optional["MerkleTree"]) -> List[str]:
|
||||
"""Find changed/added/removed files between two trees.
|
||||
|
||||
Returns:
|
||||
List of relative file paths (POSIX-style separators).
|
||||
"""
|
||||
if old is None and new is None:
|
||||
return []
|
||||
if old is None:
|
||||
return sorted({n.rel_path for n in new.root.iter_files()}) # type: ignore[union-attr]
|
||||
if new is None:
|
||||
return sorted({n.rel_path for n in old.root.iter_files()})
|
||||
|
||||
changed: set[str] = set()
|
||||
|
||||
def walk(old_node: Optional[MerkleNode], new_node: Optional[MerkleNode]) -> None:
|
||||
if old_node is None and new_node is None:
|
||||
return
|
||||
|
||||
if old_node is None and new_node is not None:
|
||||
changed.update(n.rel_path for n in new_node.iter_files())
|
||||
return
|
||||
|
||||
if new_node is None and old_node is not None:
|
||||
changed.update(n.rel_path for n in old_node.iter_files())
|
||||
return
|
||||
|
||||
assert old_node is not None and new_node is not None
|
||||
|
||||
if old_node.hash == new_node.hash:
|
||||
return
|
||||
|
||||
if not old_node.is_dir and not new_node.is_dir:
|
||||
changed.add(new_node.rel_path)
|
||||
return
|
||||
|
||||
if old_node.is_dir != new_node.is_dir:
|
||||
changed.update(n.rel_path for n in old_node.iter_files())
|
||||
changed.update(n.rel_path for n in new_node.iter_files())
|
||||
return
|
||||
|
||||
names = set(old_node.children.keys()) | set(new_node.children.keys())
|
||||
for name in names:
|
||||
walk(old_node.children.get(name), new_node.children.get(name))
|
||||
|
||||
walk(old.root, new.root)
|
||||
return sorted(changed)
|
||||
|
||||
154
codex-lens/build/lib/codexlens/storage/migration_manager.py
Normal file
154
codex-lens/build/lib/codexlens/storage/migration_manager.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Manages database schema migrations.
|
||||
|
||||
This module provides a framework for applying versioned migrations to the SQLite
|
||||
database. Migrations are discovered from the `codexlens.storage.migrations`
|
||||
package and applied sequentially. The database schema version is tracked using
|
||||
the `user_version` pragma.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from pathlib import Path
|
||||
from sqlite3 import Connection
|
||||
from typing import List, NamedTuple
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Migration(NamedTuple):
|
||||
"""Represents a single database migration."""
|
||||
|
||||
version: int
|
||||
name: str
|
||||
upgrade: callable
|
||||
|
||||
|
||||
def discover_migrations() -> List[Migration]:
|
||||
"""
|
||||
Discovers and returns a sorted list of database migrations.
|
||||
|
||||
Migrations are expected to be in the `codexlens.storage.migrations` package,
|
||||
with filenames in the format `migration_XXX_description.py`, where XXX is
|
||||
the version number. Each migration module must contain an `upgrade` function
|
||||
that takes a `sqlite3.Connection` object as its argument.
|
||||
|
||||
Returns:
|
||||
A list of Migration objects, sorted by version.
|
||||
"""
|
||||
import codexlens.storage.migrations
|
||||
|
||||
migrations = []
|
||||
package_path = Path(codexlens.storage.migrations.__file__).parent
|
||||
|
||||
for _, name, _ in pkgutil.iter_modules([str(package_path)]):
|
||||
if name.startswith("migration_"):
|
||||
try:
|
||||
version = int(name.split("_")[1])
|
||||
module = importlib.import_module(f"codexlens.storage.migrations.{name}")
|
||||
if hasattr(module, "upgrade"):
|
||||
migrations.append(
|
||||
Migration(version=version, name=name, upgrade=module.upgrade)
|
||||
)
|
||||
else:
|
||||
log.warning(f"Migration {name} is missing 'upgrade' function.")
|
||||
except (ValueError, IndexError) as e:
|
||||
log.warning(f"Could not parse migration name {name}: {e}")
|
||||
except ImportError as e:
|
||||
log.warning(f"Could not import migration {name}: {e}")
|
||||
|
||||
migrations.sort(key=lambda m: m.version)
|
||||
return migrations
|
||||
|
||||
|
||||
class MigrationManager:
|
||||
"""
|
||||
Manages the application of migrations to a database.
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn: Connection):
|
||||
"""
|
||||
Initializes the MigrationManager.
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
self.db_conn = db_conn
|
||||
self.migrations = discover_migrations()
|
||||
|
||||
def get_current_version(self) -> int:
|
||||
"""
|
||||
Gets the current version of the database schema.
|
||||
|
||||
Returns:
|
||||
The current schema version number.
|
||||
"""
|
||||
return self.db_conn.execute("PRAGMA user_version").fetchone()[0]
|
||||
|
||||
def set_version(self, version: int):
|
||||
"""
|
||||
Sets the database schema version.
|
||||
|
||||
Args:
|
||||
version: The version number to set.
|
||||
"""
|
||||
self.db_conn.execute(f"PRAGMA user_version = {version}")
|
||||
log.info(f"Database schema version set to {version}")
|
||||
|
||||
def apply_migrations(self):
|
||||
"""
|
||||
Applies all pending migrations to the database.
|
||||
|
||||
This method checks the current database version and applies all
|
||||
subsequent migrations in order. Each migration is applied within
|
||||
a transaction, unless the migration manages its own transactions.
|
||||
"""
|
||||
current_version = self.get_current_version()
|
||||
log.info(f"Current database schema version: {current_version}")
|
||||
|
||||
for migration in self.migrations:
|
||||
if migration.version > current_version:
|
||||
log.info(f"Applying migration {migration.version}: {migration.name}...")
|
||||
try:
|
||||
# Check if a transaction is already in progress
|
||||
in_transaction = self.db_conn.in_transaction
|
||||
|
||||
# Only start transaction if not already in one
|
||||
if not in_transaction:
|
||||
self.db_conn.execute("BEGIN")
|
||||
|
||||
migration.upgrade(self.db_conn)
|
||||
self.set_version(migration.version)
|
||||
|
||||
# Only commit if we started the transaction and it's still active
|
||||
if not in_transaction and self.db_conn.in_transaction:
|
||||
self.db_conn.execute("COMMIT")
|
||||
|
||||
log.info(
|
||||
f"Successfully applied migration {migration.version}: {migration.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to apply migration {migration.version}: {migration.name}. Error: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Try to rollback if transaction is active
|
||||
try:
|
||||
if self.db_conn.in_transaction:
|
||||
self.db_conn.execute("ROLLBACK")
|
||||
except Exception:
|
||||
pass # Ignore rollback errors
|
||||
raise
|
||||
|
||||
latest_migration_version = self.migrations[-1].version if self.migrations else 0
|
||||
if current_version < latest_migration_version:
|
||||
# This case can be hit if migrations were applied but the loop was exited
|
||||
# and set_version was not called for the last one for some reason.
|
||||
# To be safe, we explicitly set the version to the latest known migration.
|
||||
final_version = self.get_current_version()
|
||||
if final_version != latest_migration_version:
|
||||
log.warning(f"Database version ({final_version}) is not the latest migration version ({latest_migration_version}). This may indicate a problem.")
|
||||
|
||||
log.info("All pending migrations applied successfully.")
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# This file makes the 'migrations' directory a Python package.
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Migration 001: Normalize keywords into separate tables.
|
||||
|
||||
This migration introduces two new tables, `keywords` and `file_keywords`, to
|
||||
store semantic keywords in a normalized fashion. It then migrates the existing
|
||||
keywords from the `semantic_data` JSON blob in the `files` table into these
|
||||
new tables. This is intended to speed up keyword-based searches significantly.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""
|
||||
Applies the migration to normalize keywords.
|
||||
|
||||
- Creates `keywords` and `file_keywords` tables.
|
||||
- Creates indexes for efficient querying.
|
||||
- Migrates data from `files.semantic_data` to the new tables.
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating 'keywords' and 'file_keywords' tables...")
|
||||
# Create a table to store unique keywords
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS keywords (
|
||||
id INTEGER PRIMARY KEY,
|
||||
keyword TEXT NOT NULL UNIQUE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create a join table to link files and keywords (many-to-many)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS file_keywords (
|
||||
file_id INTEGER NOT NULL,
|
||||
keyword_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (file_id, keyword_id),
|
||||
FOREIGN KEY (file_id) REFERENCES files (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (keyword_id) REFERENCES keywords (id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for new keyword tables...")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON keywords (keyword)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_file_id ON file_keywords (file_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_keyword_id ON file_keywords (keyword_id)")
|
||||
|
||||
log.info("Migrating existing keywords from 'semantic_metadata' table...")
|
||||
|
||||
# Check if semantic_metadata table exists before querying
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_metadata'")
|
||||
if not cursor.fetchone():
|
||||
log.info("No 'semantic_metadata' table found, skipping data migration.")
|
||||
return
|
||||
|
||||
# Check if 'keywords' column exists in semantic_metadata table
|
||||
# (current schema may already use normalized tables without this column)
|
||||
cursor.execute("PRAGMA table_info(semantic_metadata)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
if "keywords" not in columns:
|
||||
log.info("No 'keywords' column in semantic_metadata table, skipping data migration.")
|
||||
return
|
||||
|
||||
cursor.execute("SELECT file_id, keywords FROM semantic_metadata WHERE keywords IS NOT NULL AND keywords != ''")
|
||||
|
||||
files_to_migrate = cursor.fetchall()
|
||||
if not files_to_migrate:
|
||||
log.info("No existing files with semantic metadata to migrate.")
|
||||
return
|
||||
|
||||
log.info(f"Found {len(files_to_migrate)} files with semantic metadata to migrate.")
|
||||
|
||||
for file_id, keywords_json in files_to_migrate:
|
||||
if not keywords_json:
|
||||
continue
|
||||
try:
|
||||
keywords = json.loads(keywords_json)
|
||||
|
||||
if not isinstance(keywords, list):
|
||||
log.warning(f"Keywords for file_id {file_id} is not a list, skipping.")
|
||||
continue
|
||||
|
||||
for keyword in keywords:
|
||||
if not isinstance(keyword, str):
|
||||
log.warning(f"Non-string keyword '{keyword}' found for file_id {file_id}, skipping.")
|
||||
continue
|
||||
|
||||
keyword = keyword.strip()
|
||||
if not keyword:
|
||||
continue
|
||||
|
||||
# Get or create keyword_id
|
||||
cursor.execute("INSERT OR IGNORE INTO keywords (keyword) VALUES (?)", (keyword,))
|
||||
cursor.execute("SELECT id FROM keywords WHERE keyword = ?", (keyword,))
|
||||
keyword_id_result = cursor.fetchone()
|
||||
|
||||
if keyword_id_result:
|
||||
keyword_id = keyword_id_result[0]
|
||||
# Link file to keyword
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO file_keywords (file_id, keyword_id) VALUES (?, ?)",
|
||||
(file_id, keyword_id),
|
||||
)
|
||||
else:
|
||||
log.error(f"Failed to retrieve or create keyword_id for keyword: {keyword}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
log.warning(f"Could not parse keywords for file_id {file_id}: {e}")
|
||||
except Exception as e:
|
||||
log.error(f"An unexpected error occurred during migration for file_id {file_id}: {e}", exc_info=True)
|
||||
|
||||
log.info("Finished migrating keywords.")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Migration 002: Add token_count and symbol_type to symbols table.
|
||||
|
||||
This migration adds token counting metadata to symbols for accurate chunk
|
||||
splitting and performance optimization. It also adds symbol_type for better
|
||||
filtering in searches.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""
|
||||
Applies the migration to add token metadata to symbols.
|
||||
|
||||
- Adds token_count column to symbols table
|
||||
- Adds symbol_type column to symbols table (for future use)
|
||||
- Creates index on symbol_type for efficient filtering
|
||||
- Backfills existing symbols with NULL token_count (to be calculated lazily)
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Adding token_count column to symbols table...")
|
||||
try:
|
||||
cursor.execute("ALTER TABLE symbols ADD COLUMN token_count INTEGER")
|
||||
log.info("Successfully added token_count column.")
|
||||
except Exception as e:
|
||||
# Column might already exist
|
||||
log.warning(f"Could not add token_count column (might already exist): {e}")
|
||||
|
||||
log.info("Adding symbol_type column to symbols table...")
|
||||
try:
|
||||
cursor.execute("ALTER TABLE symbols ADD COLUMN symbol_type TEXT")
|
||||
log.info("Successfully added symbol_type column.")
|
||||
except Exception as e:
|
||||
# Column might already exist
|
||||
log.warning(f"Could not add symbol_type column (might already exist): {e}")
|
||||
|
||||
log.info("Creating index on symbol_type for efficient filtering...")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
|
||||
|
||||
log.info("Migration 002 completed successfully.")
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Migration 004: Add dual FTS tables for exact and fuzzy matching.
|
||||
|
||||
This migration introduces two FTS5 tables:
|
||||
- files_fts_exact: Uses unicode61 tokenizer for exact token matching
|
||||
- files_fts_fuzzy: Uses trigram tokenizer (or extended unicode61) for substring/fuzzy matching
|
||||
|
||||
Both tables are synchronized with the files table via triggers for automatic updates.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
from codexlens.storage.sqlite_utils import check_trigram_support, get_sqlite_version
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""
|
||||
Applies the migration to add dual FTS tables.
|
||||
|
||||
- Drops old files_fts table and triggers
|
||||
- Creates files_fts_exact with unicode61 tokenizer
|
||||
- Creates files_fts_fuzzy with trigram or extended unicode61 tokenizer
|
||||
- Creates synchronized triggers for both tables
|
||||
- Rebuilds FTS indexes from files table
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
try:
|
||||
# Check trigram support
|
||||
has_trigram = check_trigram_support(db_conn)
|
||||
version = get_sqlite_version(db_conn)
|
||||
log.info(f"SQLite version: {'.'.join(map(str, version))}")
|
||||
|
||||
if has_trigram:
|
||||
log.info("Trigram tokenizer available, using for fuzzy FTS table")
|
||||
fuzzy_tokenizer = "trigram"
|
||||
else:
|
||||
log.warning(
|
||||
f"Trigram tokenizer not available (requires SQLite >= 3.34), "
|
||||
f"using extended unicode61 tokenizer for fuzzy matching"
|
||||
)
|
||||
fuzzy_tokenizer = "unicode61 tokenchars '_-.'"
|
||||
|
||||
# Start transaction
|
||||
cursor.execute("BEGIN TRANSACTION")
|
||||
|
||||
# Check if files table has 'name' column (v2 schema doesn't have it)
|
||||
cursor.execute("PRAGMA table_info(files)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
|
||||
if 'name' not in columns:
|
||||
log.info("Adding 'name' column to files table (v2 schema upgrade)...")
|
||||
# Add name column
|
||||
cursor.execute("ALTER TABLE files ADD COLUMN name TEXT")
|
||||
# Populate name from path (extract filename from last '/')
|
||||
# Use Python to do the extraction since SQLite doesn't have reverse()
|
||||
cursor.execute("SELECT rowid, path FROM files")
|
||||
rows = cursor.fetchall()
|
||||
for rowid, path in rows:
|
||||
# Extract filename from path
|
||||
name = path.split('/')[-1] if '/' in path else path
|
||||
cursor.execute("UPDATE files SET name = ? WHERE rowid = ?", (name, rowid))
|
||||
|
||||
# Rename 'path' column to 'full_path' if needed
|
||||
if 'path' in columns and 'full_path' not in columns:
|
||||
log.info("Renaming 'path' to 'full_path' (v2 schema upgrade)...")
|
||||
# Check if indexed_at column exists in v2 schema
|
||||
has_indexed_at = 'indexed_at' in columns
|
||||
has_mtime = 'mtime' in columns
|
||||
|
||||
# SQLite doesn't support RENAME COLUMN before 3.25, so use table recreation
|
||||
cursor.execute("""
|
||||
CREATE TABLE files_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
full_path TEXT NOT NULL UNIQUE,
|
||||
content TEXT,
|
||||
language TEXT,
|
||||
mtime REAL,
|
||||
indexed_at TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Build INSERT statement based on available columns
|
||||
# Note: v2 schema has no rowid (path is PRIMARY KEY), so use NULL for AUTOINCREMENT
|
||||
if has_indexed_at and has_mtime:
|
||||
cursor.execute("""
|
||||
INSERT INTO files_new (name, full_path, content, language, mtime, indexed_at)
|
||||
SELECT name, path, content, language, mtime, indexed_at FROM files
|
||||
""")
|
||||
elif has_indexed_at:
|
||||
cursor.execute("""
|
||||
INSERT INTO files_new (name, full_path, content, language, indexed_at)
|
||||
SELECT name, path, content, language, indexed_at FROM files
|
||||
""")
|
||||
elif has_mtime:
|
||||
cursor.execute("""
|
||||
INSERT INTO files_new (name, full_path, content, language, mtime)
|
||||
SELECT name, path, content, language, mtime FROM files
|
||||
""")
|
||||
else:
|
||||
cursor.execute("""
|
||||
INSERT INTO files_new (name, full_path, content, language)
|
||||
SELECT name, path, content, language FROM files
|
||||
""")
|
||||
|
||||
cursor.execute("DROP TABLE files")
|
||||
cursor.execute("ALTER TABLE files_new RENAME TO files")
|
||||
|
||||
log.info("Dropping old FTS triggers and table...")
|
||||
# Drop old triggers
|
||||
cursor.execute("DROP TRIGGER IF EXISTS files_ai")
|
||||
cursor.execute("DROP TRIGGER IF EXISTS files_ad")
|
||||
cursor.execute("DROP TRIGGER IF EXISTS files_au")
|
||||
|
||||
# Drop old FTS table
|
||||
cursor.execute("DROP TABLE IF EXISTS files_fts")
|
||||
|
||||
# Create exact FTS table (unicode61 with underscores/hyphens/dots as token chars)
|
||||
# Note: tokenchars includes '.' to properly tokenize qualified names like PortRole.FLOW
|
||||
log.info("Creating files_fts_exact table with unicode61 tokenizer...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE files_fts_exact USING fts5(
|
||||
name, full_path UNINDEXED, content,
|
||||
content='files',
|
||||
content_rowid='id',
|
||||
tokenize="unicode61 tokenchars '_-.'"
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create fuzzy FTS table (trigram or extended unicode61)
|
||||
log.info(f"Creating files_fts_fuzzy table with {fuzzy_tokenizer} tokenizer...")
|
||||
cursor.execute(
|
||||
f"""
|
||||
CREATE VIRTUAL TABLE files_fts_fuzzy USING fts5(
|
||||
name, full_path UNINDEXED, content,
|
||||
content='files',
|
||||
content_rowid='id',
|
||||
tokenize="{fuzzy_tokenizer}"
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create synchronized triggers for files_fts_exact
|
||||
log.info("Creating triggers for files_fts_exact...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER files_exact_ai AFTER INSERT ON files BEGIN
|
||||
INSERT INTO files_fts_exact(rowid, name, full_path, content)
|
||||
VALUES(new.id, new.name, new.full_path, new.content);
|
||||
END
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER files_exact_ad AFTER DELETE ON files BEGIN
|
||||
INSERT INTO files_fts_exact(files_fts_exact, rowid, name, full_path, content)
|
||||
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||
END
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER files_exact_au AFTER UPDATE ON files BEGIN
|
||||
INSERT INTO files_fts_exact(files_fts_exact, rowid, name, full_path, content)
|
||||
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||
INSERT INTO files_fts_exact(rowid, name, full_path, content)
|
||||
VALUES(new.id, new.name, new.full_path, new.content);
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Create synchronized triggers for files_fts_fuzzy
|
||||
log.info("Creating triggers for files_fts_fuzzy...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER files_fuzzy_ai AFTER INSERT ON files BEGIN
|
||||
INSERT INTO files_fts_fuzzy(rowid, name, full_path, content)
|
||||
VALUES(new.id, new.name, new.full_path, new.content);
|
||||
END
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER files_fuzzy_ad AFTER DELETE ON files BEGIN
|
||||
INSERT INTO files_fts_fuzzy(files_fts_fuzzy, rowid, name, full_path, content)
|
||||
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||
END
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TRIGGER files_fuzzy_au AFTER UPDATE ON files BEGIN
|
||||
INSERT INTO files_fts_fuzzy(files_fts_fuzzy, rowid, name, full_path, content)
|
||||
VALUES('delete', old.id, old.name, old.full_path, old.content);
|
||||
INSERT INTO files_fts_fuzzy(rowid, name, full_path, content)
|
||||
VALUES(new.id, new.name, new.full_path, new.content);
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Rebuild FTS indexes from files table
|
||||
log.info("Rebuilding FTS indexes from files table...")
|
||||
cursor.execute("INSERT INTO files_fts_exact(files_fts_exact) VALUES('rebuild')")
|
||||
cursor.execute("INSERT INTO files_fts_fuzzy(files_fts_fuzzy) VALUES('rebuild')")
|
||||
|
||||
# Commit transaction
|
||||
cursor.execute("COMMIT")
|
||||
log.info("Migration 004 completed successfully")
|
||||
|
||||
# Vacuum to reclaim space (outside transaction)
|
||||
try:
|
||||
log.info("Running VACUUM to reclaim space...")
|
||||
cursor.execute("VACUUM")
|
||||
except Exception as e:
|
||||
log.warning(f"VACUUM failed (non-critical): {e}")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Migration 004 failed: {e}")
|
||||
try:
|
||||
cursor.execute("ROLLBACK")
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Migration 005: Remove unused and redundant database fields.
|
||||
|
||||
This migration removes four problematic fields identified by Gemini analysis:
|
||||
|
||||
1. **semantic_metadata.keywords** (deprecated - replaced by file_keywords table)
|
||||
- Data: Migrated to normalized file_keywords table in migration 001
|
||||
- Impact: Column now redundant, remove to prevent sync issues
|
||||
|
||||
2. **symbols.token_count** (unused - always NULL)
|
||||
- Data: Never populated, always NULL
|
||||
- Impact: No data loss, just removes unused column
|
||||
|
||||
3. **symbols.symbol_type** (redundant - duplicates kind)
|
||||
- Data: Redundant with symbols.kind field
|
||||
- Impact: No data loss, kind field contains same information
|
||||
|
||||
4. **subdirs.direct_files** (unused - never displayed)
|
||||
- Data: Never used in queries or display logic
|
||||
- Impact: No data loss, just removes unused column
|
||||
|
||||
Schema changes use table recreation pattern (SQLite best practice):
|
||||
- Create new table without deprecated columns
|
||||
- Copy data from old table
|
||||
- Drop old table
|
||||
- Rename new table
|
||||
- Recreate indexes
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""Remove unused and redundant fields from schema.
|
||||
|
||||
Note: Transaction management is handled by MigrationManager.
|
||||
This migration should NOT start its own transaction.
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
# Step 1: Remove semantic_metadata.keywords (if column exists)
|
||||
log.info("Checking semantic_metadata.keywords column...")
|
||||
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_metadata'"
|
||||
)
|
||||
if cursor.fetchone():
|
||||
# Check if keywords column exists
|
||||
cursor.execute("PRAGMA table_info(semantic_metadata)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
|
||||
if "keywords" in columns:
|
||||
log.info("Removing semantic_metadata.keywords column...")
|
||||
cursor.execute("""
|
||||
CREATE TABLE semantic_metadata_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL UNIQUE,
|
||||
summary TEXT,
|
||||
purpose TEXT,
|
||||
llm_tool TEXT,
|
||||
generated_at REAL,
|
||||
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO semantic_metadata_new (id, file_id, summary, purpose, llm_tool, generated_at)
|
||||
SELECT id, file_id, summary, purpose, llm_tool, generated_at
|
||||
FROM semantic_metadata
|
||||
""")
|
||||
|
||||
cursor.execute("DROP TABLE semantic_metadata")
|
||||
cursor.execute("ALTER TABLE semantic_metadata_new RENAME TO semantic_metadata")
|
||||
|
||||
# Recreate index
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_semantic_file ON semantic_metadata(file_id)"
|
||||
)
|
||||
log.info("Removed semantic_metadata.keywords column")
|
||||
else:
|
||||
log.info("semantic_metadata.keywords column does not exist, skipping")
|
||||
else:
|
||||
log.info("semantic_metadata table does not exist, skipping")
|
||||
|
||||
# Step 2: Remove symbols.token_count and symbols.symbol_type (if columns exist)
|
||||
log.info("Checking symbols.token_count and symbols.symbol_type columns...")
|
||||
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='symbols'"
|
||||
)
|
||||
if cursor.fetchone():
|
||||
# Check if token_count or symbol_type columns exist
|
||||
cursor.execute("PRAGMA table_info(symbols)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
|
||||
if "token_count" in columns or "symbol_type" in columns:
|
||||
log.info("Removing symbols.token_count and symbols.symbol_type columns...")
|
||||
cursor.execute("""
|
||||
CREATE TABLE symbols_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT,
|
||||
start_line INTEGER,
|
||||
end_line INTEGER,
|
||||
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO symbols_new (id, file_id, name, kind, start_line, end_line)
|
||||
SELECT id, file_id, name, kind, start_line, end_line
|
||||
FROM symbols
|
||||
""")
|
||||
|
||||
cursor.execute("DROP TABLE symbols")
|
||||
cursor.execute("ALTER TABLE symbols_new RENAME TO symbols")
|
||||
|
||||
# Recreate indexes
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
|
||||
log.info("Removed symbols.token_count and symbols.symbol_type columns")
|
||||
else:
|
||||
log.info("symbols.token_count/symbol_type columns do not exist, skipping")
|
||||
else:
|
||||
log.info("symbols table does not exist, skipping")
|
||||
|
||||
# Step 3: Remove subdirs.direct_files (if column exists)
|
||||
log.info("Checking subdirs.direct_files column...")
|
||||
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='subdirs'"
|
||||
)
|
||||
if cursor.fetchone():
|
||||
# Check if direct_files column exists
|
||||
cursor.execute("PRAGMA table_info(subdirs)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
|
||||
if "direct_files" in columns:
|
||||
log.info("Removing subdirs.direct_files column...")
|
||||
cursor.execute("""
|
||||
CREATE TABLE subdirs_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
index_path TEXT NOT NULL,
|
||||
files_count INTEGER DEFAULT 0,
|
||||
last_updated REAL
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO subdirs_new (id, name, index_path, files_count, last_updated)
|
||||
SELECT id, name, index_path, files_count, last_updated
|
||||
FROM subdirs
|
||||
""")
|
||||
|
||||
cursor.execute("DROP TABLE subdirs")
|
||||
cursor.execute("ALTER TABLE subdirs_new RENAME TO subdirs")
|
||||
|
||||
# Recreate index
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_subdirs_name ON subdirs(name)")
|
||||
log.info("Removed subdirs.direct_files column")
|
||||
else:
|
||||
log.info("subdirs.direct_files column does not exist, skipping")
|
||||
else:
|
||||
log.info("subdirs table does not exist, skipping")
|
||||
|
||||
log.info("Migration 005 completed successfully")
|
||||
|
||||
# Vacuum to reclaim space (outside transaction, optional)
|
||||
# Note: VACUUM cannot run inside a transaction, so we skip it here
|
||||
# The caller can run VACUUM separately if desired
|
||||
|
||||
|
||||
def downgrade(db_conn: Connection):
|
||||
"""Restore removed fields (data will be lost for keywords, token_count, symbol_type, direct_files).
|
||||
|
||||
This is a placeholder - true downgrade is not feasible as data is lost.
|
||||
The migration is designed to be one-way since removed fields are unused/redundant.
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
log.warning(
|
||||
"Migration 005 downgrade not supported - removed fields are unused/redundant. "
|
||||
"Data cannot be restored."
|
||||
)
|
||||
raise NotImplementedError(
|
||||
"Migration 005 downgrade not supported - this is a one-way migration"
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Migration 006: Ensure relationship tables and indexes exist.
|
||||
|
||||
This migration is intentionally idempotent. It creates the `code_relationships`
|
||||
table (used for graph visualization) and its indexes if missing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Ensuring code_relationships table exists...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL REFERENCES symbols (id) ON DELETE CASCADE,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Ensuring relationship indexes exist...")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)")
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Migration 007: Add precomputed graph neighbor table for search expansion.
|
||||
|
||||
Adds:
|
||||
- graph_neighbors: cached N-hop neighbors between symbols (keyed by symbol ids)
|
||||
|
||||
This table is derived data (a cache) and is safe to rebuild at any time.
|
||||
The migration is intentionally idempotent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating graph_neighbors table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS graph_neighbors (
|
||||
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||
neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||
relationship_depth INTEGER NOT NULL,
|
||||
PRIMARY KEY (source_symbol_id, neighbor_symbol_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for graph_neighbors...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth
|
||||
ON graph_neighbors(source_symbol_id, relationship_depth)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor
|
||||
ON graph_neighbors(neighbor_symbol_id)
|
||||
"""
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user