Compare commits

...

7 Commits

Author SHA1 Message Date
catlog22
88ff109ac4 chore: bump version to 6.3.48 2026-01-24 15:06:36 +08:00
catlog22
261196a804 docs: update CCW CLI commands with recommended commands and usage examples 2026-01-24 15:05:37 +08:00
catlog22
ea6cb8440f chore: bump version to 6.3.47
- Update ccw-coordinator.md with clarified CLI execution format
- Command-first prompt structure: /workflow:<command> -y <parameters>
- Simplified documentation with universal prompt template
- Clarify that -y is a prompt parameter, not a ccw cli parameter
2026-01-24 14:52:09 +08:00
catlog22
bf896342f4 refactor: adjust prompt structure for command execution clarity 2026-01-24 14:51:05 +08:00
catlog22
f2b0a5bbc9 Refactor code structure and remove redundant changes 2026-01-24 14:47:47 +08:00
catlog22
cf5fecd66d fix(codex-lens): resolve installation issues from frontend
- Add missing README.md file required by setuptools
- Fix deprecated license format in pyproject.toml (use SPDX string instead of TOML table)
- Add MIT LICENSE file for proper packaging
- Verified successful local installation and import

Fixes permission denied error during npm-based installation on macOS
2026-01-24 14:43:39 +08:00
catlog22
86d469ccc9 build: exclude test files from TypeScript compilation 2026-01-24 14:35:05 +08:00
120 changed files with 43410 additions and 244 deletions

View File

@@ -401,17 +401,19 @@ async function executeCommandChain(chain, analysis) {
state.updated_at = new Date().toISOString();
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
// Assemble prompt with previous results
let prompt = `Task: ${analysis.goal}\n`;
// Assemble prompt: Command first, then context
let promptContent = formatCommand(cmd, state.execution_results, analysis);
// Build full prompt: Command → Task → Previous Results
let prompt = `${promptContent}\n\nTask: ${analysis.goal}`;
if (state.execution_results.length > 0) {
prompt += '\nPrevious results:\n';
prompt += '\n\nPrevious results:\n';
state.execution_results.forEach(r => {
if (r.session_id) {
prompt += `- ${r.command}: ${r.session_id} (${r.artifacts?.join(', ') || 'completed'})\n`;
}
});
}
prompt += `\n${formatCommand(cmd, state.execution_results, analysis)}\n`;
// Record prompt used
state.prompts_used.push({
@@ -421,9 +423,12 @@ async function executeCommandChain(chain, analysis) {
});
// Execute CLI command in background and stop
// Format: ccw cli -p "PROMPT" --tool <tool> --mode <mode>
// Note: -y is a command parameter INSIDE the prompt, not a ccw cli parameter
// Example prompt: "/workflow:plan -y \"task description here\""
try {
const taskId = Bash(
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write -y`,
`ccw cli -p "${escapePrompt(prompt)}" --tool claude --mode write`,
{ run_in_background: true }
).task_id;
@@ -486,69 +491,71 @@ async function executeCommandChain(chain, analysis) {
}
// Smart parameter assembly
// Returns prompt content to be used with: ccw cli -p "RETURNED_VALUE" --tool claude --mode write
function formatCommand(cmd, previousResults, analysis) {
let line = cmd.command + ' --yes';
// Format: /workflow:<command> -y <parameters>
let prompt = `/workflow:${cmd.name} -y`;
const name = cmd.name;
// Planning commands - take task description
if (['lite-plan', 'plan', 'tdd-plan', 'multi-cli-plan'].includes(name)) {
line += ` "${analysis.goal}"`;
prompt += ` "${analysis.goal}"`;
// Lite execution - use --in-memory if plan exists
} else if (name === 'lite-execute') {
const hasPlan = previousResults.some(r => r.command.includes('plan'));
line += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
prompt += hasPlan ? ' --in-memory' : ` "${analysis.goal}"`;
// Standard execution - resume from planning session
} else if (name === 'execute') {
const plan = previousResults.find(r => r.command.includes('plan'));
if (plan?.session_id) line += ` --resume-session="${plan.session_id}"`;
if (plan?.session_id) prompt += ` --resume-session="${plan.session_id}"`;
// Bug fix commands - take bug description
} else if (['lite-fix', 'debug'].includes(name)) {
line += ` "${analysis.goal}"`;
prompt += ` "${analysis.goal}"`;
// Brainstorm - take topic description
} else if (name === 'brainstorm:auto-parallel' || name === 'auto-parallel') {
line += ` "${analysis.goal}"`;
prompt += ` "${analysis.goal}"`;
// Test generation from session - needs source session
} else if (name === 'test-gen') {
const impl = previousResults.find(r =>
r.command.includes('execute') || r.command.includes('lite-execute')
);
if (impl?.session_id) line += ` "${impl.session_id}"`;
else line += ` "${analysis.goal}"`;
if (impl?.session_id) prompt += ` "${impl.session_id}"`;
else prompt += ` "${analysis.goal}"`;
// Test fix generation - session or description
} else if (name === 'test-fix-gen') {
const latest = previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) line += ` "${latest.session_id}"`;
else line += ` "${analysis.goal}"`;
if (latest?.session_id) prompt += ` "${latest.session_id}"`;
else prompt += ` "${analysis.goal}"`;
// Review commands - take session or use latest
} else if (name === 'review') {
const latest = previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
// Review fix - takes session from review
} else if (name === 'review-fix') {
const review = previousResults.find(r => r.command.includes('review'));
const latest = review || previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
// TDD verify - takes execution session
} else if (name === 'tdd-verify') {
const exec = previousResults.find(r => r.command.includes('execute'));
if (exec?.session_id) line += ` --session="${exec.session_id}"`;
if (exec?.session_id) prompt += ` --session="${exec.session_id}"`;
// Session-based commands (test-cycle, review-session, plan-verify)
} else if (name.includes('test') || name.includes('review') || name.includes('verify')) {
const latest = previousResults.filter(r => r.session_id).pop();
if (latest?.session_id) line += ` --session="${latest.session_id}"`;
if (latest?.session_id) prompt += ` --session="${latest.session_id}"`;
}
return line;
return prompt;
}
// Hook callback: Called when background CLI completes
@@ -663,12 +670,12 @@ function parseOutput(output) {
{
"index": 0,
"command": "/workflow:plan",
"prompt": "Task: Implement user registration...\n\n/workflow:plan --yes \"Implement user registration...\""
"prompt": "/workflow:plan -y \"Implement user registration...\"\n\nTask: Implement user registration..."
},
{
"index": 1,
"command": "/workflow:execute",
"prompt": "Task: Implement user registration...\n\nPrevious results:\n- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)\n\n/workflow:execute --yes --resume-session=\"WFS-plan-20250124\""
"prompt": "/workflow:execute -y --resume-session=\"WFS-plan-20250124\"\n\nTask: Implement user registration\n\nPrevious results:\n- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)"
}
]
}
@@ -728,226 +735,68 @@ const cmd = registry.getCommand('lite-plan');
// {name, command, description, argumentHint, allowedTools, filePath}
```
## Execution Examples
## Universal Prompt Template
### Simple Feature
```
Goal: Add API endpoint for user profile
Scope: [api]
Complexity: simple
Constraints: []
Task Type: feature
### Standard Format
Pipeline (with Minimum Execution Units):
需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过
Chain:
# Unit 1: Quick Implementation
1. /workflow:lite-plan --yes "Add API endpoint..."
2. /workflow:lite-execute --yes --in-memory
# Unit 2: Test Validation
3. /workflow:test-fix-gen --yes --session="WFS-xxx"
4. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
```bash
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
```
### Complex Feature with Verification
### Prompt Content Template
```
Goal: Implement OAuth2 authentication system
Scope: [auth, database, api, frontend]
Complexity: complex
Constraints: [no breaking changes]
Task Type: feature
/workflow:<command> -y <command_parameters>
Pipeline (with Minimum Execution Units):
需求 →【plan → plan-verify】→ 验证计划 → execute → 代码
→【review-session-cycle → review-fix】→ 修复代码
→【test-fix-gen → test-cycle-execute】→ 测试通过
Task: <task_description>
Chain:
# Unit 1: Full Planning (plan + plan-verify)
1. /workflow:plan --yes "Implement OAuth2..."
2. /workflow:plan-verify --yes --session="WFS-xxx"
# Execution phase
3. /workflow:execute --yes --resume-session="WFS-xxx"
# Unit 2: Code Review (review-session-cycle + review-fix)
4. /workflow:review-session-cycle --yes --session="WFS-xxx"
5. /workflow:review-fix --yes --session="WFS-xxx"
# Unit 3: Test Validation (test-fix-gen + test-cycle-execute)
6. /workflow:test-fix-gen --yes --session="WFS-xxx"
7. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
<optional_previous_results>
```
### Quick Bug Fix
```
Goal: Fix login timeout issue
Scope: [auth]
Complexity: simple
Constraints: [urgent]
Task Type: bugfix
### Template Variables
Pipeline:
Bug报告 → lite-fix → 修复代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
| Variable | Description | Examples |
|----------|-------------|----------|
| `<command>` | Workflow command name | `plan`, `lite-execute`, `test-cycle-execute` |
| `-y` | Auto-confirm flag (inside prompt) | Always include for automation |
| `<command_parameters>` | Command-specific parameters | Task description, session ID, flags |
| `<task_description>` | Brief task description | "Implement user authentication", "Fix memory leak" |
| `<optional_previous_results>` | Context from previous commands | "Previous results:\n- /workflow:plan: WFS-xxx" |
Chain:
1. /workflow:lite-fix --yes "Fix login timeout..."
2. /workflow:test-fix-gen --yes --session="WFS-xxx"
3. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
### Command Parameter Patterns
| Command Type | Parameter Pattern | Example |
|--------------|------------------|---------|
| **Planning** | `"task description"` | `/workflow:plan -y "Implement OAuth2"` |
| **Execution (with plan)** | `--resume-session="WFS-xxx"` | `/workflow:execute -y --resume-session="WFS-plan-001"` |
| **Execution (standalone)** | `--in-memory` or `"task"` | `/workflow:lite-execute -y --in-memory` |
| **Session-based** | `--session="WFS-xxx"` | `/workflow:test-fix-gen -y --session="WFS-impl-001"` |
| **Fix/Debug** | `"problem description"` | `/workflow:lite-fix -y "Fix timeout bug"` |
### Complete Examples
**Planning Command**:
```bash
ccw cli -p '/workflow:plan -y "Implement user registration with email validation"
Task: Implement user registration' --tool claude --mode write
```
### Skip Tests
```
Goal: Update documentation
Scope: [docs]
Complexity: simple
Constraints: [skip-tests]
Task Type: feature
**Execution with Context**:
```bash
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
Pipeline:
需求 → lite-plan → 计划 → lite-execute → 代码
Task: Implement user registration
Chain:
1. /workflow:lite-plan --yes "Update documentation..."
2. /workflow:lite-execute --yes --in-memory
Previous results:
- /workflow:plan: WFS-plan-20250124 (IMPL_PLAN.md)' --tool claude --mode write
```
### TDD Workflow
```
Goal: Implement user authentication with test-first approach
Scope: [auth]
Complexity: medium
Constraints: [test-driven]
Task Type: tdd
**Standalone Lite Execution**:
```bash
ccw cli -p '/workflow:lite-fix -y "Fix login timeout in auth module"
Pipeline:
需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify → TDD验证通过
Chain:
1. /workflow:tdd-plan --yes "Implement user authentication..."
2. /workflow:execute --yes --resume-session="WFS-xxx"
3. /workflow:tdd-verify --yes --session="WFS-xxx"
```
### Debug Workflow
```
Goal: Fix memory leak in WebSocket handler
Scope: [websocket]
Complexity: medium
Constraints: [production-issue]
Task Type: bugfix
Pipeline (快速修复):
Bug报告 → lite-fix → 修复代码 → test-cycle-execute → 测试通过
Pipeline (系统调试):
Bug报告 → debug → 调试日志 → 分析定位 → 修复
Chain:
1. /workflow:lite-fix --yes "Fix memory leak in WebSocket..."
2. /workflow:test-cycle-execute --yes --session="WFS-xxx"
OR (for hypothesis-driven debugging):
1. /workflow:debug --yes "Memory leak in WebSocket handler..."
```
### Test Fix Workflow
```
Goal: Fix failing authentication tests
Scope: [auth, tests]
Complexity: simple
Constraints: []
Task Type: test-fix
Pipeline:
失败测试 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
Chain:
1. /workflow:test-fix-gen --yes "WFS-auth-impl-001"
2. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
```
### Test Generation from Implementation
```
Goal: Generate comprehensive tests for completed user registration feature
Scope: [auth, tests]
Complexity: medium
Constraints: []
Task Type: test-gen
Pipeline (with Minimum Execution Units):
代码/会话 →【test-gen → execute】→ 测试通过
Chain:
# Unit: Test Generation (test-gen + execute)
1. /workflow:test-gen --yes "WFS-registration-20250124"
2. /workflow:execute --yes --session="WFS-test-registration"
Note: test-gen creates IMPL-001 (test generation) and IMPL-002 (test execution & fix)
execute runs both tasks - this is a Minimum Execution Unit
```
### Review + Fix Workflow
```
Goal: Code review of payment module
Scope: [payment]
Complexity: medium
Constraints: []
Task Type: review
Pipeline (with Minimum Execution Units):
代码 →【review-session-cycle → review-fix】→ 修复代码
→【test-fix-gen → test-cycle-execute】→ 测试通过
Chain:
# Unit 1: Code Review (review-session-cycle + review-fix)
1. /workflow:review-session-cycle --yes --session="WFS-payment-impl"
2. /workflow:review-fix --yes --session="WFS-payment-impl"
# Unit 2: Test Validation (test-fix-gen + test-cycle-execute)
3. /workflow:test-fix-gen --yes --session="WFS-payment-impl"
4. /workflow:test-cycle-execute --yes --session="WFS-test-payment-impl"
```
### Brainstorm Workflow (Uncertain Requirements)
```
Goal: Explore solutions for real-time notification system
Scope: [notifications, architecture]
Complexity: complex
Constraints: []
Task Type: brainstorm
Pipeline:
探索主题 → brainstorm:auto-parallel → 分析结果 → plan → 详细计划
→ plan-verify → 验证计划 → execute → 代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
Chain:
1. /workflow:brainstorm:auto-parallel --yes "Explore solutions for real-time..."
2. /workflow:plan --yes "Implement chosen notification approach..."
3. /workflow:plan-verify --yes --session="WFS-xxx"
4. /workflow:execute --yes --resume-session="WFS-xxx"
5. /workflow:test-fix-gen --yes --session="WFS-xxx"
6. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
```
### Multi-CLI Plan (Multi-Perspective Analysis)
```
Goal: Compare microservices vs monolith architecture
Scope: [architecture]
Complexity: complex
Constraints: []
Task Type: multi-cli
Pipeline:
需求 → multi-cli-plan → 对比计划 → lite-execute → 代码 → test-fix-gen → 测试任务 → test-cycle-execute → 测试通过
Chain:
1. /workflow:multi-cli-plan --yes "Compare microservices vs monolith..."
2. /workflow:lite-execute --yes --in-memory
3. /workflow:test-fix-gen --yes --session="WFS-xxx"
4. /workflow:test-cycle-execute --yes --session="WFS-test-xxx"
Task: Fix login timeout' --tool claude --mode write
```
## Execution Flow
@@ -983,19 +832,76 @@ async function ccwCoordinator(taskDescription) {
## CLI Execution Model
**Serial Blocking**: Commands execute one-by-one. After launching CLI in background, orchestrator stops immediately and waits for hook callback.
### CLI Invocation Format
**IMPORTANT**: The `ccw cli` command executes prompts through external tools. The format is:
```bash
ccw cli -p "PROMPT_CONTENT" --tool <tool> --mode <mode>
```
**Parameters**:
- `-p "PROMPT_CONTENT"`: The prompt content to execute (required)
- `--tool <tool>`: CLI tool to use (e.g., `claude`, `gemini`, `qwen`)
- `--mode <mode>`: Execution mode (`analysis` or `write`)
**Note**: `-y` is a **command parameter inside the prompt**, NOT a `ccw cli` parameter.
### Prompt Assembly
The prompt content MUST start with the workflow command, followed by task context:
```
/workflow:<command> -y <parameters>
Task: <description>
<optional_context>
```
**Examples**:
```bash
# Planning command
ccw cli -p '/workflow:plan -y "Implement user registration feature"
Task: Implement user registration' --tool claude --mode write
# Execution command (with session reference)
ccw cli -p '/workflow:execute -y --resume-session="WFS-plan-20250124"
Task: Implement user registration
Previous results:
- /workflow:plan: WFS-plan-20250124' --tool claude --mode write
# Lite execution (in-memory from previous plan)
ccw cli -p '/workflow:lite-execute -y --in-memory
Task: Implement user registration' --tool claude --mode write
```
### Serial Blocking
**CRITICAL**: Commands execute one-by-one. After launching CLI in background:
1. Orchestrator stops immediately (`break`)
2. Wait for hook callback - **DO NOT use TaskOutput polling**
3. Hook callback triggers next command
**Prompt Structure**: Command must be first in prompt content
```javascript
// Example: Execute command and stop
const taskId = Bash(`ccw cli -p "..." --tool claude --mode write -y`, { run_in_background: true }).task_id;
const prompt = '/workflow:plan -y "Implement user authentication"\n\nTask: Implement user auth system';
const taskId = Bash(`ccw cli -p "${prompt}" --tool claude --mode write`, { run_in_background: true }).task_id;
state.execution_results.push({ status: 'in-progress', task_id: taskId, ... });
Write(`${stateDir}/state.json`, JSON.stringify(state, null, 2));
break; // Stop, wait for hook callback
break; // ⚠️ STOP HERE - DO NOT use TaskOutput polling
// Hook calls handleCliCompletion(sessionId, taskId, output) when done
// Hook callback will call handleCliCompletion(sessionId, taskId, output) when done
// → Updates state → Triggers next command via resumeChainExecution()
```
## Available Commands
All from `~/.claude/commands/workflow/`:
@@ -1023,20 +929,20 @@ All from `~/.claude/commands/workflow/`:
- **test-gen → execute**: 生成全面的测试套件execute 执行生成和测试
- **test-fix-gen → test-cycle-execute**: 针对特定问题生成修复任务test-cycle-execute 迭代测试和修复直到通过
### Task Type Routing (Pipeline View)
### Task Type Routing (Pipeline Summary)
**Note**: `【 】` marks Minimum Execution Units (最小执行单元) - these commands must execute together.
| Task Type | Pipeline |
|-----------|----------|
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
| **feature** (complex) | 需求 →【plan → plan-verify】→ 验证计划 → execute → 代码 →review-session-cycle → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify TDD验证通过 |
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 |
| **review** | 代码 →【review-session-cycle/review-module-cycle → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
| **brainstorm** | 探索主题 → brainstorm:auto-parallel → 分析结果 →【plan → plan-verify】→ 验证计划 → execute → 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
| **multi-cli** | 需求 → multi-cli-plan → 对比计划 → lite-execute → 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 |
| Task Type | Pipeline | Minimum Units |
|-----------|----------|---|
| **feature** (simple) | 需求 →【lite-plan → lite-execute】→ 代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Quick Implementation + Test Validation |
| **feature** (complex) | 需求 →【plan → plan-verify】→ validate → execute → 代码 → review → fix | Full Planning + Code Review + Testing |
| **bugfix** | Bug报告 → lite-fix → 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Bug Fix + Test Validation |
| **tdd** | 需求 → tdd-plan → TDD任务 → execute → 代码 → tdd-verify | TDD Planning + Execution |
| **test-fix** | 失败测试 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Test Validation |
| **test-gen** | 代码/会话 →【test-gen → execute】→ 测试通过 | Test Generation + Execution |
| **review** | 代码 →【review-* → review-fix】→ 修复代码 →【test-fix-gen → test-cycle-execute】→ 测试通过 | Code Review + Testing |
| **brainstorm** | 探索主题 → brainstorm → 分析 →【plan → plan-verify】→ execute → test | Exploration + Planning + Execution |
| **multi-cli** | 需求 → multi-cli-plan → 对比分析 → lite-execute → test | Multi-Perspective + Testing |
Use `CommandRegistry.getAllCommandsSummary()` to discover all commands dynamically.

View File

@@ -263,6 +263,49 @@ Open Dashboard via `ccw view`, manage indexes and execute searches in **CodexLen
## 💻 CCW CLI Commands
### 🌟 Recommended Commands (Main Features)
<div align="center">
<table>
<tr><th>Command</th><th>Description</th><th>When to Use</th></tr>
<tr>
<td><b>/ccw</b></td>
<td>Auto workflow orchestrator - analyzes intent, selects workflow level, executes command chain in main process</td>
<td>✅ General tasks, auto workflow selection, quick development</td>
</tr>
<tr>
<td><b>/ccw-coordinator</b></td>
<td>Manual orchestrator - recommends command chains, executes via external CLI with state persistence</td>
<td>🔧 Complex multi-step workflows, custom chains, resumable sessions</td>
</tr>
</table>
</div>
**Quick Examples**:
```bash
# /ccw - Auto workflow selection (Main Process)
/ccw "Add user authentication" # Auto-selects workflow based on intent
/ccw "Fix memory leak in WebSocket" # Detects bugfix workflow
/ccw "Implement with TDD" # Routes to TDD workflow
# /ccw-coordinator - Manual chain orchestration (External CLI)
/ccw-coordinator "Implement OAuth2 system" # Analyzes → Recommends chain → User confirms → Executes
```
**Key Differences**:
| Aspect | /ccw | /ccw-coordinator |
|--------|------|------------------|
| **Execution** | Main process (SlashCommand) | External CLI (background tasks) |
| **Selection** | Auto intent-based | Manual chain confirmation |
| **State** | TodoWrite tracking | Persistent state.json |
| **Use Case** | General tasks, quick dev | Complex chains, resumable |
---
### Other CLI Commands
```bash
ccw install # Install workflow files
ccw view # Open dashboard

View File

@@ -263,6 +263,49 @@ codexlens index /path/to/project
## 💻 CCW CLI 命令
### 🌟 推荐命令(核心功能)
<div align="center">
<table>
<tr><th>命令</th><th>说明</th><th>适用场景</th></tr>
<tr>
<td><b>/ccw</b></td>
<td>自动工作流编排器 - 分析意图、自动选择工作流级别、在主进程中执行命令链</td>
<td>✅ 通用任务、自动选择工作流、快速开发</td>
</tr>
<tr>
<td><b>/ccw-coordinator</b></td>
<td>手动编排器 - 推荐命令链、通过外部 CLI 执行、持久化状态</td>
<td>🔧 复杂多步骤工作流、自定义链、可恢复会话</td>
</tr>
</table>
</div>
**快速示例**
```bash
# /ccw - 自动工作流选择(主进程)
/ccw "添加用户认证" # 自动根据意图选择工作流
/ccw "修复 WebSocket 中的内存泄漏" # 识别为 bugfix 工作流
/ccw "使用 TDD 方式实现" # 路由到 TDD 工作流
# /ccw-coordinator - 手动链编排(外部 CLI
/ccw-coordinator "实现 OAuth2 系统" # 分析 → 推荐链 → 用户确认 → 执行
```
**主要区别**
| 方面 | /ccw | /ccw-coordinator |
|------|------|------------------|
| **执行方式** | 主进程SlashCommand | 外部 CLI后台任务 |
| **选择方式** | 自动基于意图识别 | 手动链确认 |
| **状态管理** | TodoWrite 跟踪 | 持久化 state.json |
| **适用场景** | 通用任务、快速开发 | 复杂链条、可恢复 |
---
### 其他 CLI 命令
```bash
ccw install # 安装工作流文件
ccw view # 打开 Dashboard

View File

@@ -19,5 +19,5 @@
"noEmit": false
},
"include": ["src/**/*"],
"exclude": ["src/templates/**/*", "node_modules", "dist"]
"exclude": ["src/templates/**/*", "src/**/*.test.ts", "node_modules", "dist"]
}

21
codex-lens/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 CodexLens Contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

59
codex-lens/README.md Normal file
View File

@@ -0,0 +1,59 @@
# CodexLens
CodexLens is a multi-modal code analysis platform designed to provide comprehensive code understanding and analysis capabilities.
## Features
- **Multi-language Support**: Analyze code in Python, JavaScript, TypeScript and more using Tree-sitter parsers
- **Semantic Search**: Find relevant code snippets using semantic understanding with fastembed and HNSWLIB
- **Code Parsing**: Advanced code structure parsing with tree-sitter
- **Flexible Architecture**: Modular design for easy extension and customization
## Installation
### Basic Installation
```bash
pip install codex-lens
```
### With Semantic Search
```bash
pip install codex-lens[semantic]
```
### With GPU Acceleration (NVIDIA CUDA)
```bash
pip install codex-lens[semantic-gpu]
```
### With DirectML (Windows - NVIDIA/AMD/Intel)
```bash
pip install codex-lens[semantic-directml]
```
### With All Optional Features
```bash
pip install codex-lens[full]
```
## Requirements
- Python >= 3.10
- See `pyproject.toml` for detailed dependency list
## Development
This project uses setuptools for building and packaging.
## License
MIT License
## Authors
CodexLens Contributors

View File

@@ -0,0 +1,28 @@
"""CodexLens package."""
from __future__ import annotations
from . import config, entities, errors
from .config import Config
from .entities import IndexedFile, SearchResult, SemanticChunk, Symbol
from .errors import CodexLensError, ConfigError, ParseError, SearchError, StorageError
__version__ = "0.1.0"
__all__ = [
"__version__",
"config",
"entities",
"errors",
"Config",
"IndexedFile",
"SearchResult",
"SemanticChunk",
"Symbol",
"CodexLensError",
"ConfigError",
"ParseError",
"StorageError",
"SearchError",
]

View File

@@ -0,0 +1,14 @@
"""Module entrypoint for `python -m codexlens`."""
from __future__ import annotations
from codexlens.cli import app
def main() -> None:
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,88 @@
"""Codexlens Public API Layer.
This module exports all public API functions and dataclasses for the
codexlens LSP-like functionality.
Dataclasses (from models.py):
- CallInfo: Call relationship information
- MethodContext: Method context with call relationships
- FileContextResult: File context result with method summaries
- DefinitionResult: Definition lookup result
- ReferenceResult: Reference lookup result
- GroupedReferences: References grouped by definition
- SymbolInfo: Symbol information for workspace search
- HoverInfo: Hover information for a symbol
- SemanticResult: Semantic search result
Utility functions (from utils.py):
- resolve_project: Resolve and validate project root path
- normalize_relationship_type: Normalize relationship type to canonical form
- rank_by_proximity: Rank results by file path proximity
Example:
>>> from codexlens.api import (
... DefinitionResult,
... resolve_project,
... normalize_relationship_type
... )
>>> project = resolve_project("/path/to/project")
>>> rel_type = normalize_relationship_type("calls")
>>> print(rel_type)
'call'
"""
from __future__ import annotations
# Dataclasses
from .models import (
CallInfo,
MethodContext,
FileContextResult,
DefinitionResult,
ReferenceResult,
GroupedReferences,
SymbolInfo,
HoverInfo,
SemanticResult,
)
# Utility functions
from .utils import (
resolve_project,
normalize_relationship_type,
rank_by_proximity,
rank_by_score,
)
# API functions
from .definition import find_definition
from .symbols import workspace_symbols
from .hover import get_hover
from .file_context import file_context
from .references import find_references
from .semantic import semantic_search
__all__ = [
# Dataclasses
"CallInfo",
"MethodContext",
"FileContextResult",
"DefinitionResult",
"ReferenceResult",
"GroupedReferences",
"SymbolInfo",
"HoverInfo",
"SemanticResult",
# Utility functions
"resolve_project",
"normalize_relationship_type",
"rank_by_proximity",
"rank_by_score",
# API functions
"find_definition",
"workspace_symbols",
"get_hover",
"file_context",
"find_references",
"semantic_search",
]

View File

@@ -0,0 +1,126 @@
"""find_definition API implementation.
This module provides the find_definition() function for looking up
symbol definitions with a 3-stage fallback strategy.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import DefinitionResult
from .utils import resolve_project, rank_by_proximity
logger = logging.getLogger(__name__)
def find_definition(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
file_context: Optional[str] = None,
limit: int = 10
) -> List[DefinitionResult]:
"""Find definition locations for a symbol.
Uses a 3-stage fallback strategy:
1. Exact match with kind filter
2. Exact match without kind filter
3. Prefix match
Args:
project_root: Project root directory (for index location)
symbol_name: Name of the symbol to find
symbol_kind: Optional symbol kind filter (class, function, etc.)
file_context: Optional file path for proximity ranking
limit: Maximum number of results to return
Returns:
List of DefinitionResult sorted by proximity if file_context provided
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Stage 1: Exact match with kind filter
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
if results:
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
# Stage 2: Exact match without kind (if kind was specified)
if symbol_kind:
results = _search_with_kind(global_index, symbol_name, None, limit)
if results:
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
# Stage 3: Prefix match
results = global_index.search(
name=symbol_name,
kind=None,
limit=limit,
prefix_mode=True
)
if results:
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
logger.debug(f"No definitions found for {symbol_name}")
return []
def _search_with_kind(
global_index: GlobalSymbolIndex,
symbol_name: str,
symbol_kind: Optional[str],
limit: int
) -> List[Symbol]:
"""Search for symbols with optional kind filter."""
return global_index.search(
name=symbol_name,
kind=symbol_kind,
limit=limit,
prefix_mode=False
)
def _rank_and_convert(
symbols: List[Symbol],
file_context: Optional[str]
) -> List[DefinitionResult]:
"""Convert symbols to DefinitionResult and rank by proximity."""
results = [
DefinitionResult(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
end_line=sym.range[1] if sym.range else 1,
signature=None, # Could extract from file if needed
container=None, # Could extract from parent symbol
score=1.0
)
for sym in symbols
]
return rank_by_proximity(results, file_context)

View File

@@ -0,0 +1,271 @@
"""file_context API implementation.
This module provides the file_context() function for retrieving
method call graphs from a source file.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.dir_index import DirIndexStore
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import (
FileContextResult,
MethodContext,
CallInfo,
)
from .utils import resolve_project, normalize_relationship_type
logger = logging.getLogger(__name__)
def file_context(
project_root: str,
file_path: str,
include_calls: bool = True,
include_callers: bool = True,
max_depth: int = 1,
format: str = "brief"
) -> FileContextResult:
"""Get method call context for a code file.
Retrieves all methods/functions in the file along with their
outgoing calls and incoming callers.
Args:
project_root: Project root directory (for index location)
file_path: Path to the code file to analyze
include_calls: Whether to include outgoing calls
include_callers: Whether to include incoming callers
max_depth: Call chain depth (V1 only supports 1)
format: Output format (brief | detailed | tree)
Returns:
FileContextResult with method contexts and summary
Raises:
IndexNotFoundError: If project is not indexed
FileNotFoundError: If file does not exist
ValueError: If max_depth > 1 (V1 limitation)
"""
# V1 limitation: only depth=1 supported
if max_depth > 1:
raise ValueError(
f"max_depth > 1 not supported in V1. "
f"Requested: {max_depth}, supported: 1"
)
project_path = resolve_project(project_root)
file_path_resolved = Path(file_path).resolve()
# Validate file exists
if not file_path_resolved.exists():
raise FileNotFoundError(f"File not found: {file_path_resolved}")
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Get all symbols in the file
symbols = global_index.get_file_symbols(str(file_path_resolved))
# Filter to functions, methods, and classes
method_symbols = [
s for s in symbols
if s.kind in ("function", "method", "class")
]
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
# Try to find dir_index for relationship queries
dir_index = _find_dir_index(project_info, file_path_resolved)
# Build method contexts
methods: List[MethodContext] = []
outgoing_resolved = True
incoming_resolved = True
targets_resolved = True
for symbol in method_symbols:
calls: List[CallInfo] = []
callers: List[CallInfo] = []
if include_calls and dir_index:
try:
outgoing = dir_index.get_outgoing_calls(
str(file_path_resolved),
symbol.name
)
for target_name, rel_type, line, target_file in outgoing:
calls.append(CallInfo(
symbol_name=target_name,
file_path=target_file,
line=line,
relationship=normalize_relationship_type(rel_type)
))
if target_file is None:
targets_resolved = False
except Exception as e:
logger.debug(f"Failed to get outgoing calls: {e}")
outgoing_resolved = False
if include_callers and dir_index:
try:
incoming = dir_index.get_incoming_calls(symbol.name)
for source_name, rel_type, line, source_file in incoming:
callers.append(CallInfo(
symbol_name=source_name,
file_path=source_file,
line=line,
relationship=normalize_relationship_type(rel_type)
))
except Exception as e:
logger.debug(f"Failed to get incoming calls: {e}")
incoming_resolved = False
methods.append(MethodContext(
name=symbol.name,
kind=symbol.kind,
line_range=symbol.range if symbol.range else (1, 1),
signature=None, # Could extract from source
calls=calls,
callers=callers
))
# Detect language from file extension
language = _detect_language(file_path_resolved)
# Generate summary
summary = _generate_summary(file_path_resolved, methods, format)
return FileContextResult(
file_path=str(file_path_resolved),
language=language,
methods=methods,
summary=summary,
discovery_status={
"outgoing_resolved": outgoing_resolved,
"incoming_resolved": incoming_resolved,
"targets_resolved": targets_resolved
}
)
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
"""Find the dir_index that contains the file.
Args:
project_info: Project information from registry
file_path: Path to the file
Returns:
DirIndexStore if found, None otherwise
"""
try:
# Look for _index.db in file's directory or parent directories
current = file_path.parent
while current != current.parent:
index_db = current / "_index.db"
if index_db.exists():
return DirIndexStore(str(index_db))
# Also check in project's index_root
relative = current.relative_to(project_info.source_root)
index_in_cache = project_info.index_root / relative / "_index.db"
if index_in_cache.exists():
return DirIndexStore(str(index_in_cache))
current = current.parent
except Exception as e:
logger.debug(f"Failed to find dir_index: {e}")
return None
def _detect_language(file_path: Path) -> str:
"""Detect programming language from file extension.
Args:
file_path: Path to the file
Returns:
Language name
"""
ext_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".jsx": "javascript",
".tsx": "typescript",
".go": "go",
".rs": "rust",
".java": "java",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
}
return ext_map.get(file_path.suffix.lower(), "unknown")
def _generate_summary(
file_path: Path,
methods: List[MethodContext],
format: str
) -> str:
"""Generate human-readable summary of file context.
Args:
file_path: Path to the file
methods: List of method contexts
format: Output format (brief | detailed | tree)
Returns:
Markdown-formatted summary
"""
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
for method in methods:
start, end = method.line_range
lines.append(f"### {method.name} (line {start}-{end})")
if method.calls:
calls_str = ", ".join(
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
if format == "detailed"
else c.symbol_name
for c in method.calls
)
lines.append(f"- Calls: {calls_str}")
if method.callers:
callers_str = ", ".join(
f"{c.symbol_name} ({c.file_path}:{c.line})"
if format == "detailed"
else c.symbol_name
for c in method.callers
)
lines.append(f"- Called by: {callers_str}")
if not method.calls and not method.callers:
lines.append("- (no call relationships)")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,148 @@
"""get_hover API implementation.
This module provides the get_hover() function for retrieving
detailed hover information for symbols.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import HoverInfo
from .utils import resolve_project
logger = logging.getLogger(__name__)
def get_hover(
project_root: str,
symbol_name: str,
file_path: Optional[str] = None
) -> Optional[HoverInfo]:
"""Get detailed hover information for a symbol.
Args:
project_root: Project root directory (for index location)
symbol_name: Name of the symbol to look up
file_path: Optional file path to disambiguate when symbol
appears in multiple files
Returns:
HoverInfo if symbol found, None otherwise
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Search for the symbol
results = global_index.search(
name=symbol_name,
kind=None,
limit=50,
prefix_mode=False
)
if not results:
logger.debug(f"No hover info found for {symbol_name}")
return None
# If file_path provided, filter to that file
if file_path:
file_path_resolved = str(Path(file_path).resolve())
matching = [s for s in results if s.file == file_path_resolved]
if matching:
results = matching
# Take the first result
symbol = results[0]
# Build hover info
return HoverInfo(
name=symbol.name,
kind=symbol.kind,
signature=_extract_signature(symbol),
documentation=_extract_documentation(symbol),
file_path=symbol.file or "",
line_range=symbol.range if symbol.range else (1, 1),
type_info=_extract_type_info(symbol)
)
def _extract_signature(symbol: Symbol) -> str:
"""Extract signature from symbol.
For now, generates a basic signature based on kind and name.
In a full implementation, this would parse the actual source code.
Args:
symbol: The symbol to extract signature from
Returns:
Signature string
"""
if symbol.kind == "function":
return f"def {symbol.name}(...)"
elif symbol.kind == "method":
return f"def {symbol.name}(self, ...)"
elif symbol.kind == "class":
return f"class {symbol.name}"
elif symbol.kind == "variable":
return symbol.name
elif symbol.kind == "constant":
return f"{symbol.name} = ..."
else:
return f"{symbol.kind} {symbol.name}"
def _extract_documentation(symbol: Symbol) -> Optional[str]:
"""Extract documentation from symbol.
In a full implementation, this would parse docstrings from source.
For now, returns None.
Args:
symbol: The symbol to extract documentation from
Returns:
Documentation string if available, None otherwise
"""
# Would need to read source file and parse docstring
# For V1, return None
return None
def _extract_type_info(symbol: Symbol) -> Optional[str]:
"""Extract type information from symbol.
In a full implementation, this would parse type annotations.
For now, returns None.
Args:
symbol: The symbol to extract type info from
Returns:
Type info string if available, None otherwise
"""
# Would need to parse type annotations from source
# For V1, return None
return None

View File

@@ -0,0 +1,281 @@
"""API dataclass definitions for codexlens LSP API.
This module defines all result dataclasses used by the public API layer,
following the patterns established in mcp/schema.py.
"""
from __future__ import annotations
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Dict, Tuple
# =============================================================================
# Section 4.2: file_context dataclasses
# =============================================================================
@dataclass
class CallInfo:
"""Call relationship information.
Attributes:
symbol_name: Name of the called/calling symbol
file_path: Target file path (may be None if unresolved)
line: Line number of the call
relationship: Type of relationship (call | import | inheritance)
"""
symbol_name: str
file_path: Optional[str]
line: int
relationship: str # call | import | inheritance
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class MethodContext:
"""Method context with call relationships.
Attributes:
name: Method/function name
kind: Symbol kind (function | method | class)
line_range: Start and end line numbers
signature: Function signature (if available)
calls: List of outgoing calls
callers: List of incoming calls
"""
name: str
kind: str # function | method | class
line_range: Tuple[int, int]
signature: Optional[str]
calls: List[CallInfo] = field(default_factory=list)
callers: List[CallInfo] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
result = {
"name": self.name,
"kind": self.kind,
"line_range": list(self.line_range),
"calls": [c.to_dict() for c in self.calls],
"callers": [c.to_dict() for c in self.callers],
}
if self.signature is not None:
result["signature"] = self.signature
return result
@dataclass
class FileContextResult:
"""File context result with method summaries.
Attributes:
file_path: Path to the analyzed file
language: Programming language
methods: List of method contexts
summary: Human-readable summary
discovery_status: Status flags for call resolution
"""
file_path: str
language: str
methods: List[MethodContext]
summary: str
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
"outgoing_resolved": False,
"incoming_resolved": True,
"targets_resolved": False
})
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"file_path": self.file_path,
"language": self.language,
"methods": [m.to_dict() for m in self.methods],
"summary": self.summary,
"discovery_status": self.discovery_status,
}
# =============================================================================
# Section 4.3: find_definition dataclasses
# =============================================================================
@dataclass
class DefinitionResult:
"""Definition lookup result.
Attributes:
name: Symbol name
kind: Symbol kind (class, function, method, etc.)
file_path: File where symbol is defined
line: Start line number
end_line: End line number
signature: Symbol signature (if available)
container: Containing class/module (if any)
score: Match score for ranking
"""
name: str
kind: str
file_path: str
line: int
end_line: int
signature: Optional[str] = None
container: Optional[str] = None
score: float = 1.0
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
# =============================================================================
# Section 4.4: find_references dataclasses
# =============================================================================
@dataclass
class ReferenceResult:
"""Reference lookup result.
Attributes:
file_path: File containing the reference
line: Line number
column: Column number
context_line: The line of code containing the reference
relationship: Type of reference (call | import | type_annotation | inheritance)
"""
file_path: str
line: int
column: int
context_line: str
relationship: str # call | import | type_annotation | inheritance
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class GroupedReferences:
"""References grouped by definition.
Used when a symbol has multiple definitions (e.g., overloads).
Attributes:
definition: The definition this group refers to
references: List of references to this definition
"""
definition: DefinitionResult
references: List[ReferenceResult] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"definition": self.definition.to_dict(),
"references": [r.to_dict() for r in self.references],
}
# =============================================================================
# Section 4.5: workspace_symbols dataclasses
# =============================================================================
@dataclass
class SymbolInfo:
"""Symbol information for workspace search.
Attributes:
name: Symbol name
kind: Symbol kind
file_path: File where symbol is defined
line: Line number
container: Containing class/module (if any)
score: Match score for ranking
"""
name: str
kind: str
file_path: str
line: int
container: Optional[str] = None
score: float = 1.0
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
# =============================================================================
# Section 4.6: get_hover dataclasses
# =============================================================================
@dataclass
class HoverInfo:
"""Hover information for a symbol.
Attributes:
name: Symbol name
kind: Symbol kind
signature: Symbol signature
documentation: Documentation string (if available)
file_path: File where symbol is defined
line_range: Start and end line numbers
type_info: Type information (if available)
"""
name: str
kind: str
signature: str
documentation: Optional[str]
file_path: str
line_range: Tuple[int, int]
type_info: Optional[str] = None
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
result = {
"name": self.name,
"kind": self.kind,
"signature": self.signature,
"file_path": self.file_path,
"line_range": list(self.line_range),
}
if self.documentation is not None:
result["documentation"] = self.documentation
if self.type_info is not None:
result["type_info"] = self.type_info
return result
# =============================================================================
# Section 4.7: semantic_search dataclasses
# =============================================================================
@dataclass
class SemanticResult:
"""Semantic search result.
Attributes:
symbol_name: Name of the matched symbol
kind: Symbol kind
file_path: File where symbol is defined
line: Line number
vector_score: Vector similarity score (None if not available)
structural_score: Structural match score (None if not available)
fusion_score: Combined fusion score
snippet: Code snippet
match_reason: Explanation of why this matched (optional)
"""
symbol_name: str
kind: str
file_path: str
line: int
vector_score: Optional[float]
structural_score: Optional[float]
fusion_score: float
snippet: str
match_reason: Optional[str] = None
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}

View File

@@ -0,0 +1,345 @@
"""Find references API for codexlens.
This module implements the find_references() function that wraps
ChainSearchEngine.search_references() with grouped result structure
for multi-definition symbols.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional, Dict
from .models import (
DefinitionResult,
ReferenceResult,
GroupedReferences,
)
from .utils import (
resolve_project,
normalize_relationship_type,
)
logger = logging.getLogger(__name__)
def _read_line_from_file(file_path: str, line: int) -> str:
"""Read a specific line from a file.
Args:
file_path: Path to the file
line: Line number (1-based)
Returns:
The line content, stripped of trailing whitespace.
Returns empty string if file cannot be read or line doesn't exist.
"""
try:
path = Path(file_path)
if not path.exists():
return ""
with path.open("r", encoding="utf-8", errors="replace") as f:
for i, content in enumerate(f, 1):
if i == line:
return content.rstrip()
return ""
except Exception as exc:
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
return ""
def _transform_to_reference_result(
raw_ref: "RawReferenceResult",
) -> ReferenceResult:
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
Args:
raw_ref: Raw reference result from ChainSearchEngine
Returns:
API ReferenceResult with context_line and normalized relationship
"""
# Read the actual line from the file
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
# Normalize relationship type
relationship = normalize_relationship_type(raw_ref.relationship_type)
return ReferenceResult(
file_path=raw_ref.file_path,
line=raw_ref.line,
column=raw_ref.column,
context_line=context_line,
relationship=relationship,
)
def find_references(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
include_definition: bool = True,
group_by_definition: bool = True,
limit: int = 100,
) -> List[GroupedReferences]:
"""Find all reference locations for a symbol.
Multi-definition case returns grouped results to resolve ambiguity.
This function wraps ChainSearchEngine.search_references() and groups
the results by definition location. Each GroupedReferences contains
a definition and all references that point to it.
Args:
project_root: Project root directory path
symbol_name: Name of the symbol to find references for
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
include_definition: Whether to include the definition location
in the result (default True)
group_by_definition: Whether to group references by definition.
If False, returns a single group with all references.
(default True)
limit: Maximum number of references to return (default 100)
Returns:
List of GroupedReferences. Each group contains:
- definition: The DefinitionResult for this symbol definition
- references: List of ReferenceResult pointing to this definition
Raises:
ValueError: If project_root does not exist or is not a directory
Examples:
>>> refs = find_references("/path/to/project", "authenticate")
>>> for group in refs:
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
... for ref in group.references:
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
Note:
Reference relationship types are normalized:
- 'calls' -> 'call'
- 'imports' -> 'import'
- 'inherits' -> 'inheritance'
"""
# Validate and resolve project root
project_path = resolve_project(project_root)
# Import here to avoid circular imports
from codexlens.config import Config
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.search.chain_search import ChainSearchEngine
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
from codexlens.entities import Symbol
# Initialize infrastructure
config = Config()
registry = RegistryStore()
mapper = PathMapper(config.index_dir)
# Create chain search engine
engine = ChainSearchEngine(registry, mapper, config=config)
try:
# Step 1: Find definitions for the symbol
definitions: List[DefinitionResult] = []
if include_definition or group_by_definition:
# Search for symbol definitions
symbols = engine.search_symbols(
name=symbol_name,
source_path=project_path,
kind=symbol_kind,
)
# Convert Symbol to DefinitionResult
for sym in symbols:
# Only include exact name matches for definitions
if sym.name != symbol_name:
continue
# Optionally filter by kind
if symbol_kind and sym.kind != symbol_kind:
continue
definitions.append(DefinitionResult(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
end_line=sym.range[1] if sym.range else 1,
signature=None, # Not available from Symbol
container=None, # Not available from Symbol
score=1.0,
))
# Step 2: Get all references using ChainSearchEngine
raw_references = engine.search_references(
symbol_name=symbol_name,
source_path=project_path,
depth=-1,
limit=limit,
)
# Step 3: Transform raw references to API ReferenceResult
api_references: List[ReferenceResult] = []
for raw_ref in raw_references:
api_ref = _transform_to_reference_result(raw_ref)
api_references.append(api_ref)
# Step 4: Group references by definition
if group_by_definition and definitions:
return _group_references_by_definition(
definitions=definitions,
references=api_references,
include_definition=include_definition,
)
else:
# Return single group with placeholder definition or first definition
if definitions:
definition = definitions[0]
else:
# Create placeholder definition when no definition found
definition = DefinitionResult(
name=symbol_name,
kind=symbol_kind or "unknown",
file_path="",
line=0,
end_line=0,
signature=None,
container=None,
score=0.0,
)
return [GroupedReferences(
definition=definition,
references=api_references,
)]
finally:
engine.close()
def _group_references_by_definition(
definitions: List[DefinitionResult],
references: List[ReferenceResult],
include_definition: bool = True,
) -> List[GroupedReferences]:
"""Group references by their likely definition.
Uses file proximity heuristic to assign references to definitions.
References in the same file or directory as a definition are
assigned to that definition.
Args:
definitions: List of definition locations
references: List of reference locations
include_definition: Whether to include definition in results
Returns:
List of GroupedReferences with references assigned to definitions
"""
import os
if not definitions:
return []
if len(definitions) == 1:
# Single definition - all references belong to it
return [GroupedReferences(
definition=definitions[0],
references=references,
)]
# Multiple definitions - group by proximity
groups: Dict[int, List[ReferenceResult]] = {
i: [] for i in range(len(definitions))
}
for ref in references:
# Find the closest definition by file proximity
best_def_idx = 0
best_score = -1
for i, defn in enumerate(definitions):
score = _proximity_score(ref.file_path, defn.file_path)
if score > best_score:
best_score = score
best_def_idx = i
groups[best_def_idx].append(ref)
# Build result groups
result: List[GroupedReferences] = []
for i, defn in enumerate(definitions):
# Skip definitions with no references if not including definition itself
if not include_definition and not groups[i]:
continue
result.append(GroupedReferences(
definition=defn,
references=groups[i],
))
return result
def _proximity_score(ref_path: str, def_path: str) -> int:
"""Calculate proximity score between two file paths.
Args:
ref_path: Reference file path
def_path: Definition file path
Returns:
Proximity score (higher = closer):
- Same file: 1000
- Same directory: 100
- Otherwise: common path prefix length
"""
import os
if not ref_path or not def_path:
return 0
# Normalize paths
ref_path = os.path.normpath(ref_path)
def_path = os.path.normpath(def_path)
# Same file
if ref_path == def_path:
return 1000
ref_dir = os.path.dirname(ref_path)
def_dir = os.path.dirname(def_path)
# Same directory
if ref_dir == def_dir:
return 100
# Common path prefix
try:
common = os.path.commonpath([ref_path, def_path])
return len(common)
except ValueError:
# No common path (different drives on Windows)
return 0
# Type alias for the raw reference from ChainSearchEngine
class RawReferenceResult:
"""Type stub for ChainSearchEngine.ReferenceResult.
This is only used for type hints and is replaced at runtime
by the actual import.
"""
file_path: str
line: int
column: int
context: str
relationship_type: str

View File

@@ -0,0 +1,471 @@
"""Semantic search API with RRF fusion.
This module provides the semantic_search() function for combining
vector, structural, and keyword search with configurable fusion strategies.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional
from .models import SemanticResult
from .utils import resolve_project
logger = logging.getLogger(__name__)
def semantic_search(
project_root: str,
query: str,
mode: str = "fusion",
vector_weight: float = 0.5,
structural_weight: float = 0.3,
keyword_weight: float = 0.2,
fusion_strategy: str = "rrf",
kind_filter: Optional[List[str]] = None,
limit: int = 20,
include_match_reason: bool = False,
) -> List[SemanticResult]:
"""Semantic search - combining vector and structural search.
This function provides a high-level API for semantic code search,
combining vector similarity, structural (symbol + relationships),
and keyword-based search methods with configurable fusion.
Args:
project_root: Project root directory
query: Natural language query
mode: Search mode
- vector: Vector search only
- structural: Structural search only (symbol + relationships)
- fusion: Fusion search (default)
vector_weight: Vector search weight [0, 1] (default 0.5)
structural_weight: Structural search weight [0, 1] (default 0.3)
keyword_weight: Keyword search weight [0, 1] (default 0.2)
fusion_strategy: Fusion strategy (maps to chain_search.py)
- rrf: Reciprocal Rank Fusion (recommended, default)
- staged: Staged cascade -> staged_cascade_search
- binary: Binary rerank cascade -> binary_cascade_search
- hybrid: Hybrid cascade -> hybrid_cascade_search
kind_filter: Symbol type filter (e.g., ["function", "class"])
limit: Max return count (default 20)
include_match_reason: Generate match reason (heuristic, not LLM)
Returns:
Results sorted by fusion_score
Degradation:
- No vector index: vector_score=None, uses FTS + structural search
- No relationship data: structural_score=None, vector search only
Examples:
>>> results = semantic_search(
... "/path/to/project",
... "authentication handler",
... mode="fusion",
... fusion_strategy="rrf"
... )
>>> for r in results:
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
"""
# Validate and resolve project path
project_path = resolve_project(project_root)
# Normalize weights to sum to 1.0
total_weight = vector_weight + structural_weight + keyword_weight
if total_weight > 0:
vector_weight = vector_weight / total_weight
structural_weight = structural_weight / total_weight
keyword_weight = keyword_weight / total_weight
else:
# Default to equal weights if all zero
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
# Initialize search infrastructure
try:
from codexlens.config import Config
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
except ImportError as exc:
logger.error("Failed to import search dependencies: %s", exc)
return []
# Load config
config = Config.load()
# Get or create registry and mapper
try:
registry = RegistryStore.default()
mapper = PathMapper(registry)
except Exception as exc:
logger.error("Failed to initialize search infrastructure: %s", exc)
return []
# Build search options based on mode
search_options = _build_search_options(
mode=mode,
vector_weight=vector_weight,
structural_weight=structural_weight,
keyword_weight=keyword_weight,
limit=limit,
)
# Execute search based on fusion_strategy
try:
with ChainSearchEngine(registry, mapper, config=config) as engine:
chain_result = _execute_search(
engine=engine,
query=query,
source_path=project_path,
fusion_strategy=fusion_strategy,
options=search_options,
limit=limit,
)
except Exception as exc:
logger.error("Search execution failed: %s", exc)
return []
# Transform results to SemanticResult
semantic_results = _transform_results(
results=chain_result.results,
mode=mode,
vector_weight=vector_weight,
structural_weight=structural_weight,
keyword_weight=keyword_weight,
kind_filter=kind_filter,
include_match_reason=include_match_reason,
query=query,
)
return semantic_results[:limit]
def _build_search_options(
mode: str,
vector_weight: float,
structural_weight: float,
keyword_weight: float,
limit: int,
) -> "SearchOptions":
"""Build SearchOptions based on mode and weights.
Args:
mode: Search mode (vector, structural, fusion)
vector_weight: Vector search weight
structural_weight: Structural search weight
keyword_weight: Keyword search weight
limit: Result limit
Returns:
Configured SearchOptions
"""
from codexlens.search.chain_search import SearchOptions
# Default options
options = SearchOptions(
total_limit=limit * 2, # Fetch extra for filtering
limit_per_dir=limit,
include_symbols=True, # Always include symbols for structural
)
if mode == "vector":
# Pure vector mode
options.hybrid_mode = True
options.enable_vector = True
options.pure_vector = True
options.enable_fuzzy = False
elif mode == "structural":
# Structural only - use FTS + symbols
options.hybrid_mode = True
options.enable_vector = False
options.enable_fuzzy = True
options.include_symbols = True
else:
# Fusion mode (default)
options.hybrid_mode = True
options.enable_vector = vector_weight > 0
options.enable_fuzzy = keyword_weight > 0
options.include_symbols = structural_weight > 0
# Set custom weights for RRF
if options.enable_vector and keyword_weight > 0:
options.hybrid_weights = {
"vector": vector_weight,
"exact": keyword_weight * 0.7,
"fuzzy": keyword_weight * 0.3,
}
return options
def _execute_search(
engine: "ChainSearchEngine",
query: str,
source_path: Path,
fusion_strategy: str,
options: "SearchOptions",
limit: int,
) -> "ChainSearchResult":
"""Execute search using appropriate strategy.
Maps fusion_strategy to ChainSearchEngine methods:
- rrf: Standard hybrid search with RRF fusion
- staged: staged_cascade_search
- binary: binary_cascade_search
- hybrid: hybrid_cascade_search
Args:
engine: ChainSearchEngine instance
query: Search query
source_path: Project root path
fusion_strategy: Strategy name
options: Search options
limit: Result limit
Returns:
ChainSearchResult from the search
"""
from codexlens.search.chain_search import ChainSearchResult
if fusion_strategy == "staged":
# Use staged cascade search (4-stage pipeline)
return engine.staged_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
elif fusion_strategy == "binary":
# Use binary cascade search (binary coarse + dense fine)
return engine.binary_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
elif fusion_strategy == "hybrid":
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
return engine.hybrid_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
else:
# Default: rrf - Standard search with RRF fusion
return engine.search(
query=query,
source_path=source_path,
options=options,
)
def _transform_results(
results: List,
mode: str,
vector_weight: float,
structural_weight: float,
keyword_weight: float,
kind_filter: Optional[List[str]],
include_match_reason: bool,
query: str,
) -> List[SemanticResult]:
"""Transform ChainSearchEngine results to SemanticResult.
Args:
results: List of SearchResult objects
mode: Search mode
vector_weight: Vector weight used
structural_weight: Structural weight used
keyword_weight: Keyword weight used
kind_filter: Optional symbol kind filter
include_match_reason: Whether to generate match reasons
query: Original query (for match reason generation)
Returns:
List of SemanticResult objects
"""
semantic_results = []
for result in results:
# Extract symbol info
symbol_name = getattr(result, "symbol_name", None)
symbol_kind = getattr(result, "symbol_kind", None)
start_line = getattr(result, "start_line", None)
# Use symbol object if available
if hasattr(result, "symbol") and result.symbol:
symbol_name = symbol_name or result.symbol.name
symbol_kind = symbol_kind or result.symbol.kind
if hasattr(result.symbol, "range") and result.symbol.range:
start_line = start_line or result.symbol.range[0]
# Filter by kind if specified
if kind_filter and symbol_kind:
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
continue
# Determine scores based on mode and metadata
metadata = getattr(result, "metadata", {}) or {}
fusion_score = result.score
# Try to extract source scores from metadata
source_scores = metadata.get("source_scores", {})
vector_score: Optional[float] = None
structural_score: Optional[float] = None
if mode == "vector":
# In pure vector mode, the main score is the vector score
vector_score = result.score
structural_score = None
elif mode == "structural":
# In structural mode, no vector score
vector_score = None
structural_score = result.score
else:
# Fusion mode - try to extract individual scores
if "vector" in source_scores:
vector_score = source_scores["vector"]
elif metadata.get("fusion_method") == "simple_weighted":
# From weighted fusion
vector_score = source_scores.get("vector")
# Structural score approximation (from exact/fuzzy FTS)
fts_scores = []
if "exact" in source_scores:
fts_scores.append(source_scores["exact"])
if "fuzzy" in source_scores:
fts_scores.append(source_scores["fuzzy"])
if "splade" in source_scores:
fts_scores.append(source_scores["splade"])
if fts_scores:
structural_score = max(fts_scores)
# Build snippet
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
if len(snippet) > 500:
snippet = snippet[:500] + "..."
# Generate match reason if requested
match_reason = None
if include_match_reason:
match_reason = _generate_match_reason(
query=query,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
snippet=snippet,
vector_score=vector_score,
structural_score=structural_score,
)
semantic_result = SemanticResult(
symbol_name=symbol_name or Path(result.path).stem,
kind=symbol_kind or "unknown",
file_path=result.path,
line=start_line or 1,
vector_score=vector_score,
structural_score=structural_score,
fusion_score=fusion_score,
snippet=snippet,
match_reason=match_reason,
)
semantic_results.append(semantic_result)
# Sort by fusion_score descending
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
return semantic_results
def _generate_match_reason(
query: str,
symbol_name: Optional[str],
symbol_kind: Optional[str],
snippet: str,
vector_score: Optional[float],
structural_score: Optional[float],
) -> str:
"""Generate human-readable match reason heuristically.
This is a simple heuristic-based approach, not LLM-powered.
Args:
query: Original search query
symbol_name: Symbol name if available
symbol_kind: Symbol kind if available
snippet: Code snippet
vector_score: Vector similarity score
structural_score: Structural match score
Returns:
Human-readable explanation string
"""
reasons = []
# Check for direct name match
query_lower = query.lower()
query_words = set(query_lower.split())
if symbol_name:
name_lower = symbol_name.lower()
# Direct substring match
if query_lower in name_lower or name_lower in query_lower:
reasons.append(f"Symbol name '{symbol_name}' matches query")
# Word overlap
name_words = set(_split_camel_case(symbol_name).lower().split())
overlap = query_words & name_words
if overlap and not reasons:
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
# Check snippet for keyword matches
snippet_lower = snippet.lower()
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
if matching_words and len(reasons) < 2:
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
# Add score-based reasoning
if vector_score is not None and vector_score > 0.7:
reasons.append("High semantic similarity")
elif vector_score is not None and vector_score > 0.5:
reasons.append("Moderate semantic similarity")
if structural_score is not None and structural_score > 0.8:
reasons.append("Strong structural match")
# Symbol kind context
if symbol_kind and len(reasons) < 3:
reasons.append(f"Matched {symbol_kind}")
if not reasons:
reasons.append("Partial relevance based on content analysis")
return "; ".join(reasons[:3])
def _split_camel_case(name: str) -> str:
"""Split camelCase and PascalCase to words.
Args:
name: Symbol name in camelCase or PascalCase
Returns:
Space-separated words
"""
import re
# Insert space before uppercase letters
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
# Insert space before uppercase followed by lowercase
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
# Replace underscores with spaces
result = result.replace("_", " ")
return result

View File

@@ -0,0 +1,146 @@
"""workspace_symbols API implementation.
This module provides the workspace_symbols() function for searching
symbols across the entire workspace with prefix matching.
"""
from __future__ import annotations
import fnmatch
import logging
from pathlib import Path
from typing import List, Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import SymbolInfo
from .utils import resolve_project
logger = logging.getLogger(__name__)
def workspace_symbols(
project_root: str,
query: str,
kind_filter: Optional[List[str]] = None,
file_pattern: Optional[str] = None,
limit: int = 50
) -> List[SymbolInfo]:
"""Search for symbols across the entire workspace.
Uses prefix matching for efficient searching.
Args:
project_root: Project root directory (for index location)
query: Search query (prefix match)
kind_filter: Optional list of symbol kinds to include
(e.g., ["class", "function"])
file_pattern: Optional glob pattern to filter by file path
(e.g., "*.py", "src/**/*.ts")
limit: Maximum number of results to return
Returns:
List of SymbolInfo sorted by score
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project(project_path)
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Search with prefix matching
# If kind_filter has multiple kinds, we need to search for each
all_results: List[Symbol] = []
if kind_filter and len(kind_filter) > 0:
# Search for each kind separately
for kind in kind_filter:
results = global_index.search(
name=query,
kind=kind,
limit=limit,
prefix_mode=True
)
all_results.extend(results)
else:
# Search without kind filter
all_results = global_index.search(
name=query,
kind=None,
limit=limit,
prefix_mode=True
)
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
# Apply file pattern filter if specified
if file_pattern:
all_results = [
sym for sym in all_results
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
]
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
# Convert to SymbolInfo and sort by relevance
symbols = [
SymbolInfo(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
container=None, # Could extract from parent
score=_calculate_score(sym.name, query)
)
for sym in all_results
]
# Sort by score (exact matches first)
symbols.sort(key=lambda s: s.score, reverse=True)
return symbols[:limit]
def _calculate_score(symbol_name: str, query: str) -> float:
"""Calculate relevance score for a symbol match.
Scoring:
- Exact match: 1.0
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
- Case-insensitive match: 0.6
Args:
symbol_name: The matched symbol name
query: The search query
Returns:
Score between 0.0 and 1.0
"""
if symbol_name == query:
return 1.0
if symbol_name.lower() == query.lower():
return 0.9
if symbol_name.startswith(query):
ratio = len(query) / len(symbol_name)
return 0.8 + 0.2 * ratio
if symbol_name.lower().startswith(query.lower()):
ratio = len(query) / len(symbol_name)
return 0.6 + 0.2 * ratio
return 0.5

View File

@@ -0,0 +1,153 @@
"""Utility functions for the codexlens API.
This module provides helper functions for:
- Project resolution
- Relationship type normalization
- Result ranking by proximity
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import List, Optional, TypeVar, Callable
from .models import DefinitionResult
# Type variable for generic ranking
T = TypeVar('T')
def resolve_project(project_root: str) -> Path:
"""Resolve and validate project root path.
Args:
project_root: Path to project root (relative or absolute)
Returns:
Resolved absolute Path
Raises:
ValueError: If path does not exist or is not a directory
"""
path = Path(project_root).resolve()
if not path.exists():
raise ValueError(f"Project root does not exist: {path}")
if not path.is_dir():
raise ValueError(f"Project root is not a directory: {path}")
return path
# Relationship type normalization mapping
_RELATIONSHIP_NORMALIZATION = {
# Plural to singular
"calls": "call",
"imports": "import",
"inherits": "inheritance",
"uses": "use",
# Already normalized (passthrough)
"call": "call",
"import": "import",
"inheritance": "inheritance",
"use": "use",
"type_annotation": "type_annotation",
}
def normalize_relationship_type(relationship: str) -> str:
"""Normalize relationship type to canonical form.
Converts plural forms and variations to standard singular forms:
- 'calls' -> 'call'
- 'imports' -> 'import'
- 'inherits' -> 'inheritance'
- 'uses' -> 'use'
Args:
relationship: Raw relationship type string
Returns:
Normalized relationship type
Examples:
>>> normalize_relationship_type('calls')
'call'
>>> normalize_relationship_type('inherits')
'inheritance'
>>> normalize_relationship_type('call')
'call'
"""
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
def rank_by_proximity(
results: List[DefinitionResult],
file_context: Optional[str] = None
) -> List[DefinitionResult]:
"""Rank results by file path proximity to context.
V1 Implementation: Uses path-based proximity scoring.
Scoring algorithm:
1. Same directory: highest score (100)
2. Otherwise: length of common path prefix
Args:
results: List of definition results to rank
file_context: Reference file path for proximity calculation.
If None, returns results unchanged.
Returns:
Results sorted by proximity score (highest first)
Examples:
>>> results = [
... DefinitionResult(name="foo", kind="function",
... file_path="/a/b/c.py", line=1, end_line=10),
... DefinitionResult(name="foo", kind="function",
... file_path="/a/x/y.py", line=1, end_line=10),
... ]
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
>>> ranked[0].file_path
'/a/b/c.py'
"""
if not file_context or not results:
return results
def proximity_score(result: DefinitionResult) -> int:
"""Calculate proximity score for a result."""
result_dir = os.path.dirname(result.file_path)
context_dir = os.path.dirname(file_context)
# Same directory gets highest score
if result_dir == context_dir:
return 100
# Otherwise, score by common path prefix length
try:
common = os.path.commonpath([result.file_path, file_context])
return len(common)
except ValueError:
# No common path (different drives on Windows)
return 0
return sorted(results, key=proximity_score, reverse=True)
def rank_by_score(
results: List[T],
score_fn: Callable[[T], float],
reverse: bool = True
) -> List[T]:
"""Generic ranking function by custom score.
Args:
results: List of items to rank
score_fn: Function to extract score from item
reverse: If True, highest scores first (default)
Returns:
Sorted list
"""
return sorted(results, key=score_fn, reverse=reverse)

View File

@@ -0,0 +1,27 @@
"""CLI package for CodexLens."""
from __future__ import annotations
import sys
import os
# Force UTF-8 encoding for Windows console
# This ensures Chinese characters display correctly instead of GBK garbled text
if sys.platform == "win32":
# Set environment variable for Python I/O encoding
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
# Reconfigure stdout/stderr to use UTF-8 if possible
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
# Fallback: some environments don't support reconfigure
pass
from .commands import app
__all__ = ["app"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,135 @@
"""Rich and JSON output helpers for CodexLens CLI."""
from __future__ import annotations
import json
import sys
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Any, Iterable, Mapping, Sequence
from rich.console import Console
from rich.table import Table
from rich.text import Text
from codexlens.entities import SearchResult, Symbol
# Force UTF-8 encoding for Windows console to properly display Chinese text
# Use force_terminal=True and legacy_windows=False to avoid GBK encoding issues
console = Console(force_terminal=True, legacy_windows=False)
def _to_jsonable(value: Any) -> Any:
if value is None:
return None
if hasattr(value, "model_dump"):
return value.model_dump()
if is_dataclass(value):
return asdict(value)
if isinstance(value, Path):
return str(value)
if isinstance(value, Mapping):
return {k: _to_jsonable(v) for k, v in value.items()}
if isinstance(value, (list, tuple, set)):
return [_to_jsonable(v) for v in value]
return value
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
"""Print JSON output with optional additional fields.
Args:
success: Whether the operation succeeded
result: Result data (used when success=True)
error: Error message (used when success=False)
**kwargs: Additional fields to include in the payload (e.g., code, details)
"""
payload: dict[str, Any] = {"success": success}
if success:
payload["result"] = _to_jsonable(result)
else:
payload["error"] = error or "Unknown error"
# Include additional error details if provided
for key, value in kwargs.items():
payload[key] = _to_jsonable(value)
console.print_json(json.dumps(payload, ensure_ascii=False))
def render_search_results(
results: Sequence[SearchResult], *, title: str = "Search Results", verbose: bool = False
) -> None:
"""Render search results with optional source tags in verbose mode.
Args:
results: Search results to display
title: Table title
verbose: If True, show search source tags ([E], [F], [V]) and fusion scores
"""
table = Table(title=title, show_lines=False)
if verbose:
# Verbose mode: show source tags
table.add_column("Source", style="dim", width=6, justify="center")
table.add_column("Path", style="cyan", no_wrap=True)
table.add_column("Score", style="magenta", justify="right")
table.add_column("Excerpt", style="white")
for res in results:
excerpt = res.excerpt or ""
score_str = f"{res.score:.3f}"
if verbose:
# Extract search source tag if available
source = getattr(res, "search_source", None)
source_tag = ""
if source == "exact":
source_tag = "[E]"
elif source == "fuzzy":
source_tag = "[F]"
elif source == "vector":
source_tag = "[V]"
elif source == "fusion":
source_tag = "[RRF]"
table.add_row(source_tag, res.path, score_str, excerpt)
else:
table.add_row(res.path, score_str, excerpt)
console.print(table)
def render_symbols(symbols: Sequence[Symbol], *, title: str = "Symbols") -> None:
table = Table(title=title)
table.add_column("Name", style="green")
table.add_column("Kind", style="yellow")
table.add_column("Range", style="white", justify="right")
for sym in symbols:
start, end = sym.range
table.add_row(sym.name, sym.kind, f"{start}-{end}")
console.print(table)
def render_status(stats: Mapping[str, Any]) -> None:
table = Table(title="Index Status")
table.add_column("Metric", style="cyan")
table.add_column("Value", style="white")
for key, value in stats.items():
if isinstance(value, Mapping):
value_text = ", ".join(f"{k}:{v}" for k, v in value.items())
elif isinstance(value, (list, tuple)):
value_text = ", ".join(str(v) for v in value)
else:
value_text = str(value)
table.add_row(str(key), value_text)
console.print(table)
def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) -> None:
header = Text.assemble(("File: ", "bold"), (path, "cyan"), (" Language: ", "bold"), (language, "green"))
console.print(header)
render_symbols(list(symbols), title="Discovered Symbols")

View File

@@ -0,0 +1,692 @@
"""Configuration system for CodexLens."""
from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Any, Dict, List, Optional
from .errors import ConfigError
# Workspace-local directory name
WORKSPACE_DIR_NAME = ".codexlens"
# Settings file name
SETTINGS_FILE_NAME = "settings.json"
# SPLADE index database name (centralized storage)
SPLADE_DB_NAME = "_splade.db"
# Dense vector storage names (centralized storage)
VECTORS_HNSW_NAME = "_vectors.hnsw"
VECTORS_META_DB_NAME = "_vectors_meta.db"
BINARY_VECTORS_MMAP_NAME = "_binary_vectors.mmap"
log = logging.getLogger(__name__)
def _default_global_dir() -> Path:
"""Get global CodexLens data directory."""
env_override = os.getenv("CODEXLENS_DATA_DIR")
if env_override:
return Path(env_override).expanduser().resolve()
return (Path.home() / ".codexlens").resolve()
def find_workspace_root(start_path: Path) -> Optional[Path]:
"""Find the workspace root by looking for .codexlens directory.
Searches from start_path upward to find an existing .codexlens directory.
Returns None if not found.
"""
current = start_path.resolve()
# Search up to filesystem root
while current != current.parent:
workspace_dir = current / WORKSPACE_DIR_NAME
if workspace_dir.is_dir():
return current
current = current.parent
# Check root as well
workspace_dir = current / WORKSPACE_DIR_NAME
if workspace_dir.is_dir():
return current
return None
@dataclass
class Config:
"""Runtime configuration for CodexLens.
- data_dir: Base directory for all persistent CodexLens data.
- venv_path: Optional virtualenv used for language tooling.
- supported_languages: Language IDs and their associated file extensions.
- parsing_rules: Per-language parsing and chunking hints.
"""
data_dir: Path = field(default_factory=_default_global_dir)
venv_path: Path = field(default_factory=lambda: _default_global_dir() / "venv")
supported_languages: Dict[str, Dict[str, Any]] = field(
default_factory=lambda: {
# Source code languages (category: "code")
"python": {"extensions": [".py"], "tree_sitter_language": "python", "category": "code"},
"javascript": {"extensions": [".js", ".jsx"], "tree_sitter_language": "javascript", "category": "code"},
"typescript": {"extensions": [".ts", ".tsx"], "tree_sitter_language": "typescript", "category": "code"},
"java": {"extensions": [".java"], "tree_sitter_language": "java", "category": "code"},
"go": {"extensions": [".go"], "tree_sitter_language": "go", "category": "code"},
"zig": {"extensions": [".zig"], "tree_sitter_language": "zig", "category": "code"},
"objective-c": {"extensions": [".m", ".mm"], "tree_sitter_language": "objc", "category": "code"},
"c": {"extensions": [".c", ".h"], "tree_sitter_language": "c", "category": "code"},
"cpp": {"extensions": [".cc", ".cpp", ".hpp", ".cxx"], "tree_sitter_language": "cpp", "category": "code"},
"rust": {"extensions": [".rs"], "tree_sitter_language": "rust", "category": "code"},
}
)
parsing_rules: Dict[str, Dict[str, Any]] = field(
default_factory=lambda: {
"default": {
"max_chunk_chars": 4000,
"max_chunk_lines": 200,
"overlap_lines": 20,
}
}
)
llm_enabled: bool = False
llm_tool: str = "gemini"
llm_timeout_ms: int = 300000
llm_batch_size: int = 5
# Hybrid chunker configuration
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
# Embedding configuration
embedding_backend: str = "fastembed" # "fastembed" (local) or "litellm" (API)
embedding_model: str = "code" # For fastembed: profile (fast/code/multilingual/balanced)
# For litellm: model name from config (e.g., "qwen3-embedding")
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
# SPLADE sparse retrieval configuration
enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead)
splade_model: str = "naver/splade-cocondenser-ensembledistil"
splade_threshold: float = 0.01 # Min weight to store in index
splade_onnx_path: Optional[str] = None # Custom ONNX model path
# FTS fallback (disabled by default, available via --use-fts)
use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled)
# Indexing/search optimizations
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
# Graph expansion (search-time, uses precomputed neighbors)
enable_graph_expansion: bool = False
graph_expansion_depth: int = 2
# Optional search reranking (disabled by default)
enable_reranking: bool = False
reranking_top_k: int = 50
symbol_boost_factor: float = 1.5
# Optional cross-encoder reranking (second stage; requires optional reranker deps)
enable_cross_encoder_rerank: bool = False
reranker_backend: str = "onnx"
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker_top_k: int = 50
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
reranker_chunk_type_weights: Optional[Dict[str, float]] = None # Weights for chunk types: {"code": 1.0, "docstring": 0.7}
reranker_test_file_penalty: float = 0.0 # Penalty for test files (0.0-1.0, e.g., 0.2 = 20% reduction)
# Chunk stripping configuration (for semantic embedding)
chunk_strip_comments: bool = True # Strip comments from code chunks
chunk_strip_docstrings: bool = True # Strip docstrings from code chunks
# Cascade search configuration (two-stage retrieval)
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
cascade_coarse_k: int = 100 # Number of coarse candidates from first stage
cascade_fine_k: int = 10 # Number of final results after reranking
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
# Staged cascade search configuration (4-stage pipeline)
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
# RRF fusion configuration
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
rrf_k: int = 60 # RRF constant (default 60)
# Category-based filtering to separate code/doc results
enable_category_filter: bool = True # Enable code/doc result separation
# Multi-endpoint configuration for litellm backend
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
embedding_pool_enabled: bool = False # Enable high availability pool for embeddings
embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
# Reranker multi-endpoint configuration
reranker_pool_enabled: bool = False # Enable high availability pool for reranker
reranker_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
reranker_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
# API concurrency settings
api_max_workers: int = 4 # Max concurrent API calls for embedding/reranking
api_batch_size: int = 8 # Batch size for API requests
api_batch_size_dynamic: bool = False # Enable dynamic batch size calculation
api_batch_size_utilization_factor: float = 0.8 # Use 80% of model token capacity
api_batch_size_max: int = 2048 # Absolute upper limit for batch size
chars_per_token_estimate: int = 4 # Characters per token estimation ratio
def __post_init__(self) -> None:
try:
self.data_dir = self.data_dir.expanduser().resolve()
self.venv_path = self.venv_path.expanduser().resolve()
self.data_dir.mkdir(parents=True, exist_ok=True)
except PermissionError as exc:
raise ConfigError(
f"Permission denied initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
f"[{type(exc).__name__}]: {exc}"
) from exc
except OSError as exc:
raise ConfigError(
f"Filesystem error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
f"[{type(exc).__name__}]: {exc}"
) from exc
except Exception as exc:
raise ConfigError(
f"Unexpected error initializing paths (data_dir={self.data_dir}, venv_path={self.venv_path}) "
f"[{type(exc).__name__}]: {exc}"
) from exc
@cached_property
def cache_dir(self) -> Path:
"""Directory for transient caches."""
return self.data_dir / "cache"
@cached_property
def index_dir(self) -> Path:
"""Directory where index artifacts are stored."""
return self.data_dir / "index"
@cached_property
def db_path(self) -> Path:
"""Default SQLite index path."""
return self.index_dir / "codexlens.db"
def ensure_runtime_dirs(self) -> None:
"""Create standard runtime directories if missing."""
for directory in (self.cache_dir, self.index_dir):
try:
directory.mkdir(parents=True, exist_ok=True)
except PermissionError as exc:
raise ConfigError(
f"Permission denied creating directory {directory} [{type(exc).__name__}]: {exc}"
) from exc
except OSError as exc:
raise ConfigError(
f"Filesystem error creating directory {directory} [{type(exc).__name__}]: {exc}"
) from exc
except Exception as exc:
raise ConfigError(
f"Unexpected error creating directory {directory} [{type(exc).__name__}]: {exc}"
) from exc
def language_for_path(self, path: str | Path) -> str | None:
"""Infer a supported language ID from a file path."""
extension = Path(path).suffix.lower()
for language_id, spec in self.supported_languages.items():
extensions: List[str] = spec.get("extensions", [])
if extension in extensions:
return language_id
return None
def category_for_path(self, path: str | Path) -> str | None:
"""Get file category ('code' or 'doc') from a file path."""
language = self.language_for_path(path)
if language is None:
return None
spec = self.supported_languages.get(language, {})
return spec.get("category")
def rules_for_language(self, language_id: str) -> Dict[str, Any]:
"""Get parsing rules for a specific language, falling back to defaults."""
return {**self.parsing_rules.get("default", {}), **self.parsing_rules.get(language_id, {})}
@cached_property
def settings_path(self) -> Path:
"""Path to the settings file."""
return self.data_dir / SETTINGS_FILE_NAME
def save_settings(self) -> None:
"""Save embedding and other settings to file."""
embedding_config = {
"backend": self.embedding_backend,
"model": self.embedding_model,
"use_gpu": self.embedding_use_gpu,
"pool_enabled": self.embedding_pool_enabled,
"strategy": self.embedding_strategy,
"cooldown": self.embedding_cooldown,
}
# Include multi-endpoint config if present
if self.embedding_endpoints:
embedding_config["endpoints"] = self.embedding_endpoints
settings = {
"embedding": embedding_config,
"llm": {
"enabled": self.llm_enabled,
"tool": self.llm_tool,
"timeout_ms": self.llm_timeout_ms,
"batch_size": self.llm_batch_size,
},
"reranker": {
"enabled": self.enable_cross_encoder_rerank,
"backend": self.reranker_backend,
"model": self.reranker_model,
"top_k": self.reranker_top_k,
"max_input_tokens": self.reranker_max_input_tokens,
"pool_enabled": self.reranker_pool_enabled,
"strategy": self.reranker_strategy,
"cooldown": self.reranker_cooldown,
},
"cascade": {
"strategy": self.cascade_strategy,
"coarse_k": self.cascade_coarse_k,
"fine_k": self.cascade_fine_k,
},
"api": {
"max_workers": self.api_max_workers,
"batch_size": self.api_batch_size,
"batch_size_dynamic": self.api_batch_size_dynamic,
"batch_size_utilization_factor": self.api_batch_size_utilization_factor,
"batch_size_max": self.api_batch_size_max,
"chars_per_token_estimate": self.chars_per_token_estimate,
},
}
with open(self.settings_path, "w", encoding="utf-8") as f:
json.dump(settings, f, indent=2)
def load_settings(self) -> None:
"""Load settings from file if exists."""
if not self.settings_path.exists():
return
try:
with open(self.settings_path, "r", encoding="utf-8") as f:
settings = json.load(f)
# Load embedding settings
embedding = settings.get("embedding", {})
if "backend" in embedding:
backend = embedding["backend"]
# Support 'api' as alias for 'litellm'
if backend == "api":
backend = "litellm"
if backend in {"fastembed", "litellm"}:
self.embedding_backend = backend
else:
log.warning(
"Invalid embedding backend in %s: %r (expected 'fastembed' or 'litellm')",
self.settings_path,
embedding["backend"],
)
if "model" in embedding:
self.embedding_model = embedding["model"]
if "use_gpu" in embedding:
self.embedding_use_gpu = embedding["use_gpu"]
# Load multi-endpoint configuration
if "endpoints" in embedding:
self.embedding_endpoints = embedding["endpoints"]
if "pool_enabled" in embedding:
self.embedding_pool_enabled = embedding["pool_enabled"]
if "strategy" in embedding:
self.embedding_strategy = embedding["strategy"]
if "cooldown" in embedding:
self.embedding_cooldown = embedding["cooldown"]
# Load LLM settings
llm = settings.get("llm", {})
if "enabled" in llm:
self.llm_enabled = llm["enabled"]
if "tool" in llm:
self.llm_tool = llm["tool"]
if "timeout_ms" in llm:
self.llm_timeout_ms = llm["timeout_ms"]
if "batch_size" in llm:
self.llm_batch_size = llm["batch_size"]
# Load reranker settings
reranker = settings.get("reranker", {})
if "enabled" in reranker:
self.enable_cross_encoder_rerank = reranker["enabled"]
if "backend" in reranker:
backend = reranker["backend"]
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
self.reranker_backend = backend
else:
log.warning(
"Invalid reranker backend in %s: %r (expected 'fastembed', 'onnx', 'api', 'litellm', or 'legacy')",
self.settings_path,
backend,
)
if "model" in reranker:
self.reranker_model = reranker["model"]
if "top_k" in reranker:
self.reranker_top_k = reranker["top_k"]
if "max_input_tokens" in reranker:
self.reranker_max_input_tokens = reranker["max_input_tokens"]
if "pool_enabled" in reranker:
self.reranker_pool_enabled = reranker["pool_enabled"]
if "strategy" in reranker:
self.reranker_strategy = reranker["strategy"]
if "cooldown" in reranker:
self.reranker_cooldown = reranker["cooldown"]
# Load cascade settings
cascade = settings.get("cascade", {})
if "strategy" in cascade:
strategy = cascade["strategy"]
if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}:
self.cascade_strategy = strategy
else:
log.warning(
"Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')",
self.settings_path,
strategy,
)
if "coarse_k" in cascade:
self.cascade_coarse_k = cascade["coarse_k"]
if "fine_k" in cascade:
self.cascade_fine_k = cascade["fine_k"]
# Load API settings
api = settings.get("api", {})
if "max_workers" in api:
self.api_max_workers = api["max_workers"]
if "batch_size" in api:
self.api_batch_size = api["batch_size"]
if "batch_size_dynamic" in api:
self.api_batch_size_dynamic = api["batch_size_dynamic"]
if "batch_size_utilization_factor" in api:
self.api_batch_size_utilization_factor = api["batch_size_utilization_factor"]
if "batch_size_max" in api:
self.api_batch_size_max = api["batch_size_max"]
if "chars_per_token_estimate" in api:
self.chars_per_token_estimate = api["chars_per_token_estimate"]
except Exception as exc:
log.warning(
"Failed to load settings from %s (%s): %s",
self.settings_path,
type(exc).__name__,
exc,
)
# Apply .env overrides (highest priority)
self._apply_env_overrides()
def _apply_env_overrides(self) -> None:
"""Apply environment variable overrides from .env file.
Priority: default → settings.json → .env (highest)
Supported variables (with or without CODEXLENS_ prefix):
EMBEDDING_MODEL: Override embedding model/profile
EMBEDDING_BACKEND: Override embedding backend (fastembed/litellm)
EMBEDDING_POOL_ENABLED: Enable embedding high availability pool
EMBEDDING_STRATEGY: Load balance strategy for embedding
EMBEDDING_COOLDOWN: Rate limit cooldown for embedding
RERANKER_MODEL: Override reranker model
RERANKER_BACKEND: Override reranker backend
RERANKER_ENABLED: Override reranker enabled state (true/false)
RERANKER_POOL_ENABLED: Enable reranker high availability pool
RERANKER_STRATEGY: Load balance strategy for reranker
RERANKER_COOLDOWN: Rate limit cooldown for reranker
"""
from .env_config import load_global_env
env_vars = load_global_env()
if not env_vars:
return
def get_env(key: str) -> str | None:
"""Get env var with or without CODEXLENS_ prefix."""
# Check prefixed version first (Dashboard format), then unprefixed
return env_vars.get(f"CODEXLENS_{key}") or env_vars.get(key)
# Embedding overrides
embedding_model = get_env("EMBEDDING_MODEL")
if embedding_model:
self.embedding_model = embedding_model
log.debug("Overriding embedding_model from .env: %s", self.embedding_model)
embedding_backend = get_env("EMBEDDING_BACKEND")
if embedding_backend:
backend = embedding_backend.lower()
# Support 'api' as alias for 'litellm'
if backend == "api":
backend = "litellm"
if backend in {"fastembed", "litellm"}:
self.embedding_backend = backend
log.debug("Overriding embedding_backend from .env: %s", backend)
else:
log.warning("Invalid EMBEDDING_BACKEND in .env: %r", embedding_backend)
embedding_pool = get_env("EMBEDDING_POOL_ENABLED")
if embedding_pool:
value = embedding_pool.lower()
self.embedding_pool_enabled = value in {"true", "1", "yes", "on"}
log.debug("Overriding embedding_pool_enabled from .env: %s", self.embedding_pool_enabled)
embedding_strategy = get_env("EMBEDDING_STRATEGY")
if embedding_strategy:
strategy = embedding_strategy.lower()
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
self.embedding_strategy = strategy
log.debug("Overriding embedding_strategy from .env: %s", strategy)
else:
log.warning("Invalid EMBEDDING_STRATEGY in .env: %r", embedding_strategy)
embedding_cooldown = get_env("EMBEDDING_COOLDOWN")
if embedding_cooldown:
try:
self.embedding_cooldown = float(embedding_cooldown)
log.debug("Overriding embedding_cooldown from .env: %s", self.embedding_cooldown)
except ValueError:
log.warning("Invalid EMBEDDING_COOLDOWN in .env: %r", embedding_cooldown)
# Reranker overrides
reranker_model = get_env("RERANKER_MODEL")
if reranker_model:
self.reranker_model = reranker_model
log.debug("Overriding reranker_model from .env: %s", self.reranker_model)
reranker_backend = get_env("RERANKER_BACKEND")
if reranker_backend:
backend = reranker_backend.lower()
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
self.reranker_backend = backend
log.debug("Overriding reranker_backend from .env: %s", backend)
else:
log.warning("Invalid RERANKER_BACKEND in .env: %r", reranker_backend)
reranker_enabled = get_env("RERANKER_ENABLED")
if reranker_enabled:
value = reranker_enabled.lower()
self.enable_cross_encoder_rerank = value in {"true", "1", "yes", "on"}
log.debug("Overriding reranker_enabled from .env: %s", self.enable_cross_encoder_rerank)
reranker_pool = get_env("RERANKER_POOL_ENABLED")
if reranker_pool:
value = reranker_pool.lower()
self.reranker_pool_enabled = value in {"true", "1", "yes", "on"}
log.debug("Overriding reranker_pool_enabled from .env: %s", self.reranker_pool_enabled)
reranker_strategy = get_env("RERANKER_STRATEGY")
if reranker_strategy:
strategy = reranker_strategy.lower()
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
self.reranker_strategy = strategy
log.debug("Overriding reranker_strategy from .env: %s", strategy)
else:
log.warning("Invalid RERANKER_STRATEGY in .env: %r", reranker_strategy)
reranker_cooldown = get_env("RERANKER_COOLDOWN")
if reranker_cooldown:
try:
self.reranker_cooldown = float(reranker_cooldown)
log.debug("Overriding reranker_cooldown from .env: %s", self.reranker_cooldown)
except ValueError:
log.warning("Invalid RERANKER_COOLDOWN in .env: %r", reranker_cooldown)
reranker_max_tokens = get_env("RERANKER_MAX_INPUT_TOKENS")
if reranker_max_tokens:
try:
self.reranker_max_input_tokens = int(reranker_max_tokens)
log.debug("Overriding reranker_max_input_tokens from .env: %s", self.reranker_max_input_tokens)
except ValueError:
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
# Reranker tuning from environment
test_penalty = get_env("RERANKER_TEST_FILE_PENALTY")
if test_penalty:
try:
self.reranker_test_file_penalty = float(test_penalty)
log.debug("Overriding reranker_test_file_penalty from .env: %s", self.reranker_test_file_penalty)
except ValueError:
log.warning("Invalid RERANKER_TEST_FILE_PENALTY in .env: %r", test_penalty)
docstring_weight = get_env("RERANKER_DOCSTRING_WEIGHT")
if docstring_weight:
try:
weight = float(docstring_weight)
self.reranker_chunk_type_weights = {"code": 1.0, "docstring": weight}
log.debug("Overriding reranker docstring weight from .env: %s", weight)
except ValueError:
log.warning("Invalid RERANKER_DOCSTRING_WEIGHT in .env: %r", docstring_weight)
# Chunk stripping from environment
strip_comments = get_env("CHUNK_STRIP_COMMENTS")
if strip_comments:
self.chunk_strip_comments = strip_comments.lower() in ("true", "1", "yes")
log.debug("Overriding chunk_strip_comments from .env: %s", self.chunk_strip_comments)
strip_docstrings = get_env("CHUNK_STRIP_DOCSTRINGS")
if strip_docstrings:
self.chunk_strip_docstrings = strip_docstrings.lower() in ("true", "1", "yes")
log.debug("Overriding chunk_strip_docstrings from .env: %s", self.chunk_strip_docstrings)
@classmethod
def load(cls) -> "Config":
"""Load config with settings from file."""
config = cls()
config.load_settings()
return config
@dataclass
class WorkspaceConfig:
"""Workspace-local configuration for CodexLens.
Stores index data in project/.codexlens/ directory.
"""
workspace_root: Path
def __post_init__(self) -> None:
self.workspace_root = Path(self.workspace_root).resolve()
@property
def codexlens_dir(self) -> Path:
"""The .codexlens directory in workspace root."""
return self.workspace_root / WORKSPACE_DIR_NAME
@property
def db_path(self) -> Path:
"""SQLite index path for this workspace."""
return self.codexlens_dir / "index.db"
@property
def cache_dir(self) -> Path:
"""Cache directory for this workspace."""
return self.codexlens_dir / "cache"
@property
def env_path(self) -> Path:
"""Path to workspace .env file."""
return self.codexlens_dir / ".env"
def load_env(self, *, override: bool = False) -> int:
"""Load .env file and apply to os.environ.
Args:
override: If True, override existing environment variables
Returns:
Number of variables applied
"""
from .env_config import apply_workspace_env
return apply_workspace_env(self.workspace_root, override=override)
def get_api_config(self, prefix: str) -> dict:
"""Get API configuration from environment.
Args:
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
Returns:
Dictionary with api_key, api_base, model, etc.
"""
from .env_config import get_api_config
return get_api_config(prefix, workspace_root=self.workspace_root)
def initialize(self) -> None:
"""Create the .codexlens directory structure."""
try:
self.codexlens_dir.mkdir(parents=True, exist_ok=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Create .gitignore to exclude cache but keep index
gitignore_path = self.codexlens_dir / ".gitignore"
if not gitignore_path.exists():
gitignore_path.write_text(
"# CodexLens workspace data\n"
"cache/\n"
"*.log\n"
".env\n" # Exclude .env from git
)
except Exception as exc:
raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc
def exists(self) -> bool:
"""Check if workspace is already initialized."""
return self.codexlens_dir.is_dir() and self.db_path.exists()
@classmethod
def from_path(cls, path: Path) -> Optional["WorkspaceConfig"]:
"""Create WorkspaceConfig from a path by finding workspace root.
Returns None if no workspace found.
"""
root = find_workspace_root(path)
if root is None:
return None
return cls(workspace_root=root)
@classmethod
def create_at(cls, path: Path) -> "WorkspaceConfig":
"""Create a new workspace at the given path."""
config = cls(workspace_root=path)
config.initialize()
return config

View File

@@ -0,0 +1,128 @@
"""Pydantic entity models for CodexLens."""
from __future__ import annotations
import math
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field, field_validator
class Symbol(BaseModel):
"""A code symbol discovered in a file."""
name: str = Field(..., min_length=1)
kind: str = Field(..., min_length=1)
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
@field_validator("range")
@classmethod
def validate_range(cls, value: Tuple[int, int]) -> Tuple[int, int]:
if len(value) != 2:
raise ValueError("range must be a (start_line, end_line) tuple")
start_line, end_line = value
if start_line < 1 or end_line < 1:
raise ValueError("range lines must be >= 1")
if end_line < start_line:
raise ValueError("end_line must be >= start_line")
return value
class SemanticChunk(BaseModel):
"""A semantically meaningful chunk of content, optionally embedded."""
content: str = Field(..., min_length=1)
embedding: Optional[List[float]] = Field(default=None, description="Vector embedding for semantic search")
metadata: Dict[str, Any] = Field(default_factory=dict)
id: Optional[int] = Field(default=None, description="Database row ID")
file_path: Optional[str] = Field(default=None, description="Source file path")
@field_validator("embedding")
@classmethod
def validate_embedding(cls, value: Optional[List[float]]) -> Optional[List[float]]:
if value is None:
return value
if not value:
raise ValueError("embedding cannot be empty when provided")
norm = math.sqrt(sum(x * x for x in value))
epsilon = 1e-10
if norm < epsilon:
raise ValueError("embedding cannot be a zero vector")
return value
class IndexedFile(BaseModel):
"""An indexed source file with symbols and optional semantic chunks."""
path: str = Field(..., min_length=1)
language: str = Field(..., min_length=1)
symbols: List[Symbol] = Field(default_factory=list)
chunks: List[SemanticChunk] = Field(default_factory=list)
relationships: List["CodeRelationship"] = Field(default_factory=list)
@field_validator("path", "language")
@classmethod
def strip_and_validate_nonempty(cls, value: str) -> str:
cleaned = value.strip()
if not cleaned:
raise ValueError("value cannot be blank")
return cleaned
class RelationshipType(str, Enum):
"""Types of code relationships."""
CALL = "calls"
INHERITS = "inherits"
IMPORTS = "imports"
class CodeRelationship(BaseModel):
"""A relationship between code symbols (e.g., function calls, inheritance)."""
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
relationship_type: RelationshipType = Field(..., description="Type of relationship (call, inherits, etc.)")
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
class AdditionalLocation(BaseModel):
"""A pointer to another location where a similar result was found.
Used for grouping search results with similar scores and content,
where the primary result is stored in SearchResult and secondary
locations are stored in this model.
"""
path: str = Field(..., min_length=1)
score: float = Field(..., ge=0.0)
start_line: Optional[int] = Field(default=None, description="Start line of the result (1-based)")
end_line: Optional[int] = Field(default=None, description="End line of the result (1-based)")
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol")
class SearchResult(BaseModel):
"""A unified search result for lexical or semantic search."""
path: str = Field(..., min_length=1)
score: float = Field(..., ge=0.0)
excerpt: Optional[str] = None
content: Optional[str] = Field(default=None, description="Full content of matched code block")
symbol: Optional[Symbol] = None
chunk: Optional[SemanticChunk] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
# Additional context for complete code blocks
start_line: Optional[int] = Field(default=None, description="Start line of code block (1-based)")
end_line: Optional[int] = Field(default=None, description="End line of code block (1-based)")
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol/function/class")
symbol_kind: Optional[str] = Field(default=None, description="Kind of symbol (function/class/method)")
# Field for grouping similar results
additional_locations: List["AdditionalLocation"] = Field(
default_factory=list,
description="Other locations for grouped results with similar scores and content."
)

View File

@@ -0,0 +1,304 @@
"""Environment configuration loader for CodexLens.
Loads .env files from workspace .codexlens directory with fallback to project root.
Provides unified access to API configurations.
Priority order:
1. Environment variables (already set)
2. .codexlens/.env (workspace-local)
3. .env (project root)
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional
log = logging.getLogger(__name__)
# Supported environment variables with descriptions
ENV_VARS = {
# Reranker configuration (overrides settings.json)
"RERANKER_MODEL": "Reranker model name (overrides settings.json)",
"RERANKER_BACKEND": "Reranker backend: fastembed, onnx, api, litellm, legacy",
"RERANKER_ENABLED": "Enable reranker: true/false",
"RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)",
"RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)",
"RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina",
"RERANKER_POOL_ENABLED": "Enable reranker high availability pool: true/false",
"RERANKER_STRATEGY": "Reranker load balance strategy: round_robin, latency_aware, weighted_random",
"RERANKER_COOLDOWN": "Reranker rate limit cooldown in seconds",
# Embedding configuration (overrides settings.json)
"EMBEDDING_MODEL": "Embedding model/profile name (overrides settings.json)",
"EMBEDDING_BACKEND": "Embedding backend: fastembed, litellm",
"EMBEDDING_API_KEY": "API key for embedding service",
"EMBEDDING_API_BASE": "Base URL for embedding API",
"EMBEDDING_POOL_ENABLED": "Enable embedding high availability pool: true/false",
"EMBEDDING_STRATEGY": "Embedding load balance strategy: round_robin, latency_aware, weighted_random",
"EMBEDDING_COOLDOWN": "Embedding rate limit cooldown in seconds",
# LiteLLM configuration
"LITELLM_API_KEY": "API key for LiteLLM",
"LITELLM_API_BASE": "Base URL for LiteLLM",
"LITELLM_MODEL": "LiteLLM model name",
# General configuration
"CODEXLENS_DATA_DIR": "Custom data directory path",
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
# Chunking configuration
"CHUNK_STRIP_COMMENTS": "Strip comments from code chunks for embedding: true/false (default: true)",
"CHUNK_STRIP_DOCSTRINGS": "Strip docstrings from code chunks for embedding: true/false (default: true)",
# Reranker tuning
"RERANKER_TEST_FILE_PENALTY": "Penalty for test files in reranking: 0.0-1.0 (default: 0.0)",
"RERANKER_DOCSTRING_WEIGHT": "Weight for docstring chunks in reranking: 0.0-1.0 (default: 1.0)",
}
def _parse_env_line(line: str) -> tuple[str, str] | None:
"""Parse a single .env line, returning (key, value) or None."""
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
return None
# Handle export prefix
if line.startswith("export "):
line = line[7:].strip()
# Split on first =
if "=" not in line:
return None
key, _, value = line.partition("=")
key = key.strip()
value = value.strip()
# Remove surrounding quotes
if len(value) >= 2:
if (value.startswith('"') and value.endswith('"')) or \
(value.startswith("'") and value.endswith("'")):
value = value[1:-1]
return key, value
def load_env_file(env_path: Path) -> Dict[str, str]:
"""Load environment variables from a .env file.
Args:
env_path: Path to .env file
Returns:
Dictionary of environment variables
"""
if not env_path.is_file():
return {}
env_vars: Dict[str, str] = {}
try:
content = env_path.read_text(encoding="utf-8")
for line in content.splitlines():
result = _parse_env_line(line)
if result:
key, value = result
env_vars[key] = value
except Exception as exc:
log.warning("Failed to load .env file %s: %s", env_path, exc)
return env_vars
def _get_global_data_dir() -> Path:
"""Get global CodexLens data directory."""
env_override = os.environ.get("CODEXLENS_DATA_DIR")
if env_override:
return Path(env_override).expanduser().resolve()
return (Path.home() / ".codexlens").resolve()
def load_global_env() -> Dict[str, str]:
"""Load environment variables from global ~/.codexlens/.env file.
Returns:
Dictionary of environment variables from global config
"""
global_env_path = _get_global_data_dir() / ".env"
if global_env_path.is_file():
env_vars = load_env_file(global_env_path)
log.debug("Loaded %d vars from global %s", len(env_vars), global_env_path)
return env_vars
return {}
def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]:
"""Load environment variables from workspace .env files.
Priority (later overrides earlier):
1. Global ~/.codexlens/.env (lowest priority)
2. Project root .env
3. .codexlens/.env (highest priority)
Args:
workspace_root: Workspace root directory. If None, uses current directory.
Returns:
Merged dictionary of environment variables
"""
if workspace_root is None:
workspace_root = Path.cwd()
workspace_root = Path(workspace_root).resolve()
env_vars: Dict[str, str] = {}
# Load from global ~/.codexlens/.env (lowest priority)
global_vars = load_global_env()
if global_vars:
env_vars.update(global_vars)
# Load from project root .env (medium priority)
root_env = workspace_root / ".env"
if root_env.is_file():
loaded = load_env_file(root_env)
env_vars.update(loaded)
log.debug("Loaded %d vars from %s", len(loaded), root_env)
# Load from .codexlens/.env (highest priority)
codexlens_env = workspace_root / ".codexlens" / ".env"
if codexlens_env.is_file():
loaded = load_env_file(codexlens_env)
env_vars.update(loaded)
log.debug("Loaded %d vars from %s", len(loaded), codexlens_env)
return env_vars
def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int:
"""Load .env files and apply to os.environ.
Args:
workspace_root: Workspace root directory
override: If True, override existing environment variables
Returns:
Number of variables applied
"""
env_vars = load_workspace_env(workspace_root)
applied = 0
for key, value in env_vars.items():
if override or key not in os.environ:
os.environ[key] = value
applied += 1
log.debug("Applied env var: %s", key)
return applied
def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback.
Priority:
1. os.environ (already set)
2. .codexlens/.env
3. .env
4. default value
Args:
key: Environment variable name
default: Default value if not found
workspace_root: Workspace root for .env file lookup
Returns:
Value or default
"""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Load from .env files
env_vars = load_workspace_env(workspace_root)
if key in env_vars:
return env_vars[key]
return default
def get_api_config(
prefix: str,
*,
workspace_root: Path | None = None,
defaults: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
"""Get API configuration from environment.
Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc.
Args:
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
workspace_root: Workspace root for .env file lookup
defaults: Default values
Returns:
Dictionary with api_key, api_base, model, etc.
"""
defaults = defaults or {}
config: Dict[str, Any] = {}
# Standard API config fields
field_mapping = {
"api_key": f"{prefix}_API_KEY",
"api_base": f"{prefix}_API_BASE",
"model": f"{prefix}_MODEL",
"provider": f"{prefix}_PROVIDER",
"timeout": f"{prefix}_TIMEOUT",
}
for field, env_key in field_mapping.items():
value = get_env(env_key, workspace_root=workspace_root)
if value is not None:
# Type conversion for specific fields
if field == "timeout":
try:
config[field] = float(value)
except ValueError:
pass
else:
config[field] = value
elif field in defaults:
config[field] = defaults[field]
return config
def generate_env_example() -> str:
"""Generate .env.example content with all supported variables.
Returns:
String content for .env.example file
"""
lines = [
"# CodexLens Environment Configuration",
"# Copy this file to .codexlens/.env and fill in your values",
"",
]
# Group by prefix
groups: Dict[str, list] = {}
for key, desc in ENV_VARS.items():
prefix = key.split("_")[0]
if prefix not in groups:
groups[prefix] = []
groups[prefix].append((key, desc))
for prefix, items in groups.items():
lines.append(f"# {prefix} Configuration")
for key, desc in items:
lines.append(f"# {desc}")
lines.append(f"# {key}=")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,59 @@
"""CodexLens exception hierarchy."""
from __future__ import annotations
class CodexLensError(Exception):
"""Base class for all CodexLens errors."""
class ConfigError(CodexLensError):
"""Raised when configuration is invalid or cannot be loaded."""
class ParseError(CodexLensError):
"""Raised when parsing or indexing a file fails."""
class StorageError(CodexLensError):
"""Raised when reading/writing index storage fails.
Attributes:
message: Human-readable error description
db_path: Path to the database file (if applicable)
operation: The operation that failed (e.g., 'query', 'initialize', 'migrate')
details: Additional context for debugging
"""
def __init__(
self,
message: str,
db_path: str | None = None,
operation: str | None = None,
details: dict | None = None
) -> None:
super().__init__(message)
self.message = message
self.db_path = db_path
self.operation = operation
self.details = details or {}
def __str__(self) -> str:
parts = [self.message]
if self.db_path:
parts.append(f"[db: {self.db_path}]")
if self.operation:
parts.append(f"[op: {self.operation}]")
if self.details:
detail_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
parts.append(f"[{detail_str}]")
return " ".join(parts)
class SearchError(CodexLensError):
"""Raised when a search operation fails."""
class IndexNotFoundError(CodexLensError):
"""Raised when a project's index cannot be found."""

View File

@@ -0,0 +1,28 @@
"""Hybrid Search data structures for CodexLens.
This module provides core data structures for hybrid search:
- CodeSymbolNode: Graph node representing a code symbol
- CodeAssociationGraph: Graph of code relationships
- SearchResultCluster: Clustered search results
- Range: Position range in source files
- CallHierarchyItem: LSP call hierarchy item
Note: The search engine is in codexlens.search.hybrid_search
LSP-based expansion is in codexlens.lsp module
"""
from codexlens.hybrid_search.data_structures import (
CallHierarchyItem,
CodeAssociationGraph,
CodeSymbolNode,
Range,
SearchResultCluster,
)
__all__ = [
"CallHierarchyItem",
"CodeAssociationGraph",
"CodeSymbolNode",
"Range",
"SearchResultCluster",
]

View File

@@ -0,0 +1,602 @@
"""Core data structures for the hybrid search system.
This module defines the fundamental data structures used throughout the
hybrid search pipeline, including code symbol representations, association
graphs, and clustered search results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
import networkx as nx
@dataclass
class Range:
"""Position range within a source file.
Attributes:
start_line: Starting line number (0-based).
start_character: Starting character offset within the line.
end_line: Ending line number (0-based).
end_character: Ending character offset within the line.
"""
start_line: int
start_character: int
end_line: int
end_character: int
def __post_init__(self) -> None:
"""Validate range values."""
if self.start_line < 0:
raise ValueError("start_line must be >= 0")
if self.start_character < 0:
raise ValueError("start_character must be >= 0")
if self.end_line < 0:
raise ValueError("end_line must be >= 0")
if self.end_character < 0:
raise ValueError("end_character must be >= 0")
if self.end_line < self.start_line:
raise ValueError("end_line must be >= start_line")
if self.end_line == self.start_line and self.end_character < self.start_character:
raise ValueError("end_character must be >= start_character on the same line")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
"start": {"line": self.start_line, "character": self.start_character},
"end": {"line": self.end_line, "character": self.end_character},
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Range:
"""Create Range from dictionary representation."""
return cls(
start_line=data["start"]["line"],
start_character=data["start"]["character"],
end_line=data["end"]["line"],
end_character=data["end"]["character"],
)
@classmethod
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
"""Create Range from LSP Range object.
LSP Range format:
{"start": {"line": int, "character": int},
"end": {"line": int, "character": int}}
"""
return cls(
start_line=lsp_range["start"]["line"],
start_character=lsp_range["start"]["character"],
end_line=lsp_range["end"]["line"],
end_character=lsp_range["end"]["character"],
)
@dataclass
class CallHierarchyItem:
"""LSP CallHierarchyItem for representing callers/callees.
Attributes:
name: Symbol name (function, method, class name).
kind: Symbol kind (function, method, class, etc.).
file_path: Absolute file path where the symbol is defined.
range: Position range in the source file.
detail: Optional additional detail about the symbol.
"""
name: str
kind: str
file_path: str
range: Range
detail: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
}
if self.detail:
result["detail"] = self.detail
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from dictionary representation."""
return cls(
name=data["name"],
kind=data["kind"],
file_path=data["file_path"],
range=Range.from_dict(data["range"]),
detail=data.get("detail"),
)
@dataclass
class CodeSymbolNode:
"""Graph node representing a code symbol.
Attributes:
id: Unique identifier in format 'file_path:name:line'.
name: Symbol name (function, class, variable name).
kind: Symbol kind (function, class, method, variable, etc.).
file_path: Absolute file path where symbol is defined.
range: Start/end position in the source file.
embedding: Optional vector embedding for semantic search.
raw_code: Raw source code of the symbol.
docstring: Documentation string (if available).
score: Ranking score (used during reranking).
"""
id: str
name: str
kind: str
file_path: str
range: Range
embedding: Optional[List[float]] = None
raw_code: str = ""
docstring: str = ""
score: float = 0.0
def __post_init__(self) -> None:
"""Validate required fields."""
if not self.id:
raise ValueError("id cannot be empty")
if not self.name:
raise ValueError("name cannot be empty")
if not self.kind:
raise ValueError("kind cannot be empty")
if not self.file_path:
raise ValueError("file_path cannot be empty")
def __hash__(self) -> int:
"""Hash based on unique ID."""
return hash(self.id)
def __eq__(self, other: object) -> bool:
"""Equality based on unique ID."""
if not isinstance(other, CodeSymbolNode):
return False
return self.id == other.id
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"id": self.id,
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
"score": self.score,
}
if self.raw_code:
result["raw_code"] = self.raw_code
if self.docstring:
result["docstring"] = self.docstring
# Exclude embedding from serialization (too large for JSON responses)
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
"""Create CodeSymbolNode from dictionary representation."""
return cls(
id=data["id"],
name=data["name"],
kind=data["kind"],
file_path=data["file_path"],
range=Range.from_dict(data["range"]),
embedding=data.get("embedding"),
raw_code=data.get("raw_code", ""),
docstring=data.get("docstring", ""),
score=data.get("score", 0.0),
)
@classmethod
def from_lsp_location(
cls,
uri: str,
name: str,
kind: str,
lsp_range: Dict[str, Any],
raw_code: str = "",
docstring: str = "",
) -> CodeSymbolNode:
"""Create CodeSymbolNode from LSP location data.
Args:
uri: File URI (file:// prefix will be stripped).
name: Symbol name.
kind: Symbol kind.
lsp_range: LSP Range object.
raw_code: Optional raw source code.
docstring: Optional documentation string.
Returns:
New CodeSymbolNode instance.
"""
# Strip file:// prefix if present
file_path = uri
if file_path.startswith("file://"):
file_path = file_path[7:]
# Handle Windows paths (file:///C:/...)
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
range_obj = Range.from_lsp_range(lsp_range)
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
return cls(
id=symbol_id,
name=name,
kind=kind,
file_path=file_path,
range=range_obj,
raw_code=raw_code,
docstring=docstring,
)
@classmethod
def create_id(cls, file_path: str, name: str, line: int) -> str:
"""Generate a unique symbol ID.
Args:
file_path: Absolute file path.
name: Symbol name.
line: Start line number.
Returns:
Unique ID string in format 'file_path:name:line'.
"""
return f"{file_path}:{name}:{line}"
@dataclass
class CodeAssociationGraph:
"""Graph of code relationships between symbols.
This graph represents the association between code symbols discovered
through LSP queries (references, call hierarchy, etc.).
Attributes:
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
edges: List of (from_id, to_id, relationship_type) tuples.
relationship_type: 'calls', 'references', 'inherits', 'imports'.
"""
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
edges: List[Tuple[str, str, str]] = field(default_factory=list)
def add_node(self, node: CodeSymbolNode) -> None:
"""Add a node to the graph.
Args:
node: CodeSymbolNode to add. If a node with the same ID exists,
it will be replaced.
"""
self.nodes[node.id] = node
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
"""Add an edge to the graph.
Args:
from_id: Source node ID.
to_id: Target node ID.
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
Raises:
ValueError: If from_id or to_id not in graph nodes.
"""
if from_id not in self.nodes:
raise ValueError(f"Source node '{from_id}' not found in graph")
if to_id not in self.nodes:
raise ValueError(f"Target node '{to_id}' not found in graph")
edge = (from_id, to_id, rel_type)
if edge not in self.edges:
self.edges.append(edge)
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
"""Add an edge without validating node existence.
Use this method during bulk graph construction where nodes may be
added after edges, or when performance is critical.
Args:
from_id: Source node ID.
to_id: Target node ID.
rel_type: Relationship type.
"""
edge = (from_id, to_id, rel_type)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
"""Get a node by ID.
Args:
node_id: Node ID to look up.
Returns:
CodeSymbolNode if found, None otherwise.
"""
return self.nodes.get(node_id)
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
"""Get neighboring nodes connected by outgoing edges.
Args:
node_id: Node ID to find neighbors for.
rel_type: Optional filter by relationship type.
Returns:
List of neighboring CodeSymbolNode objects.
"""
neighbors = []
for from_id, to_id, edge_rel in self.edges:
if from_id == node_id:
if rel_type is None or edge_rel == rel_type:
node = self.nodes.get(to_id)
if node:
neighbors.append(node)
return neighbors
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
"""Get nodes connected by incoming edges.
Args:
node_id: Node ID to find incoming connections for.
rel_type: Optional filter by relationship type.
Returns:
List of CodeSymbolNode objects with edges pointing to node_id.
"""
incoming = []
for from_id, to_id, edge_rel in self.edges:
if to_id == node_id:
if rel_type is None or edge_rel == rel_type:
node = self.nodes.get(from_id)
if node:
incoming.append(node)
return incoming
def to_networkx(self) -> "nx.DiGraph":
"""Convert to NetworkX DiGraph for graph algorithms.
Returns:
NetworkX directed graph with nodes and edges.
Raises:
ImportError: If networkx is not installed.
"""
try:
import networkx as nx
except ImportError:
raise ImportError(
"networkx is required for graph algorithms. "
"Install with: pip install networkx"
)
graph = nx.DiGraph()
# Add nodes with attributes
for node_id, node in self.nodes.items():
graph.add_node(
node_id,
name=node.name,
kind=node.kind,
file_path=node.file_path,
score=node.score,
)
# Add edges with relationship type
for from_id, to_id, rel_type in self.edges:
graph.add_edge(from_id, to_id, relationship=rel_type)
return graph
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization.
Returns:
Dictionary with 'nodes' and 'edges' keys.
"""
return {
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"edges": [
{"from": from_id, "to": to_id, "relationship": rel_type}
for from_id, to_id, rel_type in self.edges
],
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
"""Create CodeAssociationGraph from dictionary representation.
Args:
data: Dictionary with 'nodes' and 'edges' keys.
Returns:
New CodeAssociationGraph instance.
"""
graph = cls()
# Load nodes
for node_id, node_data in data.get("nodes", {}).items():
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
# Load edges
for edge_data in data.get("edges", []):
graph.edges.append((
edge_data["from"],
edge_data["to"],
edge_data["relationship"],
))
return graph
def __len__(self) -> int:
"""Return the number of nodes in the graph."""
return len(self.nodes)
@dataclass
class SearchResultCluster:
"""Clustered search result containing related code symbols.
Search results are grouped into clusters based on graph community
detection or embedding similarity. Each cluster represents a
conceptually related group of code symbols.
Attributes:
cluster_id: Unique cluster identifier.
score: Cluster relevance score (max of symbol scores).
title: Human-readable cluster title/summary.
symbols: List of CodeSymbolNode in this cluster.
metadata: Additional cluster metadata.
"""
cluster_id: str
score: float
title: str
symbols: List[CodeSymbolNode] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate cluster fields."""
if not self.cluster_id:
raise ValueError("cluster_id cannot be empty")
if self.score < 0:
raise ValueError("score must be >= 0")
def add_symbol(self, symbol: CodeSymbolNode) -> None:
"""Add a symbol to the cluster.
Args:
symbol: CodeSymbolNode to add.
"""
self.symbols.append(symbol)
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
"""Get top N symbols by score.
Args:
n: Number of symbols to return.
Returns:
List of top N CodeSymbolNode objects sorted by score descending.
"""
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
return sorted_symbols[:n]
def update_score(self) -> None:
"""Update cluster score to max of symbol scores."""
if self.symbols:
self.score = max(s.score for s in self.symbols)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization.
Returns:
Dictionary representation of the cluster.
"""
return {
"cluster_id": self.cluster_id,
"score": self.score,
"title": self.title,
"symbols": [s.to_dict() for s in self.symbols],
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
"""Create SearchResultCluster from dictionary representation.
Args:
data: Dictionary with cluster data.
Returns:
New SearchResultCluster instance.
"""
return cls(
cluster_id=data["cluster_id"],
score=data["score"],
title=data["title"],
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
metadata=data.get("metadata", {}),
)
def __len__(self) -> int:
"""Return the number of symbols in the cluster."""
return len(self.symbols)
@dataclass
class CallHierarchyItem:
"""LSP CallHierarchyItem for representing callers/callees.
Attributes:
name: Symbol name (function, method, etc.).
kind: Symbol kind (function, method, etc.).
file_path: Absolute file path.
range: Position range in the file.
detail: Optional additional detail (e.g., signature).
"""
name: str
kind: str
file_path: str
range: Range
detail: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
}
if self.detail:
result["detail"] = self.detail
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from dictionary representation."""
return cls(
name=data.get("name", "unknown"),
kind=data.get("kind", "unknown"),
file_path=data.get("file_path", data.get("uri", "")),
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
detail=data.get("detail"),
)
@classmethod
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from LSP response format.
LSP uses 0-based line numbers and 'character' instead of 'char'.
"""
uri = data.get("uri", data.get("file_path", ""))
# Strip file:// prefix
file_path = uri
if file_path.startswith("file://"):
file_path = file_path[7:]
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
return cls(
name=data.get("name", "unknown"),
kind=str(data.get("kind", "unknown")),
file_path=file_path,
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
detail=data.get("detail"),
)

View File

@@ -0,0 +1,26 @@
"""Code indexing and symbol extraction."""
from codexlens.indexing.symbol_extractor import SymbolExtractor
from codexlens.indexing.embedding import (
BinaryEmbeddingBackend,
DenseEmbeddingBackend,
CascadeEmbeddingBackend,
get_cascade_embedder,
binarize_embedding,
pack_binary_embedding,
unpack_binary_embedding,
hamming_distance,
)
__all__ = [
"SymbolExtractor",
# Cascade embedding backends
"BinaryEmbeddingBackend",
"DenseEmbeddingBackend",
"CascadeEmbeddingBackend",
"get_cascade_embedder",
# Utility functions
"binarize_embedding",
"pack_binary_embedding",
"unpack_binary_embedding",
"hamming_distance",
]

View File

@@ -0,0 +1,582 @@
"""Multi-type embedding backends for cascade retrieval.
This module provides embedding backends optimized for cascade retrieval:
1. BinaryEmbeddingBackend - Fast coarse filtering with binary vectors
2. DenseEmbeddingBackend - High-precision dense vectors for reranking
3. CascadeEmbeddingBackend - Combined binary + dense for two-stage retrieval
Cascade retrieval workflow:
1. Binary search (fast, ~32 bytes/vector) -> top-K candidates
2. Dense rerank (precise, ~8KB/vector) -> final results
"""
from __future__ import annotations
import logging
from typing import Iterable, List, Optional, Tuple
import numpy as np
from codexlens.semantic.base import BaseEmbedder
logger = logging.getLogger(__name__)
# =============================================================================
# Utility Functions
# =============================================================================
def binarize_embedding(embedding: np.ndarray) -> np.ndarray:
"""Convert float embedding to binary vector.
Applies sign-based quantization: values > 0 become 1, values <= 0 become 0.
Args:
embedding: Float32 embedding of any dimension
Returns:
Binary vector (uint8 with values 0 or 1) of same dimension
"""
return (embedding > 0).astype(np.uint8)
def pack_binary_embedding(binary_vector: np.ndarray) -> bytes:
"""Pack binary vector into compact bytes format.
Packs 8 binary values into each byte for storage efficiency.
For a 256-dim binary vector, output is 32 bytes.
Args:
binary_vector: Binary vector (uint8 with values 0 or 1)
Returns:
Packed bytes (length = ceil(dim / 8))
"""
# Ensure vector length is multiple of 8 by padding if needed
dim = len(binary_vector)
padded_dim = ((dim + 7) // 8) * 8
if padded_dim > dim:
padded = np.zeros(padded_dim, dtype=np.uint8)
padded[:dim] = binary_vector
binary_vector = padded
# Pack 8 bits per byte
packed = np.packbits(binary_vector)
return packed.tobytes()
def unpack_binary_embedding(packed_bytes: bytes, dim: int = 256) -> np.ndarray:
"""Unpack bytes back to binary vector.
Args:
packed_bytes: Packed binary data
dim: Original vector dimension (default: 256)
Returns:
Binary vector (uint8 with values 0 or 1)
"""
unpacked = np.unpackbits(np.frombuffer(packed_bytes, dtype=np.uint8))
return unpacked[:dim]
def hamming_distance(a: bytes, b: bytes) -> int:
"""Compute Hamming distance between two packed binary vectors.
Uses XOR and popcount for efficient distance computation.
Args:
a: First packed binary vector
b: Second packed binary vector
Returns:
Hamming distance (number of differing bits)
"""
a_arr = np.frombuffer(a, dtype=np.uint8)
b_arr = np.frombuffer(b, dtype=np.uint8)
xor = np.bitwise_xor(a_arr, b_arr)
return int(np.unpackbits(xor).sum())
# =============================================================================
# Binary Embedding Backend
# =============================================================================
class BinaryEmbeddingBackend(BaseEmbedder):
"""Generate 256-dimensional binary embeddings for fast coarse retrieval.
Uses a lightweight embedding model and applies sign-based quantization
to produce compact binary vectors (32 bytes per embedding).
Suitable for:
- First-stage candidate retrieval
- Hamming distance-based similarity search
- Memory-constrained environments
Model: sentence-transformers/all-MiniLM-L6-v2 (384 dim) -> quantized to 256 bits
"""
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, fast
BINARY_DIM = 256
def __init__(
self,
model_name: Optional[str] = None,
use_gpu: bool = True,
) -> None:
"""Initialize binary embedding backend.
Args:
model_name: Base embedding model name. Defaults to BAAI/bge-small-en-v1.5
use_gpu: Whether to use GPU acceleration
"""
from codexlens.semantic import SEMANTIC_AVAILABLE
if not SEMANTIC_AVAILABLE:
raise ImportError(
"Semantic search dependencies not available. "
"Install with: pip install codexlens[semantic]"
)
self._model_name = model_name or self.DEFAULT_MODEL
self._use_gpu = use_gpu
self._model = None
# Projection matrix for dimension reduction (lazily initialized)
self._projection_matrix: Optional[np.ndarray] = None
@property
def model_name(self) -> str:
"""Return model name."""
return self._model_name
@property
def embedding_dim(self) -> int:
"""Return binary embedding dimension (256)."""
return self.BINARY_DIM
@property
def packed_bytes(self) -> int:
"""Return packed bytes size (32 bytes for 256 bits)."""
return self.BINARY_DIM // 8
def _load_model(self) -> None:
"""Lazy load the embedding model."""
if self._model is not None:
return
from fastembed import TextEmbedding
from codexlens.semantic.gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
try:
self._model = TextEmbedding(
model_name=self._model_name,
providers=providers,
)
except TypeError:
# Fallback for older fastembed versions
self._model = TextEmbedding(model_name=self._model_name)
logger.debug(f"BinaryEmbeddingBackend loaded model: {self._model_name}")
def _get_projection_matrix(self, input_dim: int) -> np.ndarray:
"""Get or create projection matrix for dimension reduction.
Uses random projection with fixed seed for reproducibility.
Args:
input_dim: Input embedding dimension from base model
Returns:
Projection matrix of shape (input_dim, BINARY_DIM)
"""
if self._projection_matrix is not None:
return self._projection_matrix
# Fixed seed for reproducibility across sessions
rng = np.random.RandomState(42)
# Gaussian random projection
self._projection_matrix = rng.randn(input_dim, self.BINARY_DIM).astype(np.float32)
# Normalize columns for consistent scale
norms = np.linalg.norm(self._projection_matrix, axis=0, keepdims=True)
self._projection_matrix /= (norms + 1e-8)
return self._projection_matrix
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate binary embeddings as numpy array.
Args:
texts: Single text or iterable of texts
Returns:
Binary embeddings of shape (n_texts, 256) with values 0 or 1
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Get base float embeddings
float_embeddings = np.array(list(self._model.embed(texts)))
input_dim = float_embeddings.shape[1]
# Project to target dimension if needed
if input_dim != self.BINARY_DIM:
projection = self._get_projection_matrix(input_dim)
float_embeddings = float_embeddings @ projection
# Binarize
return binarize_embedding(float_embeddings)
def embed_packed(self, texts: str | Iterable[str]) -> List[bytes]:
"""Generate packed binary embeddings.
Args:
texts: Single text or iterable of texts
Returns:
List of packed bytes (32 bytes each for 256-dim)
"""
binary = self.embed_to_numpy(texts)
return [pack_binary_embedding(vec) for vec in binary]
# =============================================================================
# Dense Embedding Backend
# =============================================================================
class DenseEmbeddingBackend(BaseEmbedder):
"""Generate high-dimensional dense embeddings for precise reranking.
Uses large embedding models to produce 2048-dimensional float32 vectors
for maximum retrieval quality.
Suitable for:
- Second-stage reranking
- High-precision similarity search
- Quality-critical applications
Model: BAAI/bge-large-en-v1.5 (1024 dim) with optional expansion
"""
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5" # 384 dim, use small for testing
TARGET_DIM = 768 # Reduced target for faster testing
def __init__(
self,
model_name: Optional[str] = None,
use_gpu: bool = True,
expand_dim: bool = True,
) -> None:
"""Initialize dense embedding backend.
Args:
model_name: Dense embedding model name. Defaults to BAAI/bge-large-en-v1.5
use_gpu: Whether to use GPU acceleration
expand_dim: If True, expand embeddings to TARGET_DIM using learned expansion
"""
from codexlens.semantic import SEMANTIC_AVAILABLE
if not SEMANTIC_AVAILABLE:
raise ImportError(
"Semantic search dependencies not available. "
"Install with: pip install codexlens[semantic]"
)
self._model_name = model_name or self.DEFAULT_MODEL
self._use_gpu = use_gpu
self._expand_dim = expand_dim
self._model = None
self._native_dim: Optional[int] = None
# Expansion matrix for dimension expansion (lazily initialized)
self._expansion_matrix: Optional[np.ndarray] = None
@property
def model_name(self) -> str:
"""Return model name."""
return self._model_name
@property
def embedding_dim(self) -> int:
"""Return embedding dimension.
Returns TARGET_DIM if expand_dim is True, otherwise native model dimension.
"""
if self._expand_dim:
return self.TARGET_DIM
# Return cached native dim or estimate based on model
if self._native_dim is not None:
return self._native_dim
# Model dimension estimates
model_dims = {
"BAAI/bge-large-en-v1.5": 1024,
"BAAI/bge-base-en-v1.5": 768,
"BAAI/bge-small-en-v1.5": 384,
"intfloat/multilingual-e5-large": 1024,
}
return model_dims.get(self._model_name, 1024)
@property
def max_tokens(self) -> int:
"""Return maximum token limit."""
return 512 # Conservative default for large models
def _load_model(self) -> None:
"""Lazy load the embedding model."""
if self._model is not None:
return
from fastembed import TextEmbedding
from codexlens.semantic.gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=self._use_gpu, with_device_options=True)
try:
self._model = TextEmbedding(
model_name=self._model_name,
providers=providers,
)
except TypeError:
self._model = TextEmbedding(model_name=self._model_name)
logger.debug(f"DenseEmbeddingBackend loaded model: {self._model_name}")
def _get_expansion_matrix(self, input_dim: int) -> np.ndarray:
"""Get or create expansion matrix for dimension expansion.
Uses random orthogonal projection for information-preserving expansion.
Args:
input_dim: Input embedding dimension from base model
Returns:
Expansion matrix of shape (input_dim, TARGET_DIM)
"""
if self._expansion_matrix is not None:
return self._expansion_matrix
# Fixed seed for reproducibility
rng = np.random.RandomState(123)
# Create semi-orthogonal expansion matrix
# First input_dim columns form identity-like structure
self._expansion_matrix = np.zeros((input_dim, self.TARGET_DIM), dtype=np.float32)
# Copy original dimensions
copy_dim = min(input_dim, self.TARGET_DIM)
self._expansion_matrix[:copy_dim, :copy_dim] = np.eye(copy_dim, dtype=np.float32)
# Fill remaining with random projections
if self.TARGET_DIM > input_dim:
random_part = rng.randn(input_dim, self.TARGET_DIM - input_dim).astype(np.float32)
# Normalize
norms = np.linalg.norm(random_part, axis=0, keepdims=True)
random_part /= (norms + 1e-8)
self._expansion_matrix[:, input_dim:] = random_part
return self._expansion_matrix
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate dense embeddings as numpy array.
Args:
texts: Single text or iterable of texts
Returns:
Dense embeddings of shape (n_texts, TARGET_DIM) as float32
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Get base float embeddings
float_embeddings = np.array(list(self._model.embed(texts)), dtype=np.float32)
self._native_dim = float_embeddings.shape[1]
# Expand to target dimension if needed
if self._expand_dim and self._native_dim < self.TARGET_DIM:
expansion = self._get_expansion_matrix(self._native_dim)
float_embeddings = float_embeddings @ expansion
return float_embeddings
# =============================================================================
# Cascade Embedding Backend
# =============================================================================
class CascadeEmbeddingBackend(BaseEmbedder):
"""Combined binary + dense embedding backend for cascade retrieval.
Generates both binary (for fast coarse filtering) and dense (for precise
reranking) embeddings in a single pass, optimized for two-stage retrieval.
Cascade workflow:
1. encode_cascade() returns (binary_embeddings, dense_embeddings)
2. Binary search: Use Hamming distance on binary vectors -> top-K candidates
3. Dense rerank: Use cosine similarity on dense vectors -> final results
Memory efficiency:
- Binary: 32 bytes per vector (256 bits)
- Dense: 8192 bytes per vector (2048 x float32)
- Total: ~8KB per document for full cascade support
"""
def __init__(
self,
binary_model: Optional[str] = None,
dense_model: Optional[str] = None,
use_gpu: bool = True,
) -> None:
"""Initialize cascade embedding backend.
Args:
binary_model: Model for binary embeddings. Defaults to BAAI/bge-small-en-v1.5
dense_model: Model for dense embeddings. Defaults to BAAI/bge-large-en-v1.5
use_gpu: Whether to use GPU acceleration
"""
self._binary_backend = BinaryEmbeddingBackend(
model_name=binary_model,
use_gpu=use_gpu,
)
self._dense_backend = DenseEmbeddingBackend(
model_name=dense_model,
use_gpu=use_gpu,
expand_dim=True,
)
self._use_gpu = use_gpu
@property
def model_name(self) -> str:
"""Return model names for both backends."""
return f"cascade({self._binary_backend.model_name}, {self._dense_backend.model_name})"
@property
def embedding_dim(self) -> int:
"""Return dense embedding dimension (for compatibility)."""
return self._dense_backend.embedding_dim
@property
def binary_dim(self) -> int:
"""Return binary embedding dimension."""
return self._binary_backend.embedding_dim
@property
def dense_dim(self) -> int:
"""Return dense embedding dimension."""
return self._dense_backend.embedding_dim
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate dense embeddings (for BaseEmbedder compatibility).
For cascade embeddings, use encode_cascade() instead.
Args:
texts: Single text or iterable of texts
Returns:
Dense embeddings of shape (n_texts, dense_dim)
"""
return self._dense_backend.embed_to_numpy(texts)
def encode_cascade(
self,
texts: str | Iterable[str],
batch_size: int = 32,
) -> Tuple[np.ndarray, np.ndarray]:
"""Generate both binary and dense embeddings.
Args:
texts: Single text or iterable of texts
batch_size: Batch size for processing
Returns:
Tuple of:
- binary_embeddings: Shape (n_texts, 256), uint8 values 0/1
- dense_embeddings: Shape (n_texts, 2048), float32
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
binary_embeddings = self._binary_backend.embed_to_numpy(texts)
dense_embeddings = self._dense_backend.embed_to_numpy(texts)
return binary_embeddings, dense_embeddings
def encode_binary(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate only binary embeddings.
Args:
texts: Single text or iterable of texts
Returns:
Binary embeddings of shape (n_texts, 256)
"""
return self._binary_backend.embed_to_numpy(texts)
def encode_dense(self, texts: str | Iterable[str]) -> np.ndarray:
"""Generate only dense embeddings.
Args:
texts: Single text or iterable of texts
Returns:
Dense embeddings of shape (n_texts, 2048)
"""
return self._dense_backend.embed_to_numpy(texts)
def encode_binary_packed(self, texts: str | Iterable[str]) -> List[bytes]:
"""Generate packed binary embeddings.
Args:
texts: Single text or iterable of texts
Returns:
List of packed bytes (32 bytes each)
"""
return self._binary_backend.embed_packed(texts)
# =============================================================================
# Factory Function
# =============================================================================
def get_cascade_embedder(
binary_model: Optional[str] = None,
dense_model: Optional[str] = None,
use_gpu: bool = True,
) -> CascadeEmbeddingBackend:
"""Factory function to create a cascade embedder.
Args:
binary_model: Model for binary embeddings (default: BAAI/bge-small-en-v1.5)
dense_model: Model for dense embeddings (default: BAAI/bge-large-en-v1.5)
use_gpu: Whether to use GPU acceleration
Returns:
Configured CascadeEmbeddingBackend instance
Example:
>>> embedder = get_cascade_embedder()
>>> binary, dense = embedder.encode_cascade(["hello world"])
>>> binary.shape # (1, 256)
>>> dense.shape # (1, 2048)
"""
return CascadeEmbeddingBackend(
binary_model=binary_model,
dense_model=dense_model,
use_gpu=use_gpu,
)

View File

@@ -0,0 +1,277 @@
"""Symbol and relationship extraction from source code."""
import re
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
try:
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
except Exception: # pragma: no cover - optional dependency / platform variance
TreeSitterSymbolParser = None # type: ignore[assignment]
class SymbolExtractor:
"""Extract symbols and relationships from source code using regex patterns."""
# Pattern definitions for different languages
PATTERNS = {
'python': {
'function': r'^(?:async\s+)?def\s+(\w+)\s*\(',
'class': r'^class\s+(\w+)\s*[:\(]',
'import': r'^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)',
'call': r'(?<![.\w])(\w+)\s*\(',
},
'typescript': {
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*[<\(]',
'class': r'(?:export\s+)?class\s+(\w+)',
'import': r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
'call': r'(?<![.\w])(\w+)\s*[<\(]',
},
'javascript': {
'function': r'(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(',
'class': r'(?:export\s+)?class\s+(\w+)',
'import': r"(?:import|require)\s*\(?['\"]([^'\"]+)['\"]",
'call': r'(?<![.\w])(\w+)\s*\(',
}
}
LANGUAGE_MAP = {
'.py': 'python',
'.ts': 'typescript',
'.tsx': 'typescript',
'.js': 'javascript',
'.jsx': 'javascript',
}
def __init__(self, db_path: Path):
self.db_path = db_path
self.db_conn: Optional[sqlite3.Connection] = None
def connect(self) -> None:
"""Connect to database and ensure schema exists."""
self.db_conn = sqlite3.connect(str(self.db_path))
self._ensure_tables()
def __enter__(self) -> "SymbolExtractor":
"""Context manager entry: connect to database."""
self.connect()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit: close database connection."""
self.close()
def _ensure_tables(self) -> None:
"""Create symbols and relationships tables if they don't exist."""
if not self.db_conn:
return
cursor = self.db_conn.cursor()
# Create symbols table with qualified_name
cursor.execute('''
CREATE TABLE IF NOT EXISTS symbols (
id INTEGER PRIMARY KEY AUTOINCREMENT,
qualified_name TEXT NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
file_path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
UNIQUE(file_path, name, start_line)
)
''')
# Create relationships table with target_symbol_fqn
cursor.execute('''
CREATE TABLE IF NOT EXISTS symbol_relationships (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_symbol_id INTEGER NOT NULL,
target_symbol_fqn TEXT NOT NULL,
relationship_type TEXT NOT NULL,
file_path TEXT NOT NULL,
line INTEGER,
FOREIGN KEY (source_symbol_id) REFERENCES symbols(id) ON DELETE CASCADE
)
''')
# Create performance indexes
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_path)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_source ON symbol_relationships(source_symbol_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_target ON symbol_relationships(target_symbol_fqn)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_rel_type ON symbol_relationships(relationship_type)')
self.db_conn.commit()
def extract_from_file(self, file_path: Path, content: str) -> Tuple[List[Dict], List[Dict]]:
"""Extract symbols and relationships from file content.
Args:
file_path: Path to the source file
content: File content as string
Returns:
Tuple of (symbols, relationships) where:
- symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
- relationships: List of relationship dicts with source_scope, target, type, file_path, line
"""
ext = file_path.suffix.lower()
lang = self.LANGUAGE_MAP.get(ext)
if not lang or lang not in self.PATTERNS:
return [], []
patterns = self.PATTERNS[lang]
symbols = []
relationships: List[Dict] = []
lines = content.split('\n')
current_scope = None
for line_num, line in enumerate(lines, 1):
# Extract function/class definitions
for kind in ['function', 'class']:
if kind in patterns:
match = re.search(patterns[kind], line)
if match:
name = match.group(1)
qualified_name = f"{file_path.stem}.{name}"
symbols.append({
'qualified_name': qualified_name,
'name': name,
'kind': kind,
'file_path': str(file_path),
'start_line': line_num,
'end_line': line_num, # Simplified - would need proper parsing for actual end
})
current_scope = name
if TreeSitterSymbolParser is not None:
try:
ts_parser = TreeSitterSymbolParser(lang, file_path)
if ts_parser.is_available():
indexed = ts_parser.parse(content, file_path)
if indexed is not None and indexed.relationships:
relationships = [
{
"source_scope": r.source_symbol,
"target": r.target_symbol,
"type": r.relationship_type.value,
"file_path": str(file_path),
"line": r.source_line,
}
for r in indexed.relationships
]
except Exception:
relationships = []
# Regex fallback for relationships (when tree-sitter is unavailable)
if not relationships:
current_scope = None
for line_num, line in enumerate(lines, 1):
for kind in ['function', 'class']:
if kind in patterns:
match = re.search(patterns[kind], line)
if match:
current_scope = match.group(1)
# Extract imports
if 'import' in patterns:
match = re.search(patterns['import'], line)
if match:
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
if import_target and current_scope:
relationships.append({
'source_scope': current_scope,
'target': import_target.strip(),
'type': 'imports',
'file_path': str(file_path),
'line': line_num,
})
# Extract function calls (simplified)
if 'call' in patterns and current_scope:
for match in re.finditer(patterns['call'], line):
call_name = match.group(1)
# Skip common keywords and the current function
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
relationships.append({
'source_scope': current_scope,
'target': call_name,
'type': 'calls',
'file_path': str(file_path),
'line': line_num,
})
return symbols, relationships
def save_symbols(self, symbols: List[Dict]) -> Dict[str, int]:
"""Save symbols to database and return name->id mapping.
Args:
symbols: List of symbol dicts with qualified_name, name, kind, file_path, start_line, end_line
Returns:
Dictionary mapping symbol name to database id
"""
if not self.db_conn or not symbols:
return {}
cursor = self.db_conn.cursor()
name_to_id = {}
for sym in symbols:
try:
cursor.execute('''
INSERT OR IGNORE INTO symbols
(qualified_name, name, kind, file_path, start_line, end_line)
VALUES (?, ?, ?, ?, ?, ?)
''', (sym['qualified_name'], sym['name'], sym['kind'],
sym['file_path'], sym['start_line'], sym['end_line']))
# Get the id
cursor.execute('''
SELECT id FROM symbols
WHERE file_path = ? AND name = ? AND start_line = ?
''', (sym['file_path'], sym['name'], sym['start_line']))
row = cursor.fetchone()
if row:
name_to_id[sym['name']] = row[0]
except sqlite3.Error:
continue
self.db_conn.commit()
return name_to_id
def save_relationships(self, relationships: List[Dict], name_to_id: Dict[str, int]) -> None:
"""Save relationships to database.
Args:
relationships: List of relationship dicts with source_scope, target, type, file_path, line
name_to_id: Dictionary mapping symbol names to database ids
"""
if not self.db_conn or not relationships:
return
cursor = self.db_conn.cursor()
for rel in relationships:
source_id = name_to_id.get(rel['source_scope'])
if source_id:
try:
cursor.execute('''
INSERT INTO symbol_relationships
(source_symbol_id, target_symbol_fqn, relationship_type, file_path, line)
VALUES (?, ?, ?, ?, ?)
''', (source_id, rel['target'], rel['type'], rel['file_path'], rel['line']))
except sqlite3.Error:
continue
self.db_conn.commit()
def close(self) -> None:
"""Close database connection."""
if self.db_conn:
self.db_conn.close()
self.db_conn = None

View File

@@ -0,0 +1,34 @@
"""LSP module for real-time language server integration.
This module provides:
- LspBridge: HTTP bridge to VSCode language servers
- LspGraphBuilder: Build code association graphs via LSP
- Location: Position in a source file
Example:
>>> from codexlens.lsp import LspBridge, LspGraphBuilder
>>>
>>> async with LspBridge() as bridge:
... refs = await bridge.get_references(symbol)
... graph = await LspGraphBuilder().build_from_seeds(seeds, bridge)
"""
from codexlens.lsp.lsp_bridge import (
CacheEntry,
Location,
LspBridge,
)
from codexlens.lsp.lsp_graph_builder import (
LspGraphBuilder,
)
# Alias for backward compatibility
GraphBuilder = LspGraphBuilder
__all__ = [
"CacheEntry",
"GraphBuilder",
"Location",
"LspBridge",
"LspGraphBuilder",
]

View File

@@ -0,0 +1,551 @@
"""LSP request handlers for codex-lens.
This module contains handlers for LSP requests:
- textDocument/definition
- textDocument/completion
- workspace/symbol
- textDocument/didSave
- textDocument/hover
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import List, Optional, Union
from urllib.parse import quote, unquote
try:
from lsprotocol import types as lsp
except ImportError as exc:
raise ImportError(
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
) from exc
from codexlens.entities import Symbol
from codexlens.lsp.server import server
logger = logging.getLogger(__name__)
# Symbol kind mapping from codex-lens to LSP
SYMBOL_KIND_MAP = {
"class": lsp.SymbolKind.Class,
"function": lsp.SymbolKind.Function,
"method": lsp.SymbolKind.Method,
"variable": lsp.SymbolKind.Variable,
"constant": lsp.SymbolKind.Constant,
"property": lsp.SymbolKind.Property,
"field": lsp.SymbolKind.Field,
"interface": lsp.SymbolKind.Interface,
"module": lsp.SymbolKind.Module,
"namespace": lsp.SymbolKind.Namespace,
"package": lsp.SymbolKind.Package,
"enum": lsp.SymbolKind.Enum,
"enum_member": lsp.SymbolKind.EnumMember,
"struct": lsp.SymbolKind.Struct,
"type": lsp.SymbolKind.TypeParameter,
"type_alias": lsp.SymbolKind.TypeParameter,
}
# Completion kind mapping from codex-lens to LSP
COMPLETION_KIND_MAP = {
"class": lsp.CompletionItemKind.Class,
"function": lsp.CompletionItemKind.Function,
"method": lsp.CompletionItemKind.Method,
"variable": lsp.CompletionItemKind.Variable,
"constant": lsp.CompletionItemKind.Constant,
"property": lsp.CompletionItemKind.Property,
"field": lsp.CompletionItemKind.Field,
"interface": lsp.CompletionItemKind.Interface,
"module": lsp.CompletionItemKind.Module,
"enum": lsp.CompletionItemKind.Enum,
"enum_member": lsp.CompletionItemKind.EnumMember,
"struct": lsp.CompletionItemKind.Struct,
"type": lsp.CompletionItemKind.TypeParameter,
"type_alias": lsp.CompletionItemKind.TypeParameter,
}
def _path_to_uri(path: Union[str, Path]) -> str:
"""Convert a file path to a URI.
Args:
path: File path (string or Path object)
Returns:
File URI string
"""
path_str = str(Path(path).resolve())
# Handle Windows paths
if path_str.startswith("/"):
return f"file://{quote(path_str)}"
else:
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
def _uri_to_path(uri: str) -> Path:
"""Convert a URI to a file path.
Args:
uri: File URI string
Returns:
Path object
"""
path = uri.replace("file:///", "").replace("file://", "")
return Path(unquote(path))
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
"""Extract the word at the given position in the document.
Args:
document_text: Full document text
line: 0-based line number
character: 0-based character position
Returns:
Word at position, or None if no word found
"""
lines = document_text.splitlines()
if line >= len(lines):
return None
line_text = lines[line]
if character > len(line_text):
return None
# Find word boundaries
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
for match in word_pattern.finditer(line_text):
if match.start() <= character <= match.end():
return match.group()
return None
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
"""Extract the incomplete word prefix at the given position.
Args:
document_text: Full document text
line: 0-based line number
character: 0-based character position
Returns:
Prefix string (may be empty)
"""
lines = document_text.splitlines()
if line >= len(lines):
return ""
line_text = lines[line]
if character > len(line_text):
character = len(line_text)
# Extract text before cursor
before_cursor = line_text[:character]
# Find the start of the current word
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
if match:
return match.group()
return ""
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
"""Convert a codex-lens Symbol to an LSP Location.
Args:
symbol: codex-lens Symbol object
Returns:
LSP Location, or None if symbol has no file
"""
if not symbol.file:
return None
# LSP uses 0-based lines, codex-lens uses 1-based
start_line = max(0, symbol.range[0] - 1)
end_line = max(0, symbol.range[1] - 1)
return lsp.Location(
uri=_path_to_uri(symbol.file),
range=lsp.Range(
start=lsp.Position(line=start_line, character=0),
end=lsp.Position(line=end_line, character=0),
),
)
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
"""Map codex-lens symbol kind to LSP SymbolKind.
Args:
kind: codex-lens symbol kind string
Returns:
LSP SymbolKind
"""
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
"""Map codex-lens symbol kind to LSP CompletionItemKind.
Args:
kind: codex-lens symbol kind string
Returns:
LSP CompletionItemKind
"""
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
# -----------------------------------------------------------------------------
# LSP Request Handlers
# -----------------------------------------------------------------------------
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
def lsp_definition(
params: lsp.DefinitionParams,
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
"""Handle textDocument/definition request.
Finds the definition of the symbol at the cursor position.
"""
if not server.global_index:
logger.debug("No global index available for definition lookup")
return None
# Get document
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
# Get word at position
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
logger.debug("No word found at position")
return None
logger.debug("Looking up definition for: %s", word)
# Search for exact symbol match
try:
symbols = server.global_index.search(
name=word,
limit=10,
prefix_mode=False, # Exact match preferred
)
# Filter for exact name match
exact_matches = [s for s in symbols if s.name == word]
if not exact_matches:
# Fall back to prefix search
symbols = server.global_index.search(
name=word,
limit=10,
prefix_mode=True,
)
exact_matches = [s for s in symbols if s.name == word]
if not exact_matches:
logger.debug("No definition found for: %s", word)
return None
# Convert to LSP locations
locations = []
for sym in exact_matches:
loc = symbol_to_location(sym)
if loc:
locations.append(loc)
if len(locations) == 1:
return locations[0]
elif locations:
return locations
else:
return None
except Exception as exc:
logger.error("Error looking up definition: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
"""Handle textDocument/references request.
Finds all references to the symbol at the cursor position using
the code_relationships table for accurate call-site tracking.
Falls back to same-name symbol search if search_engine is unavailable.
"""
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
return None
logger.debug("Finding references for: %s", word)
try:
# Try using search_engine.search_references() for accurate reference tracking
if server.search_engine and server.workspace_root:
references = server.search_engine.search_references(
symbol_name=word,
source_path=server.workspace_root,
limit=200,
)
if references:
locations = []
for ref in references:
locations.append(
lsp.Location(
uri=_path_to_uri(ref.file_path),
range=lsp.Range(
start=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column,
),
end=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column + len(word),
),
),
)
)
return locations if locations else None
# Fallback: search for symbols with same name using global_index
if server.global_index:
symbols = server.global_index.search(
name=word,
limit=100,
prefix_mode=False,
)
# Filter for exact matches
exact_matches = [s for s in symbols if s.name == word]
locations = []
for sym in exact_matches:
loc = symbol_to_location(sym)
if loc:
locations.append(loc)
return locations if locations else None
return None
except Exception as exc:
logger.error("Error finding references: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
"""Handle textDocument/completion request.
Provides code completion suggestions based on indexed symbols.
"""
if not server.global_index:
return None
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
prefix = _get_prefix_at_position(
document.source,
params.position.line,
params.position.character,
)
if not prefix or len(prefix) < 2:
# Require at least 2 characters for completion
return None
logger.debug("Completing prefix: %s", prefix)
try:
symbols = server.global_index.search(
name=prefix,
limit=50,
prefix_mode=True,
)
if not symbols:
return None
# Convert to completion items
items = []
seen_names = set()
for sym in symbols:
if sym.name in seen_names:
continue
seen_names.add(sym.name)
items.append(
lsp.CompletionItem(
label=sym.name,
kind=_symbol_kind_to_completion_kind(sym.kind),
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
sort_text=sym.name.lower(),
)
)
return lsp.CompletionList(
is_incomplete=len(symbols) >= 50,
items=items,
)
except Exception as exc:
logger.error("Error getting completions: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
"""Handle textDocument/hover request.
Provides hover information for the symbol at the cursor position
using HoverProvider for rich symbol information including
signature, documentation, and location.
"""
if not server.global_index:
return None
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
return None
logger.debug("Hover for: %s", word)
try:
# Use HoverProvider for rich symbol information
from codexlens.lsp.providers import HoverProvider
provider = HoverProvider(server.global_index, server.registry)
info = provider.get_hover_info(word)
if not info:
return None
# Format as markdown with signature and location
content = provider.format_hover_markdown(info)
return lsp.Hover(
contents=lsp.MarkupContent(
kind=lsp.MarkupKind.Markdown,
value=content,
),
)
except Exception as exc:
logger.error("Error getting hover info: %s", exc)
return None
@server.feature(lsp.WORKSPACE_SYMBOL)
def lsp_workspace_symbol(
params: lsp.WorkspaceSymbolParams,
) -> Optional[List[lsp.SymbolInformation]]:
"""Handle workspace/symbol request.
Searches for symbols across the workspace.
"""
if not server.global_index:
return None
query = params.query
if not query or len(query) < 2:
return None
logger.debug("Workspace symbol search: %s", query)
try:
symbols = server.global_index.search(
name=query,
limit=100,
prefix_mode=True,
)
if not symbols:
return None
result = []
for sym in symbols:
loc = symbol_to_location(sym)
if loc:
result.append(
lsp.SymbolInformation(
name=sym.name,
kind=_symbol_kind_to_lsp(sym.kind),
location=loc,
container_name=Path(sym.file).parent.name if sym.file else None,
)
)
return result if result else None
except Exception as exc:
logger.error("Error searching workspace symbols: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
"""Handle textDocument/didSave notification.
Triggers incremental re-indexing of the saved file.
Note: Full incremental indexing requires WatcherManager integration,
which is planned for Phase 2.
"""
file_path = _uri_to_path(params.text_document.uri)
logger.info("File saved: %s", file_path)
# Phase 1: Just log the save event
# Phase 2 will integrate with WatcherManager for incremental indexing
# if server.watcher_manager:
# server.watcher_manager.trigger_reindex(file_path)
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
"""Handle textDocument/didOpen notification."""
file_path = _uri_to_path(params.text_document.uri)
logger.debug("File opened: %s", file_path)
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
"""Handle textDocument/didClose notification."""
file_path = _uri_to_path(params.text_document.uri)
logger.debug("File closed: %s", file_path)

View File

@@ -0,0 +1,834 @@
"""LspBridge service for real-time LSP communication with caching.
This module provides a bridge to communicate with language servers either via:
1. Standalone LSP Manager (direct subprocess communication - default)
2. VSCode Bridge extension (HTTP-based, legacy mode)
Features:
- Direct communication with language servers (no VSCode dependency)
- Cache with TTL and file modification time invalidation
- Graceful error handling with empty results on failure
- Support for definition, references, hover, and call hierarchy
"""
from __future__ import annotations
import asyncio
import os
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from codexlens.lsp.standalone_manager import StandaloneLspManager
# Check for optional dependencies
try:
import aiohttp
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False
from codexlens.hybrid_search.data_structures import (
CallHierarchyItem,
CodeSymbolNode,
Range,
)
@dataclass
class Location:
"""A location in a source file (LSP response format)."""
file_path: str
line: int
character: int
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"file_path": self.file_path,
"line": self.line,
"character": self.character,
}
@classmethod
def from_lsp_response(cls, data: Dict[str, Any]) -> "Location":
"""Create Location from LSP response format.
Handles both direct format and VSCode URI format.
"""
# Handle VSCode URI format (file:///path/to/file)
uri = data.get("uri", data.get("file_path", ""))
if uri.startswith("file:///"):
# Windows: file:///C:/path -> C:/path
# Unix: file:///path -> /path
file_path = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
elif uri.startswith("file://"):
file_path = uri[7:]
else:
file_path = uri
# Get position from range or direct fields
if "range" in data:
range_data = data["range"]
start = range_data.get("start", {})
line = start.get("line", 0) + 1 # LSP is 0-based, convert to 1-based
character = start.get("character", 0) + 1
else:
line = data.get("line", 1)
character = data.get("character", 1)
return cls(file_path=file_path, line=line, character=character)
@dataclass
class CacheEntry:
"""A cached LSP response with expiration metadata.
Attributes:
data: The cached response data
file_mtime: File modification time when cached (for invalidation)
cached_at: Unix timestamp when entry was cached
"""
data: Any
file_mtime: float
cached_at: float
class LspBridge:
"""Bridge for real-time LSP communication with language servers.
By default, uses StandaloneLspManager to directly spawn and communicate
with language servers via JSON-RPC over stdio. No VSCode dependency required.
For legacy mode, can use VSCode Bridge HTTP server (set use_vscode_bridge=True).
Features:
- Direct language server communication (default)
- Response caching with TTL and file modification invalidation
- Timeout handling
- Graceful error handling returning empty results
Example:
# Default: standalone mode (no VSCode needed)
async with LspBridge() as bridge:
refs = await bridge.get_references(symbol)
definition = await bridge.get_definition(symbol)
# Legacy: VSCode Bridge mode
async with LspBridge(use_vscode_bridge=True) as bridge:
refs = await bridge.get_references(symbol)
"""
DEFAULT_BRIDGE_URL = "http://127.0.0.1:3457"
DEFAULT_TIMEOUT = 30.0 # seconds (increased for standalone mode)
DEFAULT_CACHE_TTL = 300 # 5 minutes
DEFAULT_MAX_CACHE_SIZE = 1000 # Maximum cache entries
def __init__(
self,
bridge_url: str = DEFAULT_BRIDGE_URL,
timeout: float = DEFAULT_TIMEOUT,
cache_ttl: int = DEFAULT_CACHE_TTL,
max_cache_size: int = DEFAULT_MAX_CACHE_SIZE,
use_vscode_bridge: bool = False,
workspace_root: Optional[str] = None,
config_file: Optional[str] = None,
):
"""Initialize LspBridge.
Args:
bridge_url: URL of the VSCode Bridge HTTP server (legacy mode only)
timeout: Request timeout in seconds
cache_ttl: Cache time-to-live in seconds
max_cache_size: Maximum number of cache entries (LRU eviction)
use_vscode_bridge: If True, use VSCode Bridge HTTP mode (requires aiohttp)
workspace_root: Root directory for standalone LSP manager
config_file: Path to lsp-servers.json configuration file
"""
self.bridge_url = bridge_url
self.timeout = timeout
self.cache_ttl = cache_ttl
self.max_cache_size = max_cache_size
self.use_vscode_bridge = use_vscode_bridge
self.workspace_root = workspace_root
self.config_file = config_file
self.cache: OrderedDict[str, CacheEntry] = OrderedDict()
# VSCode Bridge mode (legacy)
self._session: Optional["aiohttp.ClientSession"] = None
# Standalone mode (default)
self._manager: Optional["StandaloneLspManager"] = None
self._manager_started = False
# Validate dependencies
if use_vscode_bridge and not HAS_AIOHTTP:
raise ImportError(
"aiohttp is required for VSCode Bridge mode: pip install aiohttp"
)
async def _ensure_manager(self) -> "StandaloneLspManager":
"""Ensure standalone LSP manager is started."""
if self._manager is None:
from codexlens.lsp.standalone_manager import StandaloneLspManager
self._manager = StandaloneLspManager(
workspace_root=self.workspace_root,
config_file=self.config_file,
timeout=self.timeout,
)
if not self._manager_started:
await self._manager.start()
self._manager_started = True
return self._manager
async def _get_session(self) -> "aiohttp.ClientSession":
"""Get or create the aiohttp session (VSCode Bridge mode only)."""
if not HAS_AIOHTTP:
raise ImportError("aiohttp required for VSCode Bridge mode")
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
async def close(self) -> None:
"""Close connections and cleanup resources."""
# Close VSCode Bridge session
if self._session and not self._session.closed:
await self._session.close()
self._session = None
# Stop standalone manager
if self._manager and self._manager_started:
await self._manager.stop()
self._manager_started = False
def _get_file_mtime(self, file_path: str) -> float:
"""Get file modification time, or 0 if file doesn't exist."""
try:
return os.path.getmtime(file_path)
except OSError:
return 0.0
def _is_cached(self, cache_key: str, file_path: str) -> bool:
"""Check if cache entry is valid.
Cache is invalid if:
- Entry doesn't exist
- TTL has expired
- File has been modified since caching
Args:
cache_key: The cache key to check
file_path: Path to source file for mtime check
Returns:
True if cache is valid and can be used
"""
if cache_key not in self.cache:
return False
entry = self.cache[cache_key]
now = time.time()
# Check TTL
if now - entry.cached_at > self.cache_ttl:
del self.cache[cache_key]
return False
# Check file modification time
current_mtime = self._get_file_mtime(file_path)
if current_mtime != entry.file_mtime:
del self.cache[cache_key]
return False
# Move to end on access (LRU behavior)
self.cache.move_to_end(cache_key)
return True
def _cache(self, key: str, file_path: str, data: Any) -> None:
"""Store data in cache with LRU eviction.
Args:
key: Cache key
file_path: Path to source file (for mtime tracking)
data: Data to cache
"""
# Remove oldest entries if at capacity
while len(self.cache) >= self.max_cache_size:
self.cache.popitem(last=False) # Remove oldest (FIFO order)
# Move to end if key exists (update access order)
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = CacheEntry(
data=data,
file_mtime=self._get_file_mtime(file_path),
cached_at=time.time(),
)
def clear_cache(self) -> None:
"""Clear all cached entries."""
self.cache.clear()
async def _request_vscode_bridge(self, action: str, params: Dict[str, Any]) -> Any:
"""Make HTTP request to VSCode Bridge (legacy mode).
Args:
action: The endpoint/action name (e.g., "get_definition")
params: Request parameters
Returns:
Response data on success, None on failure
"""
url = f"{self.bridge_url}/{action}"
try:
session = await self._get_session()
async with session.post(url, json=params) as response:
if response.status != 200:
return None
data = await response.json()
if data.get("success") is False:
return None
return data.get("result")
except asyncio.TimeoutError:
return None
except Exception:
return None
async def get_references(self, symbol: CodeSymbolNode) -> List[Location]:
"""Get all references to a symbol via real-time LSP.
Args:
symbol: The code symbol to find references for
Returns:
List of Location objects where the symbol is referenced.
Returns empty list on error or timeout.
"""
cache_key = f"refs:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
locations: List[Location] = []
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_references", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
# Don't cache on connection error (result is None)
if result is None:
return locations
if isinstance(result, list):
for item in result:
try:
locations.append(Location.from_lsp_response(item))
except (KeyError, TypeError):
continue
else:
# Default: Standalone mode
manager = await self._ensure_manager()
result = await manager.get_references(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
for item in result:
try:
locations.append(Location.from_lsp_response(item))
except (KeyError, TypeError):
continue
self._cache(cache_key, symbol.file_path, locations)
return locations
async def get_definition(self, symbol: CodeSymbolNode) -> Optional[Location]:
"""Get symbol definition location.
Args:
symbol: The code symbol to find definition for
Returns:
Location of the definition, or None if not found
"""
cache_key = f"def:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
location: Optional[Location] = None
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_definition", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
if result:
if isinstance(result, list) and len(result) > 0:
try:
location = Location.from_lsp_response(result[0])
except (KeyError, TypeError):
pass
elif isinstance(result, dict):
try:
location = Location.from_lsp_response(result)
except (KeyError, TypeError):
pass
else:
# Default: Standalone mode
manager = await self._ensure_manager()
result = await manager.get_definition(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
if result:
try:
location = Location.from_lsp_response(result)
except (KeyError, TypeError):
pass
self._cache(cache_key, symbol.file_path, location)
return location
async def get_call_hierarchy(self, symbol: CodeSymbolNode) -> List[CallHierarchyItem]:
"""Get incoming/outgoing calls for a symbol.
If call hierarchy is not supported by the language server,
falls back to using references.
Args:
symbol: The code symbol to get call hierarchy for
Returns:
List of CallHierarchyItem representing callers/callees.
Returns empty list on error or if not supported.
"""
cache_key = f"calls:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
items: List[CallHierarchyItem] = []
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_call_hierarchy", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
if result is None:
# Fallback: use references
refs = await self.get_references(symbol)
for ref in refs:
items.append(CallHierarchyItem(
name=f"caller@{ref.line}",
kind="reference",
file_path=ref.file_path,
range=Range(
start_line=ref.line,
start_character=ref.character,
end_line=ref.line,
end_character=ref.character,
),
detail="Inferred from reference",
))
elif isinstance(result, list):
for item in result:
try:
range_data = item.get("range", {})
start = range_data.get("start", {})
end = range_data.get("end", {})
items.append(CallHierarchyItem(
name=item.get("name", "unknown"),
kind=item.get("kind", "unknown"),
file_path=item.get("file_path", item.get("uri", "")),
range=Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
),
detail=item.get("detail"),
))
except (KeyError, TypeError):
continue
else:
# Default: Standalone mode
manager = await self._ensure_manager()
# Try to get call hierarchy items
hierarchy_items = await manager.get_call_hierarchy_items(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
if hierarchy_items:
# Get incoming calls for each item
for h_item in hierarchy_items:
incoming = await manager.get_incoming_calls(h_item)
for call in incoming:
from_item = call.get("from", {})
range_data = from_item.get("range", {})
start = range_data.get("start", {})
end = range_data.get("end", {})
# Parse URI
uri = from_item.get("uri", "")
if uri.startswith("file:///"):
fp = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
elif uri.startswith("file://"):
fp = uri[7:]
else:
fp = uri
items.append(CallHierarchyItem(
name=from_item.get("name", "unknown"),
kind=str(from_item.get("kind", "unknown")),
file_path=fp,
range=Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
),
detail=from_item.get("detail"),
))
else:
# Fallback: use references
refs = await self.get_references(symbol)
for ref in refs:
items.append(CallHierarchyItem(
name=f"caller@{ref.line}",
kind="reference",
file_path=ref.file_path,
range=Range(
start_line=ref.line,
start_character=ref.character,
end_line=ref.line,
end_character=ref.character,
),
detail="Inferred from reference",
))
self._cache(cache_key, symbol.file_path, items)
return items
async def get_document_symbols(self, file_path: str) -> List[Dict[str, Any]]:
"""Get all symbols in a document (batch operation).
This is more efficient than individual hover queries when processing
multiple locations in the same file.
Args:
file_path: Path to the source file
Returns:
List of symbol dictionaries with name, kind, range, etc.
Returns empty list on error or timeout.
"""
cache_key = f"symbols:{file_path}"
if self._is_cached(cache_key, file_path):
return self.cache[cache_key].data
symbols: List[Dict[str, Any]] = []
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_document_symbols", {
"file_path": file_path,
})
if isinstance(result, list):
symbols = self._flatten_document_symbols(result)
else:
# Default: Standalone mode
manager = await self._ensure_manager()
result = await manager.get_document_symbols(file_path)
if result:
symbols = self._flatten_document_symbols(result)
self._cache(cache_key, file_path, symbols)
return symbols
def _flatten_document_symbols(
self, symbols: List[Dict[str, Any]], parent_name: str = ""
) -> List[Dict[str, Any]]:
"""Flatten nested document symbols into a flat list.
Document symbols can be nested (e.g., methods inside classes).
This flattens them for easier lookup by line number.
Args:
symbols: List of symbol dictionaries (may be nested)
parent_name: Name of parent symbol for qualification
Returns:
Flat list of all symbols with their ranges
"""
flat: List[Dict[str, Any]] = []
for sym in symbols:
# Add the symbol itself
symbol_entry = {
"name": sym.get("name", "unknown"),
"kind": self._symbol_kind_to_string(sym.get("kind", 0)),
"range": sym.get("range", sym.get("location", {}).get("range", {})),
"selection_range": sym.get("selectionRange", {}),
"detail": sym.get("detail", ""),
"parent": parent_name,
}
flat.append(symbol_entry)
# Recursively process children
children = sym.get("children", [])
if children:
qualified_name = sym.get("name", "")
if parent_name:
qualified_name = f"{parent_name}.{qualified_name}"
flat.extend(self._flatten_document_symbols(children, qualified_name))
return flat
def _symbol_kind_to_string(self, kind: int) -> str:
"""Convert LSP SymbolKind integer to string.
Args:
kind: LSP SymbolKind enum value
Returns:
Human-readable string representation
"""
# LSP SymbolKind enum (1-indexed)
kinds = {
1: "file",
2: "module",
3: "namespace",
4: "package",
5: "class",
6: "method",
7: "property",
8: "field",
9: "constructor",
10: "enum",
11: "interface",
12: "function",
13: "variable",
14: "constant",
15: "string",
16: "number",
17: "boolean",
18: "array",
19: "object",
20: "key",
21: "null",
22: "enum_member",
23: "struct",
24: "event",
25: "operator",
26: "type_parameter",
}
return kinds.get(kind, "unknown")
async def get_hover(self, symbol: CodeSymbolNode) -> Optional[str]:
"""Get hover documentation for a symbol.
Args:
symbol: The code symbol to get hover info for
Returns:
Hover documentation as string, or None if not available
"""
cache_key = f"hover:{symbol.id}"
if self._is_cached(cache_key, symbol.file_path):
return self.cache[cache_key].data
hover_text: Optional[str] = None
if self.use_vscode_bridge:
# Legacy: VSCode Bridge HTTP mode
result = await self._request_vscode_bridge("get_hover", {
"file_path": symbol.file_path,
"line": symbol.range.start_line,
"character": symbol.range.start_character,
})
if result:
hover_text = self._parse_hover_result(result)
else:
# Default: Standalone mode
manager = await self._ensure_manager()
hover_text = await manager.get_hover(
file_path=symbol.file_path,
line=symbol.range.start_line,
character=symbol.range.start_character,
)
self._cache(cache_key, symbol.file_path, hover_text)
return hover_text
def _parse_hover_result(self, result: Any) -> Optional[str]:
"""Parse hover result into string."""
if isinstance(result, str):
return result
elif isinstance(result, list):
parts = []
for item in result:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
value = item.get("value", item.get("contents", ""))
if value:
parts.append(str(value))
return "\n\n".join(parts) if parts else None
elif isinstance(result, dict):
contents = result.get("contents", result.get("value", ""))
if isinstance(contents, str):
return contents
elif isinstance(contents, list):
parts = []
for c in contents:
if isinstance(c, str):
parts.append(c)
elif isinstance(c, dict):
parts.append(str(c.get("value", "")))
return "\n\n".join(parts) if parts else None
return None
async def __aenter__(self) -> "LspBridge":
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit - close connections."""
await self.close()
# Simple test
if __name__ == "__main__":
import sys
async def test_lsp_bridge():
"""Simple test of LspBridge functionality."""
print("Testing LspBridge (Standalone Mode)...")
print(f"Timeout: {LspBridge.DEFAULT_TIMEOUT}s")
print(f"Cache TTL: {LspBridge.DEFAULT_CACHE_TTL}s")
print()
# Create a test symbol pointing to this file
test_file = os.path.abspath(__file__)
test_symbol = CodeSymbolNode(
id=f"{test_file}:LspBridge:96",
name="LspBridge",
kind="class",
file_path=test_file,
range=Range(
start_line=96,
start_character=1,
end_line=200,
end_character=1,
),
)
print(f"Test symbol: {test_symbol.name} in {os.path.basename(test_symbol.file_path)}")
print()
# Use standalone mode (default)
async with LspBridge(
workspace_root=str(Path(__file__).parent.parent.parent.parent),
) as bridge:
print("1. Testing get_document_symbols...")
try:
symbols = await bridge.get_document_symbols(test_file)
print(f" Found {len(symbols)} symbols")
for sym in symbols[:5]:
print(f" - {sym.get('name')} ({sym.get('kind')})")
except Exception as e:
print(f" Error: {e}")
print()
print("2. Testing get_definition...")
try:
definition = await bridge.get_definition(test_symbol)
if definition:
print(f" Definition: {os.path.basename(definition.file_path)}:{definition.line}")
else:
print(" No definition found")
except Exception as e:
print(f" Error: {e}")
print()
print("3. Testing get_references...")
try:
refs = await bridge.get_references(test_symbol)
print(f" Found {len(refs)} references")
for ref in refs[:3]:
print(f" - {os.path.basename(ref.file_path)}:{ref.line}")
except Exception as e:
print(f" Error: {e}")
print()
print("4. Testing get_hover...")
try:
hover = await bridge.get_hover(test_symbol)
if hover:
print(f" Hover: {hover[:100]}...")
else:
print(" No hover info found")
except Exception as e:
print(f" Error: {e}")
print()
print("5. Testing get_call_hierarchy...")
try:
calls = await bridge.get_call_hierarchy(test_symbol)
print(f" Found {len(calls)} call hierarchy items")
for call in calls[:3]:
print(f" - {call.name} in {os.path.basename(call.file_path)}")
except Exception as e:
print(f" Error: {e}")
print()
print("6. Testing cache...")
print(f" Cache entries: {len(bridge.cache)}")
for key in list(bridge.cache.keys())[:5]:
print(f" - {key}")
print()
print("Test complete!")
# Run the test
# Note: On Windows, use default ProactorEventLoop (supports subprocess creation)
asyncio.run(test_lsp_bridge())

View File

@@ -0,0 +1,375 @@
"""Graph builder for code association graphs via LSP."""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Dict, List, Optional, Set, Tuple
from codexlens.hybrid_search.data_structures import (
CallHierarchyItem,
CodeAssociationGraph,
CodeSymbolNode,
Range,
)
from codexlens.lsp.lsp_bridge import (
Location,
LspBridge,
)
logger = logging.getLogger(__name__)
class LspGraphBuilder:
"""Builds code association graph by expanding from seed symbols using LSP."""
def __init__(
self,
max_depth: int = 2,
max_nodes: int = 100,
max_concurrent: int = 10,
):
"""Initialize GraphBuilder.
Args:
max_depth: Maximum depth for BFS expansion from seeds.
max_nodes: Maximum number of nodes in the graph.
max_concurrent: Maximum concurrent LSP requests.
"""
self.max_depth = max_depth
self.max_nodes = max_nodes
self.max_concurrent = max_concurrent
# Cache for document symbols per file (avoids per-location hover queries)
self._document_symbols_cache: Dict[str, List[Dict[str, Any]]] = {}
async def build_from_seeds(
self,
seeds: List[CodeSymbolNode],
lsp_bridge: LspBridge,
) -> CodeAssociationGraph:
"""Build association graph by BFS expansion from seeds.
For each seed:
1. Get references via LSP
2. Get call hierarchy via LSP
3. Add nodes and edges to graph
4. Continue expanding until max_depth or max_nodes reached
Args:
seeds: Initial seed symbols to expand from.
lsp_bridge: LSP bridge for querying language servers.
Returns:
CodeAssociationGraph with expanded nodes and relationships.
"""
graph = CodeAssociationGraph()
visited: Set[str] = set()
semaphore = asyncio.Semaphore(self.max_concurrent)
# Initialize queue with seeds at depth 0
queue: List[Tuple[CodeSymbolNode, int]] = [(s, 0) for s in seeds]
# Add seed nodes to graph
for seed in seeds:
graph.add_node(seed)
# BFS expansion
while queue and len(graph.nodes) < self.max_nodes:
# Take a batch of nodes from queue
batch_size = min(self.max_concurrent, len(queue))
batch = queue[:batch_size]
queue = queue[batch_size:]
# Expand nodes in parallel
tasks = [
self._expand_node(
node, depth, graph, lsp_bridge, visited, semaphore
)
for node, depth in batch
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results and add new nodes to queue
for result in results:
if isinstance(result, Exception):
logger.warning("Error expanding node: %s", result)
continue
if result:
# Add new nodes to queue if not at max depth
for new_node, new_depth in result:
if (
new_depth <= self.max_depth
and len(graph.nodes) < self.max_nodes
):
queue.append((new_node, new_depth))
return graph
async def _expand_node(
self,
node: CodeSymbolNode,
depth: int,
graph: CodeAssociationGraph,
lsp_bridge: LspBridge,
visited: Set[str],
semaphore: asyncio.Semaphore,
) -> List[Tuple[CodeSymbolNode, int]]:
"""Expand a single node, return new nodes to process.
Args:
node: Node to expand.
depth: Current depth in BFS.
graph: Graph to add nodes and edges to.
lsp_bridge: LSP bridge for queries.
visited: Set of visited node IDs.
semaphore: Semaphore for concurrency control.
Returns:
List of (new_node, new_depth) tuples to add to queue.
"""
# Skip if already visited or at max depth
if node.id in visited:
return []
if depth > self.max_depth:
return []
if len(graph.nodes) >= self.max_nodes:
return []
visited.add(node.id)
new_nodes: List[Tuple[CodeSymbolNode, int]] = []
async with semaphore:
# Get relationships in parallel
try:
refs_task = lsp_bridge.get_references(node)
calls_task = lsp_bridge.get_call_hierarchy(node)
refs, calls = await asyncio.gather(
refs_task, calls_task, return_exceptions=True
)
# Handle reference results
if isinstance(refs, Exception):
logger.debug(
"Failed to get references for %s: %s", node.id, refs
)
refs = []
# Handle call hierarchy results
if isinstance(calls, Exception):
logger.debug(
"Failed to get call hierarchy for %s: %s",
node.id,
calls,
)
calls = []
# Process references
for ref in refs:
if len(graph.nodes) >= self.max_nodes:
break
ref_node = await self._location_to_node(ref, lsp_bridge)
if ref_node and ref_node.id != node.id:
if ref_node.id not in graph.nodes:
graph.add_node(ref_node)
new_nodes.append((ref_node, depth + 1))
# Use add_edge since both nodes should exist now
graph.add_edge(node.id, ref_node.id, "references")
# Process call hierarchy (incoming calls)
for call in calls:
if len(graph.nodes) >= self.max_nodes:
break
call_node = await self._call_hierarchy_to_node(
call, lsp_bridge
)
if call_node and call_node.id != node.id:
if call_node.id not in graph.nodes:
graph.add_node(call_node)
new_nodes.append((call_node, depth + 1))
# Incoming call: call_node calls node
graph.add_edge(call_node.id, node.id, "calls")
except Exception as e:
logger.warning(
"Error during node expansion for %s: %s", node.id, e
)
return new_nodes
def clear_cache(self) -> None:
"""Clear the document symbols cache.
Call this between searches to free memory and ensure fresh data.
"""
self._document_symbols_cache.clear()
async def _get_symbol_at_location(
self,
file_path: str,
line: int,
lsp_bridge: LspBridge,
) -> Optional[Dict[str, Any]]:
"""Find symbol at location using cached document symbols.
This is much more efficient than individual hover queries because
document symbols are fetched once per file and cached.
Args:
file_path: Path to the source file.
line: Line number (1-based).
lsp_bridge: LSP bridge for fetching document symbols.
Returns:
Symbol dictionary with name, kind, range, etc., or None if not found.
"""
# Get or fetch document symbols for this file
if file_path not in self._document_symbols_cache:
symbols = await lsp_bridge.get_document_symbols(file_path)
self._document_symbols_cache[file_path] = symbols
symbols = self._document_symbols_cache[file_path]
# Find symbol containing this line (best match = smallest range)
best_match: Optional[Dict[str, Any]] = None
best_range_size = float("inf")
for symbol in symbols:
sym_range = symbol.get("range", {})
start = sym_range.get("start", {})
end = sym_range.get("end", {})
# LSP ranges are 0-based, our line is 1-based
start_line = start.get("line", 0) + 1
end_line = end.get("line", 0) + 1
if start_line <= line <= end_line:
range_size = end_line - start_line
if range_size < best_range_size:
best_match = symbol
best_range_size = range_size
return best_match
async def _location_to_node(
self,
location: Location,
lsp_bridge: LspBridge,
) -> Optional[CodeSymbolNode]:
"""Convert LSP location to CodeSymbolNode.
Uses cached document symbols instead of individual hover queries
for better performance.
Args:
location: LSP location to convert.
lsp_bridge: LSP bridge for additional queries.
Returns:
CodeSymbolNode or None if conversion fails.
"""
try:
file_path = location.file_path
start_line = location.line
# Try to find symbol info from cached document symbols (fast)
symbol_info = await self._get_symbol_at_location(
file_path, start_line, lsp_bridge
)
if symbol_info:
name = symbol_info.get("name", f"symbol_L{start_line}")
kind = symbol_info.get("kind", "unknown")
# Extract range from symbol if available
sym_range = symbol_info.get("range", {})
start = sym_range.get("start", {})
end = sym_range.get("end", {})
location_range = Range(
start_line=start.get("line", start_line - 1) + 1,
start_character=start.get("character", location.character - 1) + 1,
end_line=end.get("line", start_line - 1) + 1,
end_character=end.get("character", location.character - 1) + 1,
)
else:
# Fallback to basic node without symbol info
name = f"symbol_L{start_line}"
kind = "unknown"
location_range = Range(
start_line=location.line,
start_character=location.character,
end_line=location.line,
end_character=location.character,
)
node_id = self._create_node_id(file_path, name, start_line)
return CodeSymbolNode(
id=node_id,
name=name,
kind=kind,
file_path=file_path,
range=location_range,
docstring="", # Skip hover for performance
)
except Exception as e:
logger.debug("Failed to convert location to node: %s", e)
return None
async def _call_hierarchy_to_node(
self,
call_item: CallHierarchyItem,
lsp_bridge: LspBridge,
) -> Optional[CodeSymbolNode]:
"""Convert CallHierarchyItem to CodeSymbolNode.
Args:
call_item: Call hierarchy item to convert.
lsp_bridge: LSP bridge (unused, kept for API consistency).
Returns:
CodeSymbolNode or None if conversion fails.
"""
try:
file_path = call_item.file_path
name = call_item.name
start_line = call_item.range.start_line
# CallHierarchyItem.kind is already a string
kind = call_item.kind
node_id = self._create_node_id(file_path, name, start_line)
return CodeSymbolNode(
id=node_id,
name=name,
kind=kind,
file_path=file_path,
range=call_item.range,
docstring=call_item.detail or "",
)
except Exception as e:
logger.debug(
"Failed to convert call hierarchy item to node: %s", e
)
return None
def _create_node_id(
self, file_path: str, name: str, line: int
) -> str:
"""Create unique node ID.
Args:
file_path: Path to the file.
name: Symbol name.
line: Line number (0-based).
Returns:
Unique node ID string.
"""
return f"{file_path}:{name}:{line}"

View File

@@ -0,0 +1,177 @@
"""LSP feature providers."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.registry import RegistryStore
logger = logging.getLogger(__name__)
@dataclass
class HoverInfo:
"""Hover information for a symbol."""
name: str
kind: str
signature: str
documentation: Optional[str]
file_path: str
line_range: tuple # (start_line, end_line)
class HoverProvider:
"""Provides hover information for symbols."""
def __init__(
self,
global_index: "GlobalSymbolIndex",
registry: Optional["RegistryStore"] = None,
) -> None:
"""Initialize hover provider.
Args:
global_index: Global symbol index for lookups
registry: Optional registry store for index path resolution
"""
self.global_index = global_index
self.registry = registry
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
"""Get hover information for a symbol.
Args:
symbol_name: Name of the symbol to look up
Returns:
HoverInfo or None if symbol not found
"""
# Look up symbol in global index using exact match
symbols = self.global_index.search(
name=symbol_name,
limit=1,
prefix_mode=False,
)
# Filter for exact name match
exact_matches = [s for s in symbols if s.name == symbol_name]
if not exact_matches:
return None
symbol = exact_matches[0]
# Extract signature from source file
signature = self._extract_signature(symbol)
# Symbol uses 'file' attribute and 'range' tuple
file_path = symbol.file or ""
start_line, end_line = symbol.range
return HoverInfo(
name=symbol.name,
kind=symbol.kind,
signature=signature,
documentation=None, # Symbol doesn't have docstring field
file_path=file_path,
line_range=(start_line, end_line),
)
def _extract_signature(self, symbol) -> str:
"""Extract function/class signature from source file.
Args:
symbol: Symbol object with file and range information
Returns:
Extracted signature string or fallback kind + name
"""
try:
file_path = Path(symbol.file) if symbol.file else None
if not file_path or not file_path.exists():
return f"{symbol.kind} {symbol.name}"
content = file_path.read_text(encoding="utf-8", errors="ignore")
lines = content.split("\n")
# Extract signature lines (first line of definition + continuation)
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
if start_line >= len(lines) or start_line < 0:
return f"{symbol.kind} {symbol.name}"
signature_lines = []
first_line = lines[start_line]
signature_lines.append(first_line)
# Continue if multiline signature (no closing paren + colon yet)
# Look for patterns like "def func(", "class Foo(", etc.
i = start_line + 1
max_lines = min(start_line + 5, len(lines))
while i < max_lines:
line = signature_lines[-1]
# Stop if we see closing pattern
if "):" in line or line.rstrip().endswith(":"):
break
signature_lines.append(lines[i])
i += 1
return "\n".join(signature_lines)
except Exception as e:
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
return f"{symbol.kind} {symbol.name}"
def format_hover_markdown(self, info: HoverInfo) -> str:
"""Format hover info as Markdown.
Args:
info: HoverInfo object to format
Returns:
Markdown-formatted hover content
"""
parts = []
# Detect language for code fence based on file extension
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
lang_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".tsx": "typescript",
".jsx": "javascript",
".java": "java",
".go": "go",
".rs": "rust",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
".cs": "csharp",
".rb": "ruby",
".php": "php",
}
lang = lang_map.get(ext, "")
# Code block with signature
parts.append(f"```{lang}\n{info.signature}\n```")
# Documentation if available
if info.documentation:
parts.append(f"\n---\n\n{info.documentation}")
# Location info
file_name = Path(info.file_path).name if info.file_path else "unknown"
parts.append(
f"\n---\n\n*{info.kind}* defined in "
f"`{file_name}` "
f"(line {info.line_range[0]})"
)
return "\n".join(parts)

View File

@@ -0,0 +1,263 @@
"""codex-lens LSP Server implementation using pygls.
This module provides the main Language Server class and entry point.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from typing import Optional
try:
from lsprotocol import types as lsp
from pygls.lsp.server import LanguageServer
except ImportError as exc:
raise ImportError(
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
) from exc
from codexlens.config import Config
from codexlens.search.chain_search import ChainSearchEngine
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
logger = logging.getLogger(__name__)
class CodexLensLanguageServer(LanguageServer):
"""Language Server for codex-lens code indexing.
Provides IDE features using codex-lens symbol index:
- Go to Definition
- Find References
- Code Completion
- Hover Information
- Workspace Symbol Search
Attributes:
registry: Global project registry for path lookups
mapper: Path mapper for source/index conversions
global_index: Project-wide symbol index
search_engine: Chain search engine for symbol search
workspace_root: Current workspace root path
"""
def __init__(self) -> None:
super().__init__(name="codexlens-lsp", version="0.1.0")
self.registry: Optional[RegistryStore] = None
self.mapper: Optional[PathMapper] = None
self.global_index: Optional[GlobalSymbolIndex] = None
self.search_engine: Optional[ChainSearchEngine] = None
self.workspace_root: Optional[Path] = None
self._config: Optional[Config] = None
def initialize_components(self, workspace_root: Path) -> bool:
"""Initialize codex-lens components for the workspace.
Args:
workspace_root: Root path of the workspace
Returns:
True if initialization succeeded, False otherwise
"""
self.workspace_root = workspace_root.resolve()
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
try:
# Initialize registry
self.registry = RegistryStore()
self.registry.initialize()
# Initialize path mapper
self.mapper = PathMapper()
# Try to find project in registry
project_info = self.registry.find_by_source_path(str(self.workspace_root))
if project_info:
project_id = int(project_info["id"])
index_root = Path(project_info["index_root"])
# Initialize global symbol index
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
self.global_index = GlobalSymbolIndex(global_db, project_id)
self.global_index.initialize()
# Initialize search engine
self._config = Config()
self.search_engine = ChainSearchEngine(
registry=self.registry,
mapper=self.mapper,
config=self._config,
)
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
return True
else:
logger.warning(
"Workspace not indexed by codex-lens: %s. "
"Run 'codexlens index %s' to index first.",
self.workspace_root,
self.workspace_root,
)
return False
except Exception as exc:
logger.error("Failed to initialize codex-lens: %s", exc)
return False
def shutdown_components(self) -> None:
"""Clean up codex-lens components."""
if self.global_index:
try:
self.global_index.close()
except Exception as exc:
logger.debug("Error closing global index: %s", exc)
self.global_index = None
if self.search_engine:
try:
self.search_engine.close()
except Exception as exc:
logger.debug("Error closing search engine: %s", exc)
self.search_engine = None
if self.registry:
try:
self.registry.close()
except Exception as exc:
logger.debug("Error closing registry: %s", exc)
self.registry = None
# Create server instance
server = CodexLensLanguageServer()
@server.feature(lsp.INITIALIZE)
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
"""Handle LSP initialize request."""
logger.info("LSP initialize request received")
# Get workspace root
workspace_root: Optional[Path] = None
if params.root_uri:
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
elif params.root_path:
workspace_root = Path(params.root_path)
if workspace_root:
server.initialize_components(workspace_root)
# Declare server capabilities
return lsp.InitializeResult(
capabilities=lsp.ServerCapabilities(
text_document_sync=lsp.TextDocumentSyncOptions(
open_close=True,
change=lsp.TextDocumentSyncKind.Incremental,
save=lsp.SaveOptions(include_text=False),
),
definition_provider=True,
references_provider=True,
completion_provider=lsp.CompletionOptions(
trigger_characters=[".", ":"],
resolve_provider=False,
),
hover_provider=True,
workspace_symbol_provider=True,
),
server_info=lsp.ServerInfo(
name="codexlens-lsp",
version="0.1.0",
),
)
@server.feature(lsp.SHUTDOWN)
def lsp_shutdown(params: None) -> None:
"""Handle LSP shutdown request."""
logger.info("LSP shutdown request received")
server.shutdown_components()
def main() -> int:
"""Entry point for codexlens-lsp command.
Returns:
Exit code (0 for success)
"""
# Import handlers to register them with the server
# This must be done before starting the server
import codexlens.lsp.handlers # noqa: F401
parser = argparse.ArgumentParser(
description="codex-lens Language Server",
prog="codexlens-lsp",
)
parser.add_argument(
"--stdio",
action="store_true",
default=True,
help="Use stdio for communication (default)",
)
parser.add_argument(
"--tcp",
action="store_true",
help="Use TCP for communication",
)
parser.add_argument(
"--host",
default="127.0.0.1",
help="TCP host (default: 127.0.0.1)",
)
parser.add_argument(
"--port",
type=int,
default=2087,
help="TCP port (default: 2087)",
)
parser.add_argument(
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Log level (default: INFO)",
)
parser.add_argument(
"--log-file",
help="Log file path (optional)",
)
args = parser.parse_args()
# Configure logging
log_handlers = []
if args.log_file:
log_handlers.append(logging.FileHandler(args.log_file))
else:
log_handlers.append(logging.StreamHandler(sys.stderr))
logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=log_handlers,
)
logger.info("Starting codexlens-lsp server")
if args.tcp:
logger.info("Starting TCP server on %s:%d", args.host, args.port)
server.start_tcp(args.host, args.port)
else:
logger.info("Starting stdio server")
server.start_io()
return 0
if __name__ == "__main__":
sys.exit(main())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
"""Model Context Protocol implementation for Claude Code integration."""
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
from codexlens.mcp.provider import MCPProvider
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
__all__ = [
"MCPContext",
"SymbolInfo",
"ReferenceInfo",
"RelatedSymbol",
"MCPProvider",
"HookManager",
"create_context_for_prompt",
]

View File

@@ -0,0 +1,170 @@
"""Hook interfaces for Claude Code integration."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
from codexlens.mcp.schema import MCPContext
if TYPE_CHECKING:
from codexlens.mcp.provider import MCPProvider
logger = logging.getLogger(__name__)
class HookManager:
"""Manages hook registration and execution."""
def __init__(self, mcp_provider: "MCPProvider") -> None:
self.mcp_provider = mcp_provider
self._pre_hooks: Dict[str, Callable] = {}
self._post_hooks: Dict[str, Callable] = {}
# Register default hooks
self._register_default_hooks()
def _register_default_hooks(self) -> None:
"""Register built-in hooks."""
self._pre_hooks["explain"] = self._pre_explain_hook
self._pre_hooks["refactor"] = self._pre_refactor_hook
self._pre_hooks["document"] = self._pre_document_hook
def execute_pre_hook(
self,
action: str,
params: Dict[str, Any],
) -> Optional[MCPContext]:
"""Execute pre-tool hook to gather context.
Args:
action: The action being performed (e.g., "explain", "refactor")
params: Parameters for the action
Returns:
MCPContext to inject into prompt, or None
"""
hook = self._pre_hooks.get(action)
if not hook:
logger.debug(f"No pre-hook for action: {action}")
return None
try:
return hook(params)
except Exception as e:
logger.error(f"Pre-hook failed for {action}: {e}")
return None
def execute_post_hook(
self,
action: str,
result: Any,
) -> None:
"""Execute post-tool hook for proactive caching.
Args:
action: The action that was performed
result: Result of the action
"""
hook = self._post_hooks.get(action)
if not hook:
return
try:
hook(result)
except Exception as e:
logger.error(f"Post-hook failed for {action}: {e}")
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'explain' action."""
symbol_name = params.get("symbol")
if not symbol_name:
return None
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="symbol_explanation",
include_references=True,
include_related=True,
)
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'refactor' action."""
symbol_name = params.get("symbol")
if not symbol_name:
return None
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="refactor_context",
include_references=True,
include_related=True,
max_references=20,
)
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'document' action."""
symbol_name = params.get("symbol")
file_path = params.get("file_path")
if symbol_name:
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="documentation_context",
include_references=False,
include_related=True,
)
elif file_path:
return self.mcp_provider.build_context_for_file(
Path(file_path),
context_type="file_documentation",
)
return None
def register_pre_hook(
self,
action: str,
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
) -> None:
"""Register a custom pre-tool hook."""
self._pre_hooks[action] = hook
def register_post_hook(
self,
action: str,
hook: Callable[[Any], None],
) -> None:
"""Register a custom post-tool hook."""
self._post_hooks[action] = hook
def create_context_for_prompt(
mcp_provider: "MCPProvider",
action: str,
params: Dict[str, Any],
) -> str:
"""Create context string for prompt injection.
This is the main entry point for Claude Code hook integration.
Args:
mcp_provider: The MCP provider instance
action: Action being performed
params: Action parameters
Returns:
Formatted context string for prompt injection
"""
manager = HookManager(mcp_provider)
context = manager.execute_pre_hook(action, params)
if context:
return context.to_prompt_injection()
return ""

View File

@@ -0,0 +1,202 @@
"""MCP context provider."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional, List, TYPE_CHECKING
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
if TYPE_CHECKING:
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.registry import RegistryStore
from codexlens.search.chain_search import ChainSearchEngine
logger = logging.getLogger(__name__)
class MCPProvider:
"""Builds MCP context objects from codex-lens data."""
def __init__(
self,
global_index: "GlobalSymbolIndex",
search_engine: "ChainSearchEngine",
registry: "RegistryStore",
) -> None:
self.global_index = global_index
self.search_engine = search_engine
self.registry = registry
def build_context(
self,
symbol_name: str,
context_type: str = "symbol_explanation",
include_references: bool = True,
include_related: bool = True,
max_references: int = 10,
) -> Optional[MCPContext]:
"""Build comprehensive context for a symbol.
Args:
symbol_name: Name of the symbol to contextualize
context_type: Type of context being requested
include_references: Whether to include reference locations
include_related: Whether to include related symbols
max_references: Maximum number of references to include
Returns:
MCPContext object or None if symbol not found
"""
# Look up symbol
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
if not symbols:
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
return None
symbol = symbols[0]
# Build SymbolInfo
symbol_info = SymbolInfo(
name=symbol.name,
kind=symbol.kind,
file_path=symbol.file or "",
line_start=symbol.range[0],
line_end=symbol.range[1],
signature=None, # Symbol entity doesn't have signature
documentation=None, # Symbol entity doesn't have docstring
)
# Extract definition source code
definition = self._extract_definition(symbol)
# Get references
references = []
if include_references:
refs = self.search_engine.search_references(
symbol_name,
limit=max_references,
)
references = [
ReferenceInfo(
file_path=r.file_path,
line=r.line,
column=r.column,
context=r.context,
relationship_type=r.relationship_type,
)
for r in refs
]
# Get related symbols
related_symbols = []
if include_related:
related_symbols = self._get_related_symbols(symbol)
return MCPContext(
context_type=context_type,
symbol=symbol_info,
definition=definition,
references=references,
related_symbols=related_symbols,
metadata={
"source": "codex-lens",
},
)
def _extract_definition(self, symbol) -> Optional[str]:
"""Extract source code for symbol definition."""
try:
file_path = Path(symbol.file) if symbol.file else None
if not file_path or not file_path.exists():
return None
content = file_path.read_text(encoding='utf-8', errors='ignore')
lines = content.split("\n")
start = symbol.range[0] - 1
end = symbol.range[1]
if start >= len(lines):
return None
return "\n".join(lines[start:end])
except Exception as e:
logger.debug(f"Failed to extract definition: {e}")
return None
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
"""Get symbols related to the given symbol."""
related = []
try:
# Search for symbols that might be related by name patterns
# This is a simplified implementation - could be enhanced with relationship data
# Look for imports/callers via reference search
refs = self.search_engine.search_references(symbol.name, limit=20)
seen_names = set()
for ref in refs:
# Extract potential symbol name from context
if ref.relationship_type and ref.relationship_type not in seen_names:
related.append(RelatedSymbol(
name=f"{Path(ref.file_path).stem}",
kind="module",
relationship=ref.relationship_type,
file_path=ref.file_path,
))
seen_names.add(ref.relationship_type)
if len(related) >= 10:
break
except Exception as e:
logger.debug(f"Failed to get related symbols: {e}")
return related
def build_context_for_file(
self,
file_path: Path,
context_type: str = "file_overview",
) -> MCPContext:
"""Build context for an entire file."""
# Try to get symbols by searching with file path
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
symbols = []
# Search for common symbols that might be in this file
# This is a simplified approach - a full implementation would query by file path
try:
# Use the global index to search for symbols from this file
file_str = str(file_path.resolve())
# Get all symbols and filter by file path (not efficient but works)
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
except Exception as e:
logger.debug(f"Failed to get file symbols: {e}")
related = [
RelatedSymbol(
name=s.name,
kind=s.kind,
relationship="defines",
)
for s in symbols
]
return MCPContext(
context_type=context_type,
related_symbols=related,
metadata={
"file_path": str(file_path),
"symbol_count": len(symbols),
},
)

View File

@@ -0,0 +1,113 @@
"""MCP data models."""
from __future__ import annotations
import json
from dataclasses import dataclass, field, asdict
from typing import List, Optional
@dataclass
class SymbolInfo:
"""Information about a code symbol."""
name: str
kind: str
file_path: str
line_start: int
line_end: int
signature: Optional[str] = None
documentation: Optional[str] = None
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class ReferenceInfo:
"""Information about a symbol reference."""
file_path: str
line: int
column: int
context: str
relationship_type: str
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class RelatedSymbol:
"""Related symbol (import, call target, etc.)."""
name: str
kind: str
relationship: str # "imports", "calls", "inherits", "uses"
file_path: Optional[str] = None
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class MCPContext:
"""Model Context Protocol context object.
This is the structured context that gets injected into
LLM prompts to provide code understanding.
"""
version: str = "1.0"
context_type: str = "code_context"
symbol: Optional[SymbolInfo] = None
definition: Optional[str] = None
references: List[ReferenceInfo] = field(default_factory=list)
related_symbols: List[RelatedSymbol] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
result = {
"version": self.version,
"context_type": self.context_type,
"metadata": self.metadata,
}
if self.symbol:
result["symbol"] = self.symbol.to_dict()
if self.definition:
result["definition"] = self.definition
if self.references:
result["references"] = [r.to_dict() for r in self.references]
if self.related_symbols:
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
return result
def to_json(self, indent: int = 2) -> str:
"""Serialize to JSON string."""
return json.dumps(self.to_dict(), indent=indent)
def to_prompt_injection(self) -> str:
"""Format for injection into LLM prompt."""
parts = ["<code_context>"]
if self.symbol:
parts.append(f"## Symbol: {self.symbol.name}")
parts.append(f"Type: {self.symbol.kind}")
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
if self.definition:
parts.append("\n## Definition")
parts.append(f"```\n{self.definition}\n```")
if self.references:
parts.append(f"\n## References ({len(self.references)} found)")
for ref in self.references[:5]: # Limit to 5
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
parts.append(f" ```\n {ref.context}\n ```")
if self.related_symbols:
parts.append("\n## Related Symbols")
for sym in self.related_symbols[:10]: # Limit to 10
parts.append(f"- {sym.name} ({sym.relationship})")
parts.append("</code_context>")
return "\n".join(parts)

View File

@@ -0,0 +1,8 @@
"""Parsers for CodexLens."""
from __future__ import annotations
from .factory import ParserFactory
__all__ = ["ParserFactory"]

View File

@@ -0,0 +1,202 @@
"""Optional encoding detection module for CodexLens.
Provides automatic encoding detection with graceful fallback to UTF-8.
Install with: pip install codexlens[encoding]
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Tuple, Optional
log = logging.getLogger(__name__)
# Feature flag for encoding detection availability
ENCODING_DETECTION_AVAILABLE = False
_import_error: Optional[str] = None
def _detect_chardet_backend() -> Tuple[bool, Optional[str]]:
"""Detect if chardet or charset-normalizer is available."""
try:
import chardet
return True, None
except ImportError:
pass
try:
from charset_normalizer import from_bytes
return True, None
except ImportError:
pass
return False, "chardet not available. Install with: pip install codexlens[encoding]"
# Initialize on module load
ENCODING_DETECTION_AVAILABLE, _import_error = _detect_chardet_backend()
def check_encoding_available() -> Tuple[bool, Optional[str]]:
"""Check if encoding detection dependencies are available.
Returns:
Tuple of (available, error_message)
"""
return ENCODING_DETECTION_AVAILABLE, _import_error
def detect_encoding(content_bytes: bytes, confidence_threshold: float = 0.7) -> str:
"""Detect encoding from file content bytes.
Uses chardet or charset-normalizer with configurable confidence threshold.
Falls back to UTF-8 if confidence is too low or detection unavailable.
Args:
content_bytes: Raw file content as bytes
confidence_threshold: Minimum confidence (0.0-1.0) to accept detection
Returns:
Detected encoding name (e.g., 'utf-8', 'iso-8859-1', 'gbk')
Returns 'utf-8' as fallback if detection fails or confidence too low
"""
if not ENCODING_DETECTION_AVAILABLE:
log.debug("Encoding detection not available, using UTF-8 fallback")
return "utf-8"
if not content_bytes:
return "utf-8"
try:
# Try chardet first
try:
import chardet
result = chardet.detect(content_bytes)
encoding = result.get("encoding")
confidence = result.get("confidence", 0.0)
if encoding and confidence >= confidence_threshold:
log.debug(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
# Normalize encoding name: replace underscores with hyphens
return encoding.lower().replace('_', '-')
else:
log.debug(
f"Low confidence encoding detection: {encoding} "
f"(confidence: {confidence:.2f}), using UTF-8 fallback"
)
return "utf-8"
except ImportError:
pass
# Fallback to charset-normalizer
try:
from charset_normalizer import from_bytes
results = from_bytes(content_bytes)
if results:
best = results.best()
if best and best.encoding:
log.debug(f"Detected encoding via charset-normalizer: {best.encoding}")
# Normalize encoding name: replace underscores with hyphens
return best.encoding.lower().replace('_', '-')
except ImportError:
pass
except Exception as e:
log.warning(f"Encoding detection failed: {e}, using UTF-8 fallback")
return "utf-8"
def read_file_safe(
path: Path | str,
confidence_threshold: float = 0.7,
max_detection_bytes: int = 100_000
) -> Tuple[str, str]:
"""Read file with automatic encoding detection and safe decoding.
Reads file bytes, detects encoding, and decodes with error replacement
to preserve file structure even with encoding issues.
Args:
path: Path to file to read
confidence_threshold: Minimum confidence for encoding detection
max_detection_bytes: Maximum bytes to use for encoding detection (default 100KB)
Returns:
Tuple of (content, detected_encoding)
- content: Decoded file content (with <20> for unmappable bytes)
- detected_encoding: Detected encoding name
Raises:
OSError: If file cannot be read
IsADirectoryError: If path is a directory
"""
file_path = Path(path) if isinstance(path, str) else path
# Read file bytes
try:
content_bytes = file_path.read_bytes()
except Exception as e:
log.error(f"Failed to read file {file_path}: {e}")
raise
# Detect encoding from first N bytes for performance
detection_sample = content_bytes[:max_detection_bytes] if len(content_bytes) > max_detection_bytes else content_bytes
encoding = detect_encoding(detection_sample, confidence_threshold)
# Decode with error replacement to preserve structure
try:
content = content_bytes.decode(encoding, errors='replace')
log.debug(f"Successfully decoded {file_path} using {encoding}")
return content, encoding
except Exception as e:
# Final fallback to UTF-8 with replacement
log.warning(f"Failed to decode {file_path} with {encoding}, using UTF-8: {e}")
content = content_bytes.decode('utf-8', errors='replace')
return content, 'utf-8'
def is_binary_file(path: Path | str, sample_size: int = 8192) -> bool:
"""Check if file is likely binary by sampling first bytes.
Uses heuristic: if >30% of sample bytes are null or non-text, consider binary.
Args:
path: Path to file to check
sample_size: Number of bytes to sample (default 8KB)
Returns:
True if file appears to be binary, False otherwise
"""
file_path = Path(path) if isinstance(path, str) else path
try:
with file_path.open('rb') as f:
sample = f.read(sample_size)
if not sample:
return False
# Count null bytes and non-printable characters
null_count = sample.count(b'\x00')
non_text_count = sum(1 for byte in sample if byte < 0x20 and byte not in (0x09, 0x0a, 0x0d))
# If >30% null bytes or >50% non-text, consider binary
null_ratio = null_count / len(sample)
non_text_ratio = non_text_count / len(sample)
return null_ratio > 0.3 or non_text_ratio > 0.5
except Exception as e:
log.debug(f"Binary check failed for {file_path}: {e}, assuming text")
return False
__all__ = [
"ENCODING_DETECTION_AVAILABLE",
"check_encoding_available",
"detect_encoding",
"read_file_safe",
"is_binary_file",
]

View File

@@ -0,0 +1,385 @@
"""Parser factory for CodexLens.
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
available. Regex fallbacks are retained to preserve the existing parser
interface and behavior in minimal environments.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Protocol
from codexlens.config import Config
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class Parser(Protocol):
def parse(self, text: str, path: Path) -> IndexedFile: ...
@dataclass
class SimpleRegexParser:
language_id: str
def parse(self, text: str, path: Path) -> IndexedFile:
# Try tree-sitter first for supported languages
if self.language_id in {"python", "javascript", "typescript"}:
ts_parser = TreeSitterSymbolParser(self.language_id, path)
if ts_parser.is_available():
indexed = ts_parser.parse(text, path)
if indexed is not None:
return indexed
# Fallback to regex parsing
if self.language_id == "python":
symbols = _parse_python_symbols_regex(text)
relationships = _parse_python_relationships_regex(text, path)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols_regex(text)
relationships = _parse_js_ts_relationships_regex(text, path)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
relationships = []
elif self.language_id == "go":
symbols = _parse_go_symbols(text)
relationships = []
elif self.language_id == "markdown":
symbols = _parse_markdown_symbols(text)
relationships = []
elif self.language_id == "text":
symbols = _parse_text_symbols(text)
relationships = []
else:
symbols = _parse_generic_symbols(text)
relationships = []
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
class ParserFactory:
def __init__(self, config: Config) -> None:
self.config = config
self._parsers: Dict[str, Parser] = {}
def get_parser(self, language_id: str) -> Parser:
if language_id not in self._parsers:
self._parsers[language_id] = SimpleRegexParser(language_id)
return self._parsers[language_id]
# Regex-based fallback parsers
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_python_symbols(text: str) -> List[Symbol]:
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser("python")
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Optional[Path] = None,
) -> List[Symbol]:
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser(language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
current_class_indent: Optional[int] = None
for i, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_class_indent = len(line) - len(line.lstrip(" "))
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
indent = len(line) - len(line.lstrip(" "))
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
continue
if current_class_indent is not None:
indent = len(line) - len(line.lstrip(" "))
if line.strip() and indent <= current_class_indent:
current_class_indent = None
return symbols
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
current_scope = def_match.group(1)
continue
if current_scope is None:
continue
import_match = _PY_IMPORT_RE.search(line)
if import_match:
import_target = import_match.group(1) or import_match.group(2)
if import_target:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_target.strip(),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _PY_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {
"if",
"for",
"while",
"return",
"print",
"len",
"str",
"int",
"float",
"list",
"dict",
"set",
"tuple",
current_scope,
}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
)
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
in_class = False
class_brace_depth = 0
brace_depth = 0
for i, line in enumerate(text.splitlines(), start=1):
brace_depth += line.count("{") - line.count("}")
class_match = _JS_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
in_class = True
class_brace_depth = brace_depth
continue
if in_class and brace_depth < class_brace_depth:
in_class = False
func_match = _JS_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
continue
if in_class:
method_match = _JS_METHOD_RE.match(line)
if method_match:
name = method_match.group(1)
if name != "constructor":
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
return symbols
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _JS_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
func_match = _JS_FUNC_RE.match(line)
if func_match:
current_scope = func_match.group(1)
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
current_scope = arrow_match.group(1)
continue
if current_scope is None:
continue
import_match = _JS_IMPORT_RE.search(line)
if import_match:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_match.group(1),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _JS_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {current_scope}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
)
def _parse_java_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _JAVA_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
method_match = _JAVA_METHOD_RE.match(line)
if method_match:
symbols.append(Symbol(name=method_match.group(1), kind="method", range=(i, i)))
return symbols
_GO_FUNC_RE = re.compile(r"^\s*func\s+(?:\([^)]+\)\s+)?([A-Za-z_]\w*)\s*\(")
_GO_TYPE_RE = re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+(?:struct|interface)\b")
def _parse_go_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
type_match = _GO_TYPE_RE.match(line)
if type_match:
symbols.append(Symbol(name=type_match.group(1), kind="class", range=(i, i)))
continue
func_match = _GO_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
return symbols
_GENERIC_DEF_RE = re.compile(r"^\s*(?:def|function|func)\s+([A-Za-z_]\w*)\b")
_GENERIC_CLASS_RE = re.compile(r"^\s*(?:class|struct|interface)\s+([A-Za-z_]\w*)\b")
def _parse_generic_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _GENERIC_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _GENERIC_DEF_RE.match(line)
if def_match:
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
return symbols
# Markdown heading regex: # Heading, ## Heading, etc.
_MD_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$")
def _parse_markdown_symbols(text: str) -> List[Symbol]:
"""Parse Markdown headings as symbols.
Extracts # headings as 'section' symbols with heading level as kind suffix.
"""
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
heading_match = _MD_HEADING_RE.match(line)
if heading_match:
level = len(heading_match.group(1))
title = heading_match.group(2).strip()
# Use 'section' kind with level indicator
kind = f"h{level}"
symbols.append(Symbol(name=title, kind=kind, range=(i, i)))
return symbols
def _parse_text_symbols(text: str) -> List[Symbol]:
"""Parse plain text files - no symbols, just index content."""
# Text files don't have structured symbols, return empty list
# The file content will still be indexed for FTS search
return []

View File

@@ -0,0 +1,98 @@
"""Token counting utilities for CodexLens.
Provides accurate token counting using tiktoken with character count fallback.
"""
from __future__ import annotations
from typing import Optional
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
class Tokenizer:
"""Token counter with tiktoken primary and character count fallback."""
def __init__(self, encoding_name: str = "cl100k_base") -> None:
"""Initialize tokenizer.
Args:
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
"""
self._encoding: Optional[object] = None
self._encoding_name = encoding_name
if TIKTOKEN_AVAILABLE:
try:
self._encoding = tiktoken.get_encoding(encoding_name)
except Exception:
# Fallback to character counting if encoding fails
self._encoding = None
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Uses tiktoken if available, otherwise falls back to character count / 4.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if not text:
return 0
if self._encoding is not None:
try:
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
except Exception:
# Fall through to character count fallback
pass
# Fallback: rough estimate using character count
# Average of ~4 characters per token for English text
return max(1, len(text) // 4)
def is_using_tiktoken(self) -> bool:
"""Check if tiktoken is being used.
Returns:
True if tiktoken is available and initialized
"""
return self._encoding is not None
# Global default tokenizer instance
_default_tokenizer: Optional[Tokenizer] = None
def get_default_tokenizer() -> Tokenizer:
"""Get the global default tokenizer instance.
Returns:
Shared Tokenizer instance
"""
global _default_tokenizer
if _default_tokenizer is None:
_default_tokenizer = Tokenizer()
return _default_tokenizer
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
"""Count tokens in text using default or provided tokenizer.
Args:
text: Text to count tokens for
tokenizer: Optional tokenizer instance (uses default if None)
Returns:
Estimated token count
"""
if tokenizer is None:
tokenizer = get_default_tokenizer()
return tokenizer.count_tokens(text)

View File

@@ -0,0 +1,809 @@
"""Tree-sitter based parser for CodexLens.
Provides precise AST-level parsing via tree-sitter.
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
return `None`; callers should use a regex-based fallback such as
`codexlens.parsers.factory.SimpleRegexParser`.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
class TreeSitterSymbolParser:
"""Parser using tree-sitter for AST-level symbol extraction."""
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize tree-sitter parser for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
path: Optional file path for language variant detection (e.g., .tsx)
"""
self.language_id = language_id
self.path = path
self._parser: Optional[object] = None
self._language: Optional[TreeSitterLanguage] = None
self._tokenizer = get_default_tokenizer()
if TREE_SITTER_AVAILABLE:
self._initialize_parser()
def _initialize_parser(self) -> None:
"""Initialize tree-sitter parser and language."""
if TreeSitterParser is None or TreeSitterLanguage is None:
return
try:
# Load language grammar
if self.language_id == "python":
import tree_sitter_python
self._language = TreeSitterLanguage(tree_sitter_python.language())
elif self.language_id == "javascript":
import tree_sitter_javascript
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
elif self.language_id == "typescript":
import tree_sitter_typescript
# Detect TSX files by extension
if self.path is not None and self.path.suffix.lower() == ".tsx":
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
else:
return
# Create parser
self._parser = TreeSitterParser()
if hasattr(self._parser, "set_language"):
self._parser.set_language(self._language) # type: ignore[attr-defined]
else:
self._parser.language = self._language # type: ignore[assignment]
except Exception:
# Gracefully handle missing language bindings
self._parser = None
self._language = None
def is_available(self) -> bool:
"""Check if tree-sitter parser is available.
Returns:
True if parser is initialized and ready
"""
return self._parser is not None and self._language is not None
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
if not self.is_available() or self._parser is None:
return None
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
return source_bytes, tree.root_node
except Exception:
return None
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
"""Parse source code and extract symbols without creating IndexedFile.
Args:
text: Source code text
Returns:
List of symbols if parsing succeeds, None if tree-sitter unavailable
"""
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
return self._extract_symbols(source_bytes, root)
except Exception:
# Gracefully handle extraction errors
return None
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse source code and extract symbols.
Args:
text: Source code text
path: File path
Returns:
IndexedFile if parsing succeeds, None if tree-sitter unavailable
"""
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
symbols = self._extract_symbols(source_bytes, root)
relationships = self._extract_relationships(source_bytes, root, path)
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
except Exception:
# Gracefully handle parsing errors
return None
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of extracted symbols
"""
if self.language_id == "python":
return self._extract_python_symbols(source_bytes, root)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_symbols(source_bytes, root)
else:
return []
def _extract_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, path)
if self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, path)
return []
def _extract_python_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"self", "cls"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "class_definition" and pushed_scope:
superclasses = node.child_by_field_name("superclasses")
if superclasses is not None:
for child in superclasses.children:
dotted = self._python_expression_to_dotted(source_bytes, child)
if not dotted:
continue
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type in {"import_statement", "import_from_statement"}:
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
if node.type == "call":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _extract_js_ts_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"this", "super"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if pushed_scope:
superclass = node.child_by_field_name("superclass")
if superclass is not None:
dotted = self._js_expression_to_dotted(source_bytes, superclass)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type in {"identifier", "property_identifier"}
and value_node.type == "arrow_function"
):
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name and scope_name != "constructor":
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"import_declaration", "import_statement"}:
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
# Best-effort support for CommonJS require() imports:
# const fs = require("fs")
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type == "identifier"
and value_node.type == "call_expression"
):
callee = value_node.child_by_field_name("function")
args = value_node.child_by_field_name("arguments")
if (
callee is not None
and self._node_text(source_bytes, callee).strip() == "require"
and args is not None
):
module_name = self._js_first_string_argument(source_bytes, args)
if module_name:
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
record_import(module_name, self._node_start_line(node))
if node.type == "call_expression":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _node_start_line(self, node: TreeSitterNode) -> int:
return node.start_point[0] + 1
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
dotted = (dotted or "").strip()
if not dotted:
return ""
base, sep, rest = dotted.partition(".")
resolved_base = aliases.get(base, base)
if not rest:
return resolved_base
if resolved_base and rest:
return f"{resolved_base}.{rest}"
return resolved_base
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"identifier", "dotted_name"}:
return self._node_text(source_bytes, node).strip()
if node.type == "attribute":
obj = node.child_by_field_name("object")
attr = node.child_by_field_name("attribute")
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
if obj_text and attr_text:
return f"{obj_text}.{attr_text}"
return obj_text or attr_text
return ""
def _python_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
if node.type == "import_statement":
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
module_name = self._node_text(source_bytes, name_node).strip()
if not module_name:
continue
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else module_name.split(".", 1)[0]
)
if bound_name:
aliases[bound_name] = module_name
targets.append(module_name)
elif child.type == "dotted_name":
module_name = self._node_text(source_bytes, child).strip()
if not module_name:
continue
bound_name = module_name.split(".", 1)[0]
if bound_name:
aliases[bound_name] = bound_name
targets.append(module_name)
if node.type == "import_from_statement":
module_name = ""
module_node = node.child_by_field_name("module_name")
if module_node is None:
for child in node.children:
if child.type == "dotted_name":
module_node = child
break
if module_node is not None:
module_name = self._node_text(source_bytes, module_node).strip()
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
imported_name = self._node_text(source_bytes, name_node).strip()
if not imported_name or imported_name == "*":
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported_name
)
if bound_name:
aliases[bound_name] = target
targets.append(target)
elif child.type == "identifier":
imported_name = self._node_text(source_bytes, child).strip()
if not imported_name or imported_name in {"from", "import", "*"}:
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
aliases[imported_name] = target
targets.append(target)
return aliases, targets
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"this", "super"}:
return node.type
if node.type in {"identifier", "property_identifier"}:
return self._node_text(source_bytes, node).strip()
if node.type == "member_expression":
obj = node.child_by_field_name("object")
prop = node.child_by_field_name("property")
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
if obj_text and prop_text:
return f"{obj_text}.{prop_text}"
return obj_text or prop_text
return ""
def _js_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
module_name = ""
source_node = node.child_by_field_name("source")
if source_node is not None:
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
if module_name:
targets.append(module_name)
for child in node.children:
if child.type == "import_clause":
for clause_child in child.children:
if clause_child.type == "identifier":
# Default import: import React from "react"
local = self._node_text(source_bytes, clause_child).strip()
if local and module_name:
aliases[local] = module_name
if clause_child.type == "namespace_import":
# Namespace import: import * as fs from "fs"
name_node = clause_child.child_by_field_name("name")
if name_node is not None and module_name:
local = self._node_text(source_bytes, name_node).strip()
if local:
aliases[local] = module_name
if clause_child.type == "named_imports":
for spec in clause_child.children:
if spec.type != "import_specifier":
continue
name_node = spec.child_by_field_name("name")
alias_node = spec.child_by_field_name("alias")
if name_node is None:
continue
imported = self._node_text(source_bytes, name_node).strip()
if not imported:
continue
local = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported
)
if local and module_name:
aliases[local] = f"{module_name}.{imported}"
targets.append(f"{module_name}.{imported}")
return aliases, targets
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
for child in args_node.children:
if child.type == "string":
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
return ""
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract Python symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of Python symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind=self._python_function_kind(node),
range=self._node_range(node),
))
return symbols
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract JavaScript/TypeScript symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of JS/TS symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = self._node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=self._node_range(node),
))
return symbols
def _python_function_kind(self, node: TreeSitterNode) -> str:
"""Determine if Python function is a method or standalone function.
Args:
node: Function definition node
Returns:
'method' if inside a class, 'function' otherwise
"""
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
"""Check if node has a class ancestor.
Args:
node: AST node to check
Returns:
True if node is inside a class
"""
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
"""Get line range for a node.
Args:
node: AST node
Returns:
(start_line, end_line) tuple, 1-based inclusive
"""
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Args:
text: Text to count tokens for
Returns:
Token count
"""
return self._tokenizer.count_tokens(text)

View File

@@ -0,0 +1,53 @@
from .chain_search import (
ChainSearchEngine,
SearchOptions,
SearchStats,
ChainSearchResult,
quick_search,
)
# Clustering availability flag (lazy import pattern)
CLUSTERING_AVAILABLE = False
_clustering_import_error: str | None = None
try:
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
from .clustering import check_clustering_available
CLUSTERING_AVAILABLE = _clustering_flag
except ImportError as e:
_clustering_import_error = str(e)
def check_clustering_available() -> tuple[bool, str | None]:
"""Fallback when clustering module not loadable."""
return False, _clustering_import_error
# Clustering module exports (conditional)
try:
from .clustering import (
BaseClusteringStrategy,
ClusteringConfig,
ClusteringStrategyFactory,
get_strategy,
)
_clustering_exports = [
"BaseClusteringStrategy",
"ClusteringConfig",
"ClusteringStrategyFactory",
"get_strategy",
]
except ImportError:
_clustering_exports = []
__all__ = [
"ChainSearchEngine",
"SearchOptions",
"SearchStats",
"ChainSearchResult",
"quick_search",
# Clustering
"CLUSTERING_AVAILABLE",
"check_clustering_available",
*_clustering_exports,
]

View File

@@ -0,0 +1,21 @@
"""Association tree module for LSP-based code relationship discovery.
This module provides components for building and processing call association trees
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from .builder import AssociationTreeBuilder
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
from .deduplicator import ResultDeduplicator
__all__ = [
"AssociationTreeBuilder",
"CallTree",
"TreeNode",
"UniqueNode",
"ResultDeduplicator",
]

View File

@@ -0,0 +1,450 @@
"""Association tree builder using LSP call hierarchy.
Builds call relationship trees by recursively expanding from seed locations
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Dict, List, Optional, Set
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
from codexlens.lsp.standalone_manager import StandaloneLspManager
from .data_structures import CallTree, TreeNode
logger = logging.getLogger(__name__)
class AssociationTreeBuilder:
"""Builds association trees from seed locations using LSP call hierarchy.
Uses depth-first recursive expansion to build a tree of code relationships
starting from seed locations (typically from vector search results).
Strategy:
- Start from seed locations (vector search results)
- For each seed, get call hierarchy items via LSP
- Recursively expand incoming calls (callers) if expand_callers=True
- Recursively expand outgoing calls (callees) if expand_callees=True
- Track visited nodes to prevent cycles
- Stop at max_depth or when no more relations found
Attributes:
lsp_manager: StandaloneLspManager for LSP communication
visited: Set of visited node IDs to prevent cycles
timeout: Timeout for individual LSP requests (seconds)
"""
def __init__(
self,
lsp_manager: StandaloneLspManager,
timeout: float = 5.0,
analysis_wait: float = 2.0,
):
"""Initialize AssociationTreeBuilder.
Args:
lsp_manager: StandaloneLspManager instance for LSP communication
timeout: Timeout for individual LSP requests in seconds
analysis_wait: Time to wait for LSP analysis on first file (seconds)
"""
self.lsp_manager = lsp_manager
self.timeout = timeout
self.analysis_wait = analysis_wait
self.visited: Set[str] = set()
self._analyzed_files: Set[str] = set() # Track files already analyzed
async def build_tree(
self,
seed_file_path: str,
seed_line: int,
seed_character: int = 1,
max_depth: int = 5,
expand_callers: bool = True,
expand_callees: bool = True,
) -> CallTree:
"""Build call tree from a single seed location.
Args:
seed_file_path: Path to the seed file
seed_line: Line number of the seed symbol (1-based)
seed_character: Character position (1-based, default 1)
max_depth: Maximum recursion depth (default 5)
expand_callers: Whether to expand incoming calls (callers)
expand_callees: Whether to expand outgoing calls (callees)
Returns:
CallTree containing all discovered nodes and relationships
"""
tree = CallTree()
self.visited.clear()
# Determine wait time - only wait for analysis on first encounter of file
wait_time = 0.0
if seed_file_path not in self._analyzed_files:
wait_time = self.analysis_wait
self._analyzed_files.add(seed_file_path)
# Get call hierarchy items for the seed position
try:
hierarchy_items = await asyncio.wait_for(
self.lsp_manager.get_call_hierarchy_items(
file_path=seed_file_path,
line=seed_line,
character=seed_character,
wait_for_analysis=wait_time,
),
timeout=self.timeout + wait_time,
)
except asyncio.TimeoutError:
logger.warning(
"Timeout getting call hierarchy items for %s:%d",
seed_file_path,
seed_line,
)
return tree
except Exception as e:
logger.error(
"Error getting call hierarchy items for %s:%d: %s",
seed_file_path,
seed_line,
e,
)
return tree
if not hierarchy_items:
logger.debug(
"No call hierarchy items found for %s:%d",
seed_file_path,
seed_line,
)
return tree
# Create root nodes from hierarchy items
for item_dict in hierarchy_items:
# Convert LSP dict to CallHierarchyItem
item = self._dict_to_call_hierarchy_item(item_dict)
if not item:
continue
root_node = TreeNode(
item=item,
depth=0,
path_from_root=[self._create_node_id(item)],
)
tree.roots.append(root_node)
tree.add_node(root_node)
# Mark as visited
self.visited.add(root_node.node_id)
# Recursively expand the tree
await self._expand_node(
node=root_node,
node_dict=item_dict,
tree=tree,
current_depth=0,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
tree.depth_reached = max_depth
return tree
async def _expand_node(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Recursively expand a node by fetching its callers and callees.
Args:
node: TreeNode to expand
node_dict: LSP CallHierarchyItem dict (for LSP requests)
tree: CallTree to add discovered nodes to
current_depth: Current recursion depth
max_depth: Maximum allowed depth
expand_callers: Whether to expand incoming calls
expand_callees: Whether to expand outgoing calls
"""
# Stop if max depth reached
if current_depth >= max_depth:
return
# Prepare tasks for parallel expansion
tasks = []
if expand_callers:
tasks.append(
self._expand_incoming_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
if expand_callees:
tasks.append(
self._expand_outgoing_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
# Execute expansions in parallel
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _expand_incoming_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand incoming calls (callers) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to continue expanding callers
expand_callees: Whether to expand callees
"""
try:
incoming_calls = await asyncio.wait_for(
self.lsp_manager.get_incoming_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting incoming calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting incoming calls for %s: %s", node.node_id, e)
return
if not incoming_calls:
return
# Process each incoming call
for call_dict in incoming_calls:
caller_dict = call_dict.get("from")
if not caller_dict:
continue
# Convert to CallHierarchyItem
caller_item = self._dict_to_call_hierarchy_item(caller_dict)
if not caller_item:
continue
caller_id = self._create_node_id(caller_item)
# Check for cycles
if caller_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [caller_id],
)
node.parents.append(cycle_node)
continue
# Create new caller node
caller_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [caller_id],
)
# Add to tree
tree.add_node(caller_node)
tree.add_edge(caller_node, node)
# Update relationships
node.parents.append(caller_node)
caller_node.children.append(node)
# Mark as visited
self.visited.add(caller_id)
# Recursively expand the caller
await self._expand_node(
node=caller_node,
node_dict=caller_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
async def _expand_outgoing_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand outgoing calls (callees) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to expand callers
expand_callees: Whether to continue expanding callees
"""
try:
outgoing_calls = await asyncio.wait_for(
self.lsp_manager.get_outgoing_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting outgoing calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting outgoing calls for %s: %s", node.node_id, e)
return
if not outgoing_calls:
return
# Process each outgoing call
for call_dict in outgoing_calls:
callee_dict = call_dict.get("to")
if not callee_dict:
continue
# Convert to CallHierarchyItem
callee_item = self._dict_to_call_hierarchy_item(callee_dict)
if not callee_item:
continue
callee_id = self._create_node_id(callee_item)
# Check for cycles
if callee_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [callee_id],
)
node.children.append(cycle_node)
continue
# Create new callee node
callee_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [callee_id],
)
# Add to tree
tree.add_node(callee_node)
tree.add_edge(node, callee_node)
# Update relationships
node.children.append(callee_node)
callee_node.parents.append(node)
# Mark as visited
self.visited.add(callee_id)
# Recursively expand the callee
await self._expand_node(
node=callee_node,
node_dict=callee_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
def _dict_to_call_hierarchy_item(
self, item_dict: Dict
) -> Optional[CallHierarchyItem]:
"""Convert LSP dict to CallHierarchyItem.
Args:
item_dict: LSP CallHierarchyItem dictionary
Returns:
CallHierarchyItem or None if conversion fails
"""
try:
# Extract URI and convert to file path
uri = item_dict.get("uri", "")
file_path = uri.replace("file:///", "").replace("file://", "")
# Handle Windows paths (file:///C:/...)
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
# Extract range
range_dict = item_dict.get("range", {})
start = range_dict.get("start", {})
end = range_dict.get("end", {})
# Create Range (convert from 0-based to 1-based)
item_range = Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
)
return CallHierarchyItem(
name=item_dict.get("name", "unknown"),
kind=str(item_dict.get("kind", "unknown")),
file_path=file_path,
range=item_range,
detail=item_dict.get("detail"),
)
except Exception as e:
logger.debug("Failed to convert dict to CallHierarchyItem: %s", e)
return None
def _create_node_id(self, item: CallHierarchyItem) -> str:
"""Create unique node ID from CallHierarchyItem.
Args:
item: CallHierarchyItem
Returns:
Unique node ID string
"""
return f"{item.file_path}:{item.name}:{item.range.start_line}"

View File

@@ -0,0 +1,191 @@
"""Data structures for association tree building.
Defines the core data classes for representing call hierarchy trees and
deduplicated results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
@dataclass
class TreeNode:
"""Node in the call association tree.
Represents a single function/method in the tree, including its position
in the hierarchy and relationships.
Attributes:
item: LSP CallHierarchyItem containing symbol information
depth: Distance from the root node (seed) - 0 for roots
children: List of child nodes (functions called by this node)
parents: List of parent nodes (functions that call this node)
is_cycle: Whether this node creates a circular reference
path_from_root: Path (list of node IDs) from root to this node
"""
item: CallHierarchyItem
depth: int = 0
children: List[TreeNode] = field(default_factory=list)
parents: List[TreeNode] = field(default_factory=list)
is_cycle: bool = False
path_from_root: List[str] = field(default_factory=list)
@property
def node_id(self) -> str:
"""Unique identifier for this node."""
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
def __hash__(self) -> int:
"""Hash based on node ID."""
return hash(self.node_id)
def __eq__(self, other: object) -> bool:
"""Equality based on node ID."""
if not isinstance(other, TreeNode):
return False
return self.node_id == other.node_id
def __repr__(self) -> str:
"""String representation of the node."""
cycle_marker = " [CYCLE]" if self.is_cycle else ""
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
@dataclass
class CallTree:
"""Complete call tree structure built from seeds.
Contains all nodes discovered through recursive expansion and
the relationships between them.
Attributes:
roots: List of root nodes (seed symbols)
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
node_list: Flat list of all nodes in tree order
edges: List of (from_node_id, to_node_id) tuples representing calls
depth_reached: Maximum depth achieved in expansion
"""
roots: List[TreeNode] = field(default_factory=list)
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
node_list: List[TreeNode] = field(default_factory=list)
edges: List[tuple[str, str]] = field(default_factory=list)
depth_reached: int = 0
def add_node(self, node: TreeNode) -> None:
"""Add a node to the tree.
Args:
node: TreeNode to add
"""
if node.node_id not in self.all_nodes:
self.all_nodes[node.node_id] = node
self.node_list.append(node)
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
"""Add an edge between two nodes.
Args:
from_node: Source node
to_node: Target node
"""
edge = (from_node.node_id, to_node.node_id)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[TreeNode]:
"""Get a node by ID.
Args:
node_id: Node identifier
Returns:
TreeNode if found, None otherwise
"""
return self.all_nodes.get(node_id)
def __len__(self) -> int:
"""Return total number of nodes in tree."""
return len(self.all_nodes)
def __repr__(self) -> str:
"""String representation of the tree."""
return (
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
f"depth={self.depth_reached})"
)
@dataclass
class UniqueNode:
"""Deduplicated unique code symbol from the tree.
Represents a single unique code location that may appear multiple times
in the tree under different contexts. Contains aggregated information
about all occurrences.
Attributes:
file_path: Absolute path to the file
name: Symbol name (function, method, class, etc.)
kind: Symbol kind (function, method, class, etc.)
range: Code range in the file
min_depth: Minimum depth at which this node appears in the tree
occurrences: Number of times this node appears in the tree
paths: List of paths from roots to this node
context_nodes: Related nodes from the tree
score: Composite relevance score (higher is better)
"""
file_path: str
name: str
kind: str
range: Range
min_depth: int = 0
occurrences: int = 1
paths: List[List[str]] = field(default_factory=list)
context_nodes: List[str] = field(default_factory=list)
score: float = 0.0
@property
def node_key(self) -> tuple[str, int, int]:
"""Unique key for deduplication.
Uses (file_path, start_line, end_line) as the unique identifier
for this symbol across all occurrences.
"""
return (
self.file_path,
self.range.start_line,
self.range.end_line,
)
def add_path(self, path: List[str]) -> None:
"""Add a path from root to this node.
Args:
path: List of node IDs from root to this node
"""
if path not in self.paths:
self.paths.append(path)
def __hash__(self) -> int:
"""Hash based on node key."""
return hash(self.node_key)
def __eq__(self, other: object) -> bool:
"""Equality based on node key."""
if not isinstance(other, UniqueNode):
return False
return self.node_key == other.node_key
def __repr__(self) -> str:
"""String representation of the unique node."""
return (
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
)

View File

@@ -0,0 +1,301 @@
"""Result deduplication for association tree nodes.
Provides functionality to extract unique nodes from a call tree and assign
relevance scores based on various factors.
"""
from __future__ import annotations
import logging
from typing import Dict, List, Optional
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
logger = logging.getLogger(__name__)
# Symbol kind weights for scoring (higher = more relevant)
KIND_WEIGHTS: Dict[str, float] = {
# Functions and methods are primary targets
"function": 1.0,
"method": 1.0,
"12": 1.0, # LSP SymbolKind.Function
"6": 1.0, # LSP SymbolKind.Method
# Classes are important but secondary
"class": 0.8,
"5": 0.8, # LSP SymbolKind.Class
# Interfaces and types
"interface": 0.7,
"11": 0.7, # LSP SymbolKind.Interface
"type": 0.6,
# Constructors
"constructor": 0.9,
"9": 0.9, # LSP SymbolKind.Constructor
# Variables and constants
"variable": 0.4,
"13": 0.4, # LSP SymbolKind.Variable
"constant": 0.5,
"14": 0.5, # LSP SymbolKind.Constant
# Default for unknown kinds
"unknown": 0.3,
}
class ResultDeduplicator:
"""Extracts and scores unique nodes from call trees.
Processes a CallTree to extract unique code locations, merging duplicates
and assigning relevance scores based on:
- Depth: Shallower nodes (closer to seeds) score higher
- Frequency: Nodes appearing multiple times score higher
- Kind: Function/method > class > variable
Attributes:
depth_weight: Weight for depth factor in scoring (default 0.4)
frequency_weight: Weight for frequency factor (default 0.3)
kind_weight: Weight for symbol kind factor (default 0.3)
max_depth_penalty: Maximum depth before full penalty applied
"""
def __init__(
self,
depth_weight: float = 0.4,
frequency_weight: float = 0.3,
kind_weight: float = 0.3,
max_depth_penalty: int = 10,
):
"""Initialize ResultDeduplicator.
Args:
depth_weight: Weight for depth factor (0.0-1.0)
frequency_weight: Weight for frequency factor (0.0-1.0)
kind_weight: Weight for symbol kind factor (0.0-1.0)
max_depth_penalty: Depth at which score becomes 0 for depth factor
"""
self.depth_weight = depth_weight
self.frequency_weight = frequency_weight
self.kind_weight = kind_weight
self.max_depth_penalty = max_depth_penalty
def deduplicate(
self,
tree: CallTree,
max_results: Optional[int] = None,
) -> List[UniqueNode]:
"""Extract unique nodes from the call tree.
Traverses the tree, groups nodes by their unique key (file_path,
start_line, end_line), and merges duplicate occurrences.
Args:
tree: CallTree to process
max_results: Maximum number of results to return (None = all)
Returns:
List of UniqueNode objects, sorted by score descending
"""
if not tree.node_list:
return []
# Group nodes by unique key
unique_map: Dict[tuple, UniqueNode] = {}
for node in tree.node_list:
if node.is_cycle:
# Skip cycle markers - they point to already-counted nodes
continue
key = self._get_node_key(node)
if key in unique_map:
# Update existing unique node
unique_node = unique_map[key]
unique_node.occurrences += 1
unique_node.min_depth = min(unique_node.min_depth, node.depth)
unique_node.add_path(node.path_from_root)
# Collect context from relationships
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
else:
# Create new unique node
unique_node = UniqueNode(
file_path=node.item.file_path,
name=node.item.name,
kind=node.item.kind,
range=node.item.range,
min_depth=node.depth,
occurrences=1,
paths=[node.path_from_root.copy()],
context_nodes=[],
score=0.0,
)
# Collect initial context
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
unique_map[key] = unique_node
# Calculate scores for all unique nodes
unique_nodes = list(unique_map.values())
# Find max frequency for normalization
max_frequency = max((n.occurrences for n in unique_nodes), default=1)
for node in unique_nodes:
node.score = self._score_node(node, max_frequency)
# Sort by score descending
unique_nodes.sort(key=lambda n: n.score, reverse=True)
# Apply max_results limit
if max_results is not None and max_results > 0:
unique_nodes = unique_nodes[:max_results]
logger.debug(
"Deduplicated %d tree nodes to %d unique nodes",
len(tree.node_list),
len(unique_nodes),
)
return unique_nodes
def _score_node(
self,
node: UniqueNode,
max_frequency: int,
) -> float:
"""Calculate composite score for a unique node.
Score = depth_weight * depth_score +
frequency_weight * frequency_score +
kind_weight * kind_score
Args:
node: UniqueNode to score
max_frequency: Maximum occurrence count for normalization
Returns:
Composite score between 0.0 and 1.0
"""
# Depth score: closer to root = higher score
# Score of 1.0 at depth 0, decreasing to 0.0 at max_depth_penalty
depth_score = max(
0.0,
1.0 - (node.min_depth / self.max_depth_penalty),
)
# Frequency score: more occurrences = higher score
frequency_score = node.occurrences / max_frequency if max_frequency > 0 else 0.0
# Kind score: function/method > class > variable
kind_str = str(node.kind).lower()
kind_score = KIND_WEIGHTS.get(kind_str, KIND_WEIGHTS["unknown"])
# Composite score
score = (
self.depth_weight * depth_score
+ self.frequency_weight * frequency_score
+ self.kind_weight * kind_score
)
return score
def _get_node_key(self, node: TreeNode) -> tuple:
"""Get unique key for a tree node.
Uses (file_path, start_line, end_line) as the unique identifier.
Args:
node: TreeNode
Returns:
Tuple key for deduplication
"""
return (
node.item.file_path,
node.item.range.start_line,
node.item.range.end_line,
)
def filter_by_kind(
self,
nodes: List[UniqueNode],
kinds: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by symbol kind.
Args:
nodes: List of UniqueNode to filter
kinds: List of allowed kinds (e.g., ["function", "method"])
Returns:
Filtered list of UniqueNode
"""
kinds_lower = [k.lower() for k in kinds]
return [
node
for node in nodes
if str(node.kind).lower() in kinds_lower
]
def filter_by_file(
self,
nodes: List[UniqueNode],
file_patterns: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by file path patterns.
Args:
nodes: List of UniqueNode to filter
file_patterns: List of path substrings to match
Returns:
Filtered list of UniqueNode
"""
return [
node
for node in nodes
if any(pattern in node.file_path for pattern in file_patterns)
]
def to_dict_list(self, nodes: List[UniqueNode]) -> List[Dict]:
"""Convert list of UniqueNode to JSON-serializable dicts.
Args:
nodes: List of UniqueNode
Returns:
List of dictionaries
"""
return [
{
"file_path": node.file_path,
"name": node.name,
"kind": node.kind,
"range": {
"start_line": node.range.start_line,
"start_character": node.range.start_character,
"end_line": node.range.end_line,
"end_character": node.range.end_character,
},
"min_depth": node.min_depth,
"occurrences": node.occurrences,
"path_count": len(node.paths),
"score": round(node.score, 4),
}
for node in nodes
]

View File

@@ -0,0 +1,277 @@
"""Binary vector searcher for cascade search.
This module provides fast binary vector search using Hamming distance
for the first stage of cascade search (coarse filtering).
Supports two loading modes:
1. Memory-mapped file (preferred): Low memory footprint, OS-managed paging
2. Database loading (fallback): Loads all vectors into RAM
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
# Pre-computed popcount lookup table for vectorized Hamming distance
# Each byte value (0-255) maps to its bit count
_POPCOUNT_TABLE = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
class BinarySearcher:
"""Fast binary vector search using Hamming distance.
This class implements the first stage of cascade search:
fast, approximate retrieval using binary vectors and Hamming distance.
The binary vectors are derived from dense embeddings by thresholding:
binary[i] = 1 if dense[i] > 0 else 0
Hamming distance between two binary vectors counts the number of
differing bits, which can be computed very efficiently using XOR
and population count.
Supports two loading modes:
- Memory-mapped file (preferred): Uses np.memmap for minimal RAM usage
- Database (fallback): Loads all vectors into memory from SQLite
"""
def __init__(self, index_root_or_meta_path: Path) -> None:
"""Initialize BinarySearcher.
Args:
index_root_or_meta_path: Either:
- Path to index root directory (containing _binary_vectors.mmap)
- Path to _vectors_meta.db (legacy mode, loads from DB)
"""
path = Path(index_root_or_meta_path)
# Determine if this is an index root or a specific DB path
if path.suffix == '.db':
# Legacy mode: specific DB path
self.index_root = path.parent
self.meta_store_path = path
else:
# New mode: index root directory
self.index_root = path
self.meta_store_path = path / "_vectors_meta.db"
self._chunk_ids: Optional[np.ndarray] = None
self._binary_matrix: Optional[np.ndarray] = None
self._is_memmap = False
self._loaded = False
def load(self) -> bool:
"""Load binary vectors using memory-mapped file or database fallback.
Tries to load from memory-mapped file first (preferred for large indexes),
falls back to database loading if mmap file doesn't exist.
Returns:
True if vectors were loaded successfully.
"""
if self._loaded:
return True
# Try memory-mapped file first (preferred)
mmap_path = self.index_root / "_binary_vectors.mmap"
meta_path = mmap_path.with_suffix('.meta.json')
if mmap_path.exists() and meta_path.exists():
try:
with open(meta_path, 'r') as f:
meta = json.load(f)
shape = tuple(meta['shape'])
self._chunk_ids = np.array(meta['chunk_ids'], dtype=np.int64)
# Memory-map the binary matrix (read-only)
self._binary_matrix = np.memmap(
str(mmap_path),
dtype=np.uint8,
mode='r',
shape=shape
)
self._is_memmap = True
self._loaded = True
logger.info(
"Memory-mapped %d binary vectors (%d bytes each)",
len(self._chunk_ids), shape[1]
)
return True
except Exception as e:
logger.warning("Failed to load mmap binary vectors, falling back to DB: %s", e)
# Fallback: load from database
return self._load_from_db()
def _load_from_db(self) -> bool:
"""Load binary vectors from database (legacy/fallback mode).
Returns:
True if vectors were loaded successfully.
"""
try:
from codexlens.storage.vector_meta_store import VectorMetadataStore
with VectorMetadataStore(self.meta_store_path) as store:
rows = store.get_all_binary_vectors()
if not rows:
logger.warning("No binary vectors found in %s", self.meta_store_path)
return False
# Convert to numpy arrays for fast computation
self._chunk_ids = np.array([r[0] for r in rows], dtype=np.int64)
# Unpack bytes to numpy array
binary_arrays = []
for _, vec_bytes in rows:
arr = np.frombuffer(vec_bytes, dtype=np.uint8)
binary_arrays.append(arr)
self._binary_matrix = np.vstack(binary_arrays)
self._is_memmap = False
self._loaded = True
logger.info(
"Loaded %d binary vectors from DB (%d bytes each)",
len(self._chunk_ids), self._binary_matrix.shape[1]
)
return True
except Exception as e:
logger.error("Failed to load binary vectors: %s", e)
return False
def search(
self,
query_vector: np.ndarray,
top_k: int = 100
) -> List[Tuple[int, int]]:
"""Search for similar vectors using Hamming distance.
Args:
query_vector: Dense query vector (will be binarized).
top_k: Number of top results to return.
Returns:
List of (chunk_id, hamming_distance) tuples sorted by distance.
"""
if not self._loaded and not self.load():
return []
# Binarize query vector
query_binary = (query_vector > 0).astype(np.uint8)
query_packed = np.packbits(query_binary)
# Compute Hamming distances using XOR and popcount
# XOR gives 1 for differing bits
xor_result = np.bitwise_xor(self._binary_matrix, query_packed)
# Vectorized popcount using lookup table (orders of magnitude faster)
# Sum the bit counts for each byte across all columns
distances = np.sum(_POPCOUNT_TABLE[xor_result], axis=1, dtype=np.int32)
# Get top-k with smallest distances
if top_k >= len(distances):
top_indices = np.argsort(distances)
else:
# Partial sort for efficiency
top_indices = np.argpartition(distances, top_k)[:top_k]
top_indices = top_indices[np.argsort(distances[top_indices])]
results = [
(int(self._chunk_ids[i]), int(distances[i]))
for i in top_indices
]
return results
def search_with_rerank(
self,
query_dense: np.ndarray,
dense_vectors: np.ndarray,
dense_chunk_ids: np.ndarray,
top_k: int = 10,
candidates: int = 100
) -> List[Tuple[int, float]]:
"""Two-stage cascade search: binary filter + dense rerank.
Args:
query_dense: Dense query vector.
dense_vectors: Dense vectors for reranking (from HNSW or stored).
dense_chunk_ids: Chunk IDs corresponding to dense_vectors.
top_k: Final number of results.
candidates: Number of candidates from binary search.
Returns:
List of (chunk_id, cosine_similarity) tuples.
"""
# Stage 1: Binary filtering
binary_results = self.search(query_dense, top_k=candidates)
if not binary_results:
return []
candidate_ids = {r[0] for r in binary_results}
# Stage 2: Dense reranking
# Find indices of candidates in dense_vectors
candidate_mask = np.isin(dense_chunk_ids, list(candidate_ids))
candidate_indices = np.where(candidate_mask)[0]
if len(candidate_indices) == 0:
# Fallback: return binary results with normalized distance
max_dist = max(r[1] for r in binary_results) if binary_results else 1
return [(r[0], 1.0 - r[1] / max_dist) for r in binary_results[:top_k]]
# Compute cosine similarities for candidates
candidate_vectors = dense_vectors[candidate_indices]
candidate_ids_array = dense_chunk_ids[candidate_indices]
# Normalize vectors
query_norm = query_dense / (np.linalg.norm(query_dense) + 1e-8)
cand_norms = candidate_vectors / (
np.linalg.norm(candidate_vectors, axis=1, keepdims=True) + 1e-8
)
# Cosine similarities
similarities = np.dot(cand_norms, query_norm)
# Sort by similarity (descending)
sorted_indices = np.argsort(-similarities)[:top_k]
results = [
(int(candidate_ids_array[i]), float(similarities[i]))
for i in sorted_indices
]
return results
@property
def vector_count(self) -> int:
"""Get number of loaded binary vectors."""
return len(self._chunk_ids) if self._chunk_ids is not None else 0
@property
def is_memmap(self) -> bool:
"""Check if using memory-mapped file (vs in-memory array)."""
return self._is_memmap
def clear(self) -> None:
"""Clear loaded vectors from memory."""
# For memmap, just delete the reference (OS will handle cleanup)
if self._is_memmap and self._binary_matrix is not None:
del self._binary_matrix
self._chunk_ids = None
self._binary_matrix = None
self._is_memmap = False
self._loaded = False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,124 @@
"""Clustering strategies for the staged hybrid search pipeline.
This module provides extensible clustering infrastructure for grouping
similar search results and selecting representative results.
Install with: pip install codexlens[clustering]
Example:
>>> from codexlens.search.clustering import (
... CLUSTERING_AVAILABLE,
... ClusteringConfig,
... get_strategy,
... )
>>> config = ClusteringConfig(min_cluster_size=3)
>>> # Auto-select best available strategy with fallback
>>> strategy = get_strategy("auto", config)
>>> representatives = strategy.fit_predict(embeddings, results)
>>>
>>> # Or explicitly use a specific strategy
>>> if CLUSTERING_AVAILABLE:
... from codexlens.search.clustering import HDBSCANStrategy
... strategy = HDBSCANStrategy(config)
... representatives = strategy.fit_predict(embeddings, results)
"""
from __future__ import annotations
# Always export base classes and factory (no heavy dependencies)
from .base import BaseClusteringStrategy, ClusteringConfig
from .factory import (
ClusteringStrategyFactory,
check_clustering_strategy_available,
get_strategy,
)
from .noop_strategy import NoOpStrategy
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
# Feature flag for clustering availability (hdbscan + sklearn)
CLUSTERING_AVAILABLE = False
HDBSCAN_AVAILABLE = False
DBSCAN_AVAILABLE = False
_import_error: str | None = None
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
"""Detect if clustering dependencies are available.
Returns:
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
"""
hdbscan_ok = False
dbscan_ok = False
try:
import hdbscan # noqa: F401
hdbscan_ok = True
except ImportError:
pass
try:
from sklearn.cluster import DBSCAN # noqa: F401
dbscan_ok = True
except ImportError:
pass
all_ok = hdbscan_ok and dbscan_ok
error = None
if not all_ok:
missing = []
if not hdbscan_ok:
missing.append("hdbscan")
if not dbscan_ok:
missing.append("scikit-learn")
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
return all_ok, hdbscan_ok, dbscan_ok, error
# Initialize on module load
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
_detect_clustering_available()
)
def check_clustering_available() -> tuple[bool, str | None]:
"""Check if all clustering dependencies are available.
Returns:
Tuple of (is_available, error_message).
error_message is None if available, otherwise contains install instructions.
"""
return CLUSTERING_AVAILABLE, _import_error
# Conditionally export strategy implementations
__all__ = [
# Feature flags
"CLUSTERING_AVAILABLE",
"HDBSCAN_AVAILABLE",
"DBSCAN_AVAILABLE",
"check_clustering_available",
# Base classes
"BaseClusteringStrategy",
"ClusteringConfig",
# Factory
"ClusteringStrategyFactory",
"get_strategy",
"check_clustering_strategy_available",
# Always-available strategies
"NoOpStrategy",
"FrequencyStrategy",
"FrequencyConfig",
]
# Conditionally add strategy classes to __all__ and module namespace
if HDBSCAN_AVAILABLE:
from .hdbscan_strategy import HDBSCANStrategy
__all__.append("HDBSCANStrategy")
if DBSCAN_AVAILABLE:
from .dbscan_strategy import DBSCANStrategy
__all__.append("DBSCANStrategy")

View File

@@ -0,0 +1,153 @@
"""Base classes for clustering strategies in the hybrid search pipeline.
This module defines the abstract base class for clustering strategies used
in the staged hybrid search pipeline. Strategies cluster search results
based on their embeddings and select representative results from each cluster.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
@dataclass
class ClusteringConfig:
"""Configuration parameters for clustering strategies.
Attributes:
min_cluster_size: Minimum number of results to form a cluster.
HDBSCAN default is 5, but for search results 2-3 is often better.
min_samples: Number of samples in a neighborhood for a point to be
considered a core point. Lower values allow more clusters.
metric: Distance metric for clustering. Common options:
- 'euclidean': Standard L2 distance
- 'cosine': Cosine distance (1 - cosine_similarity)
- 'manhattan': L1 distance
cluster_selection_epsilon: Distance threshold for cluster selection.
Results within this distance may be merged into the same cluster.
allow_single_cluster: If True, allow all results to form one cluster.
Useful when results are very similar.
prediction_data: If True, generate prediction data for new points.
"""
min_cluster_size: int = 3
min_samples: int = 2
metric: str = "cosine"
cluster_selection_epsilon: float = 0.0
allow_single_cluster: bool = True
prediction_data: bool = False
def __post_init__(self) -> None:
"""Validate configuration parameters."""
if self.min_cluster_size < 2:
raise ValueError("min_cluster_size must be >= 2")
if self.min_samples < 1:
raise ValueError("min_samples must be >= 1")
if self.metric not in ("euclidean", "cosine", "manhattan"):
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
if self.cluster_selection_epsilon < 0:
raise ValueError("cluster_selection_epsilon must be >= 0")
class BaseClusteringStrategy(ABC):
"""Abstract base class for clustering strategies.
Clustering strategies are used in the staged hybrid search pipeline to
group similar search results and select representative results from each
cluster, reducing redundancy while maintaining diversity.
Subclasses must implement:
- cluster(): Group results into clusters based on embeddings
- select_representatives(): Choose best result(s) from each cluster
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize the clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
"""
self.config = config or ClusteringConfig()
@abstractmethod
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results based on their embeddings.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Used for additional metadata during clustering.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Results not assigned to any cluster
(noise points) should be returned as single-element clusters.
Example:
>>> strategy = HDBSCANStrategy()
>>> clusters = strategy.cluster(embeddings, results)
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
>>> # Result indices 0, 2, 5 are in cluster 0
>>> # Result indices 1, 3 are in cluster 1
>>> # Result index 4 is a noise point (singleton cluster)
>>> # Result indices 6, 7, 8 are in cluster 2
"""
...
@abstractmethod
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
This method chooses the best result(s) from each cluster to include
in the final search results. The selection can be based on:
- Highest score within cluster
- Closest to cluster centroid
- Custom selection logic
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings array for centroid-based selection.
Returns:
List of representative SearchResult objects, one or more per cluster,
ordered by relevance (highest score first).
Example:
>>> representatives = strategy.select_representatives(clusters, results)
>>> # Returns best result from each cluster
"""
...
def fit_predict(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List["SearchResult"]:
"""Convenience method to cluster and select representatives in one call.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
results: List of SearchResult objects.
Returns:
List of representative SearchResult objects.
"""
clusters = self.cluster(embeddings, results)
return self.select_representatives(clusters, results, embeddings)

View File

@@ -0,0 +1,197 @@
"""DBSCAN-based clustering strategy for search results.
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
is the fallback clustering strategy when HDBSCAN is not available.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class DBSCANStrategy(BaseClusteringStrategy):
"""DBSCAN-based clustering strategy.
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
DBSCAN requires an explicit eps parameter, which is auto-computed from the
distance distribution if not provided.
Example:
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
>>> strategy = DBSCANStrategy(config)
>>> clusters = strategy.cluster(embeddings, results)
>>> representatives = strategy.select_representatives(clusters, results)
"""
# Default eps percentile for auto-computation
DEFAULT_EPS_PERCENTILE: float = 15.0
def __init__(
self,
config: Optional[ClusteringConfig] = None,
eps: Optional[float] = None,
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
) -> None:
"""Initialize DBSCAN clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
from the distance distribution.
eps_percentile: Percentile of pairwise distances to use for
auto-computing eps. Default is 15th percentile.
Raises:
ImportError: If sklearn is not installed.
"""
super().__init__(config)
self.eps = eps
self.eps_percentile = eps_percentile
# Validate sklearn is available
try:
from sklearn.cluster import DBSCAN # noqa: F401
except ImportError as exc:
raise ImportError(
"scikit-learn package is required for DBSCANStrategy. "
"Install with: pip install codexlens[clustering]"
) from exc
def _compute_eps(self, embeddings: "np.ndarray") -> float:
"""Auto-compute eps from pairwise distance distribution.
Uses the specified percentile of pairwise distances as eps,
which typically captures local density well.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
Returns:
Computed eps value.
"""
import numpy as np
from sklearn.metrics import pairwise_distances
# Compute pairwise distances
distances = pairwise_distances(embeddings, metric=self.config.metric)
# Get upper triangle (excluding diagonal)
upper_tri = distances[np.triu_indices_from(distances, k=1)]
if len(upper_tri) == 0:
# Only one point, return a default small eps
return 0.1
# Use percentile of distances as eps
eps = float(np.percentile(upper_tri, self.eps_percentile))
# Ensure eps is positive
return max(eps, 1e-6)
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results using DBSCAN algorithm.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Noise points are returned as singleton clusters.
"""
from sklearn.cluster import DBSCAN
import numpy as np
n_results = len(results)
if n_results == 0:
return []
# Handle edge case: single result
if n_results == 1:
return [[0]]
# Determine eps value
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
# Configure DBSCAN clusterer
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
clusterer = DBSCAN(
eps=eps,
min_samples=self.config.min_samples,
metric=self.config.metric,
)
# Fit and get cluster labels
# Labels: -1 = noise, 0+ = cluster index
labels = clusterer.fit_predict(embeddings)
# Group indices by cluster label
cluster_map: dict[int, list[int]] = {}
for idx, label in enumerate(labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(idx)
# Build result: non-noise clusters first, then noise as singletons
clusters: List[List[int]] = []
# Add proper clusters (label >= 0)
for label in sorted(cluster_map.keys()):
if label >= 0:
clusters.append(cluster_map[label])
# Add noise points as singleton clusters (label == -1)
if -1 in cluster_map:
for idx in cluster_map[-1]:
clusters.append([idx])
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
Selects the result with the highest score from each cluster.
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used in score-based selection).
Returns:
List of representative SearchResult objects, one per cluster,
ordered by score (highest first).
"""
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
# Find the result with the highest score in this cluster
best_idx = max(cluster_indices, key=lambda i: results[i].score)
representatives.append(results[best_idx])
# Sort by score descending
representatives.sort(key=lambda r: r.score, reverse=True)
return representatives

View File

@@ -0,0 +1,202 @@
"""Factory for creating clustering strategies.
Provides a unified interface for instantiating different clustering backends
with automatic fallback chain: hdbscan -> dbscan -> noop.
"""
from __future__ import annotations
from typing import Any, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
from .noop_strategy import NoOpStrategy
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
"""Check whether a specific clustering strategy can be used.
Args:
strategy: Strategy name to check. Options:
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
- "dbscan": DBSCAN clustering (requires sklearn)
- "frequency": Frequency-based clustering (always available)
- "noop": No-op strategy (always available)
Returns:
Tuple of (is_available, error_message).
error_message is None if available, otherwise contains install instructions.
"""
strategy = (strategy or "").strip().lower()
if strategy == "hdbscan":
try:
import hdbscan # noqa: F401
except ImportError:
return False, (
"hdbscan package not available. "
"Install with: pip install codexlens[clustering]"
)
return True, None
if strategy == "dbscan":
try:
from sklearn.cluster import DBSCAN # noqa: F401
except ImportError:
return False, (
"scikit-learn package not available. "
"Install with: pip install codexlens[clustering]"
)
return True, None
if strategy == "frequency":
# Frequency strategy is always available (no external deps)
return True, None
if strategy == "noop":
return True, None
return False, (
f"Invalid clustering strategy: {strategy}. "
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
)
def get_strategy(
strategy: str = "hdbscan",
config: Optional[ClusteringConfig] = None,
*,
fallback: bool = True,
**kwargs: Any,
) -> BaseClusteringStrategy:
"""Factory function to create clustering strategy with fallback chain.
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
Args:
strategy: Clustering strategy to use. Options:
- "hdbscan": HDBSCAN clustering (default, recommended)
- "dbscan": DBSCAN clustering (fallback)
- "frequency": Frequency-based clustering (groups by symbol occurrence)
- "noop": No-op strategy (returns all results ungrouped)
- "auto": Try hdbscan, then dbscan, then noop
config: Clustering configuration. Uses defaults if not provided.
For frequency strategy, pass FrequencyConfig for full control.
fallback: If True (default), automatically fall back to next strategy
in the chain when primary is unavailable. If False, raise ImportError
when requested strategy is unavailable.
**kwargs: Additional strategy-specific arguments.
For DBSCANStrategy: eps, eps_percentile
For FrequencyStrategy: group_by, min_frequency, etc.
Returns:
BaseClusteringStrategy: Configured clustering strategy instance.
Raises:
ValueError: If strategy is not recognized.
ImportError: If required dependencies are not installed and fallback=False.
Example:
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3)
>>> # Auto-select best available strategy
>>> strategy = get_strategy("auto", config)
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
>>> strategy = get_strategy("hdbscan", config)
>>> # Use frequency-based strategy
>>> from codexlens.search.clustering import FrequencyConfig
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
>>> strategy = get_strategy("frequency", freq_config)
"""
strategy = (strategy or "").strip().lower()
# Handle "auto" - try strategies in order
if strategy == "auto":
return _get_best_available_strategy(config, **kwargs)
if strategy == "hdbscan":
ok, err = check_clustering_strategy_available("hdbscan")
if ok:
from .hdbscan_strategy import HDBSCANStrategy
return HDBSCANStrategy(config)
if fallback:
# Try dbscan fallback
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
if ok_dbscan:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
# Final fallback to noop
return NoOpStrategy(config)
raise ImportError(err)
if strategy == "dbscan":
ok, err = check_clustering_strategy_available("dbscan")
if ok:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
if fallback:
# Fallback to noop
return NoOpStrategy(config)
raise ImportError(err)
if strategy == "frequency":
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
if config is None or not isinstance(config, FrequencyConfig):
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
else:
freq_config = config
return FrequencyStrategy(freq_config)
if strategy == "noop":
return NoOpStrategy(config)
raise ValueError(
f"Unknown clustering strategy: {strategy}. "
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
)
def _get_best_available_strategy(
config: Optional[ClusteringConfig] = None,
**kwargs: Any,
) -> BaseClusteringStrategy:
"""Get the best available clustering strategy.
Tries strategies in order: hdbscan -> dbscan -> noop
Args:
config: Clustering configuration.
**kwargs: Additional strategy-specific arguments.
Returns:
Best available clustering strategy instance.
"""
# Try HDBSCAN first
ok, _ = check_clustering_strategy_available("hdbscan")
if ok:
from .hdbscan_strategy import HDBSCANStrategy
return HDBSCANStrategy(config)
# Try DBSCAN second
ok, _ = check_clustering_strategy_available("dbscan")
if ok:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
# Fallback to NoOp
return NoOpStrategy(config)
# Alias for backward compatibility
ClusteringStrategyFactory = type(
"ClusteringStrategyFactory",
(),
{
"get_strategy": staticmethod(get_strategy),
"check_available": staticmethod(check_clustering_strategy_available),
},
)

View File

@@ -0,0 +1,263 @@
"""Frequency-based clustering strategy for search result deduplication.
This strategy groups search results by symbol/method name and prunes based on
occurrence frequency. High-frequency symbols (frequently referenced methods)
are considered more important and retained, while low-frequency results
(potentially noise) can be filtered out.
Use cases:
- Prioritize commonly called methods/functions
- Filter out one-off results that may be less relevant
- Deduplicate results pointing to the same symbol from different locations
"""
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
@dataclass
class FrequencyConfig(ClusteringConfig):
"""Configuration for frequency-based clustering strategy.
Attributes:
group_by: Field to group results by for frequency counting.
- 'symbol': Group by symbol_name (default, for method/function dedup)
- 'file': Group by file path
- 'symbol_kind': Group by symbol type (function, class, etc.)
min_frequency: Minimum occurrence count to keep a result.
Results appearing less than this are considered noise and pruned.
max_representatives_per_group: Maximum results to keep per symbol group.
frequency_weight: How much to boost score based on frequency.
Final score = original_score * (1 + frequency_weight * log(frequency))
keep_mode: How to handle low-frequency results.
- 'filter': Remove results below min_frequency
- 'demote': Keep but lower their score ranking
"""
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
max_representatives_per_group: int = 3
frequency_weight: float = 0.1 # Boost factor for frequency
keep_mode: Literal["filter", "demote"] = "demote"
def __post_init__(self) -> None:
"""Validate configuration parameters."""
# Skip parent validation since we don't use HDBSCAN params
if self.min_frequency < 1:
raise ValueError("min_frequency must be >= 1")
if self.max_representatives_per_group < 1:
raise ValueError("max_representatives_per_group must be >= 1")
if self.frequency_weight < 0:
raise ValueError("frequency_weight must be >= 0")
if self.group_by not in ("symbol", "file", "symbol_kind"):
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
if self.keep_mode not in ("filter", "demote"):
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
class FrequencyStrategy(BaseClusteringStrategy):
"""Frequency-based clustering strategy for search result deduplication.
This strategy groups search results by symbol name (or file/kind) and:
1. Counts how many times each symbol appears in results
2. Higher frequency = more important (frequently referenced method)
3. Filters or demotes low-frequency results
4. Selects top representatives from each frequency group
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
- Does NOT require embeddings (works with metadata only)
- Is very fast (O(n) complexity)
- Is deterministic (no random initialization)
- Works well for symbol-level deduplication
Example:
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
>>> strategy = FrequencyStrategy(config)
>>> # Results with symbol "authenticate" appearing 5 times
>>> # will be prioritized over "helper_func" appearing once
>>> representatives = strategy.fit_predict(embeddings, results)
"""
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
"""Initialize the frequency strategy.
Args:
config: Frequency configuration. Uses defaults if not provided.
"""
self.config: FrequencyConfig = config or FrequencyConfig()
def _get_group_key(self, result: "SearchResult") -> str:
"""Extract grouping key from a search result.
Args:
result: SearchResult to extract key from.
Returns:
String key for grouping (symbol name, file path, or kind).
"""
if self.config.group_by == "symbol":
# Use symbol_name if available, otherwise fall back to file:line
symbol = getattr(result, "symbol_name", None)
if symbol:
return str(symbol)
# Fallback: use file path + start_line as pseudo-symbol
start_line = getattr(result, "start_line", 0) or 0
return f"{result.path}:{start_line}"
elif self.config.group_by == "file":
return str(result.path)
elif self.config.group_by == "symbol_kind":
kind = getattr(result, "symbol_kind", None)
return str(kind) if kind else "unknown"
return str(result.path) # Default fallback
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Group search results by frequency of occurrence.
Note: This method ignores embeddings and groups by metadata only.
The embeddings parameter is kept for interface compatibility.
Args:
embeddings: Ignored (kept for interface compatibility).
results: List of SearchResult objects to cluster.
Returns:
List of clusters (groups), where each cluster contains indices
of results with the same grouping key. Clusters are ordered by
frequency (highest frequency first).
"""
if not results:
return []
# Group results by key
groups: Dict[str, List[int]] = defaultdict(list)
for idx, result in enumerate(results):
key = self._get_group_key(result)
groups[key].append(idx)
# Sort groups by frequency (descending) then by key (for stability)
sorted_groups = sorted(
groups.items(),
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
)
# Convert to list of clusters
clusters = [indices for _, indices in sorted_groups]
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results based on frequency and score.
For each frequency group:
1. If frequency < min_frequency: filter or demote based on keep_mode
2. Sort by score within group
3. Apply frequency boost to scores
4. Select top N representatives
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (used for tie-breaking if provided).
Returns:
List of representative SearchResult objects, ordered by
frequency-adjusted score (highest first).
"""
import math
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
demoted: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
frequency = len(cluster_indices)
# Get results in this cluster, sorted by score
cluster_results = [results[i] for i in cluster_indices]
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
# Check frequency threshold
if frequency < self.config.min_frequency:
if self.config.keep_mode == "filter":
# Skip low-frequency results entirely
continue
else: # demote mode
# Keep but add to demoted list (lower priority)
for result in cluster_results[: self.config.max_representatives_per_group]:
demoted.append(result)
continue
# Apply frequency boost and select top representatives
for result in cluster_results[: self.config.max_representatives_per_group]:
# Calculate frequency-boosted score
original_score = getattr(result, "score", 0.0)
# log(frequency + 1) to handle frequency=1 case smoothly
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
boosted_score = original_score * frequency_boost
# Create new result with boosted score and frequency metadata
# Note: SearchResult might be immutable, so we preserve original
# and track boosted score in metadata
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
result.metadata["frequency"] = frequency
result.metadata["frequency_boosted_score"] = boosted_score
representatives.append(result)
# Sort representatives by boosted score (or original score as fallback)
def get_sort_score(r: "SearchResult") -> float:
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
return getattr(r, "score", 0.0)
representatives.sort(key=get_sort_score, reverse=True)
# Add demoted results at the end
if demoted:
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
representatives.extend(demoted)
return representatives
def fit_predict(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List["SearchResult"]:
"""Convenience method to cluster and select representatives in one call.
Args:
embeddings: NumPy array (may be ignored for frequency-based clustering).
results: List of SearchResult objects.
Returns:
List of representative SearchResult objects.
"""
clusters = self.cluster(embeddings, results)
return self.select_representatives(clusters, results, embeddings)

View File

@@ -0,0 +1,153 @@
"""HDBSCAN-based clustering strategy for search results.
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
is the primary clustering strategy for grouping similar search results.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class HDBSCANStrategy(BaseClusteringStrategy):
"""HDBSCAN-based clustering strategy.
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
HDBSCAN is preferred over DBSCAN because it:
- Automatically determines the number of clusters
- Handles varying density clusters well
- Identifies noise points (outliers) effectively
Example:
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
>>> strategy = HDBSCANStrategy(config)
>>> clusters = strategy.cluster(embeddings, results)
>>> representatives = strategy.select_representatives(clusters, results)
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize HDBSCAN clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
Raises:
ImportError: If hdbscan package is not installed.
"""
super().__init__(config)
# Validate hdbscan is available
try:
import hdbscan # noqa: F401
except ImportError as exc:
raise ImportError(
"hdbscan package is required for HDBSCANStrategy. "
"Install with: pip install codexlens[clustering]"
) from exc
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results using HDBSCAN algorithm.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Noise points are returned as singleton clusters.
"""
import hdbscan
import numpy as np
n_results = len(results)
if n_results == 0:
return []
# Handle edge case: fewer results than min_cluster_size
if n_results < self.config.min_cluster_size:
# Return each result as its own singleton cluster
return [[i] for i in range(n_results)]
# Configure HDBSCAN clusterer
clusterer = hdbscan.HDBSCAN(
min_cluster_size=self.config.min_cluster_size,
min_samples=self.config.min_samples,
metric=self.config.metric,
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
allow_single_cluster=self.config.allow_single_cluster,
prediction_data=self.config.prediction_data,
)
# Fit and get cluster labels
# Labels: -1 = noise, 0+ = cluster index
labels = clusterer.fit_predict(embeddings)
# Group indices by cluster label
cluster_map: dict[int, list[int]] = {}
for idx, label in enumerate(labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(idx)
# Build result: non-noise clusters first, then noise as singletons
clusters: List[List[int]] = []
# Add proper clusters (label >= 0)
for label in sorted(cluster_map.keys()):
if label >= 0:
clusters.append(cluster_map[label])
# Add noise points as singleton clusters (label == -1)
if -1 in cluster_map:
for idx in cluster_map[-1]:
clusters.append([idx])
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
Selects the result with the highest score from each cluster.
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used in score-based selection).
Returns:
List of representative SearchResult objects, one per cluster,
ordered by score (highest first).
"""
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
# Find the result with the highest score in this cluster
best_idx = max(cluster_indices, key=lambda i: results[i].score)
representatives.append(results[best_idx])
# Sort by score descending
representatives.sort(key=lambda r: r.score, reverse=True)
return representatives

View File

@@ -0,0 +1,83 @@
"""No-op clustering strategy for search results.
NoOpStrategy returns all results ungrouped when clustering dependencies
are not available or clustering is disabled.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class NoOpStrategy(BaseClusteringStrategy):
"""No-op clustering strategy that returns all results ungrouped.
This strategy is used as a final fallback when no clustering dependencies
are available, or when clustering is explicitly disabled. Each result
is treated as its own singleton cluster.
Example:
>>> from codexlens.search.clustering import NoOpStrategy
>>> strategy = NoOpStrategy()
>>> clusters = strategy.cluster(embeddings, results)
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
>>> representatives = strategy.select_representatives(clusters, results)
>>> # Returns all results sorted by score
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize NoOp clustering strategy.
Args:
config: Clustering configuration. Ignored for NoOpStrategy
but accepted for interface compatibility.
"""
super().__init__(config)
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Return each result as its own singleton cluster.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
Not used but accepted for interface compatibility.
results: List of SearchResult objects.
Returns:
List of singleton clusters, one per result.
"""
return [[i] for i in range(len(results))]
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Return all results sorted by score.
Since each cluster is a singleton, this effectively returns all
results sorted by score descending.
Args:
clusters: List of singleton clusters.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used).
Returns:
All SearchResult objects sorted by score (highest first).
"""
if not results:
return []
# Return all results sorted by score
return sorted(results, key=lambda r: r.score, reverse=True)

View File

@@ -0,0 +1,171 @@
# codex-lens/src/codexlens/search/enrichment.py
"""Relationship enrichment for search results."""
import sqlite3
from pathlib import Path
from typing import List, Dict, Any, Optional
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.graph_expander import GraphExpander
from codexlens.storage.path_mapper import PathMapper
class RelationshipEnricher:
"""Enriches search results with code graph relationships."""
def __init__(self, index_path: Path):
"""Initialize with path to index database.
Args:
index_path: Path to _index.db SQLite database
"""
self.index_path = index_path
self.db_conn: Optional[sqlite3.Connection] = None
self._connect()
def _connect(self) -> None:
"""Establish read-only database connection."""
if self.index_path.exists():
self.db_conn = sqlite3.connect(
f"file:{self.index_path}?mode=ro",
uri=True,
check_same_thread=False
)
self.db_conn.row_factory = sqlite3.Row
def enrich(self, results: List[Dict[str, Any]], limit: int = 10) -> List[Dict[str, Any]]:
"""Add relationship data to search results.
Args:
results: List of search result dictionaries
limit: Maximum number of results to enrich
Returns:
Results with relationships field added
"""
if not self.db_conn:
return results
for result in results[:limit]:
file_path = result.get('file') or result.get('path')
symbol_name = result.get('symbol')
result['relationships'] = self._find_relationships(file_path, symbol_name)
return results
def _find_relationships(self, file_path: Optional[str], symbol_name: Optional[str]) -> List[Dict[str, Any]]:
"""Query relationships for a symbol.
Args:
file_path: Path to file containing the symbol
symbol_name: Name of the symbol
Returns:
List of relationship dictionaries with type, direction, target/source, file, line
"""
if not self.db_conn or not symbol_name:
return []
relationships = []
cursor = self.db_conn.cursor()
try:
# Find symbol ID(s) by name and optionally file
if file_path:
cursor.execute(
'SELECT id FROM symbols WHERE name = ? AND file_path = ?',
(symbol_name, file_path)
)
else:
cursor.execute('SELECT id FROM symbols WHERE name = ?', (symbol_name,))
symbol_ids = [row[0] for row in cursor.fetchall()]
if not symbol_ids:
return []
# Query outgoing relationships (symbol is source)
placeholders = ','.join('?' * len(symbol_ids))
cursor.execute(f'''
SELECT sr.relationship_type, sr.target_symbol_fqn, sr.file_path, sr.line
FROM symbol_relationships sr
WHERE sr.source_symbol_id IN ({placeholders})
''', symbol_ids)
for row in cursor.fetchall():
relationships.append({
'type': row[0],
'direction': 'outgoing',
'target': row[1],
'file': row[2],
'line': row[3],
})
# Query incoming relationships (symbol is target)
# Match against symbol name or qualified name patterns
cursor.execute('''
SELECT sr.relationship_type, s.name AS source_name, sr.file_path, sr.line
FROM symbol_relationships sr
JOIN symbols s ON sr.source_symbol_id = s.id
WHERE sr.target_symbol_fqn = ? OR sr.target_symbol_fqn LIKE ?
''', (symbol_name, f'%.{symbol_name}'))
for row in cursor.fetchall():
rel_type = row[0]
# Convert to incoming type
incoming_type = self._to_incoming_type(rel_type)
relationships.append({
'type': incoming_type,
'direction': 'incoming',
'source': row[1],
'file': row[2],
'line': row[3],
})
except sqlite3.Error:
return []
return relationships
def _to_incoming_type(self, outgoing_type: str) -> str:
"""Convert outgoing relationship type to incoming type.
Args:
outgoing_type: The outgoing relationship type (e.g., 'calls', 'imports')
Returns:
Corresponding incoming type (e.g., 'called_by', 'imported_by')
"""
type_map = {
'calls': 'called_by',
'imports': 'imported_by',
'extends': 'extended_by',
}
return type_map.get(outgoing_type, f'{outgoing_type}_by')
def close(self) -> None:
"""Close database connection."""
if self.db_conn:
self.db_conn.close()
self.db_conn = None
def __enter__(self) -> 'RelationshipEnricher':
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
class SearchEnrichmentPipeline:
"""Search post-processing pipeline (optional enrichments)."""
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
self._config = config
self._graph_expander = GraphExpander(mapper, config=config)
def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]:
"""Expand base results with related symbols when enabled in config."""
if self._config is None or not getattr(self._config, "enable_graph_expansion", False):
return []
depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2)
return self._graph_expander.expand(results, depth=depth)

View File

@@ -0,0 +1,264 @@
"""Graph expansion for search results using precomputed neighbors.
Expands top search results with related symbol definitions by traversing
precomputed N-hop neighbors stored in the per-directory index databases.
"""
from __future__ import annotations
import logging
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.storage.path_mapper import PathMapper
logger = logging.getLogger(__name__)
def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]:
return (result.path, result.symbol_name, result.start_line, result.end_line)
def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]:
if content is None:
return None
if start_line is None or end_line is None:
return None
if start_line < 1 or end_line < start_line:
return None
lines = content.splitlines()
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
if start_idx >= len(lines):
return None
return "\n".join(lines[start_idx:end_idx])
class GraphExpander:
"""Expands SearchResult lists with related symbols from the code graph."""
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
self._mapper = mapper
self._config = config
self._logger = logging.getLogger(__name__)
def expand(
self,
results: Sequence[SearchResult],
*,
depth: Optional[int] = None,
max_expand: int = 10,
max_related: int = 50,
) -> List[SearchResult]:
"""Expand top results with related symbols.
Args:
results: Base ranked results.
depth: Maximum relationship depth to include (defaults to Config or 2).
max_expand: Only expand the top-N base results to bound cost.
max_related: Maximum related results to return.
Returns:
A list of related SearchResult objects with relationship_depth metadata.
"""
if not results:
return []
configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2
max_depth = int(depth if depth is not None else configured_depth)
if max_depth <= 0:
return []
max_depth = min(max_depth, 2)
expand_count = max(0, int(max_expand))
related_limit = max(0, int(max_related))
if expand_count == 0 or related_limit == 0:
return []
seen = {_result_key(r) for r in results}
related_results: List[SearchResult] = []
conn_cache: Dict[Path, sqlite3.Connection] = {}
try:
for base in list(results)[:expand_count]:
if len(related_results) >= related_limit:
break
if not base.symbol_name or not base.path:
continue
index_path = self._mapper.source_to_index_db(Path(base.path).parent)
conn = conn_cache.get(index_path)
if conn is None:
conn = self._connect_readonly(index_path)
if conn is None:
continue
conn_cache[index_path] = conn
source_ids = self._resolve_source_symbol_ids(
conn,
file_path=base.path,
symbol_name=base.symbol_name,
symbol_kind=base.symbol_kind,
)
if not source_ids:
continue
for source_id in source_ids:
neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit)
for neighbor_id, rel_depth in neighbors:
if len(related_results) >= related_limit:
break
row = self._get_symbol_details(conn, neighbor_id)
if row is None:
continue
path = str(row["full_path"])
symbol_name = str(row["name"])
symbol_kind = str(row["kind"])
start_line = int(row["start_line"]) if row["start_line"] is not None else None
end_line = int(row["end_line"]) if row["end_line"] is not None else None
content_block = _slice_content_block(
str(row["content"]) if row["content"] is not None else "",
start_line,
end_line,
)
score = float(base.score) * (0.5 ** int(rel_depth))
candidate = SearchResult(
path=path,
score=max(0.0, score),
excerpt=None,
content=content_block,
start_line=start_line,
end_line=end_line,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
metadata={"relationship_depth": int(rel_depth)},
)
key = _result_key(candidate)
if key in seen:
continue
seen.add(key)
related_results.append(candidate)
finally:
for conn in conn_cache.values():
try:
conn.close()
except Exception:
pass
return related_results
def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]:
try:
if not index_path.exists() or index_path.stat().st_size == 0:
return None
except OSError:
return None
try:
conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
except Exception as exc:
self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc)
return None
def _resolve_source_symbol_ids(
self,
conn: sqlite3.Connection,
*,
file_path: str,
symbol_name: str,
symbol_kind: Optional[str],
) -> List[int]:
try:
if symbol_kind:
rows = conn.execute(
"""
SELECT s.id
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE f.full_path = ? AND s.name = ? AND s.kind = ?
""",
(file_path, symbol_name, symbol_kind),
).fetchall()
else:
rows = conn.execute(
"""
SELECT s.id
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE f.full_path = ? AND s.name = ?
""",
(file_path, symbol_name),
).fetchall()
except sqlite3.Error:
return []
ids: List[int] = []
for row in rows:
try:
ids.append(int(row["id"]))
except Exception:
continue
return ids
def _get_neighbors(
self,
conn: sqlite3.Connection,
source_symbol_id: int,
*,
max_depth: int,
limit: int,
) -> List[Tuple[int, int]]:
try:
rows = conn.execute(
"""
SELECT neighbor_symbol_id, relationship_depth
FROM graph_neighbors
WHERE source_symbol_id = ? AND relationship_depth <= ?
ORDER BY relationship_depth ASC, neighbor_symbol_id ASC
LIMIT ?
""",
(int(source_symbol_id), int(max_depth), int(limit)),
).fetchall()
except sqlite3.Error:
return []
neighbors: List[Tuple[int, int]] = []
for row in rows:
try:
neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"])))
except Exception:
continue
return neighbors
def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]:
try:
return conn.execute(
"""
SELECT
s.id,
s.name,
s.kind,
s.start_line,
s.end_line,
f.full_path,
f.content
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE s.id = ?
""",
(int(symbol_id),),
).fetchone()
except sqlite3.Error:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,242 @@
"""Query preprocessing for CodexLens search.
Provides query expansion for better identifier matching:
- CamelCase splitting: UserAuth → User OR Auth
- snake_case splitting: user_auth → user OR auth
- Preserves original query for exact matching
"""
from __future__ import annotations
import logging
import re
from typing import Set, List
log = logging.getLogger(__name__)
class QueryParser:
"""Parser for preprocessing search queries before FTS5 execution.
Expands identifier-style queries (CamelCase, snake_case) into OR queries
to improve recall when searching for code symbols.
Example transformations:
- 'UserAuth''UserAuth OR User OR Auth'
- 'user_auth''user_auth OR user OR auth'
- 'getUserData''getUserData OR get OR User OR Data'
"""
# Patterns for identifier splitting
CAMEL_CASE_PATTERN = re.compile(r'([a-z])([A-Z])')
SNAKE_CASE_PATTERN = re.compile(r'_+')
KEBAB_CASE_PATTERN = re.compile(r'-+')
# Minimum token length to include in expansion (avoid noise from single chars)
MIN_TOKEN_LENGTH = 2
# All-caps acronyms pattern (e.g., HTTP, SQL, API)
ALL_CAPS_PATTERN = re.compile(r'^[A-Z]{2,}$')
def __init__(self, enable: bool = True, min_token_length: int = 2):
"""Initialize query parser.
Args:
enable: Whether to enable query preprocessing
min_token_length: Minimum token length to include in expansion
"""
self.enable = enable
self.min_token_length = min_token_length
def preprocess_query(self, query: str) -> str:
"""Preprocess query with identifier expansion.
Args:
query: Original search query
Returns:
Expanded query with OR operator connecting original and split tokens
Example:
>>> parser = QueryParser()
>>> parser.preprocess_query('UserAuth')
'UserAuth OR User OR Auth'
>>> parser.preprocess_query('get_user_data')
'get_user_data OR get OR user OR data'
"""
if not self.enable:
return query
query = query.strip()
if not query:
return query
# Extract tokens from query (handle multiple words/terms)
# For simple queries, just process the whole thing
# For complex FTS5 queries with operators, preserve structure
if self._is_simple_query(query):
return self._expand_simple_query(query)
else:
# Complex query with FTS5 operators, don't expand
log.debug(f"Skipping expansion for complex FTS5 query: {query}")
return query
def _is_simple_query(self, query: str) -> bool:
"""Check if query is simple (no FTS5 operators).
Args:
query: Search query
Returns:
True if query is simple (safe to expand), False otherwise
"""
# Check for FTS5 operators that indicate complex query
fts5_operators = ['OR', 'AND', 'NOT', 'NEAR', '*', '^', '"']
return not any(op in query for op in fts5_operators)
def _expand_simple_query(self, query: str) -> str:
"""Expand a simple query with identifier splitting.
Args:
query: Simple search query
Returns:
Expanded query with OR operators
"""
tokens: Set[str] = set()
# Always include original query
tokens.add(query)
# Split on whitespace first
words = query.split()
for word in words:
# Extract tokens from this word
word_tokens = self._extract_tokens(word)
tokens.update(word_tokens)
# Filter out short tokens and duplicates
filtered_tokens = [
t for t in tokens
if len(t) >= self.min_token_length
]
# Remove duplicates while preserving original query first
unique_tokens: List[str] = []
seen: Set[str] = set()
# Always put original query first
if query not in seen and len(query) >= self.min_token_length:
unique_tokens.append(query)
seen.add(query)
# Add other tokens
for token in filtered_tokens:
if token not in seen:
unique_tokens.append(token)
seen.add(token)
# Join with OR operator (only if we have multiple tokens)
if len(unique_tokens) > 1:
expanded = ' OR '.join(unique_tokens)
log.debug(f"Expanded query: '{query}''{expanded}'")
return expanded
else:
return query
def _extract_tokens(self, word: str) -> Set[str]:
"""Extract tokens from a single word using various splitting strategies.
Args:
word: Single word/identifier to split
Returns:
Set of extracted tokens
"""
tokens: Set[str] = set()
# Add original word
tokens.add(word)
# Handle all-caps acronyms (don't split)
if self.ALL_CAPS_PATTERN.match(word):
return tokens
# CamelCase splitting
camel_tokens = self._split_camel_case(word)
tokens.update(camel_tokens)
# snake_case splitting
snake_tokens = self._split_snake_case(word)
tokens.update(snake_tokens)
# kebab-case splitting
kebab_tokens = self._split_kebab_case(word)
tokens.update(kebab_tokens)
return tokens
def _split_camel_case(self, word: str) -> List[str]:
"""Split CamelCase identifier into tokens.
Args:
word: CamelCase identifier (e.g., 'getUserData')
Returns:
List of tokens (e.g., ['get', 'User', 'Data'])
"""
# Insert space before uppercase letters preceded by lowercase
spaced = self.CAMEL_CASE_PATTERN.sub(r'\1 \2', word)
# Split on spaces and filter empty
return [t for t in spaced.split() if t]
def _split_snake_case(self, word: str) -> List[str]:
"""Split snake_case identifier into tokens.
Args:
word: snake_case identifier (e.g., 'get_user_data')
Returns:
List of tokens (e.g., ['get', 'user', 'data'])
"""
# Split on underscores
return [t for t in self.SNAKE_CASE_PATTERN.split(word) if t]
def _split_kebab_case(self, word: str) -> List[str]:
"""Split kebab-case identifier into tokens.
Args:
word: kebab-case identifier (e.g., 'get-user-data')
Returns:
List of tokens (e.g., ['get', 'user', 'data'])
"""
# Split on hyphens
return [t for t in self.KEBAB_CASE_PATTERN.split(word) if t]
# Global default parser instance
_default_parser = QueryParser(enable=True)
def preprocess_query(query: str, enable: bool = True) -> str:
"""Convenience function for query preprocessing.
Args:
query: Original search query
enable: Whether to enable preprocessing
Returns:
Preprocessed query with identifier expansion
"""
if not enable:
return query
return _default_parser.preprocess_query(query)
__all__ = [
"QueryParser",
"preprocess_query",
]

View File

@@ -0,0 +1,942 @@
"""Ranking algorithms for hybrid search result fusion.
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
"""
from __future__ import annotations
import re
import math
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from codexlens.entities import SearchResult, AdditionalLocation
# Default RRF weights for SPLADE-based hybrid search
DEFAULT_WEIGHTS = {
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
"vector": 0.5,
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
}
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
FTS_FALLBACK_WEIGHTS = {
"exact": 0.25,
"fuzzy": 0.1,
"vector": 0.5,
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
}
class QueryIntent(str, Enum):
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
KEYWORD = "keyword"
SEMANTIC = "semantic"
MIXED = "mixed"
def normalize_weights(weights: Dict[str, float | None]) -> Dict[str, float | None]:
"""Normalize weights to sum to 1.0 (best-effort)."""
total = sum(float(v) for v in weights.values() if v is not None)
# NaN total: do not attempt to normalize (division would propagate NaNs).
if math.isnan(total):
return dict(weights)
# Infinite total: do not attempt to normalize (division yields 0 or NaN).
if not math.isfinite(total):
return dict(weights)
# Zero/negative total: do not attempt to normalize (invalid denominator).
if total <= 0:
return dict(weights)
return {k: (float(v) / total if v is not None else None) for k, v in weights.items()}
def detect_query_intent(query: str) -> QueryIntent:
"""Detect whether a query is code-like, natural-language, or mixed.
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
"""
trimmed = (query or "").strip()
if not trimmed:
return QueryIntent.MIXED
lower = trimmed.lower()
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
has_code_signals = bool(
re.search(r"(::|->|\.)", trimmed)
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
or re.search(r"\b\w+_\w+\b", trimmed)
or re.search(
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
lower,
flags=re.IGNORECASE,
)
)
has_natural_signals = bool(
word_count > 5
or "?" in trimmed
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
or re.search(
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
trimmed,
flags=re.IGNORECASE,
)
)
if has_code_signals and has_natural_signals:
return QueryIntent.MIXED
if has_code_signals:
return QueryIntent.KEYWORD
if has_natural_signals:
return QueryIntent.SEMANTIC
return QueryIntent.MIXED
def adjust_weights_by_intent(
intent: QueryIntent,
base_weights: Dict[str, float],
) -> Dict[str, float]:
"""Adjust RRF weights based on query intent."""
# Check if using SPLADE or FTS mode
use_splade = "splade" in base_weights
if intent == QueryIntent.KEYWORD:
if use_splade:
target = {"splade": 0.6, "vector": 0.4}
else:
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
elif intent == QueryIntent.SEMANTIC:
if use_splade:
target = {"splade": 0.3, "vector": 0.7}
else:
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
else:
target = dict(base_weights)
# Filter to active backends
keys = list(base_weights.keys())
filtered = {k: float(target.get(k, 0.0)) for k in keys}
return normalize_weights(filtered)
def get_rrf_weights(
query: str,
base_weights: Dict[str, float],
) -> Dict[str, float]:
"""Compute adaptive RRF weights from query intent."""
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
# File extensions to category mapping for fast lookup
_EXT_TO_CATEGORY: Dict[str, str] = {
# Code extensions
".py": "code", ".js": "code", ".jsx": "code", ".ts": "code", ".tsx": "code",
".java": "code", ".go": "code", ".zig": "code", ".m": "code", ".mm": "code",
".c": "code", ".h": "code", ".cc": "code", ".cpp": "code", ".hpp": "code", ".cxx": "code",
".rs": "code",
# Doc extensions
".md": "doc", ".mdx": "doc", ".txt": "doc", ".rst": "doc",
}
def get_file_category(path: str) -> Optional[str]:
"""Get file category ('code' or 'doc') from path extension.
Args:
path: File path string
Returns:
'code', 'doc', or None if unknown
"""
ext = Path(path).suffix.lower()
return _EXT_TO_CATEGORY.get(ext)
def filter_results_by_category(
results: List[SearchResult],
intent: QueryIntent,
allow_mixed: bool = True,
) -> List[SearchResult]:
"""Filter results by category based on query intent.
Strategy:
- KEYWORD (code intent): Only return code files
- SEMANTIC (doc intent): Prefer docs, but allow code if allow_mixed=True
- MIXED: Return all results
Args:
results: List of SearchResult objects
intent: Query intent from detect_query_intent()
allow_mixed: If True, SEMANTIC intent includes code files with lower priority
Returns:
Filtered and re-ranked list of SearchResult objects
"""
if not results or intent == QueryIntent.MIXED:
return results
code_results = []
doc_results = []
unknown_results = []
for r in results:
category = get_file_category(r.path)
if category == "code":
code_results.append(r)
elif category == "doc":
doc_results.append(r)
else:
unknown_results.append(r)
if intent == QueryIntent.KEYWORD:
# Code intent: return only code files + unknown (might be code)
filtered = code_results + unknown_results
elif intent == QueryIntent.SEMANTIC:
if allow_mixed:
# Semantic intent with mixed: docs first, then code
filtered = doc_results + code_results + unknown_results
else:
# Semantic intent strict: only docs
filtered = doc_results + unknown_results
else:
filtered = results
return filtered
def simple_weighted_fusion(
results_map: Dict[str, List[SearchResult]],
weights: Dict[str, float] = None,
) -> List[SearchResult]:
"""Combine search results using simple weighted sum of normalized scores.
This is an alternative to RRF that preserves score magnitude information.
Scores are min-max normalized per source before weighted combination.
Formula: score(d) = Σ weight_source * normalized_score_source(d)
Args:
results_map: Dictionary mapping source name to list of SearchResult objects
Sources: 'exact', 'fuzzy', 'vector', 'splade'
weights: Dictionary mapping source name to weight (default: equal weights)
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
Returns:
List of SearchResult objects sorted by fused score (descending)
Examples:
>>> fts_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
>>> vector_results = [SearchResult(path="b.py", score=0.85, excerpt="...")]
>>> results_map = {'exact': fts_results, 'vector': vector_results}
>>> fused = simple_weighted_fusion(results_map)
"""
if not results_map:
return []
# Default equal weights if not provided
if weights is None:
num_sources = len(results_map)
weights = {source: 1.0 / num_sources for source in results_map}
# Normalize weights to sum to 1.0
weight_sum = sum(weights.values())
if not math.isclose(weight_sum, 1.0, abs_tol=0.01) and weight_sum > 0:
weights = {source: w / weight_sum for source, w in weights.items()}
# Compute min-max normalization parameters per source
source_stats: Dict[str, tuple] = {}
for source_name, results in results_map.items():
if not results:
continue
scores = [r.score for r in results]
min_s, max_s = min(scores), max(scores)
source_stats[source_name] = (min_s, max_s)
def normalize_score(score: float, source: str) -> float:
"""Normalize score to [0, 1] range using min-max scaling."""
if source not in source_stats:
return 0.0
min_s, max_s = source_stats[source]
if max_s == min_s:
return 1.0 if score >= min_s else 0.0
return (score - min_s) / (max_s - min_s)
# Build unified result set with weighted scores
path_to_result: Dict[str, SearchResult] = {}
path_to_fusion_score: Dict[str, float] = {}
path_to_source_scores: Dict[str, Dict[str, float]] = {}
for source_name, results in results_map.items():
weight = weights.get(source_name, 0.0)
if weight == 0:
continue
for result in results:
path = result.path
normalized = normalize_score(result.score, source_name)
contribution = weight * normalized
if path not in path_to_fusion_score:
path_to_fusion_score[path] = 0.0
path_to_result[path] = result
path_to_source_scores[path] = {}
path_to_fusion_score[path] += contribution
path_to_source_scores[path][source_name] = normalized
# Create final results with fusion scores
fused_results = []
for path, base_result in path_to_result.items():
fusion_score = path_to_fusion_score[path]
fused_result = SearchResult(
path=base_result.path,
score=fusion_score,
excerpt=base_result.excerpt,
content=base_result.content,
symbol=base_result.symbol,
chunk=base_result.chunk,
metadata={
**base_result.metadata,
"fusion_method": "simple_weighted",
"fusion_score": fusion_score,
"original_score": base_result.score,
"source_scores": path_to_source_scores[path],
},
start_line=base_result.start_line,
end_line=base_result.end_line,
symbol_name=base_result.symbol_name,
symbol_kind=base_result.symbol_kind,
)
fused_results.append(fused_result)
fused_results.sort(key=lambda r: r.score, reverse=True)
return fused_results
def reciprocal_rank_fusion(
results_map: Dict[str, List[SearchResult]],
weights: Dict[str, float] = None,
k: int = 60,
) -> List[SearchResult]:
"""Combine search results from multiple sources using Reciprocal Rank Fusion.
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
Supports three-way fusion with FTS, Vector, and SPLADE sources.
Args:
results_map: Dictionary mapping source name to list of SearchResult objects
Sources: 'exact', 'fuzzy', 'vector', 'splade'
weights: Dictionary mapping source name to weight (default: equal weights)
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
Or: {'splade': 0.4, 'vector': 0.6}
k: Constant to avoid division by zero and control rank influence (default 60)
Returns:
List of SearchResult objects sorted by fused score (descending)
Examples:
>>> exact_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
>>> fused = reciprocal_rank_fusion(results_map)
# Three-way fusion with SPLADE
>>> results_map = {
... 'exact': exact_results,
... 'vector': vector_results,
... 'splade': splade_results
... }
>>> fused = reciprocal_rank_fusion(results_map, k=60)
"""
if not results_map:
return []
# Default equal weights if not provided
if weights is None:
num_sources = len(results_map)
weights = {source: 1.0 / num_sources for source in results_map}
# Validate weights sum to 1.0
weight_sum = sum(weights.values())
if not math.isclose(weight_sum, 1.0, abs_tol=0.01):
# Normalize weights to sum to 1.0
weights = {source: w / weight_sum for source, w in weights.items()}
# Build unified result set with RRF scores
path_to_result: Dict[str, SearchResult] = {}
path_to_fusion_score: Dict[str, float] = {}
path_to_source_ranks: Dict[str, Dict[str, int]] = {}
for source_name, results in results_map.items():
weight = weights.get(source_name, 0.0)
if weight == 0:
continue
for rank, result in enumerate(results, start=1):
path = result.path
rrf_contribution = weight / (k + rank)
# Initialize or accumulate fusion score
if path not in path_to_fusion_score:
path_to_fusion_score[path] = 0.0
path_to_result[path] = result
path_to_source_ranks[path] = {}
path_to_fusion_score[path] += rrf_contribution
path_to_source_ranks[path][source_name] = rank
# Create final results with fusion scores
fused_results = []
for path, base_result in path_to_result.items():
fusion_score = path_to_fusion_score[path]
# Create new SearchResult with fusion_score in metadata
fused_result = SearchResult(
path=base_result.path,
score=fusion_score,
excerpt=base_result.excerpt,
content=base_result.content,
symbol=base_result.symbol,
chunk=base_result.chunk,
metadata={
**base_result.metadata,
"fusion_method": "rrf",
"fusion_score": fusion_score,
"original_score": base_result.score,
"rrf_k": k,
"source_ranks": path_to_source_ranks[path],
},
start_line=base_result.start_line,
end_line=base_result.end_line,
symbol_name=base_result.symbol_name,
symbol_kind=base_result.symbol_kind,
)
fused_results.append(fused_result)
# Sort by fusion score descending
fused_results.sort(key=lambda r: r.score, reverse=True)
return fused_results
def apply_symbol_boost(
results: List[SearchResult],
boost_factor: float = 1.5,
) -> List[SearchResult]:
"""Boost fused scores for results that include an explicit symbol match.
The boost is multiplicative on the current result.score (typically the RRF fusion score).
When boosted, the original score is preserved in metadata["original_fusion_score"] and
metadata["boosted"] is set to True.
"""
if not results:
return []
if boost_factor <= 1.0:
# Still return new objects to follow immutable transformation pattern.
return [
SearchResult(
path=r.path,
score=r.score,
excerpt=r.excerpt,
content=r.content,
symbol=r.symbol,
chunk=r.chunk,
metadata={**r.metadata},
start_line=r.start_line,
end_line=r.end_line,
symbol_name=r.symbol_name,
symbol_kind=r.symbol_kind,
additional_locations=list(r.additional_locations),
)
for r in results
]
boosted_results: List[SearchResult] = []
for result in results:
has_symbol = bool(result.symbol_name)
original_score = float(result.score)
boosted_score = original_score * boost_factor if has_symbol else original_score
metadata = {**result.metadata}
if has_symbol:
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
metadata["boosted"] = True
metadata["symbol_boost_factor"] = boost_factor
boosted_results.append(
SearchResult(
path=result.path,
score=boosted_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata=metadata,
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
boosted_results.sort(key=lambda r: r.score, reverse=True)
return boosted_results
def rerank_results(
query: str,
results: List[SearchResult],
embedder: Any,
top_k: int = 50,
) -> List[SearchResult]:
"""Re-rank results with embedding cosine similarity, combined with current score.
Combined score formula:
0.5 * rrf_score + 0.5 * cosine_similarity
If embedder is None or embedding fails, returns results as-is.
"""
if not results:
return []
if embedder is None or top_k <= 0:
return results
rerank_count = min(int(top_k), len(results))
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
# Defensive: handle mismatched lengths and zero vectors.
n = min(len(vec_a), len(vec_b))
if n == 0:
return 0.0
dot = 0.0
norm_a = 0.0
norm_b = 0.0
for i in range(n):
a = float(vec_a[i])
b = float(vec_b[i])
dot += a * b
norm_a += a * a
norm_b += b * b
if norm_a <= 0.0 or norm_b <= 0.0:
return 0.0
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
return max(0.0, min(1.0, sim))
def text_for_embedding(r: SearchResult) -> str:
if r.excerpt and r.excerpt.strip():
return r.excerpt
if r.content and r.content.strip():
return r.content
if r.chunk and r.chunk.content and r.chunk.content.strip():
return r.chunk.content
# Fallback: stable, non-empty text.
return r.symbol_name or r.path
try:
if hasattr(embedder, "embed_single"):
query_vec = embedder.embed_single(query)
else:
query_vec = embedder.embed(query)[0]
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
doc_vecs = embedder.embed(doc_texts)
except Exception:
return results
reranked_results: List[SearchResult] = []
for idx, result in enumerate(results):
if idx < rerank_count:
rrf_score = float(result.score)
sim = cosine_similarity(query_vec, doc_vecs[idx])
combined_score = 0.5 * rrf_score + 0.5 * sim
reranked_results.append(
SearchResult(
path=result.path,
score=combined_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={
**result.metadata,
"rrf_score": rrf_score,
"cosine_similarity": sim,
"reranked": True,
},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
else:
# Preserve remaining results without re-ranking, but keep immutability.
reranked_results.append(
SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
reranked_results.sort(key=lambda r: r.score, reverse=True)
return reranked_results
def cross_encoder_rerank(
query: str,
results: List[SearchResult],
reranker: Any,
top_k: int = 50,
batch_size: int = 32,
chunk_type_weights: Optional[Dict[str, float]] = None,
test_file_penalty: float = 0.0,
) -> List[SearchResult]:
"""Second-stage reranking using a cross-encoder model.
This function is dependency-agnostic: callers can pass any object that exposes
a compatible `score_pairs(pairs, batch_size=...)` method.
Args:
query: Search query string
results: List of search results to rerank
reranker: Cross-encoder model with score_pairs or predict method
top_k: Number of top results to rerank
batch_size: Batch size for reranking
chunk_type_weights: Optional weights for different chunk types.
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
test_file_penalty: Penalty applied to test files (0.0-1.0).
Example: 0.2 means test files get 20% score reduction
"""
if not results:
return []
if reranker is None or top_k <= 0:
return results
rerank_count = min(int(top_k), len(results))
def text_for_pair(r: SearchResult) -> str:
if r.excerpt and r.excerpt.strip():
return r.excerpt
if r.content and r.content.strip():
return r.content
if r.chunk and r.chunk.content and r.chunk.content.strip():
return r.chunk.content
return r.symbol_name or r.path
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
try:
if hasattr(reranker, "score_pairs"):
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
elif hasattr(reranker, "predict"):
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
else:
return results
except Exception:
return results
if not raw_scores or len(raw_scores) != rerank_count:
return results
scores = [float(s) for s in raw_scores]
min_s = min(scores)
max_s = max(scores)
def sigmoid(x: float) -> float:
# Clamp to keep exp() stable.
x = max(-50.0, min(50.0, x))
return 1.0 / (1.0 + math.exp(-x))
if 0.0 <= min_s and max_s <= 1.0:
probs = scores
else:
probs = [sigmoid(s) for s in scores]
reranked_results: List[SearchResult] = []
# Helper to detect test files
def is_test_file(path: str) -> bool:
if not path:
return False
basename = path.split("/")[-1].split("\\")[-1]
return (
basename.startswith("test_") or
basename.endswith("_test.py") or
basename.endswith(".test.ts") or
basename.endswith(".test.js") or
basename.endswith(".spec.ts") or
basename.endswith(".spec.js") or
"/tests/" in path or
"\\tests\\" in path or
"/test/" in path or
"\\test\\" in path
)
for idx, result in enumerate(results):
if idx < rerank_count:
prev_score = float(result.score)
ce_score = scores[idx]
ce_prob = probs[idx]
# Base combined score
combined_score = 0.5 * prev_score + 0.5 * ce_prob
# Apply chunk_type weight adjustment
if chunk_type_weights:
chunk_type = None
if result.chunk and hasattr(result.chunk, "metadata"):
chunk_type = result.chunk.metadata.get("chunk_type")
elif result.metadata:
chunk_type = result.metadata.get("chunk_type")
if chunk_type and chunk_type in chunk_type_weights:
weight = chunk_type_weights[chunk_type]
# Apply weight to CE contribution only
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
# Apply test file penalty
if test_file_penalty > 0 and is_test_file(result.path):
combined_score = combined_score * (1.0 - test_file_penalty)
reranked_results.append(
SearchResult(
path=result.path,
score=combined_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={
**result.metadata,
"pre_cross_encoder_score": prev_score,
"cross_encoder_score": ce_score,
"cross_encoder_prob": ce_prob,
"cross_encoder_reranked": True,
},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
else:
reranked_results.append(
SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
reranked_results.sort(key=lambda r: r.score, reverse=True)
return reranked_results
def normalize_bm25_score(score: float) -> float:
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
SQLite FTS5 returns negative BM25 scores (more negative = better match).
Uses sigmoid transformation for normalization.
Args:
score: Raw BM25 score from SQLite (typically negative)
Returns:
Normalized score in range [0, 1]
Examples:
>>> normalize_bm25_score(-10.5) # Good match
0.85
>>> normalize_bm25_score(-1.2) # Weak match
0.62
"""
# Take absolute value (BM25 is negative in SQLite)
abs_score = abs(score)
# Sigmoid transformation: 1 / (1 + e^(-x))
# Scale factor of 0.1 maps typical BM25 range (-20 to 0) to (0, 1)
normalized = 1.0 / (1.0 + math.exp(-abs_score * 0.1))
return normalized
def tag_search_source(results: List[SearchResult], source: str) -> List[SearchResult]:
"""Tag search results with their source for RRF tracking.
Args:
results: List of SearchResult objects
source: Source identifier ('exact', 'fuzzy', 'vector')
Returns:
List of SearchResult objects with 'search_source' in metadata
"""
tagged_results = []
for result in results:
tagged_result = SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata, "search_source": source},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
)
tagged_results.append(tagged_result)
return tagged_results
def group_similar_results(
results: List[SearchResult],
score_threshold_abs: float = 0.01,
content_field: str = "excerpt"
) -> List[SearchResult]:
"""Group search results by content and score similarity.
Groups results that have similar content and similar scores into a single
representative result, with other locations stored in additional_locations.
Algorithm:
1. Group results by content (using excerpt or content field)
2. Within each content group, create subgroups based on score similarity
3. Select highest-scoring result as representative for each subgroup
4. Store other results in subgroup as additional_locations
Args:
results: A list of SearchResult objects (typically sorted by score)
score_threshold_abs: Absolute score difference to consider results similar.
Results with |score_a - score_b| <= threshold are grouped.
Default 0.01 is suitable for RRF fusion scores.
content_field: The field to use for content grouping ('excerpt' or 'content')
Returns:
A new list of SearchResult objects where similar items are grouped.
The list is sorted by score descending.
Examples:
>>> results = [SearchResult(path="a.py", score=0.5, excerpt="def foo()"),
... SearchResult(path="b.py", score=0.5, excerpt="def foo()")]
>>> grouped = group_similar_results(results)
>>> len(grouped) # Two results merged into one
1
>>> len(grouped[0].additional_locations) # One additional location
1
"""
if not results:
return []
# Group results by content
content_map: Dict[str, List[SearchResult]] = {}
unidentifiable_results: List[SearchResult] = []
for r in results:
key = getattr(r, content_field, None)
if key and key.strip():
content_map.setdefault(key, []).append(r)
else:
# Results without content can't be grouped by content
unidentifiable_results.append(r)
final_results: List[SearchResult] = []
# Process each content group
for content_group in content_map.values():
# Sort by score descending within group
content_group.sort(key=lambda r: r.score, reverse=True)
while content_group:
# Take highest scoring as representative
representative = content_group.pop(0)
others_in_group = []
remaining_for_next_pass = []
# Find results with similar scores
for item in content_group:
if abs(representative.score - item.score) <= score_threshold_abs:
others_in_group.append(item)
else:
remaining_for_next_pass.append(item)
# Create grouped result with additional locations
if others_in_group:
# Build new result with additional_locations populated
grouped_result = SearchResult(
path=representative.path,
score=representative.score,
excerpt=representative.excerpt,
content=representative.content,
symbol=representative.symbol,
chunk=representative.chunk,
metadata={
**representative.metadata,
"grouped_count": len(others_in_group) + 1,
},
start_line=representative.start_line,
end_line=representative.end_line,
symbol_name=representative.symbol_name,
symbol_kind=representative.symbol_kind,
additional_locations=[
AdditionalLocation(
path=other.path,
score=other.score,
start_line=other.start_line,
end_line=other.end_line,
symbol_name=other.symbol_name,
) for other in others_in_group
],
)
final_results.append(grouped_result)
else:
final_results.append(representative)
content_group = remaining_for_next_pass
# Add ungroupable results
final_results.extend(unidentifiable_results)
# Sort final results by score descending
final_results.sort(key=lambda r: r.score, reverse=True)
return final_results

View File

@@ -0,0 +1,118 @@
"""Optional semantic search module for CodexLens.
Install with: pip install codexlens[semantic]
Uses fastembed (ONNX-based, lightweight ~200MB)
GPU Acceleration:
- Automatic GPU detection and usage when available
- Supports CUDA (NVIDIA), TensorRT, DirectML (Windows), ROCm (AMD), CoreML (Apple)
- Install GPU support: pip install onnxruntime-gpu (NVIDIA) or onnxruntime-directml (Windows)
"""
from __future__ import annotations
SEMANTIC_AVAILABLE = False
SEMANTIC_BACKEND: str | None = None
GPU_AVAILABLE = False
LITELLM_AVAILABLE = False
_import_error: str | None = None
def _detect_backend() -> tuple[bool, str | None, bool, str | None]:
"""Detect if fastembed and GPU are available."""
try:
import numpy as np
except ImportError as e:
return False, None, False, f"numpy not available: {e}"
try:
from fastembed import TextEmbedding
except ImportError:
return False, None, False, "fastembed not available. Install with: pip install codexlens[semantic]"
# Check GPU availability
gpu_available = False
try:
from .gpu_support import is_gpu_available
gpu_available = is_gpu_available()
except ImportError:
pass
return True, "fastembed", gpu_available, None
# Initialize on module load
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, GPU_AVAILABLE, _import_error = _detect_backend()
def check_semantic_available() -> tuple[bool, str | None]:
"""Check if semantic search dependencies are available."""
return SEMANTIC_AVAILABLE, _import_error
def check_gpu_available() -> tuple[bool, str]:
"""Check if GPU acceleration is available.
Returns:
Tuple of (is_available, status_message)
"""
if not SEMANTIC_AVAILABLE:
return False, "Semantic search not available"
try:
from .gpu_support import is_gpu_available, get_gpu_summary
if is_gpu_available():
return True, get_gpu_summary()
return False, "No GPU detected (using CPU)"
except ImportError:
return False, "GPU support module not available"
# Export embedder components
# BaseEmbedder is always available (abstract base class)
from .base import BaseEmbedder
# Factory function for creating embedders
from .factory import get_embedder as get_embedder_factory
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
try:
import ccw_litellm # noqa: F401
from .litellm_embedder import LiteLLMEmbedderWrapper
LITELLM_AVAILABLE = True
except ImportError:
LiteLLMEmbedderWrapper = None
LITELLM_AVAILABLE = False
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific embedding backend can be used.
Notes:
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
- "litellm" requires ccw-litellm to be installed in the same environment.
"""
backend = (backend or "").strip().lower()
if backend == "fastembed":
if SEMANTIC_AVAILABLE:
return True, None
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
if backend == "litellm":
if LITELLM_AVAILABLE:
return True, None
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
__all__ = [
"SEMANTIC_AVAILABLE",
"SEMANTIC_BACKEND",
"GPU_AVAILABLE",
"LITELLM_AVAILABLE",
"check_semantic_available",
"is_embedding_backend_available",
"check_gpu_available",
"BaseEmbedder",
"get_embedder_factory",
"LiteLLMEmbedderWrapper",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
"""Base class for embedders.
Defines the interface that all embedders must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterable
import numpy as np
class BaseEmbedder(ABC):
"""Base class for all embedders.
All embedder implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@property
@abstractmethod
def embedding_dim(self) -> int:
"""Return embedding dimensions.
Returns:
int: Dimension of the embedding vectors.
"""
...
@property
@abstractmethod
def model_name(self) -> str:
"""Return model name.
Returns:
str: Name or identifier of the underlying model.
"""
...
@property
def max_tokens(self) -> int:
"""Return maximum token limit for embeddings.
Returns:
int: Maximum number of tokens that can be embedded at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Embed texts to numpy array.
Args:
texts: Single text or iterable of texts to embed.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
"""
...

View File

@@ -0,0 +1,821 @@
"""Code chunking strategies for semantic search.
This module provides various chunking strategies for breaking down source code
into semantic chunks suitable for embedding and search.
Lightweight Mode:
The ChunkConfig supports a `skip_token_count` option for performance optimization.
When enabled, token counting uses a fast character-based estimation (char/4)
instead of expensive tiktoken encoding.
Use cases for lightweight mode:
- Large-scale indexing where speed is critical
- Scenarios where approximate token counts are acceptable
- Memory-constrained environments
- Initial prototyping and development
Example:
# Default mode (accurate tiktoken encoding)
config = ChunkConfig()
chunker = Chunker(config)
# Lightweight mode (fast char/4 estimation)
config = ChunkConfig(skip_token_count=True)
chunker = Chunker(config)
chunks = chunker.chunk_file(content, symbols, path, language)
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import SemanticChunk, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@dataclass
class ChunkConfig:
"""Configuration for chunking strategies."""
max_chunk_size: int = 1000 # Max characters per chunk
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
strip_comments: bool = True # Remove comments from chunk content for embedding
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
preserve_original: bool = True # Store original content in metadata when stripping
class CommentStripper:
"""Remove comments from source code while preserving structure."""
@staticmethod
def strip_python_comments(content: str) -> str:
"""Strip Python comments (# style) but preserve docstrings.
Args:
content: Python source code
Returns:
Code with comments removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
in_string = False
string_char = None
for line in lines:
new_line = []
i = 0
while i < len(line):
char = line[i]
# Handle string literals
if char in ('"', "'") and not in_string:
# Check for triple quotes
if line[i:i+3] in ('"""', "'''"):
in_string = True
string_char = line[i:i+3]
new_line.append(line[i:i+3])
i += 3
continue
else:
in_string = True
string_char = char
elif in_string:
if string_char and len(string_char) == 3:
if line[i:i+3] == string_char:
in_string = False
new_line.append(line[i:i+3])
i += 3
string_char = None
continue
elif char == string_char:
# Check for escape
if i > 0 and line[i-1] != '\\':
in_string = False
string_char = None
# Handle comments (only outside strings)
if char == '#' and not in_string:
# Rest of line is comment, skip it
new_line.append('\n' if line.endswith('\n') else '')
break
new_line.append(char)
i += 1
result_lines.append(''.join(new_line))
return ''.join(result_lines)
@staticmethod
def strip_c_style_comments(content: str) -> str:
"""Strip C-style comments (// and /* */) from code.
Args:
content: Source code with C-style comments
Returns:
Code with comments removed
"""
result = []
i = 0
in_string = False
string_char = None
in_multiline_comment = False
while i < len(content):
# Handle multi-line comment end
if in_multiline_comment:
if content[i:i+2] == '*/':
in_multiline_comment = False
i += 2
continue
i += 1
continue
char = content[i]
# Handle string literals
if char in ('"', "'", '`') and not in_string:
in_string = True
string_char = char
result.append(char)
i += 1
continue
elif in_string:
result.append(char)
if char == string_char and (i == 0 or content[i-1] != '\\'):
in_string = False
string_char = None
i += 1
continue
# Handle comments
if content[i:i+2] == '//':
# Single line comment - skip to end of line
while i < len(content) and content[i] != '\n':
i += 1
if i < len(content):
result.append('\n')
i += 1
continue
if content[i:i+2] == '/*':
in_multiline_comment = True
i += 2
continue
result.append(char)
i += 1
return ''.join(result)
@classmethod
def strip_comments(cls, content: str, language: str) -> str:
"""Strip comments based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with comments removed
"""
if language == "python":
return cls.strip_python_comments(content)
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
return cls.strip_c_style_comments(content)
return content
class DocstringStripper:
"""Remove docstrings from source code."""
@staticmethod
def strip_python_docstrings(content: str) -> str:
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
Args:
content: Python source code
Returns:
Code with docstrings removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
# Check for docstring start
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
# Single line docstring
if stripped.count(quote_type) >= 2:
# Skip this line (docstring)
i += 1
continue
# Multi-line docstring - skip until closing
i += 1
while i < len(lines):
if quote_type in lines[i]:
i += 1
break
i += 1
continue
result_lines.append(line)
i += 1
return ''.join(result_lines)
@staticmethod
def strip_jsdoc_comments(content: str) -> str:
"""Strip JSDoc comments (/** ... */) from code.
Args:
content: JavaScript/TypeScript source code
Returns:
Code with JSDoc comments removed
"""
result = []
i = 0
in_jsdoc = False
while i < len(content):
if in_jsdoc:
if content[i:i+2] == '*/':
in_jsdoc = False
i += 2
continue
i += 1
continue
# Check for JSDoc start (/** but not /*)
if content[i:i+3] == '/**':
in_jsdoc = True
i += 3
continue
result.append(content[i])
i += 1
return ''.join(result)
@classmethod
def strip_docstrings(cls, content: str, language: str) -> str:
"""Strip docstrings based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with docstrings removed
"""
if language == "python":
return cls.strip_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.strip_jsdoc_comments(content)
return content
class Chunker:
"""Chunk code files for semantic embedding."""
def __init__(self, config: ChunkConfig | None = None) -> None:
self.config = config or ChunkConfig()
self._tokenizer = get_default_tokenizer()
self._comment_stripper = CommentStripper()
self._docstring_stripper = DocstringStripper()
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
"""Process chunk content by stripping comments/docstrings if configured.
Args:
content: Original chunk content
language: Programming language
Returns:
Tuple of (processed_content, original_content_if_preserved)
"""
original = content if self.config.preserve_original else None
processed = content
if self.config.strip_comments:
processed = self._comment_stripper.strip_comments(processed, language)
if self.config.strip_docstrings:
processed = self._docstring_stripper.strip_docstrings(processed, language)
# If nothing changed, don't store original
if processed == content:
original = None
return processed, original
def _estimate_token_count(self, text: str) -> int:
"""Estimate token count based on config.
If skip_token_count is True, uses character-based estimation (char/4).
Otherwise, uses accurate tiktoken encoding.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if self.config.skip_token_count:
# Fast character-based estimation: ~4 chars per token
return max(1, len(text) // 4)
return self._tokenizer.count_tokens(text)
def chunk_by_symbol(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk code by extracted symbols (functions, classes).
Each symbol becomes one chunk with its full content.
Large symbols exceeding max_chunk_size are recursively split using sliding window.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
for symbol in symbols:
start_line, end_line = symbol.range
# Convert to 0-indexed
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
chunk_content = "".join(lines[start_idx:end_idx])
if len(chunk_content.strip()) < self.config.min_chunk_size:
continue
# Check if symbol content exceeds max_chunk_size
if len(chunk_content) > self.config.max_chunk_size:
# Create line mapping for correct line number tracking
line_mapping = list(range(start_line, end_line + 1))
# Use sliding window to split large symbol
sub_chunks = self.chunk_sliding_window(
chunk_content,
file_path=file_path,
language=language,
line_mapping=line_mapping
)
# Update sub_chunks with parent symbol metadata
for sub_chunk in sub_chunks:
sub_chunk.metadata["symbol_name"] = symbol.name
sub_chunk.metadata["symbol_kind"] = symbol.kind
sub_chunk.metadata["strategy"] = "symbol_split"
sub_chunk.metadata["chunk_type"] = "code"
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
chunks.extend(sub_chunks)
else:
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._estimate_token_count(processed_content)
metadata = {
"file": str(file_path),
"language": language,
"symbol_name": symbol.name,
"symbol_kind": symbol.kind,
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=processed_content,
embedding=None,
metadata=metadata
))
return chunks
def chunk_sliding_window(
self,
content: str,
file_path: str | Path,
language: str,
line_mapping: Optional[List[int]] = None,
) -> List[SemanticChunk]:
"""Chunk code using sliding window approach.
Used for files without clear symbol boundaries or very long functions.
Args:
content: Source code content
file_path: Path to source file
language: Programming language
line_mapping: Optional list mapping content line indices to original line numbers
(1-indexed). If provided, line_mapping[i] is the original line number
for the i-th line in content.
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
if not lines:
return chunks
# Calculate lines per chunk based on average line length
avg_line_len = len(content) / max(len(lines), 1)
lines_per_chunk = max(10, int(self.config.max_chunk_size / max(avg_line_len, 1)))
overlap_lines = max(2, int(self.config.overlap / max(avg_line_len, 1)))
# Ensure overlap is less than chunk size to prevent infinite loop
overlap_lines = min(overlap_lines, lines_per_chunk - 1)
start = 0
chunk_idx = 0
while start < len(lines):
end = min(start + lines_per_chunk, len(lines))
chunk_content = "".join(lines[start:end])
if len(chunk_content.strip()) >= self.config.min_chunk_size:
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
# Move window forward
step = lines_per_chunk - overlap_lines
if step <= 0:
step = 1
start += step
continue
token_count = self._estimate_token_count(processed_content)
# Calculate correct line numbers
if line_mapping:
# Use line mapping to get original line numbers
start_line = line_mapping[start]
end_line = line_mapping[end - 1]
else:
# Default behavior: treat content as starting at line 1
start_line = start + 1
end_line = end
metadata = {
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=processed_content,
embedding=None,
metadata=metadata
))
chunk_idx += 1
# Move window, accounting for overlap
step = lines_per_chunk - overlap_lines
if step <= 0:
step = 1 # Failsafe to prevent infinite loop
start += step
# Break if we've reached the end
if end >= len(lines):
break
return chunks
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk a file using the best strategy.
Uses symbol-based chunking if symbols available,
falls back to sliding window for files without symbols.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
if symbols:
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
return self.chunk_sliding_window(content, file_path, language)
class DocstringExtractor:
"""Extract docstrings from source code."""
@staticmethod
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
"""Extract Python docstrings with their line ranges.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
docstrings: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
start_line = i + 1
if stripped.count(quote_type) >= 2:
docstring_content = line
end_line = i + 1
docstrings.append((docstring_content, start_line, end_line))
i += 1
continue
docstring_lines = [line]
i += 1
while i < len(lines):
docstring_lines.append(lines[i])
if quote_type in lines[i]:
break
i += 1
end_line = i + 1
docstring_content = "".join(docstring_lines)
docstrings.append((docstring_content, start_line, end_line))
i += 1
return docstrings
@staticmethod
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
"""Extract JSDoc comments with their line ranges.
Returns: List of (comment_content, start_line, end_line) tuples
"""
comments: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('/**'):
start_line = i + 1
comment_lines = [line]
i += 1
while i < len(lines):
comment_lines.append(lines[i])
if '*/' in lines[i]:
break
i += 1
end_line = i + 1
comment_content = "".join(comment_lines)
comments.append((comment_content, start_line, end_line))
i += 1
return comments
@classmethod
def extract_docstrings(
cls,
content: str,
language: str
) -> List[Tuple[str, int, int]]:
"""Extract docstrings based on language.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
if language == "python":
return cls.extract_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.extract_jsdoc_comments(content)
return []
class HybridChunker:
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
Composition-based strategy that:
1. Extracts docstrings as dedicated chunks
2. For remaining code, uses base chunker (symbol or sliding window)
"""
def __init__(
self,
base_chunker: Chunker | None = None,
config: ChunkConfig | None = None
) -> None:
"""Initialize hybrid chunker.
Args:
base_chunker: Chunker to use for non-docstring content
config: Configuration for chunking
"""
self.config = config or ChunkConfig()
self.base_chunker = base_chunker or Chunker(self.config)
self.docstring_extractor = DocstringExtractor()
def _get_excluded_line_ranges(
self,
docstrings: List[Tuple[str, int, int]]
) -> set[int]:
"""Get set of line numbers that are part of docstrings."""
excluded_lines: set[int] = set()
for _, start_line, end_line in docstrings:
for line_num in range(start_line, end_line + 1):
excluded_lines.add(line_num)
return excluded_lines
def _filter_symbols_outside_docstrings(
self,
symbols: List[Symbol],
excluded_lines: set[int]
) -> List[Symbol]:
"""Filter symbols to exclude those completely within docstrings."""
filtered: List[Symbol] = []
for symbol in symbols:
start_line, end_line = symbol.range
symbol_lines = set(range(start_line, end_line + 1))
if not symbol_lines.issubset(excluded_lines):
filtered.append(symbol)
return filtered
def _find_parent_symbol(
self,
start_line: int,
end_line: int,
symbols: List[Symbol],
) -> Optional[Symbol]:
"""Find the smallest symbol range that fully contains a docstring span."""
candidates: List[Symbol] = []
for symbol in symbols:
sym_start, sym_end = symbol.range
if sym_start <= start_line and end_line <= sym_end:
candidates.append(symbol)
if not candidates:
return None
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk file using hybrid strategy.
Extracts docstrings first, then chunks remaining code.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
# Step 1: Extract docstrings as dedicated chunks
docstrings: List[Tuple[str, int, int]] = []
if language == "python":
# Fast path: avoid expensive docstring extraction if delimiters are absent.
if '"""' in content or "'''" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
elif language in {"javascript", "typescript"}:
if "/**" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
else:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
# Fast path: no docstrings -> delegate to base chunker directly.
if not docstrings:
if symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, symbols, file_path, language, symbol_token_counts
)
else:
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
return base_chunks
for docstring_content, start_line, end_line in docstrings:
if len(docstring_content.strip()) >= self.config.min_chunk_size:
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
# Use base chunker's token estimation method
token_count = self.base_chunker._estimate_token_count(docstring_content)
metadata = {
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
if parent_symbol is not None:
metadata["parent_symbol"] = parent_symbol.name
metadata["parent_symbol_kind"] = parent_symbol.kind
metadata["parent_symbol_range"] = parent_symbol.range
chunks.append(SemanticChunk(
content=docstring_content,
embedding=None,
metadata=metadata
))
# Step 2: Get line ranges occupied by docstrings
excluded_lines = self._get_excluded_line_ranges(docstrings)
# Step 3: Filter symbols to exclude docstring-only ranges
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
# Step 4: Chunk remaining content using base chunker
if filtered_symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, filtered_symbols, file_path, language, symbol_token_counts
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
else:
lines = content.splitlines(keepends=True)
remaining_lines: List[str] = []
for i, line in enumerate(lines, start=1):
if i not in excluded_lines:
remaining_lines.append(line)
if remaining_lines:
remaining_content = "".join(remaining_lines)
if len(remaining_content.strip()) >= self.config.min_chunk_size:
base_chunks = self.base_chunker.chunk_sliding_window(
remaining_content, file_path, language
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
return chunks

View File

@@ -0,0 +1,274 @@
"""Smart code extraction for complete code blocks."""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import SearchResult, Symbol
def extract_complete_code_block(
result: SearchResult,
source_file_path: Optional[str] = None,
context_lines: int = 0,
) -> str:
"""Extract complete code block from a search result.
Args:
result: SearchResult from semantic search.
source_file_path: Optional path to source file for re-reading.
context_lines: Additional lines of context to include above/below.
Returns:
Complete code block as string.
"""
# If we have full content stored, use it
if result.content:
if context_lines == 0:
return result.content
# Need to add context, read from file
# Try to read from source file
file_path = source_file_path or result.path
if not file_path or not Path(file_path).exists():
# Fall back to excerpt
return result.excerpt or ""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
# Get line range
start_line = result.start_line or 1
end_line = result.end_line or len(lines)
# Add context
start_idx = max(0, start_line - 1 - context_lines)
end_idx = min(len(lines), end_line + context_lines)
return "\n".join(lines[start_idx:end_idx])
except Exception:
return result.excerpt or result.content or ""
def extract_symbol_with_context(
file_path: str,
symbol: Symbol,
include_docstring: bool = True,
include_decorators: bool = True,
) -> str:
"""Extract a symbol (function/class) with its docstring and decorators.
Args:
file_path: Path to source file.
symbol: Symbol to extract.
include_docstring: Include docstring if present.
include_decorators: Include decorators/annotations above symbol.
Returns:
Complete symbol code with context.
"""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
start_line, end_line = symbol.range
start_idx = start_line - 1
end_idx = end_line
# Look for decorators above the symbol
if include_decorators and start_idx > 0:
decorator_start = start_idx
# Search backwards for decorators
i = start_idx - 1
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
line = lines[i].strip()
if line.startswith("@"):
decorator_start = i
i -= 1
elif line == "" or line.startswith("#"):
# Skip empty lines and comments, continue looking
i -= 1
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
# JavaScript/Java style comments
decorator_start = i
i -= 1
else:
# Found non-decorator, non-comment line, stop
break
start_idx = decorator_start
return "\n".join(lines[start_idx:end_idx])
except Exception:
return ""
def format_search_result_code(
result: SearchResult,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
highlight_match: bool = False,
) -> str:
"""Format search result code for display.
Args:
result: SearchResult to format.
max_lines: Maximum lines to show (None for all).
show_line_numbers: Include line numbers in output.
highlight_match: Add markers for matched region.
Returns:
Formatted code string.
"""
content = result.content or result.excerpt or ""
if not content:
return ""
lines = content.splitlines()
# Truncate if needed
truncated = False
if max_lines and len(lines) > max_lines:
lines = lines[:max_lines]
truncated = True
# Format with line numbers
if show_line_numbers:
start = result.start_line or 1
formatted_lines = []
for i, line in enumerate(lines):
line_num = start + i
formatted_lines.append(f"{line_num:4d} | {line}")
output = "\n".join(formatted_lines)
else:
output = "\n".join(lines)
if truncated:
output += "\n... (truncated)"
return output
def get_code_block_summary(result: SearchResult) -> str:
"""Get a concise summary of a code block.
Args:
result: SearchResult to summarize.
Returns:
Summary string like "function hello_world (lines 10-25)"
"""
parts = []
if result.symbol_kind:
parts.append(result.symbol_kind)
if result.symbol_name:
parts.append(f"`{result.symbol_name}`")
elif result.excerpt:
# Extract first meaningful identifier
first_line = result.excerpt.split("\n")[0][:50]
parts.append(f'"{first_line}..."')
if result.start_line and result.end_line:
if result.start_line == result.end_line:
parts.append(f"(line {result.start_line})")
else:
parts.append(f"(lines {result.start_line}-{result.end_line})")
if result.path:
file_name = Path(result.path).name
parts.append(f"in {file_name}")
return " ".join(parts) if parts else "unknown code block"
class CodeBlockResult:
"""Enhanced search result with complete code block."""
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
self.result = result
self.source_path = source_path or result.path
self._full_code: Optional[str] = None
@property
def score(self) -> float:
return self.result.score
@property
def path(self) -> str:
return self.result.path
@property
def file_name(self) -> str:
return Path(self.result.path).name
@property
def symbol_name(self) -> Optional[str]:
return self.result.symbol_name
@property
def symbol_kind(self) -> Optional[str]:
return self.result.symbol_kind
@property
def line_range(self) -> Tuple[int, int]:
return (
self.result.start_line or 1,
self.result.end_line or 1
)
@property
def full_code(self) -> str:
"""Get full code block content."""
if self._full_code is None:
self._full_code = extract_complete_code_block(self.result, self.source_path)
return self._full_code
@property
def excerpt(self) -> str:
"""Get short excerpt."""
return self.result.excerpt or ""
@property
def summary(self) -> str:
"""Get code block summary."""
return get_code_block_summary(self.result)
def format(
self,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
) -> str:
"""Format code for display."""
# Use full code if available
display_result = SearchResult(
path=self.result.path,
score=self.result.score,
content=self.full_code,
start_line=self.result.start_line,
end_line=self.result.end_line,
)
return format_search_result_code(
display_result,
max_lines=max_lines,
show_line_numbers=show_line_numbers
)
def __repr__(self) -> str:
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
def enhance_search_results(
results: List[SearchResult],
) -> List[CodeBlockResult]:
"""Enhance search results with complete code block access.
Args:
results: List of SearchResult from semantic search.
Returns:
List of CodeBlockResult with full code access.
"""
return [CodeBlockResult(r) for r in results]

View File

@@ -0,0 +1,288 @@
"""Embedder for semantic code search using fastembed.
Supports GPU acceleration via ONNX execution providers (CUDA, TensorRT, DirectML, ROCm, CoreML).
GPU acceleration is automatic when available, with transparent CPU fallback.
"""
from __future__ import annotations
import gc
import logging
import threading
from typing import Dict, Iterable, List, Optional
import numpy as np
from . import SEMANTIC_AVAILABLE
from .base import BaseEmbedder
from .gpu_support import get_optimal_providers, is_gpu_available, get_gpu_summary, get_selected_device_id
logger = logging.getLogger(__name__)
# Global embedder cache for singleton pattern
_embedder_cache: Dict[str, "Embedder"] = {}
_cache_lock = threading.RLock()
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
"""Get or create a cached Embedder instance (thread-safe singleton).
This function provides significant performance improvement by reusing
Embedder instances across multiple searches, avoiding repeated model
loading overhead (~0.8s per load).
Args:
profile: Model profile ("fast", "code", "multilingual", "balanced")
use_gpu: If True, use GPU acceleration when available (default: True)
Returns:
Cached Embedder instance for the given profile
"""
global _embedder_cache
# Cache key includes GPU preference to support mixed configurations
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
# All cache access is protected by _cache_lock to avoid races with
# clear_embedder_cache() during concurrent access.
with _cache_lock:
embedder = _embedder_cache.get(cache_key)
if embedder is not None:
return embedder
# Create new embedder and cache it
embedder = Embedder(profile=profile, use_gpu=use_gpu)
# Pre-load model to ensure it's ready
embedder._load_model()
_embedder_cache[cache_key] = embedder
# Log GPU status on first embedder creation
if use_gpu and is_gpu_available():
logger.info(f"Embedder initialized with GPU: {get_gpu_summary()}")
elif use_gpu:
logger.debug("GPU not available, using CPU for embeddings")
return embedder
def clear_embedder_cache() -> None:
"""Clear the embedder cache and release ONNX resources.
This method ensures proper cleanup of ONNX model resources to prevent
memory leaks when embedders are no longer needed.
"""
global _embedder_cache
with _cache_lock:
# Release ONNX resources before clearing cache
for embedder in _embedder_cache.values():
if embedder._model is not None:
del embedder._model
embedder._model = None
_embedder_cache.clear()
gc.collect()
class Embedder(BaseEmbedder):
"""Generate embeddings for code chunks using fastembed (ONNX-based).
Supported Model Profiles:
- fast: BAAI/bge-small-en-v1.5 (384 dim) - Fast, lightweight, English-optimized
- code: jinaai/jina-embeddings-v2-base-code (768 dim) - Code-optimized, best for programming languages
- multilingual: intfloat/multilingual-e5-large (1024 dim) - Multilingual + code support
- balanced: mixedbread-ai/mxbai-embed-large-v1 (1024 dim) - High accuracy, general purpose
"""
# Model profiles for different use cases
MODELS = {
"fast": "BAAI/bge-small-en-v1.5", # 384 dim - Fast, lightweight
"code": "jinaai/jina-embeddings-v2-base-code", # 768 dim - Code-optimized
"multilingual": "intfloat/multilingual-e5-large", # 1024 dim - Multilingual
"balanced": "mixedbread-ai/mxbai-embed-large-v1", # 1024 dim - High accuracy
}
# Dimension mapping for each model
MODEL_DIMS = {
"BAAI/bge-small-en-v1.5": 384,
"jinaai/jina-embeddings-v2-base-code": 768,
"intfloat/multilingual-e5-large": 1024,
"mixedbread-ai/mxbai-embed-large-v1": 1024,
}
# Default model (fast profile)
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
DEFAULT_PROFILE = "fast"
def __init__(
self,
model_name: str | None = None,
profile: str | None = None,
use_gpu: bool = True,
providers: List[str] | None = None,
) -> None:
"""Initialize embedder with model or profile.
Args:
model_name: Explicit model name (e.g., "jinaai/jina-embeddings-v2-base-code")
profile: Model profile shortcut ("fast", "code", "multilingual", "balanced")
If both provided, model_name takes precedence.
use_gpu: If True, use GPU acceleration when available (default: True)
providers: Explicit ONNX providers list (overrides use_gpu if provided)
"""
if not SEMANTIC_AVAILABLE:
raise ImportError(
"Semantic search dependencies not available. "
"Install with: pip install codexlens[semantic]"
)
# Resolve model name from profile or use explicit name
if model_name:
self._model_name = model_name
elif profile and profile in self.MODELS:
self._model_name = self.MODELS[profile]
else:
self._model_name = self.DEFAULT_MODEL
# Configure ONNX execution providers with device_id options for GPU selection
# Using with_device_options=True ensures DirectML/CUDA device_id is passed correctly
if providers is not None:
self._providers = providers
else:
self._providers = get_optimal_providers(use_gpu=use_gpu, with_device_options=True)
self._use_gpu = use_gpu
self._model = None
@property
def model_name(self) -> str:
"""Get model name."""
return self._model_name
@property
def embedding_dim(self) -> int:
"""Get embedding dimension for current model."""
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
@property
def max_tokens(self) -> int:
"""Get maximum token limit for current model.
Returns:
int: Maximum number of tokens based on model profile.
- fast: 512 (lightweight, optimized for speed)
- code: 8192 (code-optimized, larger context)
- multilingual: 512 (standard multilingual model)
- balanced: 512 (general purpose)
"""
# Determine profile from model name
profile = None
for prof, model in self.MODELS.items():
if model == self._model_name:
profile = prof
break
# Return token limit based on profile
if profile == "code":
return 8192
elif profile in ("fast", "multilingual", "balanced"):
return 512
else:
# Default for unknown models
return 512
@property
def providers(self) -> List[str]:
"""Get configured ONNX execution providers."""
return self._providers
@property
def is_gpu_enabled(self) -> bool:
"""Check if GPU acceleration is enabled for this embedder."""
gpu_providers = {"CUDAExecutionProvider", "TensorrtExecutionProvider",
"DmlExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"}
# Handle both string providers and tuple providers (name, options)
for p in self._providers:
provider_name = p[0] if isinstance(p, tuple) else p
if provider_name in gpu_providers:
return True
return False
def _load_model(self) -> None:
"""Lazy load the embedding model with configured providers."""
if self._model is not None:
return
from fastembed import TextEmbedding
# providers already include device_id options via get_optimal_providers(with_device_options=True)
# DO NOT pass device_ids separately - fastembed ignores it when providers is specified
# See: fastembed/text/onnx_embedding.py - device_ids is only used with cuda=True
try:
self._model = TextEmbedding(
model_name=self.model_name,
providers=self._providers,
)
logger.debug(f"Model loaded with providers: {self._providers}")
except TypeError:
# Fallback for older fastembed versions without providers parameter
logger.warning(
"fastembed version doesn't support 'providers' parameter. "
"Upgrade fastembed for GPU acceleration: pip install --upgrade fastembed"
)
self._model = TextEmbedding(model_name=self.model_name)
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
"""Generate embeddings for one or more texts.
Args:
texts: Single text or iterable of texts to embed.
Returns:
List of embedding vectors (each is a list of floats).
Note:
This method converts numpy arrays to Python lists for backward compatibility.
For memory-efficient processing, use embed_to_numpy() instead.
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
embeddings = list(self._model.embed(texts))
return [emb.tolist() for emb in embeddings]
def embed_to_numpy(self, texts: str | Iterable[str], batch_size: Optional[int] = None) -> np.ndarray:
"""Generate embeddings for one or more texts (returns numpy arrays).
This method is more memory-efficient than embed() as it avoids converting
numpy arrays to Python lists, which can significantly reduce memory usage
during batch processing.
Args:
texts: Single text or iterable of texts to embed.
batch_size: Optional batch size for fastembed processing.
Larger values improve GPU utilization but use more memory.
Returns:
numpy.ndarray of shape (n_texts, embedding_dim) containing embeddings.
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Pass batch_size to fastembed for optimal GPU utilization
# Default batch_size in fastembed is 256, but larger values can improve throughput
if batch_size is not None:
embeddings = list(self._model.embed(texts, batch_size=batch_size))
else:
embeddings = list(self._model.embed(texts))
return np.array(embeddings)
def embed_single(self, text: str) -> List[float]:
"""Generate embedding for a single text."""
return self.embed(text)[0]

View File

@@ -0,0 +1,158 @@
"""Factory for creating embedders.
Provides a unified interface for instantiating different embedder backends.
Includes caching to avoid repeated model loading overhead.
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Dict, List, Optional
from .base import BaseEmbedder
# Module-level cache for embedder instances
# Key: (backend, profile, model, use_gpu) -> embedder instance
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
_cache_lock = threading.Lock()
_logger = logging.getLogger(__name__)
def get_embedder(
backend: str = "fastembed",
profile: str = "code",
model: str = "default",
use_gpu: bool = True,
endpoints: Optional[List[Dict[str, Any]]] = None,
strategy: str = "latency_aware",
cooldown: float = 60.0,
**kwargs: Any,
) -> BaseEmbedder:
"""Factory function to create embedder based on backend.
Args:
backend: Embedder backend to use. Options:
- "fastembed": Use fastembed (ONNX-based) embedder (default)
- "litellm": Use ccw-litellm embedder
profile: Model profile for fastembed backend ("fast", "code", "multilingual", "balanced")
Used only when backend="fastembed". Default: "code"
model: Model identifier for litellm backend.
Used only when backend="litellm". Default: "default"
use_gpu: Whether to use GPU acceleration when available (default: True).
Used only when backend="fastembed".
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
Each endpoint is a dict with keys: model, api_key, api_base, weight.
Used only when backend="litellm" and multiple endpoints provided.
strategy: Selection strategy for multi-endpoint mode:
"round_robin", "latency_aware", "weighted_random".
Default: "latency_aware"
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
**kwargs: Additional backend-specific arguments
Returns:
BaseEmbedder: Configured embedder instance
Raises:
ValueError: If backend is not recognized
ImportError: If required backend dependencies are not installed
Examples:
Create fastembed embedder with code profile:
>>> embedder = get_embedder(backend="fastembed", profile="code")
Create fastembed embedder with fast profile and CPU only:
>>> embedder = get_embedder(backend="fastembed", profile="fast", use_gpu=False)
Create litellm embedder:
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
Create rotational embedder with multiple endpoints:
>>> endpoints = [
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
... ]
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
"""
# Build cache key from immutable configuration
if backend == "fastembed":
cache_key = ("fastembed", profile, None, use_gpu)
elif backend == "litellm":
# For litellm, use model as part of cache key
# Multi-endpoint mode is not cached as it's more complex
if endpoints and len(endpoints) > 1:
cache_key = None # Skip cache for multi-endpoint
else:
effective_model = endpoints[0]["model"] if endpoints else model
cache_key = ("litellm", None, effective_model, None)
else:
cache_key = None
# Check cache first (thread-safe)
if cache_key is not None:
with _cache_lock:
if cache_key in _embedder_cache:
_logger.debug("Returning cached embedder for %s", cache_key)
return _embedder_cache[cache_key]
# Create new embedder instance
embedder: Optional[BaseEmbedder] = None
if backend == "fastembed":
from .embedder import Embedder
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
elif backend == "litellm":
# Check if multi-endpoint mode is requested
if endpoints and len(endpoints) > 1:
from .rotational_embedder import create_rotational_embedder
# Multi-endpoint is not cached
return create_rotational_embedder(
endpoints_config=endpoints,
strategy=strategy,
default_cooldown=cooldown,
)
elif endpoints and len(endpoints) == 1:
# Single endpoint in list - use it directly
ep = endpoints[0]
ep_kwargs = {**kwargs}
if "api_key" in ep:
ep_kwargs["api_key"] = ep["api_key"]
if "api_base" in ep:
ep_kwargs["api_base"] = ep["api_base"]
from .litellm_embedder import LiteLLMEmbedderWrapper
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
else:
# No endpoints list - use model parameter
from .litellm_embedder import LiteLLMEmbedderWrapper
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
else:
raise ValueError(
f"Unknown backend: {backend}. "
f"Supported backends: 'fastembed', 'litellm'"
)
# Cache the embedder for future use (thread-safe)
if cache_key is not None and embedder is not None:
with _cache_lock:
# Double-check to avoid race condition
if cache_key not in _embedder_cache:
_embedder_cache[cache_key] = embedder
_logger.debug("Cached new embedder for %s", cache_key)
else:
# Another thread created it already, use that one
embedder = _embedder_cache[cache_key]
return embedder # type: ignore
def clear_embedder_cache() -> int:
"""Clear the embedder cache.
Returns:
Number of embedders cleared from cache
"""
with _cache_lock:
count = len(_embedder_cache)
_embedder_cache.clear()
_logger.debug("Cleared %d embedders from cache", count)
return count

View File

@@ -0,0 +1,431 @@
"""GPU acceleration support for semantic embeddings.
This module provides GPU detection, initialization, and fallback handling
for ONNX-based embedding generation.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional
logger = logging.getLogger(__name__)
@dataclass
class GPUDevice:
"""Individual GPU device info."""
device_id: int
name: str
is_discrete: bool # True for discrete GPU (NVIDIA, AMD), False for integrated (Intel UHD)
vendor: str # "nvidia", "amd", "intel", "unknown"
@dataclass
class GPUInfo:
"""GPU availability and configuration info."""
gpu_available: bool = False
cuda_available: bool = False
gpu_count: int = 0
gpu_name: Optional[str] = None
onnx_providers: List[str] = None
devices: List[GPUDevice] = None # List of detected GPU devices
preferred_device_id: Optional[int] = None # Preferred GPU for embedding
def __post_init__(self):
if self.onnx_providers is None:
self.onnx_providers = ["CPUExecutionProvider"]
if self.devices is None:
self.devices = []
_gpu_info_cache: Optional[GPUInfo] = None
def _enumerate_gpus() -> List[GPUDevice]:
"""Enumerate available GPU devices using WMI on Windows.
Returns:
List of GPUDevice with device info, ordered by device_id.
"""
devices = []
try:
import subprocess
import sys
if sys.platform == "win32":
# Use PowerShell to query GPU information via WMI
cmd = [
"powershell", "-NoProfile", "-Command",
"Get-WmiObject Win32_VideoController | Select-Object DeviceID, Name, AdapterCompatibility | ConvertTo-Json"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0 and result.stdout.strip():
import json
gpu_data = json.loads(result.stdout)
# Handle single GPU case (returns dict instead of list)
if isinstance(gpu_data, dict):
gpu_data = [gpu_data]
for idx, gpu in enumerate(gpu_data):
name = gpu.get("Name", "Unknown GPU")
compat = gpu.get("AdapterCompatibility", "").lower()
# Determine vendor
name_lower = name.lower()
if "nvidia" in name_lower or "nvidia" in compat:
vendor = "nvidia"
is_discrete = True
elif "amd" in name_lower or "radeon" in name_lower or "amd" in compat:
vendor = "amd"
is_discrete = True
elif "intel" in name_lower or "intel" in compat:
vendor = "intel"
# Intel UHD/Iris are integrated, Intel Arc is discrete
is_discrete = "arc" in name_lower
else:
vendor = "unknown"
is_discrete = False
devices.append(GPUDevice(
device_id=idx,
name=name,
is_discrete=is_discrete,
vendor=vendor
))
logger.debug(f"Detected GPU {idx}: {name} (vendor={vendor}, discrete={is_discrete})")
except Exception as e:
logger.debug(f"GPU enumeration failed: {e}")
return devices
def _get_preferred_device_id(devices: List[GPUDevice]) -> Optional[int]:
"""Determine the preferred GPU device_id for embedding.
Preference order:
1. NVIDIA discrete GPU (best DirectML/CUDA support)
2. AMD discrete GPU
3. Intel Arc (discrete)
4. Intel integrated (fallback)
Returns:
device_id of preferred GPU, or None to use default.
"""
if not devices:
return None
# Priority: NVIDIA > AMD > Intel Arc > Intel integrated
priority_order = [
("nvidia", True), # NVIDIA discrete
("amd", True), # AMD discrete
("intel", True), # Intel Arc (discrete)
("intel", False), # Intel integrated (fallback)
]
for target_vendor, target_discrete in priority_order:
for device in devices:
if device.vendor == target_vendor and device.is_discrete == target_discrete:
logger.info(f"Preferred GPU: {device.name} (device_id={device.device_id})")
return device.device_id
# If no match, use first device
if devices:
return devices[0].device_id
return None
def detect_gpu(force_refresh: bool = False) -> GPUInfo:
"""Detect available GPU resources for embedding acceleration.
Args:
force_refresh: If True, re-detect GPU even if cached.
Returns:
GPUInfo with detection results.
"""
global _gpu_info_cache
if _gpu_info_cache is not None and not force_refresh:
return _gpu_info_cache
info = GPUInfo()
# Enumerate GPU devices first
info.devices = _enumerate_gpus()
info.gpu_count = len(info.devices)
if info.devices:
# Set preferred device (discrete GPU preferred over integrated)
info.preferred_device_id = _get_preferred_device_id(info.devices)
# Set gpu_name to preferred device name
for dev in info.devices:
if dev.device_id == info.preferred_device_id:
info.gpu_name = dev.name
break
# Check PyTorch CUDA availability (most reliable detection)
try:
import torch
if torch.cuda.is_available():
info.cuda_available = True
info.gpu_available = True
info.gpu_count = torch.cuda.device_count()
if info.gpu_count > 0:
info.gpu_name = torch.cuda.get_device_name(0)
logger.debug(f"PyTorch CUDA detected: {info.gpu_count} GPU(s)")
except ImportError:
logger.debug("PyTorch not available for GPU detection")
# Check ONNX Runtime providers with validation
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
# Build provider list with priority order
providers = []
# Test each provider to ensure it actually works
def test_provider(provider_name: str) -> bool:
"""Test if a provider actually works by creating a dummy session."""
try:
# Create a minimal ONNX model to test provider
import numpy as np
# Simple test: just check if provider can be instantiated
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 4 # Suppress warnings
return True
except Exception:
return False
# CUDA provider (NVIDIA GPU) - check if CUDA runtime is available
if "CUDAExecutionProvider" in available_providers:
# Verify CUDA is actually usable by checking for cuBLAS
cuda_works = False
try:
import ctypes
# Try to load cuBLAS to verify CUDA installation
try:
ctypes.CDLL("cublas64_12.dll")
cuda_works = True
except OSError:
try:
ctypes.CDLL("cublas64_11.dll")
cuda_works = True
except OSError:
pass
except Exception:
pass
if cuda_works:
providers.append("CUDAExecutionProvider")
info.gpu_available = True
logger.debug("ONNX CUDAExecutionProvider available and working")
else:
logger.debug("ONNX CUDAExecutionProvider listed but CUDA runtime not found")
# TensorRT provider (optimized NVIDIA inference)
if "TensorrtExecutionProvider" in available_providers:
# TensorRT requires additional libraries, skip for now
logger.debug("ONNX TensorrtExecutionProvider available (requires TensorRT SDK)")
# DirectML provider (Windows GPU - AMD/Intel/NVIDIA)
if "DmlExecutionProvider" in available_providers:
providers.append("DmlExecutionProvider")
info.gpu_available = True
logger.debug("ONNX DmlExecutionProvider available (DirectML)")
# ROCm provider (AMD GPU on Linux)
if "ROCMExecutionProvider" in available_providers:
providers.append("ROCMExecutionProvider")
info.gpu_available = True
logger.debug("ONNX ROCMExecutionProvider available (AMD)")
# CoreML provider (Apple Silicon)
if "CoreMLExecutionProvider" in available_providers:
providers.append("CoreMLExecutionProvider")
info.gpu_available = True
logger.debug("ONNX CoreMLExecutionProvider available (Apple)")
# Always include CPU as fallback
providers.append("CPUExecutionProvider")
info.onnx_providers = providers
except ImportError:
logger.debug("ONNX Runtime not available")
info.onnx_providers = ["CPUExecutionProvider"]
_gpu_info_cache = info
return info
def get_optimal_providers(use_gpu: bool = True, with_device_options: bool = False) -> list:
"""Get optimal ONNX execution providers based on availability.
Args:
use_gpu: If True, include GPU providers when available.
If False, force CPU-only execution.
with_device_options: If True, return providers as tuples with device_id options
for proper GPU device selection (required for DirectML).
Returns:
List of provider names or tuples (provider_name, options_dict) in priority order.
"""
if not use_gpu:
return ["CPUExecutionProvider"]
gpu_info = detect_gpu()
# Check if GPU was requested but not available - log warning
if not gpu_info.gpu_available:
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
except ImportError:
available_providers = []
logger.warning(
"GPU acceleration was requested, but no supported GPU provider (CUDA, DirectML) "
f"was found. Available providers: {available_providers}. Falling back to CPU."
)
else:
# Log which GPU provider is being used
gpu_providers = [p for p in gpu_info.onnx_providers if p != "CPUExecutionProvider"]
if gpu_providers:
logger.info(f"Using {gpu_providers[0]} for ONNX GPU acceleration")
if not with_device_options:
return gpu_info.onnx_providers
# Build providers with device_id options for GPU providers
device_id = get_selected_device_id()
providers = []
for provider in gpu_info.onnx_providers:
if provider == "DmlExecutionProvider" and device_id is not None:
# DirectML requires device_id in provider_options tuple
providers.append(("DmlExecutionProvider", {"device_id": device_id}))
logger.debug(f"DmlExecutionProvider configured with device_id={device_id}")
elif provider == "CUDAExecutionProvider" and device_id is not None:
# CUDA also supports device_id in provider_options
providers.append(("CUDAExecutionProvider", {"device_id": device_id}))
logger.debug(f"CUDAExecutionProvider configured with device_id={device_id}")
elif provider == "ROCMExecutionProvider" and device_id is not None:
# ROCm supports device_id
providers.append(("ROCMExecutionProvider", {"device_id": device_id}))
logger.debug(f"ROCMExecutionProvider configured with device_id={device_id}")
else:
# CPU and other providers don't need device_id
providers.append(provider)
return providers
def is_gpu_available() -> bool:
"""Check if any GPU acceleration is available."""
return detect_gpu().gpu_available
def get_gpu_summary() -> str:
"""Get human-readable GPU status summary."""
info = detect_gpu()
if not info.gpu_available:
return "GPU: Not available (using CPU)"
parts = []
if info.gpu_name:
parts.append(f"GPU: {info.gpu_name}")
if info.gpu_count > 1:
parts.append(f"({info.gpu_count} devices)")
# Show active providers (excluding CPU fallback)
gpu_providers = [p for p in info.onnx_providers if p != "CPUExecutionProvider"]
if gpu_providers:
parts.append(f"Providers: {', '.join(gpu_providers)}")
return " | ".join(parts) if parts else "GPU: Available"
def clear_gpu_cache() -> None:
"""Clear cached GPU detection info."""
global _gpu_info_cache
_gpu_info_cache = None
# User-selected device ID (overrides auto-detection)
_selected_device_id: Optional[int] = None
def get_gpu_devices() -> List[dict]:
"""Get list of available GPU devices for frontend selection.
Returns:
List of dicts with device info for each GPU.
"""
info = detect_gpu()
devices = []
for dev in info.devices:
devices.append({
"device_id": dev.device_id,
"name": dev.name,
"vendor": dev.vendor,
"is_discrete": dev.is_discrete,
"is_preferred": dev.device_id == info.preferred_device_id,
"is_selected": dev.device_id == get_selected_device_id(),
})
return devices
def get_selected_device_id() -> Optional[int]:
"""Get the user-selected GPU device_id.
Returns:
User-selected device_id, or auto-detected preferred device_id if not set.
"""
global _selected_device_id
if _selected_device_id is not None:
return _selected_device_id
# Fall back to auto-detected preferred device
info = detect_gpu()
return info.preferred_device_id
def set_selected_device_id(device_id: Optional[int]) -> bool:
"""Set the GPU device_id to use for embeddings.
Args:
device_id: GPU device_id to use, or None to use auto-detection.
Returns:
True if device_id is valid, False otherwise.
"""
global _selected_device_id
if device_id is None:
_selected_device_id = None
logger.info("GPU selection reset to auto-detection")
return True
# Validate device_id exists
info = detect_gpu()
valid_ids = [dev.device_id for dev in info.devices]
if device_id in valid_ids:
_selected_device_id = device_id
device_name = next((dev.name for dev in info.devices if dev.device_id == device_id), "Unknown")
logger.info(f"GPU selection set to device {device_id}: {device_name}")
return True
else:
logger.warning(f"Invalid device_id {device_id}. Valid IDs: {valid_ids}")
return False

View File

@@ -0,0 +1,144 @@
"""LiteLLM embedder wrapper for CodexLens.
Provides integration with ccw-litellm's LiteLLMEmbedder for embedding generation.
"""
from __future__ import annotations
from typing import Iterable
import numpy as np
from .base import BaseEmbedder
class LiteLLMEmbedderWrapper(BaseEmbedder):
"""Wrapper for ccw-litellm LiteLLMEmbedder.
This wrapper adapts the ccw-litellm LiteLLMEmbedder to the CodexLens
BaseEmbedder interface, enabling seamless integration with CodexLens
semantic search functionality.
Args:
model: Model identifier for LiteLLM (default: "default")
**kwargs: Additional arguments passed to LiteLLMEmbedder
Raises:
ImportError: If ccw-litellm package is not installed
"""
def __init__(self, model: str = "default", **kwargs) -> None:
"""Initialize LiteLLM embedder wrapper.
Args:
model: Model identifier for LiteLLM (default: "default")
**kwargs: Additional arguments passed to LiteLLMEmbedder
Raises:
ImportError: If ccw-litellm package is not installed
"""
try:
from ccw_litellm import LiteLLMEmbedder
self._embedder = LiteLLMEmbedder(model=model, **kwargs)
except ImportError as e:
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from e
@property
def embedding_dim(self) -> int:
"""Return embedding dimensions from LiteLLMEmbedder.
Returns:
int: Dimension of the embedding vectors.
"""
return self._embedder.dimensions
@property
def model_name(self) -> str:
"""Return model name from LiteLLMEmbedder.
Returns:
str: Name or identifier of the underlying model.
"""
return self._embedder.model_name
@property
def max_tokens(self) -> int:
"""Return maximum token limit for the embedding model.
Returns:
int: Maximum number of tokens that can be embedded at once.
Reads from LiteLLM config's max_input_tokens property.
"""
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
if hasattr(self._embedder, 'max_input_tokens'):
return self._embedder.max_input_tokens
# Fallback: infer from model name
model_name_lower = self.model_name.lower()
# Large models (8B or "large" in name)
if '8b' in model_name_lower or 'large' in model_name_lower:
return 32768
# OpenAI text-embedding-3-* models
if 'text-embedding-3' in model_name_lower:
return 8191
# Default fallback
return 8192
def _sanitize_text(self, text: str) -> str:
"""Sanitize text to work around ModelScope API routing bug.
ModelScope incorrectly routes text starting with lowercase 'import'
to an Ollama endpoint, causing failures. This adds a leading space
to work around the issue without affecting embedding quality.
Args:
text: Text to sanitize.
Returns:
Sanitized text safe for embedding API.
"""
if text.startswith('import'):
return ' ' + text
return text
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
"""Embed texts to numpy array using LiteLLMEmbedder.
Args:
texts: Single text or iterable of texts to embed.
**kwargs: Additional arguments (ignored for LiteLLM backend).
Accepts batch_size for API compatibility with fastembed.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Sanitize texts to avoid ModelScope routing bug
texts = [self._sanitize_text(t) for t in texts]
# LiteLLM handles batching internally, ignore batch_size parameter
return self._embedder.embed(texts)
def embed_single(self, text: str) -> list[float]:
"""Generate embedding for a single text.
Args:
text: Text to embed.
Returns:
list[float]: Embedding vector as a list of floats.
"""
# Sanitize text before embedding
sanitized = self._sanitize_text(text)
embedding = self._embedder.embed([sanitized])
return embedding[0].tolist()

View File

@@ -0,0 +1,25 @@
"""Reranker backends for second-stage search ranking.
This subpackage provides a unified interface and factory for different reranking
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
"""
from __future__ import annotations
from .base import BaseReranker
from .factory import check_reranker_available, get_reranker
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
__all__ = [
"BaseReranker",
"check_reranker_available",
"get_reranker",
"CrossEncoderReranker",
"check_cross_encoder_available",
"FastEmbedReranker",
"check_fastembed_reranker_available",
"ONNXReranker",
"check_onnx_reranker_available",
]

View File

@@ -0,0 +1,403 @@
"""API-based reranker using a remote HTTP provider.
Supported providers:
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
- Cohere: https://api.cohere.ai/v1/rerank
- Jina: https://api.jina.ai/v1/rerank
"""
from __future__ import annotations
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Mapping, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback."""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Try loading from .env files
try:
from codexlens.env_config import get_env
return get_env(key, workspace_root=workspace_root)
except ImportError:
return None
def check_httpx_available() -> tuple[bool, str | None]:
try:
import httpx # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"httpx not available: {exc}. Install with: pip install httpx"
return True, None
class APIReranker(BaseReranker):
"""Reranker backed by a remote reranking HTTP API."""
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
"siliconflow": {
"api_base": "https://api.siliconflow.cn",
"endpoint": "/v1/rerank",
"default_model": "BAAI/bge-reranker-v2-m3",
},
"cohere": {
"api_base": "https://api.cohere.ai",
"endpoint": "/v1/rerank",
"default_model": "rerank-english-v3.0",
},
"jina": {
"api_base": "https://api.jina.ai",
"endpoint": "/v1/rerank",
"default_model": "jina-reranker-v2-base-multilingual",
},
}
def __init__(
self,
*,
provider: str = "siliconflow",
model_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
timeout: float = 30.0,
max_retries: int = 3,
backoff_base_s: float = 0.5,
backoff_max_s: float = 8.0,
env_api_key: str = _DEFAULT_ENV_API_KEY,
workspace_root: Path | str | None = None,
max_input_tokens: int | None = None,
) -> None:
ok, err = check_httpx_available()
if not ok: # pragma: no cover - exercised via factory availability tests
raise ImportError(err)
import httpx
self._workspace_root = Path(workspace_root) if workspace_root else None
self.provider = (provider or "").strip().lower()
if self.provider not in self._PROVIDER_DEFAULTS:
raise ValueError(
f"Unknown reranker provider: {provider}. "
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
)
defaults = self._PROVIDER_DEFAULTS[self.provider]
# Load api_base from env with .env fallback
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
self.endpoint = defaults["endpoint"]
# Load model from env with .env fallback
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
# Load API key from env with .env fallback
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
resolved_key = resolved_key.strip()
if not resolved_key:
raise ValueError(
f"Missing API key for reranker provider '{self.provider}'. "
f"Pass api_key=... or set ${env_api_key}."
)
self._api_key = resolved_key
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
if self.provider == "cohere":
headers.setdefault("Cohere-Version", "2022-12-06")
self._client = httpx.Client(
base_url=self.api_base,
headers=headers,
timeout=self.timeout_s,
)
# Store max_input_tokens with model-aware defaults
if max_input_tokens is not None:
self._max_input_tokens = max_input_tokens
else:
# Infer from model name
model_lower = self.model_name.lower()
if '8b' in model_lower or 'large' in model_lower:
self._max_input_tokens = 32768
else:
self._max_input_tokens = 8192
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking."""
return self._max_input_tokens
def close(self) -> None:
try:
self._client.close()
except Exception: # pragma: no cover - defensive
return
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
if retry_after_s is not None and retry_after_s > 0:
time.sleep(min(float(retry_after_s), self.backoff_max_s))
return
exp = self.backoff_base_s * (2**attempt)
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
time.sleep(min(self.backoff_max_s, exp + jitter))
@staticmethod
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
value = (headers.get("Retry-After") or "").strip()
if not value:
return None
try:
return float(value)
except ValueError:
return None
@staticmethod
def _should_retry_status(status_code: int) -> bool:
return status_code == 429 or 500 <= status_code <= 599
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
response = self._client.post(self.endpoint, json=dict(payload))
except Exception as exc: # httpx is optional at import-time
last_exc = exc
if attempt < self.max_retries:
self._sleep_backoff(attempt)
continue
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' after "
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
) from exc
status = int(getattr(response, "status_code", 0) or 0)
if status >= 400:
body_preview = ""
try:
body_preview = (response.text or "").strip()
except Exception:
body_preview = ""
if len(body_preview) > 300:
body_preview = body_preview[:300] + ""
if self._should_retry_status(status) and attempt < self.max_retries:
retry_after = self._parse_retry_after_seconds(response.headers)
logger.warning(
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
self.api_base,
self.endpoint,
status,
attempt + 1,
self.max_retries + 1,
)
self._sleep_backoff(attempt, retry_after_s=retry_after)
continue
if status in {401, 403}:
raise RuntimeError(
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
"Check your API key."
)
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
f"Response: {body_preview or '<empty>'}"
)
try:
data = response.json()
except Exception as exc:
raise RuntimeError(
f"Rerank response from provider '{self.provider}' is not valid JSON: "
f"{type(exc).__name__}: {exc}"
) from exc
if not isinstance(data, dict):
raise RuntimeError(
f"Rerank response from provider '{self.provider}' must be a JSON object; "
f"got {type(data).__name__}"
)
return data
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
)
@staticmethod
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
if not isinstance(results, list):
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
scores: list[float] = [0.0 for _ in range(expected)]
filled = 0
for item in results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score"))
if idx is None or score is None:
continue
try:
idx_int = int(idx)
score_f = float(score)
except (TypeError, ValueError):
continue
if 0 <= idx_int < expected:
scores[idx_int] = score_f
filled += 1
if filled != expected:
raise RuntimeError(
f"Rerank response contained {filled}/{expected} scored documents; "
"ensure top_n matches the number of documents."
)
return scores
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
payload: dict[str, Any] = {
"model": self.model_name,
"query": query,
"documents": list(documents),
"top_n": len(documents),
"return_documents": False,
}
return payload
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count using fast heuristic.
Uses len(text) // 4 as approximation (~4 chars per token for English).
Not perfectly accurate for all models/languages but sufficient for
batch sizing decisions where exact counts aren't critical.
"""
return len(text) // 4
def _create_token_aware_batches(
self,
query: str,
documents: Sequence[str],
) -> list[list[tuple[int, str]]]:
"""Split documents into batches that fit within token limits.
Uses 90% of max_input_tokens as safety margin.
Each batch includes the query tokens overhead.
"""
max_tokens = int(self._max_input_tokens * 0.9)
query_tokens = self._estimate_tokens(query)
batches: list[list[tuple[int, str]]] = []
current_batch: list[tuple[int, str]] = []
current_tokens = query_tokens # Start with query overhead
for idx, doc in enumerate(documents):
doc_tokens = self._estimate_tokens(doc)
# Warn if single document exceeds token limit (will be truncated by API)
if doc_tokens > max_tokens - query_tokens:
logger.warning(
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
f"(limit: {max_tokens - query_tokens} after query overhead). "
"Document will likely be truncated by the API."
)
# If batch would exceed limit, start new batch
if current_tokens + doc_tokens > max_tokens and current_batch:
batches.append(current_batch)
current_batch = []
current_tokens = query_tokens
current_batch.append((idx, doc))
current_tokens += doc_tokens
if current_batch:
batches.append(current_batch)
return batches
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
if not documents:
return []
# Create token-aware batches
batches = self._create_token_aware_batches(query, documents)
if len(batches) == 1:
# Single batch - original behavior
payload = self._build_payload(query=query, documents=documents)
data = self._request_json(payload)
results = data.get("results")
return self._extract_scores_from_results(results, expected=len(documents))
# Multiple batches - process each and merge results
logger.info(
f"Splitting {len(documents)} documents into {len(batches)} batches "
f"(max_input_tokens: {self._max_input_tokens})"
)
all_scores: list[float] = [0.0] * len(documents)
for batch in batches:
batch_docs = [doc for _, doc in batch]
payload = self._build_payload(query=query, documents=batch_docs)
data = self._request_json(payload)
results = data.get("results")
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
# Map scores back to original indices
for (orig_idx, _), score in zip(batch, batch_scores):
all_scores[orig_idx] = score
return all_scores
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
) -> list[float]:
if not pairs:
return []
grouped: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
grouped.setdefault(str(query), []).append((idx, str(doc)))
scores: list[float] = [0.0 for _ in range(len(pairs))]
for query, items in grouped.items():
documents = [doc for _, doc in items]
query_scores = self._rerank_one_query(query=query, documents=documents)
for (orig_idx, _), score in zip(items, query_scores):
scores[orig_idx] = float(score)
return scores

View File

@@ -0,0 +1,46 @@
"""Base class for rerankers.
Defines the interface that all rerankers must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence
class BaseReranker(ABC):
"""Base class for all rerankers.
All reranker implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking.
Returns:
int: Maximum number of tokens that can be processed at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair).
"""
...

View File

@@ -0,0 +1,159 @@
"""Factory for creating rerankers.
Provides a unified interface for instantiating different reranker backends.
"""
from __future__ import annotations
from typing import Any
from .base import BaseReranker
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific reranker backend can be used.
Notes:
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
- "onnx" redirects to "fastembed" for backward compatibility.
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
- "api" uses a remote reranking HTTP API (requires httpx).
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
"""
backend = (backend or "").strip().lower()
if backend == "legacy":
from .legacy import check_cross_encoder_available
return check_cross_encoder_available()
if backend == "fastembed":
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "onnx":
# Redirect to fastembed for backward compatibility
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "litellm":
try:
import ccw_litellm # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
)
try:
from .litellm_reranker import LiteLLMReranker # noqa: F401
except Exception as exc: # pragma: no cover - defensive
return False, f"LiteLLM reranker backend not available: {exc}"
return True, None
if backend == "api":
from .api_reranker import check_httpx_available
return check_httpx_available()
return False, (
f"Invalid reranker backend: {backend}. "
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
)
def get_reranker(
backend: str = "fastembed",
model_name: str | None = None,
*,
device: str | None = None,
**kwargs: Any,
) -> BaseReranker:
"""Factory function to create reranker based on backend.
Args:
backend: Reranker backend to use. Options:
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
- "onnx": Redirects to fastembed for backward compatibility
- "api": HTTP API backend (remote providers)
- "litellm": LiteLLM backend (LLM-based, for API mode)
- "legacy": sentence-transformers CrossEncoder backend (optional)
model_name: Model identifier for model-based backends. Defaults depend on backend:
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
- onnx: (redirects to fastembed)
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
- litellm: default
device: Optional device string for backends that support it (legacy only).
**kwargs: Additional backend-specific arguments.
Returns:
BaseReranker: Configured reranker instance.
Raises:
ValueError: If backend is not recognized.
ImportError: If required backend dependencies are not installed or backend is unavailable.
"""
backend = (backend or "").strip().lower()
if backend == "fastembed":
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "onnx":
# Redirect to fastembed for backward compatibility
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "legacy":
ok, err = check_reranker_available("legacy")
if not ok:
raise ImportError(err)
from .legacy import CrossEncoderReranker
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
if backend == "litellm":
ok, err = check_reranker_available("litellm")
if not ok:
raise ImportError(err)
from .litellm_reranker import LiteLLMReranker
_ = device # Device selection is not applicable to remote LLM backends.
resolved_model_name = (model_name or "").strip() or "default"
return LiteLLMReranker(model=resolved_model_name, **kwargs)
if backend == "api":
ok, err = check_reranker_available("api")
if not ok:
raise ImportError(err)
from .api_reranker import APIReranker
_ = device # Device selection is not applicable to remote HTTP backends.
resolved_model_name = (model_name or "").strip() or None
return APIReranker(model_name=resolved_model_name, **kwargs)
raise ValueError(
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
)

View File

@@ -0,0 +1,257 @@
"""FastEmbed-based reranker backend.
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
Install:
pip install fastembed>=0.4.0
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
"""Check whether fastembed reranker dependencies are available."""
try:
import fastembed # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
)
try:
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed TextCrossEncoder not available: {exc}. "
"Upgrade with: pip install fastembed>=0.4.0",
)
return True, None
class FastEmbedReranker(BaseReranker):
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
# Alternative models supported by fastembed:
# - "BAAI/bge-reranker-base"
# - "BAAI/bge-reranker-large"
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
cache_dir: str | None = None,
threads: int | None = None,
) -> None:
"""Initialize FastEmbed reranker.
Args:
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
use_gpu: Whether to use GPU acceleration when available.
cache_dir: Optional directory for caching downloaded models.
threads: Optional number of threads for ONNX Runtime.
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.cache_dir = cache_dir
self.threads = threads
self._encoder: Any | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
"""Lazy-load the TextCrossEncoder model."""
if self._encoder is not None:
return
ok, err = check_fastembed_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._encoder is not None:
return
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Determine providers based on GPU preference
providers: list[str] | None = None
if self.use_gpu:
try:
from ..gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
except Exception:
# Fallback: let fastembed decide
providers = None
# Build initialization kwargs
init_kwargs: dict[str, Any] = {}
if self.cache_dir:
init_kwargs["cache_dir"] = self.cache_dir
if self.threads is not None:
init_kwargs["threads"] = self.threads
if providers:
init_kwargs["providers"] = providers
logger.debug(
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
self.model_name,
self.use_gpu,
)
self._encoder = TextCrossEncoder(
model_name=self.model_name,
**init_kwargs,
)
logger.debug("FastEmbed reranker model loaded successfully")
@staticmethod
def _sigmoid(x: float) -> float:
"""Numerically stable sigmoid function."""
if x < -709:
return 0.0
if x > 709:
return 1.0
import math
return 1.0 / (1.0 + math.exp(-x))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair), normalized to [0, 1] range.
"""
if not pairs:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
# For batch scoring of multiple query-doc pairs, we need to process them.
# Group by query for efficiency when same query appears multiple times.
query_to_docs: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
if query not in query_to_docs:
query_to_docs[query] = []
query_to_docs[query].append((idx, doc))
# Score each query group
scores: list[float] = [0.0] * len(pairs)
for query, indexed_docs in query_to_docs.items():
docs = [doc for _, doc in indexed_docs]
indices = [idx for idx, _ in indexed_docs]
try:
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=docs,
batch_size=batch_size,
)
)
# Map scores back to original positions and normalize with sigmoid
for i, raw_score in enumerate(raw_scores):
if i < len(indices):
original_idx = indices[i]
# Normalize score to [0, 1] using stable sigmoid
scores[original_idx] = self._sigmoid(float(raw_score))
except Exception as e:
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
# Leave scores as 0.0 for failed queries
return scores
def rerank(
self,
query: str,
documents: Sequence[str],
*,
top_k: int | None = None,
batch_size: int = 32,
) -> list[tuple[float, str, int]]:
"""Rerank documents for a single query.
This is a convenience method that provides results in ranked order.
Args:
query: The query string.
documents: List of documents to rerank.
top_k: Return only top K results. None returns all.
batch_size: Batch size for scoring.
Returns:
List of (score, document, original_index) tuples, sorted by score descending.
"""
if not documents:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
try:
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=list(documents),
batch_size=batch_size,
)
)
# Convert to our format: (normalized_score, document, original_index)
ranked = []
for idx, raw_score in enumerate(raw_scores):
if idx < len(documents):
# Normalize score to [0, 1] using stable sigmoid
normalized = self._sigmoid(float(raw_score))
ranked.append((normalized, documents[idx], idx))
# Sort by score descending
ranked.sort(key=lambda x: x[0], reverse=True)
if top_k is not None and top_k > 0:
ranked = ranked[:top_k]
return ranked
except Exception as e:
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
return []

View File

@@ -0,0 +1,91 @@
"""Legacy sentence-transformers cross-encoder reranker.
Install with: pip install codexlens[reranker-legacy]
"""
from __future__ import annotations
import logging
import threading
from typing import List, Sequence, Tuple
from .base import BaseReranker
logger = logging.getLogger(__name__)
try:
from sentence_transformers import CrossEncoder as _CrossEncoder
CROSS_ENCODER_AVAILABLE = True
_import_error: str | None = None
except ImportError as exc: # pragma: no cover - optional dependency
_CrossEncoder = None # type: ignore[assignment]
CROSS_ENCODER_AVAILABLE = False
_import_error = str(exc)
def check_cross_encoder_available() -> tuple[bool, str | None]:
if CROSS_ENCODER_AVAILABLE:
return True, None
return (
False,
_import_error
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
)
class CrossEncoderReranker(BaseReranker):
"""Cross-encoder reranker with lazy model loading."""
def __init__(self, model_name: str, *, device: str | None = None) -> None:
self.model_name = (model_name or "").strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.device = (device or "").strip() or None
self._model = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None:
return
ok, err = check_cross_encoder_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None:
return
try:
if self.device:
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
else:
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
except Exception as exc:
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
raise
def score_pairs(
self,
pairs: Sequence[Tuple[str, str]],
*,
batch_size: int = 32,
) -> List[float]:
"""Score (query, doc) pairs using the cross-encoder.
Returns:
List of scores (one per pair) in the model's native scale (usually logits).
"""
if not pairs:
return []
self._load_model()
if self._model is None: # pragma: no cover - defensive
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
return [float(s) for s in scores]

View File

@@ -0,0 +1,214 @@
"""Experimental LiteLLM reranker backend.
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
relevance of a single (query, document) pair per request.
Notes:
- This backend is experimental and may be slow/expensive compared to local
rerankers.
- It relies on `ccw-litellm` for a unified LLM API across providers.
"""
from __future__ import annotations
import json
import logging
import re
import threading
import time
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
def _coerce_score_to_unit_interval(score: float) -> float:
"""Coerce a numeric score into [0, 1].
The prompt asks for a float in [0, 1], but some models may respond with 0-10
or 0-100 scales. This function attempts a conservative normalization.
"""
if 0.0 <= score <= 1.0:
return score
if 0.0 <= score <= 10.0:
return score / 10.0
if 0.0 <= score <= 100.0:
return score / 100.0
return max(0.0, min(1.0, score))
def _extract_score(text: str) -> float | None:
"""Extract a numeric relevance score from an LLM response."""
content = (text or "").strip()
if not content:
return None
# Prefer JSON if present.
if "{" in content and "}" in content:
try:
start = content.index("{")
end = content.rindex("}") + 1
payload = json.loads(content[start:end])
if isinstance(payload, dict) and "score" in payload:
return float(payload["score"])
except Exception:
pass
match = _NUMBER_RE.search(content)
if not match:
return None
try:
return float(match.group(0))
except ValueError:
return None
class LiteLLMReranker(BaseReranker):
"""Experimental reranker that uses a LiteLLM-compatible model.
This reranker scores each (query, doc) pair in isolation (single-pair mode)
to improve prompt reliability across providers.
"""
_SYSTEM_PROMPT = (
"You are a relevance scoring assistant.\n"
"Given a search query and a document snippet, output a single numeric "
"relevance score between 0 and 1.\n\n"
"Scoring guidance:\n"
"- 1.0: The document directly answers the query.\n"
"- 0.5: The document is partially relevant.\n"
"- 0.0: The document is unrelated.\n\n"
"Output requirements:\n"
"- Output ONLY the number (e.g., 0.73).\n"
"- Do not include any other text."
)
def __init__(
self,
model: str = "default",
*,
requests_per_minute: float | None = None,
min_interval_seconds: float | None = None,
default_score: float = 0.0,
max_doc_chars: int = 8000,
**litellm_kwargs: Any,
) -> None:
"""Initialize the reranker.
Args:
model: Model name from ccw-litellm configuration (default: "default").
requests_per_minute: Optional rate limit in requests per minute.
min_interval_seconds: Optional minimum interval between requests. If set,
it takes precedence over requests_per_minute.
default_score: Score to use when an API call fails or parsing fails.
max_doc_chars: Maximum number of document characters to include in the prompt.
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
Raises:
ImportError: If ccw-litellm is not installed.
ValueError: If model is blank.
"""
self.model_name = (model or "").strip()
if not self.model_name:
raise ValueError("model cannot be blank")
self.default_score = float(default_score)
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
if min_interval_seconds is not None:
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
elif requests_per_minute is not None and float(requests_per_minute) > 0:
self._min_interval_seconds = 60.0 / float(requests_per_minute)
else:
self._min_interval_seconds = 0.0
# Prefer deterministic output by default; allow overrides via kwargs.
litellm_kwargs = dict(litellm_kwargs)
litellm_kwargs.setdefault("temperature", 0.0)
litellm_kwargs.setdefault("max_tokens", 16)
try:
from ccw_litellm import ChatMessage, LiteLLMClient
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from exc
self._ChatMessage = ChatMessage
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
self._lock = threading.RLock()
self._last_request_at = 0.0
def _sanitize_text(self, text: str) -> str:
# Keep consistent with LiteLLMEmbedderWrapper workaround.
if text.startswith("import"):
return " " + text
return text
def _rate_limit(self) -> None:
if self._min_interval_seconds <= 0:
return
with self._lock:
now = time.monotonic()
elapsed = now - self._last_request_at
if elapsed < self._min_interval_seconds:
time.sleep(self._min_interval_seconds - elapsed)
self._last_request_at = time.monotonic()
def _build_user_prompt(self, query: str, doc: str) -> str:
sanitized_query = self._sanitize_text(query or "")
sanitized_doc = self._sanitize_text(doc or "")
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
sanitized_doc = sanitized_doc[: self.max_doc_chars]
return (
"Query:\n"
f"{sanitized_query}\n\n"
"Document:\n"
f"{sanitized_doc}\n\n"
"Return the relevance score (0 to 1) as a single number:"
)
def _score_single_pair(self, query: str, doc: str) -> float:
messages = [
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
]
try:
self._rate_limit()
response = self._client.chat(messages)
except Exception as exc:
logger.debug("LiteLLM reranker request failed: %s", exc)
return self.default_score
raw = getattr(response, "content", "") or ""
score = _extract_score(raw)
if score is None:
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
return self.default_score
return _coerce_score_to_unit_interval(float(score))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with per-pair LLM calls."""
if not pairs:
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for i in range(0, len(pairs), bs):
batch = pairs[i : i + bs]
for query, doc in batch:
scores.append(self._score_single_pair(query, doc))
return scores

View File

@@ -0,0 +1,268 @@
"""Optimum + ONNX Runtime reranker backend.
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
classification models. It is designed to run without requiring PyTorch at
runtime by using numpy tensors and ONNX Runtime execution providers.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Iterable, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_onnx_reranker_available() -> tuple[bool, str | None]:
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
try:
import numpy # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
class ONNXReranker(BaseReranker):
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
providers: list[Any] | None = None,
max_length: int | None = None,
) -> None:
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.providers = providers
self.max_length = int(max_length) if max_length is not None else None
self._tokenizer: Any | None = None
self._model: Any | None = None
self._model_input_names: set[str] | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_onnx_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
if self.providers is None:
from ..gpu_support import get_optimal_providers
# Include device_id options for DirectML/CUDA selection when available.
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=True
)
# Some Optimum versions accept `providers`, others accept a single `provider`.
# Prefer passing the full providers list, with a conservative fallback.
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception:
model_kwargs = {}
try:
self._model = ORTModelForSequenceClassification.from_pretrained(
self.model_name,
**model_kwargs,
)
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments.
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache model input names to filter tokenizer outputs defensively.
input_names: set[str] | None = None
for attr in ("input_names", "model_input_names"):
names = getattr(self._model, attr, None)
if isinstance(names, (list, tuple)) and names:
input_names = {str(n) for n in names}
break
if input_names is None:
try:
session = getattr(self._model, "model", None)
if session is not None and hasattr(session, "get_inputs"):
input_names = {i.name for i in session.get_inputs()}
except Exception:
input_names = None
self._model_input_names = input_names
@staticmethod
def _sigmoid(x: "Any") -> "Any":
import numpy as np
x = np.clip(x, -50.0, 50.0)
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _select_relevance_logit(logits: "Any") -> "Any":
import numpy as np
arr = np.asarray(logits)
if arr.ndim == 0:
return arr.reshape(1)
if arr.ndim == 1:
return arr
if arr.ndim >= 2:
# Common cases:
# - Regression: (batch, 1)
# - Binary classification: (batch, 2)
if arr.shape[-1] == 1:
return arr[..., 0]
if arr.shape[-1] == 2:
# Convert 2-logit softmax into a single logit via difference.
return arr[..., 1] - arr[..., 0]
return arr.max(axis=-1)
return arr.reshape(-1)
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
queries = [q for q, _ in batch]
docs = [d for _, d in batch]
tokenizer_kwargs: dict[str, Any] = {
"text": queries,
"text_pair": docs,
"padding": True,
"truncation": True,
"return_tensors": "np",
}
max_len = self.max_length
if max_len is None:
try:
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
if 0 < model_max < 10_000:
max_len = model_max
else:
max_len = 512
except Exception:
max_len = 512
if max_len is not None and max_len > 0:
tokenizer_kwargs["max_length"] = int(max_len)
encoded = self._tokenizer(**tokenizer_kwargs)
inputs = dict(encoded)
# Some models do not accept token_type_ids; filter to known input names if available.
if self._model_input_names:
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
return inputs
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
if self._model is None:
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
outputs = self._model(**inputs)
if hasattr(outputs, "logits"):
return outputs.logits
if isinstance(outputs, dict) and "logits" in outputs:
return outputs["logits"]
if isinstance(outputs, (list, tuple)) and outputs:
return outputs[0]
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
if not pairs:
return []
self._load_model()
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
return []
import numpy as np
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for batch in _iter_batches(list(pairs), bs):
inputs = self._tokenize_batch(batch)
logits = self._forward_logits(inputs)
rel_logits = self._select_relevance_logit(logits)
probs = self._sigmoid(rel_logits)
probs = np.clip(probs, 0.0, 1.0)
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
if len(scores) != len(pairs):
logger.debug(
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
)
return scores[: len(pairs)]
return scores

View File

@@ -0,0 +1,434 @@
"""Rotational embedder for multi-endpoint API load balancing.
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
to maximize throughput while respecting rate limits.
"""
from __future__ import annotations
import logging
import random
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional
import numpy as np
from .base import BaseEmbedder
logger = logging.getLogger(__name__)
class EndpointStatus(Enum):
"""Status of an API endpoint."""
AVAILABLE = "available"
COOLING = "cooling" # Rate limited, temporarily unavailable
FAILED = "failed" # Permanent failure (auth error, etc.)
class SelectionStrategy(Enum):
"""Strategy for selecting endpoints."""
ROUND_ROBIN = "round_robin"
LATENCY_AWARE = "latency_aware"
WEIGHTED_RANDOM = "weighted_random"
@dataclass
class EndpointConfig:
"""Configuration for a single API endpoint."""
model: str
api_key: Optional[str] = None
api_base: Optional[str] = None
weight: float = 1.0 # Higher weight = more requests
max_concurrent: int = 4 # Max concurrent requests to this endpoint
@dataclass
class EndpointState:
"""Runtime state for an endpoint."""
config: EndpointConfig
embedder: Any = None # LiteLLMEmbedderWrapper instance
# Health metrics
status: EndpointStatus = EndpointStatus.AVAILABLE
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
# Performance metrics
total_requests: int = 0
total_failures: int = 0
avg_latency_ms: float = 0.0
last_latency_ms: float = 0.0
# Concurrency tracking
active_requests: int = 0
lock: threading.Lock = field(default_factory=threading.Lock)
def is_available(self) -> bool:
"""Check if endpoint is available for requests."""
if self.status == EndpointStatus.FAILED:
return False
if self.status == EndpointStatus.COOLING:
if time.time() >= self.cooldown_until:
self.status = EndpointStatus.AVAILABLE
return True
return False
return True
def set_cooldown(self, seconds: float) -> None:
"""Put endpoint in cooldown state."""
self.status = EndpointStatus.COOLING
self.cooldown_until = time.time() + seconds
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
def mark_failed(self) -> None:
"""Mark endpoint as permanently failed."""
self.status = EndpointStatus.FAILED
logger.error(f"Endpoint {self.config.model} marked as failed")
def record_success(self, latency_ms: float) -> None:
"""Record successful request."""
self.total_requests += 1
self.last_latency_ms = latency_ms
# Exponential moving average for latency
alpha = 0.3
if self.avg_latency_ms == 0:
self.avg_latency_ms = latency_ms
else:
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
def record_failure(self) -> None:
"""Record failed request."""
self.total_requests += 1
self.total_failures += 1
@property
def health_score(self) -> float:
"""Calculate health score (0-1) based on metrics."""
if not self.is_available():
return 0.0
# Base score from success rate
if self.total_requests > 0:
success_rate = 1 - (self.total_failures / self.total_requests)
else:
success_rate = 1.0
# Latency factor (faster = higher score)
# Normalize: 100ms = 1.0, 1000ms = 0.1
if self.avg_latency_ms > 0:
latency_factor = min(1.0, 100 / self.avg_latency_ms)
else:
latency_factor = 1.0
# Availability factor (less concurrent = more available)
if self.config.max_concurrent > 0:
availability = 1 - (self.active_requests / self.config.max_concurrent)
else:
availability = 1.0
# Combined score with weights
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
class RotationalEmbedder(BaseEmbedder):
"""Embedder that load balances across multiple API endpoints.
Features:
- Intelligent endpoint selection based on latency and health
- Automatic failover on rate limits (429) and server errors
- Cooldown management to respect rate limits
- Thread-safe concurrent request handling
Args:
endpoints: List of endpoint configurations
strategy: Selection strategy (default: latency_aware)
default_cooldown: Default cooldown seconds for rate limits (default: 60)
max_retries: Maximum retry attempts across all endpoints (default: 3)
"""
def __init__(
self,
endpoints: List[EndpointConfig],
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
default_cooldown: float = 60.0,
max_retries: int = 3,
) -> None:
if not endpoints:
raise ValueError("At least one endpoint must be provided")
self.strategy = strategy
self.default_cooldown = default_cooldown
self.max_retries = max_retries
# Initialize endpoint states
self._endpoints: List[EndpointState] = []
self._lock = threading.Lock()
self._round_robin_index = 0
# Create embedder instances for each endpoint
from .litellm_embedder import LiteLLMEmbedderWrapper
for config in endpoints:
# Build kwargs for LiteLLMEmbedderWrapper
kwargs: Dict[str, Any] = {}
if config.api_key:
kwargs["api_key"] = config.api_key
if config.api_base:
kwargs["api_base"] = config.api_base
try:
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
state = EndpointState(config=config, embedder=embedder)
self._endpoints.append(state)
logger.info(f"Initialized endpoint: {config.model}")
except Exception as e:
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
if not self._endpoints:
raise ValueError("Failed to initialize any endpoints")
# Cache embedding properties from first endpoint
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
self._max_tokens = self._endpoints[0].embedder.max_tokens
@property
def embedding_dim(self) -> int:
"""Return embedding dimensions."""
return self._embedding_dim
@property
def model_name(self) -> str:
"""Return model name."""
return self._model_name
@property
def max_tokens(self) -> int:
"""Return maximum token limit."""
return self._max_tokens
@property
def endpoint_count(self) -> int:
"""Return number of configured endpoints."""
return len(self._endpoints)
@property
def available_endpoint_count(self) -> int:
"""Return number of available endpoints."""
return sum(1 for ep in self._endpoints if ep.is_available())
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
"""Get statistics for all endpoints."""
stats = []
for ep in self._endpoints:
stats.append({
"model": ep.config.model,
"status": ep.status.value,
"total_requests": ep.total_requests,
"total_failures": ep.total_failures,
"avg_latency_ms": round(ep.avg_latency_ms, 2),
"health_score": round(ep.health_score, 3),
"active_requests": ep.active_requests,
})
return stats
def _select_endpoint(self) -> Optional[EndpointState]:
"""Select best available endpoint based on strategy."""
available = [ep for ep in self._endpoints if ep.is_available()]
if not available:
return None
if self.strategy == SelectionStrategy.ROUND_ROBIN:
with self._lock:
self._round_robin_index = (self._round_robin_index + 1) % len(available)
return available[self._round_robin_index]
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
# Sort by health score (descending) and pick top candidate
# Add small random factor to prevent thundering herd
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
scored.sort(key=lambda x: x[1], reverse=True)
return scored[0][0]
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
# Weighted random selection based on health scores
scores = [ep.health_score for ep in available]
total = sum(scores)
if total == 0:
return random.choice(available)
weights = [s / total for s in scores]
return random.choices(available, weights=weights, k=1)[0]
return available[0]
def _parse_retry_after(self, error: Exception) -> Optional[float]:
"""Extract Retry-After value from error if available."""
error_str = str(error)
# Try to find Retry-After in error message
import re
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
if match:
return float(match.group(1))
return None
def _is_rate_limit_error(self, error: Exception) -> bool:
"""Check if error is a rate limit error."""
error_str = str(error).lower()
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
def _is_retryable_error(self, error: Exception) -> bool:
"""Check if error is retryable (not auth/config error)."""
error_str = str(error).lower()
# Retryable errors
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
"timeout", "connection", "service unavailable"]):
return True
# Non-retryable errors (auth, config)
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
"unauthorized", "api key"]):
return False
# Default to retryable for unknown errors
return True
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
"""Embed texts using load-balanced endpoint selection.
Args:
texts: Single text or iterable of texts to embed.
**kwargs: Additional arguments passed to underlying embedder.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
Raises:
RuntimeError: If all endpoints fail after retries.
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
last_error: Optional[Exception] = None
tried_endpoints: set = set()
for attempt in range(self.max_retries + 1):
endpoint = self._select_endpoint()
if endpoint is None:
# All endpoints unavailable, wait for shortest cooldown
min_cooldown = min(
(ep.cooldown_until - time.time() for ep in self._endpoints
if ep.status == EndpointStatus.COOLING),
default=self.default_cooldown
)
if min_cooldown > 0 and attempt < self.max_retries:
wait_time = min(min_cooldown, 30) # Cap wait at 30s
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
time.sleep(wait_time)
continue
break
# Track tried endpoints to avoid infinite loops
endpoint_id = id(endpoint)
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
# Already tried all endpoints
break
tried_endpoints.add(endpoint_id)
# Acquire slot
with endpoint.lock:
endpoint.active_requests += 1
try:
start_time = time.time()
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
latency_ms = (time.time() - start_time) * 1000
# Record success
endpoint.record_success(latency_ms)
return result
except Exception as e:
last_error = e
endpoint.record_failure()
if self._is_rate_limit_error(e):
# Rate limited - set cooldown
retry_after = self._parse_retry_after(e) or self.default_cooldown
endpoint.set_cooldown(retry_after)
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
f"cooling for {retry_after}s")
elif not self._is_retryable_error(e):
# Permanent failure (auth error, etc.)
endpoint.mark_failed()
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
else:
# Temporary error - short cooldown
endpoint.set_cooldown(5.0)
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
finally:
with endpoint.lock:
endpoint.active_requests -= 1
# All retries exhausted
available = self.available_endpoint_count
raise RuntimeError(
f"All embedding attempts failed after {self.max_retries + 1} tries. "
f"Available endpoints: {available}/{len(self._endpoints)}. "
f"Last error: {last_error}"
)
def create_rotational_embedder(
endpoints_config: List[Dict[str, Any]],
strategy: str = "latency_aware",
default_cooldown: float = 60.0,
) -> RotationalEmbedder:
"""Factory function to create RotationalEmbedder from config dicts.
Args:
endpoints_config: List of endpoint configuration dicts with keys:
- model: Model identifier (required)
- api_key: API key (optional)
- api_base: API base URL (optional)
- weight: Request weight (optional, default 1.0)
- max_concurrent: Max concurrent requests (optional, default 4)
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
default_cooldown: Default cooldown seconds for rate limits
Returns:
Configured RotationalEmbedder instance
Example config:
endpoints_config = [
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
]
"""
endpoints = []
for cfg in endpoints_config:
endpoints.append(EndpointConfig(
model=cfg["model"],
api_key=cfg.get("api_key"),
api_base=cfg.get("api_base"),
weight=cfg.get("weight", 1.0),
max_concurrent=cfg.get("max_concurrent", 4),
))
strategy_enum = SelectionStrategy[strategy.upper()]
return RotationalEmbedder(
endpoints=endpoints,
strategy=strategy_enum,
default_cooldown=default_cooldown,
)

View File

@@ -0,0 +1,567 @@
"""ONNX-optimized SPLADE sparse encoder for code search.
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
that combine the interpretability of BM25 with neural relevance modeling.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
Install (GPU):
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
"""
from __future__ import annotations
import logging
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def check_splade_available() -> Tuple[bool, Optional[str]]:
"""Check whether SPLADE dependencies are available.
Returns:
Tuple of (available: bool, error_message: Optional[str])
"""
try:
import numpy # noqa: F401
except ImportError as exc:
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc:
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
except ImportError as exc:
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc:
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
# Global cache for SPLADE encoders (singleton pattern)
_splade_cache: Dict[str, "SpladeEncoder"] = {}
_cache_lock = threading.RLock()
def get_splade_encoder(
model_name: str = "naver/splade-cocondenser-ensembledistil",
use_gpu: bool = True,
max_length: int = 512,
sparsity_threshold: float = 0.01,
cache_dir: Optional[str] = None,
) -> "SpladeEncoder":
"""Get or create cached SPLADE encoder (thread-safe singleton).
This function provides significant performance improvement by reusing
SpladeEncoder instances across multiple searches, avoiding repeated model
loading overhead.
Args:
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
Returns:
Cached SpladeEncoder instance for the given configuration
"""
global _splade_cache
# Cache key includes all configuration parameters
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
with _cache_lock:
encoder = _splade_cache.get(cache_key)
if encoder is not None:
return encoder
# Create new encoder and cache it
encoder = SpladeEncoder(
model_name=model_name,
use_gpu=use_gpu,
max_length=max_length,
sparsity_threshold=sparsity_threshold,
cache_dir=cache_dir,
)
# Pre-load model to ensure it's ready
encoder._load_model()
_splade_cache[cache_key] = encoder
return encoder
def clear_splade_cache() -> None:
"""Clear the SPLADE encoder cache and release ONNX resources.
This method ensures proper cleanup of ONNX model resources to prevent
memory leaks when encoders are no longer needed.
"""
global _splade_cache
with _cache_lock:
# Release ONNX resources before clearing cache
for encoder in _splade_cache.values():
if encoder._model is not None:
del encoder._model
encoder._model = None
if encoder._tokenizer is not None:
del encoder._tokenizer
encoder._tokenizer = None
_splade_cache.clear()
class SpladeEncoder:
"""ONNX-optimized SPLADE sparse encoder.
Produces sparse vectors with vocabulary-aligned dimensions.
Output: Dict[int, float] mapping token_id to weight.
SPLADE activation formula:
splade_repr = log(1 + ReLU(logits)) * attention_mask
splade_vec = max_pooling(splade_repr, axis=sequence_length)
References:
- SPLADE: https://arxiv.org/abs/2107.05720
- SPLADE v2: https://arxiv.org/abs/2109.10086
"""
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
def __init__(
self,
model_name: str = DEFAULT_MODEL,
use_gpu: bool = True,
max_length: int = 512,
sparsity_threshold: float = 0.01,
providers: Optional[List[Any]] = None,
cache_dir: Optional[str] = None,
) -> None:
"""Initialize SPLADE encoder.
Args:
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
providers: Explicit ONNX providers list (overrides use_gpu)
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.max_length = int(max_length) if max_length > 0 else 512
self.sparsity_threshold = float(sparsity_threshold)
self.providers = providers
# Setup ONNX cache directory
if cache_dir:
self._cache_dir = Path(cache_dir)
else:
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
self._tokenizer: Any | None = None
self._model: Any | None = None
self._vocab_size: int | None = None
self._lock = threading.RLock()
def _get_local_cache_path(self) -> Path:
"""Get local cache path for this model's ONNX files.
Returns:
Path to the local ONNX cache directory for this model
"""
# Replace / with -- for filesystem-safe naming
safe_name = self.model_name.replace("/", "--")
return self._cache_dir / safe_name
def _load_model(self) -> None:
"""Lazy load ONNX model and tokenizer.
First checks local cache for ONNX model, falling back to
HuggingFace download and conversion if not cached.
"""
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_splade_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForMaskedLM
from transformers import AutoTokenizer
if self.providers is None:
from .gpu_support import get_optimal_providers, get_selected_device_id
# Get providers as pure string list (cache-friendly)
# NOTE: with_device_options=False to avoid tuple-based providers
# which break optimum's caching mechanism
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=False
)
# Get device_id separately for provider_options
self._device_id = get_selected_device_id() if self.use_gpu else None
# Some Optimum versions accept `providers`, others accept a single `provider`
# Prefer passing the full providers list, with a conservative fallback
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
# Pass device_id via provider_options for GPU selection
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
# Build provider_options dict for each GPU provider
provider_options = {}
for p in self.providers:
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
provider_options[p] = {"device_id": self._device_id}
if provider_options:
model_kwargs["provider_options"] = provider_options
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception as e:
logger.debug(f"Failed to inspect ORTModel signature: {e}")
model_kwargs = {}
# Check for local ONNX cache first
local_cache = self._get_local_cache_path()
onnx_model_path = local_cache / "model.onnx"
if onnx_model_path.exists():
# Load from local cache
logger.info(f"Loading SPLADE from local cache: {local_cache}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
str(local_cache),
**model_kwargs,
)
self._tokenizer = AutoTokenizer.from_pretrained(
str(local_cache), use_fast=True
)
self._vocab_size = len(self._tokenizer)
logger.info(
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
)
return
except Exception as e:
logger.warning(f"Failed to load from cache, redownloading: {e}")
# Download and convert from HuggingFace
logger.info(f"Downloading SPLADE model: {self.model_name}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True, # Export to ONNX
**model_kwargs,
)
logger.debug(f"SPLADE model loaded: {self.model_name}")
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True,
)
logger.warning(
"Optimum version doesn't support provider parameters. "
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache vocabulary size
self._vocab_size = len(self._tokenizer)
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
# Save to local cache for future use
try:
local_cache.mkdir(parents=True, exist_ok=True)
self._model.save_pretrained(str(local_cache))
self._tokenizer.save_pretrained(str(local_cache))
logger.info(f"SPLADE model cached to: {local_cache}")
except Exception as e:
logger.warning(f"Failed to cache SPLADE model: {e}")
@staticmethod
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
"""Apply SPLADE activation function to model outputs.
Formula: log(1 + ReLU(logits)) * attention_mask
Args:
logits: Model output logits (batch, seq_len, vocab_size)
attention_mask: Attention mask (batch, seq_len)
Returns:
SPLADE representations (batch, seq_len, vocab_size)
"""
import numpy as np
# ReLU activation
relu_logits = np.maximum(0, logits)
# Log(1 + x) transformation
log_relu = np.log1p(relu_logits)
# Apply attention mask (expand to match vocab dimension)
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
mask_expanded = np.expand_dims(attention_mask, axis=-1)
# Element-wise multiplication
splade_repr = log_relu * mask_expanded
return splade_repr
@staticmethod
def _max_pooling(splade_repr: Any) -> Any:
"""Max pooling over sequence length dimension.
Args:
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
Returns:
Pooled sparse vectors (batch, vocab_size)
"""
import numpy as np
# Max pooling over sequence dimension (axis=1)
return np.max(splade_repr, axis=1)
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
"""Convert dense vector to sparse dictionary.
Args:
dense_vec: Dense vector (vocab_size,)
Returns:
Sparse dictionary {token_id: weight} with weights above threshold
"""
import numpy as np
# Find non-zero indices above threshold
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
# Create sparse dictionary
sparse_dict = {
int(idx): float(dense_vec[idx])
for idx in nonzero_indices
}
return sparse_dict
def warmup(self, text: str = "warmup query") -> None:
"""Warmup the encoder by running a dummy inference.
First-time model inference includes initialization overhead.
Call this method once before the first real search to avoid
latency spikes.
Args:
text: Dummy text for warmup (default: "warmup query")
"""
logger.info("Warming up SPLADE encoder...")
# Trigger model loading and first inference
_ = self.encode_text(text)
logger.info("SPLADE encoder warmup complete")
def encode_text(self, text: str) -> Dict[int, float]:
"""Encode text to sparse vector {token_id: weight}.
Args:
text: Input text to encode
Returns:
Sparse vector as dictionary mapping token_id to weight
"""
self._load_model()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
import numpy as np
# Tokenize input
encoded = self._tokenizer(
text,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="np",
)
# Forward pass through model
outputs = self._model(**encoded)
# Extract logits
if hasattr(outputs, "logits"):
logits = outputs.logits
elif isinstance(outputs, dict) and "logits" in outputs:
logits = outputs["logits"]
elif isinstance(outputs, (list, tuple)) and outputs:
logits = outputs[0]
else:
raise RuntimeError("Unexpected model output format")
# Apply SPLADE activation
attention_mask = encoded["attention_mask"]
splade_repr = self._splade_activation(logits, attention_mask)
# Max pooling over sequence length
splade_vec = self._max_pooling(splade_repr)
# Convert to sparse dictionary (single item batch)
sparse_dict = self._to_sparse_dict(splade_vec[0])
return sparse_dict
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
"""Batch encode texts to sparse vectors.
Args:
texts: List of input texts to encode
batch_size: Batch size for encoding (default: 32)
Returns:
List of sparse vectors as dictionaries
"""
if not texts:
return []
self._load_model()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
import numpy as np
results: List[Dict[int, float]] = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Tokenize batch
encoded = self._tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="np",
)
# Forward pass through model
outputs = self._model(**encoded)
# Extract logits
if hasattr(outputs, "logits"):
logits = outputs.logits
elif isinstance(outputs, dict) and "logits" in outputs:
logits = outputs["logits"]
elif isinstance(outputs, (list, tuple)) and outputs:
logits = outputs[0]
else:
raise RuntimeError("Unexpected model output format")
# Apply SPLADE activation
attention_mask = encoded["attention_mask"]
splade_repr = self._splade_activation(logits, attention_mask)
# Max pooling over sequence length
splade_vecs = self._max_pooling(splade_repr)
# Convert each vector to sparse dictionary
for vec in splade_vecs:
sparse_dict = self._to_sparse_dict(vec)
results.append(sparse_dict)
return results
@property
def vocab_size(self) -> int:
"""Return vocabulary size (~30k for BERT-based models).
Returns:
Vocabulary size (number of tokens in tokenizer)
"""
if self._vocab_size is not None:
return self._vocab_size
self._load_model()
return self._vocab_size or 0
def get_token(self, token_id: int) -> str:
"""Convert token_id to string (for debugging).
Args:
token_id: Token ID to convert
Returns:
Token string
"""
self._load_model()
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded")
return self._tokenizer.decode([token_id])
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
"""Get top-k tokens with highest weights from sparse vector.
Useful for debugging and understanding what the model is focusing on.
Args:
sparse_vec: Sparse vector as {token_id: weight}
top_k: Number of top tokens to return
Returns:
List of (token_string, weight) tuples, sorted by weight descending
"""
self._load_model()
if not sparse_vec:
return []
# Sort by weight descending
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
# Take top-k and convert token_ids to strings
top_items = sorted_items[:top_k]
return [
(self.get_token(token_id), weight)
for token_id, weight in top_items
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
"""Storage backends for CodexLens."""
from __future__ import annotations
from .sqlite_store import SQLiteStore
from .path_mapper import PathMapper
from .registry import RegistryStore, ProjectInfo, DirMapping
from .dir_index import DirIndexStore, SubdirLink, FileEntry
from .index_tree import IndexTreeBuilder, BuildResult, DirBuildResult
from .vector_meta_store import VectorMetadataStore
__all__ = [
# Legacy (workspace-local)
"SQLiteStore",
# Path mapping
"PathMapper",
# Global registry
"RegistryStore",
"ProjectInfo",
"DirMapping",
# Directory index
"DirIndexStore",
"SubdirLink",
"FileEntry",
# Tree builder
"IndexTreeBuilder",
"BuildResult",
"DirBuildResult",
# Vector metadata
"VectorMetadataStore",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
"""Simple filesystem cache helpers."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@dataclass
class FileCache:
"""Caches file mtimes for incremental indexing."""
cache_path: Path
def load_mtime(self, path: Path) -> Optional[float]:
try:
key = self._key_for(path)
record = (self.cache_path / key).read_text(encoding="utf-8")
return float(record)
except Exception:
return None
def store_mtime(self, path: Path, mtime: float) -> None:
self.cache_path.mkdir(parents=True, exist_ok=True)
key = self._key_for(path)
(self.cache_path / key).write_text(str(mtime), encoding="utf-8")
def _key_for(self, path: Path) -> str:
safe = str(path).replace(":", "_").replace("\\", "_").replace("/", "_")
return f"{safe}.mtime"

View File

@@ -0,0 +1,398 @@
"""Global cross-directory symbol index for fast lookups.
Stores symbols for an entire project in a single SQLite database so symbol search
does not require traversing every directory _index.db.
This index is updated incrementally during file indexing (delete+insert per file)
to avoid expensive batch rebuilds.
"""
from __future__ import annotations
import logging
import sqlite3
import threading
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import Symbol
from codexlens.errors import StorageError
class GlobalSymbolIndex:
"""Project-wide symbol index with incremental updates."""
SCHEMA_VERSION = 1
DEFAULT_DB_NAME = "_global_symbols.db"
def __init__(self, db_path: str | Path, project_id: int) -> None:
self.db_path = Path(db_path).resolve()
self.project_id = int(project_id)
self._lock = threading.RLock()
self._conn: Optional[sqlite3.Connection] = None
self.logger = logging.getLogger(__name__)
def initialize(self) -> None:
"""Create database and schema if not exists."""
with self._lock:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = self._get_connection()
current_version = self._get_schema_version(conn)
if current_version > self.SCHEMA_VERSION:
raise StorageError(
f"Database schema version {current_version} is newer than "
f"supported version {self.SCHEMA_VERSION}. "
f"Please update the application or use a compatible database.",
db_path=str(self.db_path),
operation="initialize",
details={
"current_version": current_version,
"supported_version": self.SCHEMA_VERSION,
},
)
if current_version == 0:
self._create_schema(conn)
self._set_schema_version(conn, self.SCHEMA_VERSION)
elif current_version < self.SCHEMA_VERSION:
self._apply_migrations(conn, current_version)
self._set_schema_version(conn, self.SCHEMA_VERSION)
conn.commit()
def close(self) -> None:
"""Close database connection."""
with self._lock:
if self._conn is not None:
try:
self._conn.close()
except Exception:
pass
finally:
self._conn = None
def __enter__(self) -> "GlobalSymbolIndex":
self.initialize()
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
self.close()
def add_symbol(self, symbol: Symbol, file_path: str | Path, index_path: str | Path) -> None:
"""Insert a single symbol (idempotent) for incremental updates."""
file_path_str = str(Path(file_path).resolve())
index_path_str = str(Path(index_path).resolve())
with self._lock:
conn = self._get_connection()
try:
conn.execute(
"""
INSERT INTO global_symbols(
project_id, symbol_name, symbol_kind,
file_path, start_line, end_line, index_path
)
VALUES(?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(
project_id, symbol_name, symbol_kind,
file_path, start_line, end_line
)
DO UPDATE SET
index_path=excluded.index_path
""",
(
self.project_id,
symbol.name,
symbol.kind,
file_path_str,
symbol.range[0],
symbol.range[1],
index_path_str,
),
)
conn.commit()
except sqlite3.DatabaseError as exc:
conn.rollback()
raise StorageError(
f"Failed to add symbol {symbol.name}: {exc}",
db_path=str(self.db_path),
operation="add_symbol",
) from exc
def update_file_symbols(
self,
file_path: str | Path,
symbols: List[Symbol],
index_path: str | Path | None = None,
) -> None:
"""Replace all symbols for a file atomically (delete + insert)."""
file_path_str = str(Path(file_path).resolve())
index_path_str: Optional[str]
if index_path is not None:
index_path_str = str(Path(index_path).resolve())
else:
index_path_str = self._get_existing_index_path(file_path_str)
with self._lock:
conn = self._get_connection()
try:
conn.execute("BEGIN")
conn.execute(
"DELETE FROM global_symbols WHERE project_id=? AND file_path=?",
(self.project_id, file_path_str),
)
if symbols:
if not index_path_str:
raise StorageError(
"index_path is required when inserting symbols for a new file",
db_path=str(self.db_path),
operation="update_file_symbols",
details={"file_path": file_path_str},
)
rows = [
(
self.project_id,
s.name,
s.kind,
file_path_str,
s.range[0],
s.range[1],
index_path_str,
)
for s in symbols
]
conn.executemany(
"""
INSERT INTO global_symbols(
project_id, symbol_name, symbol_kind,
file_path, start_line, end_line, index_path
)
VALUES(?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(
project_id, symbol_name, symbol_kind,
file_path, start_line, end_line
)
DO UPDATE SET
index_path=excluded.index_path
""",
rows,
)
conn.commit()
except sqlite3.DatabaseError as exc:
conn.rollback()
raise StorageError(
f"Failed to update symbols for {file_path_str}: {exc}",
db_path=str(self.db_path),
operation="update_file_symbols",
) from exc
def delete_file_symbols(self, file_path: str | Path) -> int:
"""Remove all symbols for a file. Returns number of rows deleted."""
file_path_str = str(Path(file_path).resolve())
with self._lock:
conn = self._get_connection()
try:
cur = conn.execute(
"DELETE FROM global_symbols WHERE project_id=? AND file_path=?",
(self.project_id, file_path_str),
)
conn.commit()
return int(cur.rowcount or 0)
except sqlite3.DatabaseError as exc:
conn.rollback()
raise StorageError(
f"Failed to delete symbols for {file_path_str}: {exc}",
db_path=str(self.db_path),
operation="delete_file_symbols",
) from exc
def search(
self,
name: str,
kind: Optional[str] = None,
limit: int = 50,
prefix_mode: bool = True,
) -> List[Symbol]:
"""Search symbols and return full Symbol objects."""
if prefix_mode:
pattern = f"{name}%"
else:
pattern = f"%{name}%"
with self._lock:
conn = self._get_connection()
if kind:
rows = conn.execute(
"""
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
FROM global_symbols
WHERE project_id=? AND symbol_name LIKE ? AND symbol_kind=?
ORDER BY symbol_name
LIMIT ?
""",
(self.project_id, pattern, kind, limit),
).fetchall()
else:
rows = conn.execute(
"""
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
FROM global_symbols
WHERE project_id=? AND symbol_name LIKE ?
ORDER BY symbol_name
LIMIT ?
""",
(self.project_id, pattern, limit),
).fetchall()
return [
Symbol(
name=row["symbol_name"],
kind=row["symbol_kind"],
range=(row["start_line"], row["end_line"]),
file=row["file_path"],
)
for row in rows
]
def search_symbols(
self,
name: str,
kind: Optional[str] = None,
limit: int = 50,
prefix_mode: bool = True,
) -> List[Tuple[str, Tuple[int, int]]]:
"""Search symbols and return only (file_path, (start_line, end_line))."""
symbols = self.search(name=name, kind=kind, limit=limit, prefix_mode=prefix_mode)
return [(s.file or "", s.range) for s in symbols]
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
"""Get all symbols in a specific file, sorted by start_line.
Args:
file_path: Full path to the file
Returns:
List of Symbol objects sorted by start_line
"""
file_path_str = str(Path(file_path).resolve())
with self._lock:
conn = self._get_connection()
rows = conn.execute(
"""
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
FROM global_symbols
WHERE project_id=? AND file_path=?
ORDER BY start_line
""",
(self.project_id, file_path_str),
).fetchall()
return [
Symbol(
name=row["symbol_name"],
kind=row["symbol_kind"],
range=(row["start_line"], row["end_line"]),
file=row["file_path"],
)
for row in rows
]
def _get_existing_index_path(self, file_path_str: str) -> Optional[str]:
with self._lock:
conn = self._get_connection()
row = conn.execute(
"""
SELECT index_path
FROM global_symbols
WHERE project_id=? AND file_path=?
LIMIT 1
""",
(self.project_id, file_path_str),
).fetchone()
return str(row["index_path"]) if row else None
def _get_schema_version(self, conn: sqlite3.Connection) -> int:
try:
row = conn.execute("PRAGMA user_version").fetchone()
return int(row[0]) if row else 0
except Exception:
return 0
def _set_schema_version(self, conn: sqlite3.Connection, version: int) -> None:
conn.execute(f"PRAGMA user_version = {int(version)}")
def _apply_migrations(self, conn: sqlite3.Connection, from_version: int) -> None:
# No migrations yet (v1).
_ = (conn, from_version)
return
def _get_connection(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA synchronous=NORMAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._conn.execute("PRAGMA mmap_size=30000000000")
return self._conn
def _create_schema(self, conn: sqlite3.Connection) -> None:
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS global_symbols (
id INTEGER PRIMARY KEY,
project_id INTEGER NOT NULL,
symbol_name TEXT NOT NULL,
symbol_kind TEXT NOT NULL,
file_path TEXT NOT NULL,
start_line INTEGER,
end_line INTEGER,
index_path TEXT NOT NULL,
UNIQUE(
project_id, symbol_name, symbol_kind,
file_path, start_line, end_line
)
)
"""
)
# Required by optimization spec.
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_global_symbols_name_kind
ON global_symbols(symbol_name, symbol_kind)
"""
)
# Used by common queries (project-scoped name lookups).
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_name_kind
ON global_symbols(project_id, symbol_name, symbol_kind)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_file
ON global_symbols(project_id, file_path)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_global_symbols_project_index_path
ON global_symbols(project_id, index_path)
"""
)
except sqlite3.DatabaseError as exc:
raise StorageError(
f"Failed to initialize global symbol schema: {exc}",
db_path=str(self.db_path),
operation="_create_schema",
) from exc

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,136 @@
"""Merkle tree utilities for change detection.
This module provides a generic, file-system based Merkle tree implementation
that can be used to efficiently diff directory states.
"""
from __future__ import annotations
import hashlib
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional
def sha256_bytes(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
def sha256_text(text: str) -> str:
return sha256_bytes(text.encode("utf-8", errors="ignore"))
@dataclass
class MerkleNode:
"""A Merkle node representing either a file (leaf) or directory (internal)."""
name: str
rel_path: str
hash: str
is_dir: bool
children: Dict[str, "MerkleNode"] = field(default_factory=dict)
def iter_files(self) -> Iterable["MerkleNode"]:
if not self.is_dir:
yield self
return
for child in self.children.values():
yield from child.iter_files()
@dataclass
class MerkleTree:
"""Merkle tree for a directory snapshot."""
root: MerkleNode
@classmethod
def build_from_directory(cls, root_dir: Path) -> "MerkleTree":
root_dir = Path(root_dir).resolve()
node = cls._build_node(root_dir, base=root_dir)
return cls(root=node)
@classmethod
def _build_node(cls, path: Path, *, base: Path) -> MerkleNode:
if path.is_file():
rel = str(path.relative_to(base)).replace("\\", "/")
return MerkleNode(
name=path.name,
rel_path=rel,
hash=sha256_bytes(path.read_bytes()),
is_dir=False,
)
if not path.is_dir():
rel = str(path.relative_to(base)).replace("\\", "/")
return MerkleNode(name=path.name, rel_path=rel, hash="", is_dir=False)
children: Dict[str, MerkleNode] = {}
for child in sorted(path.iterdir(), key=lambda p: p.name):
child_node = cls._build_node(child, base=base)
children[child_node.name] = child_node
items = [
f"{'d' if n.is_dir else 'f'}:{name}:{n.hash}"
for name, n in sorted(children.items(), key=lambda kv: kv[0])
]
dir_hash = sha256_text("\n".join(items))
rel_path = "." if path == base else str(path.relative_to(base)).replace("\\", "/")
return MerkleNode(
name="." if path == base else path.name,
rel_path=rel_path,
hash=dir_hash,
is_dir=True,
children=children,
)
@staticmethod
def find_changed_files(old: Optional["MerkleTree"], new: Optional["MerkleTree"]) -> List[str]:
"""Find changed/added/removed files between two trees.
Returns:
List of relative file paths (POSIX-style separators).
"""
if old is None and new is None:
return []
if old is None:
return sorted({n.rel_path for n in new.root.iter_files()}) # type: ignore[union-attr]
if new is None:
return sorted({n.rel_path for n in old.root.iter_files()})
changed: set[str] = set()
def walk(old_node: Optional[MerkleNode], new_node: Optional[MerkleNode]) -> None:
if old_node is None and new_node is None:
return
if old_node is None and new_node is not None:
changed.update(n.rel_path for n in new_node.iter_files())
return
if new_node is None and old_node is not None:
changed.update(n.rel_path for n in old_node.iter_files())
return
assert old_node is not None and new_node is not None
if old_node.hash == new_node.hash:
return
if not old_node.is_dir and not new_node.is_dir:
changed.add(new_node.rel_path)
return
if old_node.is_dir != new_node.is_dir:
changed.update(n.rel_path for n in old_node.iter_files())
changed.update(n.rel_path for n in new_node.iter_files())
return
names = set(old_node.children.keys()) | set(new_node.children.keys())
for name in names:
walk(old_node.children.get(name), new_node.children.get(name))
walk(old.root, new.root)
return sorted(changed)

View File

@@ -0,0 +1,154 @@
"""
Manages database schema migrations.
This module provides a framework for applying versioned migrations to the SQLite
database. Migrations are discovered from the `codexlens.storage.migrations`
package and applied sequentially. The database schema version is tracked using
the `user_version` pragma.
"""
import importlib
import logging
import pkgutil
from pathlib import Path
from sqlite3 import Connection
from typing import List, NamedTuple
log = logging.getLogger(__name__)
class Migration(NamedTuple):
"""Represents a single database migration."""
version: int
name: str
upgrade: callable
def discover_migrations() -> List[Migration]:
"""
Discovers and returns a sorted list of database migrations.
Migrations are expected to be in the `codexlens.storage.migrations` package,
with filenames in the format `migration_XXX_description.py`, where XXX is
the version number. Each migration module must contain an `upgrade` function
that takes a `sqlite3.Connection` object as its argument.
Returns:
A list of Migration objects, sorted by version.
"""
import codexlens.storage.migrations
migrations = []
package_path = Path(codexlens.storage.migrations.__file__).parent
for _, name, _ in pkgutil.iter_modules([str(package_path)]):
if name.startswith("migration_"):
try:
version = int(name.split("_")[1])
module = importlib.import_module(f"codexlens.storage.migrations.{name}")
if hasattr(module, "upgrade"):
migrations.append(
Migration(version=version, name=name, upgrade=module.upgrade)
)
else:
log.warning(f"Migration {name} is missing 'upgrade' function.")
except (ValueError, IndexError) as e:
log.warning(f"Could not parse migration name {name}: {e}")
except ImportError as e:
log.warning(f"Could not import migration {name}: {e}")
migrations.sort(key=lambda m: m.version)
return migrations
class MigrationManager:
"""
Manages the application of migrations to a database.
"""
def __init__(self, db_conn: Connection):
"""
Initializes the MigrationManager.
Args:
db_conn: The SQLite database connection.
"""
self.db_conn = db_conn
self.migrations = discover_migrations()
def get_current_version(self) -> int:
"""
Gets the current version of the database schema.
Returns:
The current schema version number.
"""
return self.db_conn.execute("PRAGMA user_version").fetchone()[0]
def set_version(self, version: int):
"""
Sets the database schema version.
Args:
version: The version number to set.
"""
self.db_conn.execute(f"PRAGMA user_version = {version}")
log.info(f"Database schema version set to {version}")
def apply_migrations(self):
"""
Applies all pending migrations to the database.
This method checks the current database version and applies all
subsequent migrations in order. Each migration is applied within
a transaction, unless the migration manages its own transactions.
"""
current_version = self.get_current_version()
log.info(f"Current database schema version: {current_version}")
for migration in self.migrations:
if migration.version > current_version:
log.info(f"Applying migration {migration.version}: {migration.name}...")
try:
# Check if a transaction is already in progress
in_transaction = self.db_conn.in_transaction
# Only start transaction if not already in one
if not in_transaction:
self.db_conn.execute("BEGIN")
migration.upgrade(self.db_conn)
self.set_version(migration.version)
# Only commit if we started the transaction and it's still active
if not in_transaction and self.db_conn.in_transaction:
self.db_conn.execute("COMMIT")
log.info(
f"Successfully applied migration {migration.version}: {migration.name}"
)
except Exception as e:
log.error(
f"Failed to apply migration {migration.version}: {migration.name}. Error: {e}",
exc_info=True,
)
# Try to rollback if transaction is active
try:
if self.db_conn.in_transaction:
self.db_conn.execute("ROLLBACK")
except Exception:
pass # Ignore rollback errors
raise
latest_migration_version = self.migrations[-1].version if self.migrations else 0
if current_version < latest_migration_version:
# This case can be hit if migrations were applied but the loop was exited
# and set_version was not called for the last one for some reason.
# To be safe, we explicitly set the version to the latest known migration.
final_version = self.get_current_version()
if final_version != latest_migration_version:
log.warning(f"Database version ({final_version}) is not the latest migration version ({latest_migration_version}). This may indicate a problem.")
log.info("All pending migrations applied successfully.")

View File

@@ -0,0 +1 @@
# This file makes the 'migrations' directory a Python package.

View File

@@ -0,0 +1,123 @@
"""
Migration 001: Normalize keywords into separate tables.
This migration introduces two new tables, `keywords` and `file_keywords`, to
store semantic keywords in a normalized fashion. It then migrates the existing
keywords from the `semantic_data` JSON blob in the `files` table into these
new tables. This is intended to speed up keyword-based searches significantly.
"""
import json
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to normalize keywords.
- Creates `keywords` and `file_keywords` tables.
- Creates indexes for efficient querying.
- Migrates data from `files.semantic_data` to the new tables.
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Creating 'keywords' and 'file_keywords' tables...")
# Create a table to store unique keywords
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS keywords (
id INTEGER PRIMARY KEY,
keyword TEXT NOT NULL UNIQUE
)
"""
)
# Create a join table to link files and keywords (many-to-many)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS file_keywords (
file_id INTEGER NOT NULL,
keyword_id INTEGER NOT NULL,
PRIMARY KEY (file_id, keyword_id),
FOREIGN KEY (file_id) REFERENCES files (id) ON DELETE CASCADE,
FOREIGN KEY (keyword_id) REFERENCES keywords (id) ON DELETE CASCADE
)
"""
)
log.info("Creating indexes for new keyword tables...")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON keywords (keyword)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_file_id ON file_keywords (file_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_keyword_id ON file_keywords (keyword_id)")
log.info("Migrating existing keywords from 'semantic_metadata' table...")
# Check if semantic_metadata table exists before querying
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_metadata'")
if not cursor.fetchone():
log.info("No 'semantic_metadata' table found, skipping data migration.")
return
# Check if 'keywords' column exists in semantic_metadata table
# (current schema may already use normalized tables without this column)
cursor.execute("PRAGMA table_info(semantic_metadata)")
columns = {row[1] for row in cursor.fetchall()}
if "keywords" not in columns:
log.info("No 'keywords' column in semantic_metadata table, skipping data migration.")
return
cursor.execute("SELECT file_id, keywords FROM semantic_metadata WHERE keywords IS NOT NULL AND keywords != ''")
files_to_migrate = cursor.fetchall()
if not files_to_migrate:
log.info("No existing files with semantic metadata to migrate.")
return
log.info(f"Found {len(files_to_migrate)} files with semantic metadata to migrate.")
for file_id, keywords_json in files_to_migrate:
if not keywords_json:
continue
try:
keywords = json.loads(keywords_json)
if not isinstance(keywords, list):
log.warning(f"Keywords for file_id {file_id} is not a list, skipping.")
continue
for keyword in keywords:
if not isinstance(keyword, str):
log.warning(f"Non-string keyword '{keyword}' found for file_id {file_id}, skipping.")
continue
keyword = keyword.strip()
if not keyword:
continue
# Get or create keyword_id
cursor.execute("INSERT OR IGNORE INTO keywords (keyword) VALUES (?)", (keyword,))
cursor.execute("SELECT id FROM keywords WHERE keyword = ?", (keyword,))
keyword_id_result = cursor.fetchone()
if keyword_id_result:
keyword_id = keyword_id_result[0]
# Link file to keyword
cursor.execute(
"INSERT OR IGNORE INTO file_keywords (file_id, keyword_id) VALUES (?, ?)",
(file_id, keyword_id),
)
else:
log.error(f"Failed to retrieve or create keyword_id for keyword: {keyword}")
except json.JSONDecodeError as e:
log.warning(f"Could not parse keywords for file_id {file_id}: {e}")
except Exception as e:
log.error(f"An unexpected error occurred during migration for file_id {file_id}: {e}", exc_info=True)
log.info("Finished migrating keywords.")

View File

@@ -0,0 +1,48 @@
"""
Migration 002: Add token_count and symbol_type to symbols table.
This migration adds token counting metadata to symbols for accurate chunk
splitting and performance optimization. It also adds symbol_type for better
filtering in searches.
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to add token metadata to symbols.
- Adds token_count column to symbols table
- Adds symbol_type column to symbols table (for future use)
- Creates index on symbol_type for efficient filtering
- Backfills existing symbols with NULL token_count (to be calculated lazily)
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Adding token_count column to symbols table...")
try:
cursor.execute("ALTER TABLE symbols ADD COLUMN token_count INTEGER")
log.info("Successfully added token_count column.")
except Exception as e:
# Column might already exist
log.warning(f"Could not add token_count column (might already exist): {e}")
log.info("Adding symbol_type column to symbols table...")
try:
cursor.execute("ALTER TABLE symbols ADD COLUMN symbol_type TEXT")
log.info("Successfully added symbol_type column.")
except Exception as e:
# Column might already exist
log.warning(f"Could not add symbol_type column (might already exist): {e}")
log.info("Creating index on symbol_type for efficient filtering...")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
log.info("Migration 002 completed successfully.")

View File

@@ -0,0 +1,232 @@
"""
Migration 004: Add dual FTS tables for exact and fuzzy matching.
This migration introduces two FTS5 tables:
- files_fts_exact: Uses unicode61 tokenizer for exact token matching
- files_fts_fuzzy: Uses trigram tokenizer (or extended unicode61) for substring/fuzzy matching
Both tables are synchronized with the files table via triggers for automatic updates.
"""
import logging
from sqlite3 import Connection
from codexlens.storage.sqlite_utils import check_trigram_support, get_sqlite_version
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to add dual FTS tables.
- Drops old files_fts table and triggers
- Creates files_fts_exact with unicode61 tokenizer
- Creates files_fts_fuzzy with trigram or extended unicode61 tokenizer
- Creates synchronized triggers for both tables
- Rebuilds FTS indexes from files table
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
try:
# Check trigram support
has_trigram = check_trigram_support(db_conn)
version = get_sqlite_version(db_conn)
log.info(f"SQLite version: {'.'.join(map(str, version))}")
if has_trigram:
log.info("Trigram tokenizer available, using for fuzzy FTS table")
fuzzy_tokenizer = "trigram"
else:
log.warning(
f"Trigram tokenizer not available (requires SQLite >= 3.34), "
f"using extended unicode61 tokenizer for fuzzy matching"
)
fuzzy_tokenizer = "unicode61 tokenchars '_-.'"
# Start transaction
cursor.execute("BEGIN TRANSACTION")
# Check if files table has 'name' column (v2 schema doesn't have it)
cursor.execute("PRAGMA table_info(files)")
columns = {row[1] for row in cursor.fetchall()}
if 'name' not in columns:
log.info("Adding 'name' column to files table (v2 schema upgrade)...")
# Add name column
cursor.execute("ALTER TABLE files ADD COLUMN name TEXT")
# Populate name from path (extract filename from last '/')
# Use Python to do the extraction since SQLite doesn't have reverse()
cursor.execute("SELECT rowid, path FROM files")
rows = cursor.fetchall()
for rowid, path in rows:
# Extract filename from path
name = path.split('/')[-1] if '/' in path else path
cursor.execute("UPDATE files SET name = ? WHERE rowid = ?", (name, rowid))
# Rename 'path' column to 'full_path' if needed
if 'path' in columns and 'full_path' not in columns:
log.info("Renaming 'path' to 'full_path' (v2 schema upgrade)...")
# Check if indexed_at column exists in v2 schema
has_indexed_at = 'indexed_at' in columns
has_mtime = 'mtime' in columns
# SQLite doesn't support RENAME COLUMN before 3.25, so use table recreation
cursor.execute("""
CREATE TABLE files_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
full_path TEXT NOT NULL UNIQUE,
content TEXT,
language TEXT,
mtime REAL,
indexed_at TEXT
)
""")
# Build INSERT statement based on available columns
# Note: v2 schema has no rowid (path is PRIMARY KEY), so use NULL for AUTOINCREMENT
if has_indexed_at and has_mtime:
cursor.execute("""
INSERT INTO files_new (name, full_path, content, language, mtime, indexed_at)
SELECT name, path, content, language, mtime, indexed_at FROM files
""")
elif has_indexed_at:
cursor.execute("""
INSERT INTO files_new (name, full_path, content, language, indexed_at)
SELECT name, path, content, language, indexed_at FROM files
""")
elif has_mtime:
cursor.execute("""
INSERT INTO files_new (name, full_path, content, language, mtime)
SELECT name, path, content, language, mtime FROM files
""")
else:
cursor.execute("""
INSERT INTO files_new (name, full_path, content, language)
SELECT name, path, content, language FROM files
""")
cursor.execute("DROP TABLE files")
cursor.execute("ALTER TABLE files_new RENAME TO files")
log.info("Dropping old FTS triggers and table...")
# Drop old triggers
cursor.execute("DROP TRIGGER IF EXISTS files_ai")
cursor.execute("DROP TRIGGER IF EXISTS files_ad")
cursor.execute("DROP TRIGGER IF EXISTS files_au")
# Drop old FTS table
cursor.execute("DROP TABLE IF EXISTS files_fts")
# Create exact FTS table (unicode61 with underscores/hyphens/dots as token chars)
# Note: tokenchars includes '.' to properly tokenize qualified names like PortRole.FLOW
log.info("Creating files_fts_exact table with unicode61 tokenizer...")
cursor.execute(
"""
CREATE VIRTUAL TABLE files_fts_exact USING fts5(
name, full_path UNINDEXED, content,
content='files',
content_rowid='id',
tokenize="unicode61 tokenchars '_-.'"
)
"""
)
# Create fuzzy FTS table (trigram or extended unicode61)
log.info(f"Creating files_fts_fuzzy table with {fuzzy_tokenizer} tokenizer...")
cursor.execute(
f"""
CREATE VIRTUAL TABLE files_fts_fuzzy USING fts5(
name, full_path UNINDEXED, content,
content='files',
content_rowid='id',
tokenize="{fuzzy_tokenizer}"
)
"""
)
# Create synchronized triggers for files_fts_exact
log.info("Creating triggers for files_fts_exact...")
cursor.execute(
"""
CREATE TRIGGER files_exact_ai AFTER INSERT ON files BEGIN
INSERT INTO files_fts_exact(rowid, name, full_path, content)
VALUES(new.id, new.name, new.full_path, new.content);
END
"""
)
cursor.execute(
"""
CREATE TRIGGER files_exact_ad AFTER DELETE ON files BEGIN
INSERT INTO files_fts_exact(files_fts_exact, rowid, name, full_path, content)
VALUES('delete', old.id, old.name, old.full_path, old.content);
END
"""
)
cursor.execute(
"""
CREATE TRIGGER files_exact_au AFTER UPDATE ON files BEGIN
INSERT INTO files_fts_exact(files_fts_exact, rowid, name, full_path, content)
VALUES('delete', old.id, old.name, old.full_path, old.content);
INSERT INTO files_fts_exact(rowid, name, full_path, content)
VALUES(new.id, new.name, new.full_path, new.content);
END
"""
)
# Create synchronized triggers for files_fts_fuzzy
log.info("Creating triggers for files_fts_fuzzy...")
cursor.execute(
"""
CREATE TRIGGER files_fuzzy_ai AFTER INSERT ON files BEGIN
INSERT INTO files_fts_fuzzy(rowid, name, full_path, content)
VALUES(new.id, new.name, new.full_path, new.content);
END
"""
)
cursor.execute(
"""
CREATE TRIGGER files_fuzzy_ad AFTER DELETE ON files BEGIN
INSERT INTO files_fts_fuzzy(files_fts_fuzzy, rowid, name, full_path, content)
VALUES('delete', old.id, old.name, old.full_path, old.content);
END
"""
)
cursor.execute(
"""
CREATE TRIGGER files_fuzzy_au AFTER UPDATE ON files BEGIN
INSERT INTO files_fts_fuzzy(files_fts_fuzzy, rowid, name, full_path, content)
VALUES('delete', old.id, old.name, old.full_path, old.content);
INSERT INTO files_fts_fuzzy(rowid, name, full_path, content)
VALUES(new.id, new.name, new.full_path, new.content);
END
"""
)
# Rebuild FTS indexes from files table
log.info("Rebuilding FTS indexes from files table...")
cursor.execute("INSERT INTO files_fts_exact(files_fts_exact) VALUES('rebuild')")
cursor.execute("INSERT INTO files_fts_fuzzy(files_fts_fuzzy) VALUES('rebuild')")
# Commit transaction
cursor.execute("COMMIT")
log.info("Migration 004 completed successfully")
# Vacuum to reclaim space (outside transaction)
try:
log.info("Running VACUUM to reclaim space...")
cursor.execute("VACUUM")
except Exception as e:
log.warning(f"VACUUM failed (non-critical): {e}")
except Exception as e:
log.error(f"Migration 004 failed: {e}")
try:
cursor.execute("ROLLBACK")
except Exception:
pass
raise

View File

@@ -0,0 +1,196 @@
"""
Migration 005: Remove unused and redundant database fields.
This migration removes four problematic fields identified by Gemini analysis:
1. **semantic_metadata.keywords** (deprecated - replaced by file_keywords table)
- Data: Migrated to normalized file_keywords table in migration 001
- Impact: Column now redundant, remove to prevent sync issues
2. **symbols.token_count** (unused - always NULL)
- Data: Never populated, always NULL
- Impact: No data loss, just removes unused column
3. **symbols.symbol_type** (redundant - duplicates kind)
- Data: Redundant with symbols.kind field
- Impact: No data loss, kind field contains same information
4. **subdirs.direct_files** (unused - never displayed)
- Data: Never used in queries or display logic
- Impact: No data loss, just removes unused column
Schema changes use table recreation pattern (SQLite best practice):
- Create new table without deprecated columns
- Copy data from old table
- Drop old table
- Rename new table
- Recreate indexes
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""Remove unused and redundant fields from schema.
Note: Transaction management is handled by MigrationManager.
This migration should NOT start its own transaction.
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
# Step 1: Remove semantic_metadata.keywords (if column exists)
log.info("Checking semantic_metadata.keywords column...")
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_metadata'"
)
if cursor.fetchone():
# Check if keywords column exists
cursor.execute("PRAGMA table_info(semantic_metadata)")
columns = {row[1] for row in cursor.fetchall()}
if "keywords" in columns:
log.info("Removing semantic_metadata.keywords column...")
cursor.execute("""
CREATE TABLE semantic_metadata_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL UNIQUE,
summary TEXT,
purpose TEXT,
llm_tool TEXT,
generated_at REAL,
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
)
""")
cursor.execute("""
INSERT INTO semantic_metadata_new (id, file_id, summary, purpose, llm_tool, generated_at)
SELECT id, file_id, summary, purpose, llm_tool, generated_at
FROM semantic_metadata
""")
cursor.execute("DROP TABLE semantic_metadata")
cursor.execute("ALTER TABLE semantic_metadata_new RENAME TO semantic_metadata")
# Recreate index
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_semantic_file ON semantic_metadata(file_id)"
)
log.info("Removed semantic_metadata.keywords column")
else:
log.info("semantic_metadata.keywords column does not exist, skipping")
else:
log.info("semantic_metadata table does not exist, skipping")
# Step 2: Remove symbols.token_count and symbols.symbol_type (if columns exist)
log.info("Checking symbols.token_count and symbols.symbol_type columns...")
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='symbols'"
)
if cursor.fetchone():
# Check if token_count or symbol_type columns exist
cursor.execute("PRAGMA table_info(symbols)")
columns = {row[1] for row in cursor.fetchall()}
if "token_count" in columns or "symbol_type" in columns:
log.info("Removing symbols.token_count and symbols.symbol_type columns...")
cursor.execute("""
CREATE TABLE symbols_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT,
start_line INTEGER,
end_line INTEGER,
FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
)
""")
cursor.execute("""
INSERT INTO symbols_new (id, file_id, name, kind, start_line, end_line)
SELECT id, file_id, name, kind, start_line, end_line
FROM symbols
""")
cursor.execute("DROP TABLE symbols")
cursor.execute("ALTER TABLE symbols_new RENAME TO symbols")
# Recreate indexes
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
log.info("Removed symbols.token_count and symbols.symbol_type columns")
else:
log.info("symbols.token_count/symbol_type columns do not exist, skipping")
else:
log.info("symbols table does not exist, skipping")
# Step 3: Remove subdirs.direct_files (if column exists)
log.info("Checking subdirs.direct_files column...")
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='subdirs'"
)
if cursor.fetchone():
# Check if direct_files column exists
cursor.execute("PRAGMA table_info(subdirs)")
columns = {row[1] for row in cursor.fetchall()}
if "direct_files" in columns:
log.info("Removing subdirs.direct_files column...")
cursor.execute("""
CREATE TABLE subdirs_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
index_path TEXT NOT NULL,
files_count INTEGER DEFAULT 0,
last_updated REAL
)
""")
cursor.execute("""
INSERT INTO subdirs_new (id, name, index_path, files_count, last_updated)
SELECT id, name, index_path, files_count, last_updated
FROM subdirs
""")
cursor.execute("DROP TABLE subdirs")
cursor.execute("ALTER TABLE subdirs_new RENAME TO subdirs")
# Recreate index
cursor.execute("CREATE INDEX IF NOT EXISTS idx_subdirs_name ON subdirs(name)")
log.info("Removed subdirs.direct_files column")
else:
log.info("subdirs.direct_files column does not exist, skipping")
else:
log.info("subdirs table does not exist, skipping")
log.info("Migration 005 completed successfully")
# Vacuum to reclaim space (outside transaction, optional)
# Note: VACUUM cannot run inside a transaction, so we skip it here
# The caller can run VACUUM separately if desired
def downgrade(db_conn: Connection):
"""Restore removed fields (data will be lost for keywords, token_count, symbol_type, direct_files).
This is a placeholder - true downgrade is not feasible as data is lost.
The migration is designed to be one-way since removed fields are unused/redundant.
Args:
db_conn: The SQLite database connection.
"""
log.warning(
"Migration 005 downgrade not supported - removed fields are unused/redundant. "
"Data cannot be restored."
)
raise NotImplementedError(
"Migration 005 downgrade not supported - this is a one-way migration"
)

View File

@@ -0,0 +1,37 @@
"""
Migration 006: Ensure relationship tables and indexes exist.
This migration is intentionally idempotent. It creates the `code_relationships`
table (used for graph visualization) and its indexes if missing.
"""
from __future__ import annotations
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection) -> None:
cursor = db_conn.cursor()
log.info("Ensuring code_relationships table exists...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL REFERENCES symbols (id) ON DELETE CASCADE,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
)
"""
)
log.info("Ensuring relationship indexes exist...")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)")

View File

@@ -0,0 +1,47 @@
"""
Migration 007: Add precomputed graph neighbor table for search expansion.
Adds:
- graph_neighbors: cached N-hop neighbors between symbols (keyed by symbol ids)
This table is derived data (a cache) and is safe to rebuild at any time.
The migration is intentionally idempotent.
"""
from __future__ import annotations
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection) -> None:
cursor = db_conn.cursor()
log.info("Creating graph_neighbors table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS graph_neighbors (
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
relationship_depth INTEGER NOT NULL,
PRIMARY KEY (source_symbol_id, neighbor_symbol_id)
)
"""
)
log.info("Creating indexes for graph_neighbors...")
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth
ON graph_neighbors(source_symbol_id, relationship_depth)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor
ON graph_neighbors(neighbor_symbol_id)
"""
)

Some files were not shown because too many files have changed in this diff Show More