feat(cli): 添加 --rule 选项支持模板自动发现

重构 ccw cli 模板系统:

- 新增 template-discovery.ts 模块,支持扁平化模板自动发现
- 添加 --rule <template> 选项,自动加载 protocol 和 template
- 模板目录从嵌套结构 (prompts/category/file.txt) 迁移到扁平结构 (prompts/category-function.txt)
- 更新所有 agent/command 文件,使用 $PROTO $TMPL 环境变量替代 $(cat ...) 模式
- 支持模糊匹配:--rule 02-review-architecture 可匹配 analysis-review-architecture.txt

其他更新:
- Dashboard: 添加 Claude Manager 和 Issue Manager 页面
- Codex-lens: 增强 chain_search 和 clustering 模块

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
catlog22
2026-01-17 19:20:24 +08:00
parent 1fae35c05d
commit f14418603a
137 changed files with 13125 additions and 301 deletions

View File

@@ -0,0 +1,676 @@
# Codexlens LSP API 规范
**版本**: 1.1
**状态**: ✅ APPROVED (Gemini Review)
**架构**: codexlens 提供 Python APICCW 实现 MCP 端点
**分析来源**: Gemini (架构评审) + Codex (实现评审)
**最后更新**: 2025-01-17
---
## 一、概述
### 1.1 背景
基于 cclsp MCP 服务器实现的分析,设计 codexlens 的 LSP 搜索方法接口,为 AI 提供代码智能能力。
### 1.2 架构决策
**MCP 端点由 CCW 实现codexlens 只提供 Python API**
```
┌─────────────────────────────────────────────────────────────┐
│ Claude Code │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ MCP Client │ │
│ └───────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ CCW MCP Server │ │
│ │ ┌─────────────────────────────────────────────────┐ │ │
│ │ │ MCP Tool Handlers │ │ │
│ │ │ • codexlens_file_context │ │ │
│ │ │ • codexlens_find_definition │ │ │
│ │ │ • codexlens_find_references │ │ │
│ │ │ • codexlens_semantic_search │ │ │
│ │ └──────────────────────┬──────────────────────────┘ │ │
│ └─────────────────────────┼─────────────────────────────┘ │
└────────────────────────────┼────────────────────────────────┘
│ Python API 调用
┌─────────────────────────────────────────────────────────────┐
│ codexlens │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Public API Layer │ │
│ │ codexlens.api.file_context() │ │
│ │ codexlens.api.find_definition() │ │
│ │ codexlens.api.find_references() │ │
│ │ codexlens.api.semantic_search() │ │
│ └──────────────────────┬────────────────────────────────┘ │
│ │ │
│ ┌──────────────────────▼────────────────────────────────┐ │
│ │ Core Components │ │
│ │ GlobalSymbolIndex | ChainSearchEngine | HoverProvider │ │
│ └───────────────────────────────────────────────────────┘ │
│ │ │
│ ┌──────────────────────▼────────────────────────────────┐ │
│ │ SQLite Index Databases │ │
│ │ global_symbols.db | *.index.db (per-directory) │ │
│ └───────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
### 1.3 职责分离
| 组件 | 职责 |
|------|------|
| **codexlens** | Python API、索引查询、搜索算法、结果聚合、降级处理 |
| **CCW** | MCP 协议、参数校验、结果序列化、错误处理、project_root 推断 |
### 1.4 codexlens vs cclsp 对比
| 特性 | cclsp | codexlens |
|------|-------|-----------|
| 数据源 | 实时 LSP 服务器 | 预建 SQLite 索引 |
| 启动时间 | 200-3000ms | <50ms |
| 响应时间 | 50-500ms | <5ms |
| 跨语言 | 每语言需要 LSP 服务器 | 统一 Python/TS/JS/Go 索引 |
| 依赖 | 需要语言服务器 | 无外部依赖 |
| 准确度 | 100% (编译器级) | 95%+ (tree-sitter) |
| 重命名支持 | 是 | 否 (只读索引) |
| 实时诊断 | 是 | 通过 IDE MCP |
**推荐**: codexlens 用于快速搜索cclsp 用于精确重构
---
## 二、cclsp 设计模式 (参考)
### 2.1 MCP 工具接口设计
| 模式 | 说明 | 代码位置 |
|------|------|----------|
| **基于名称** | 接受 `symbol_name` 而非文件坐标 | `index.ts:70` |
| **安全消歧义** | `rename_symbol``rename_symbol_strict` 两步 | `index.ts:133, 172` |
| **复杂性抽象** | 隐藏 LSP 协议细节 | `index.ts:211` |
| **优雅失败** | 返回有用的文本响应 | 全局 |
### 2.2 符号解析算法
```
1. getDocumentSymbols (lsp-client.ts:1406)
└─ 获取文件所有符号
2. 处理两种格式:
├─ DocumentSymbol[] → 扁平化
└─ SymbolInformation[] → 二次定位
3. 过滤: symbol.name === symbolName && symbol.kind
4. 回退: 无结果时移除 kind 约束重试
5. 聚合: 遍历所有匹配,聚合定义位置
```
---
## 三、需求规格
### 需求 1: 文件上下文查询 (`file_context`)
**用途**: 读取代码文件,返回文件中所有方法的调用关系摘要
**输出示例**:
```markdown
## src/auth/login.py (3 methods)
### login_user (line 15-45)
- Calls: validate_password (auth/utils.py:23), create_session (session/manager.py:89)
- Called by: handle_login_request (api/routes.py:156), test_login (tests/test_auth.py:34)
### validate_token (line 47-62)
- Calls: decode_jwt (auth/jwt.py:12)
- Called by: auth_middleware (middleware/auth.py:28)
```
### 需求 2: 通用 LSP 搜索 (cclsp 兼容)
| 端点 | 用途 |
|------|------|
| `find_definition` | 根据符号名查找定义位置 |
| `find_references` | 查找符号的所有引用 |
| `workspace_symbols` | 工作区符号搜索 |
| `get_hover` | 获取符号悬停信息 |
### 需求 3: 向量 + LSP 融合搜索
**用途**: 结合向量语义搜索和结构化 LSP 搜索
**融合策略**:
- **RRF** (首选): 简单、不需要分数归一化、鲁棒
- **Cascade**: 特定场景,先向量后 LSP
- **Adaptive**: 长期目标,按查询类型自动选择
---
## 四、API 规范
### 4.1 模块结构
```
src/codexlens/
├─ api/ [新增] 公开 API 层
│ ├─ __init__.py 导出所有 API
│ ├─ file_context.py 文件上下文
│ ├─ definition.py 定义查找
│ ├─ references.py 引用查找
│ ├─ symbols.py 符号搜索
│ ├─ hover.py 悬停信息
│ └─ semantic.py 语义搜索
├─ storage/
│ ├─ global_index.py [扩展] get_file_symbols()
│ └─ relationship_query.py [新增] 有向调用查询
└─ search/
└─ chain_search.py [修复] schema 兼容
```
### 4.2 `codexlens.api.file_context()`
```python
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Tuple
@dataclass
class CallInfo:
"""调用关系信息"""
symbol_name: str
file_path: Optional[str] # 目标文件 (可能为 None)
line: int
relationship: str # call | import | inheritance
@dataclass
class MethodContext:
"""方法上下文"""
name: str
kind: str # function | method | class
line_range: Tuple[int, int]
signature: Optional[str]
calls: List[CallInfo] # 出向调用
callers: List[CallInfo] # 入向调用
@dataclass
class FileContextResult:
"""文件上下文结果"""
file_path: str
language: str
methods: List[MethodContext]
summary: str # 人类可读摘要
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
"outgoing_resolved": False,
"incoming_resolved": True,
"targets_resolved": False
})
def file_context(
project_root: str,
file_path: str,
include_calls: bool = True,
include_callers: bool = True,
max_depth: int = 1,
format: str = "brief" # brief | detailed | tree
) -> FileContextResult:
"""
获取代码文件的方法调用上下文。
Args:
project_root: 项目根目录 (用于定位索引)
file_path: 代码文件路径
include_calls: 是否包含出向调用
include_callers: 是否包含入向调用
max_depth: 调用链深度 (1=直接调用)
⚠️ V1 限制: 当前版本仅支持 max_depth=1
深度调用链分析将在 V2 实现
format: 输出格式
Returns:
FileContextResult
Raises:
IndexNotFoundError: 项目未索引
FileNotFoundError: 文件不存在
Note:
V1 实现限制:
- max_depth 仅支持 1 (直接调用)
- 出向调用目标文件可能为 None (未解析)
- 深度调用链分析作为 V2 特性规划
"""
```
### 4.3 `codexlens.api.find_definition()`
```python
@dataclass
class DefinitionResult:
"""定义查找结果"""
name: str
kind: str
file_path: str
line: int
end_line: int
signature: Optional[str]
container: Optional[str] # 所属类/模块
score: float
def find_definition(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
file_context: Optional[str] = None,
limit: int = 10
) -> List[DefinitionResult]:
"""
根据符号名称查找定义位置。
Fallback 策略:
1. 精确匹配 + kind 过滤
2. 精确匹配 (移除 kind)
3. 前缀匹配
"""
```
### 4.4 `codexlens.api.find_references()`
```python
@dataclass
class ReferenceResult:
"""引用结果"""
file_path: str
line: int
column: int
context_line: str
relationship: str # call | import | type_annotation | inheritance
@dataclass
class GroupedReferences:
"""按定义分组的引用"""
definition: DefinitionResult
references: List[ReferenceResult]
def find_references(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
include_definition: bool = True,
group_by_definition: bool = True,
limit: int = 100
) -> List[GroupedReferences]:
"""
查找符号的所有引用位置。
多定义时分组返回,解决引用混淆问题。
"""
```
### 4.5 `codexlens.api.workspace_symbols()`
```python
@dataclass
class SymbolInfo:
"""符号信息"""
name: str
kind: str
file_path: str
line: int
container: Optional[str]
score: float
def workspace_symbols(
project_root: str,
query: str,
kind_filter: Optional[List[str]] = None,
file_pattern: Optional[str] = None,
limit: int = 50
) -> List[SymbolInfo]:
"""在整个工作区搜索符号 (前缀匹配)。"""
```
### 4.6 `codexlens.api.get_hover()`
```python
@dataclass
class HoverInfo:
"""悬停信息"""
name: str
kind: str
signature: str
documentation: Optional[str]
file_path: str
line_range: Tuple[int, int]
type_info: Optional[str]
def get_hover(
project_root: str,
symbol_name: str,
file_path: Optional[str] = None
) -> Optional[HoverInfo]:
"""获取符号的详细悬停信息。"""
```
### 4.7 `codexlens.api.semantic_search()`
```python
@dataclass
class SemanticResult:
"""语义搜索结果"""
symbol_name: str
kind: str
file_path: str
line: int
vector_score: Optional[float]
structural_score: Optional[float]
fusion_score: float
snippet: str
match_reason: Optional[str]
def semantic_search(
project_root: str,
query: str,
mode: str = "fusion", # vector | structural | fusion
vector_weight: float = 0.5,
structural_weight: float = 0.3,
keyword_weight: float = 0.2,
fusion_strategy: str = "rrf", # rrf | staged | binary | hybrid
kind_filter: Optional[List[str]] = None,
limit: int = 20,
include_match_reason: bool = False
) -> List[SemanticResult]:
"""
语义搜索 - 结合向量和结构化搜索。
Args:
project_root: 项目根目录
query: 自然语言查询
mode: 搜索模式
- vector: 仅向量搜索
- structural: 仅结构搜索 (符号 + 关系)
- fusion: 融合搜索 (默认)
vector_weight: 向量搜索权重 [0, 1]
structural_weight: 结构搜索权重 [0, 1]
keyword_weight: 关键词搜索权重 [0, 1]
fusion_strategy: 融合策略 (映射到 chain_search.py)
- rrf: Reciprocal Rank Fusion (推荐,默认)
- staged: 分阶段级联 → staged_cascade_search
- binary: 二分重排级联 → binary_rerank_cascade_search
- hybrid: 混合级联 → hybrid_cascade_search
kind_filter: 符号类型过滤
limit: 最大返回数量
include_match_reason: 是否生成匹配原因 (启发式,非 LLM)
Returns:
按 fusion_score 排序的结果列表
降级行为:
- 无向量索引: vector_score=None, 使用 FTS + 结构搜索
- 无关系数据: structural_score=None, 仅向量搜索
"""
```
---
## 五、已知问题与解决方案
### 5.1 P0 阻塞项
| 问题 | 位置 | 解决方案 |
|------|------|----------|
| **索引 Schema 不匹配** | `chain_search.py:313-324` vs `dir_index.py:304-312` | 兼容 `full_path``path` |
| **文件符号查询缺失** | `global_index.py:214-260` | 新增 `get_file_symbols()` |
| **出向调用查询缺失** | `dir_index.py:333-342` | 新增 `RelationshipQuery` |
| **关系类型不一致** | `entities.py:74-79` | 规范化 `calls``call` |
### 5.2 设计缺陷 (Gemini 发现)
| 缺陷 | 影响 | 解决方案 |
|------|------|----------|
| **调用图不完整** | `file_context` 缺少出向调用 | 新增有向调用 API |
| **消歧义未定义** | 多定义时无法区分 | 实现 `rank_by_proximity()` |
| **AI 特性成本过高** | `explanation` 需要 LLM | 设为可选,默认关闭 |
| **融合参数不一致** | 3 分支但只有 2 权重 | 补充 `keyword_weight` |
### 5.3 消歧义算法
**V1 实现** (基于文件路径接近度):
```python
def rank_by_proximity(
results: List[DefinitionResult],
file_context: str
) -> List[DefinitionResult]:
"""按文件接近度排序 (V1: 路径接近度)"""
def proximity_score(result):
# 1. 同目录最高分
if os.path.dirname(result.file_path) == os.path.dirname(file_context):
return 100
# 2. 共同路径前缀长度
common = os.path.commonpath([result.file_path, file_context])
return len(common)
return sorted(results, key=proximity_score, reverse=True)
```
**V2 增强计划** (基于 import graph 距离):
```python
def rank_by_import_distance(
results: List[DefinitionResult],
file_context: str,
import_graph: Dict[str, Set[str]]
) -> List[DefinitionResult]:
"""按 import graph 距离排序 (V2)"""
def import_distance(result):
# BFS 计算最短 import 路径
return bfs_shortest_path(
import_graph,
file_context,
result.file_path
)
# 组合: 0.6 * import_distance + 0.4 * path_proximity
return sorted(results, key=lambda r: (
0.6 * import_distance(r) +
0.4 * (100 - proximity_score(r))
))
```
### 5.4 参考实现: `get_file_symbols()`
**位置**: `src/codexlens/storage/global_index.py`
```python
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
"""
获取指定文件中定义的所有符号。
Args:
file_path: 文件路径 (相对或绝对)
Returns:
按行号排序的符号列表
"""
file_path_str = str(Path(file_path).resolve())
with self._lock:
conn = self._get_connection()
rows = conn.execute(
"""
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
FROM global_symbols
WHERE project_id = ? AND file_path = ?
ORDER BY start_line
""",
(self.project_id, file_path_str),
).fetchall()
return [
Symbol(
name=row["symbol_name"],
kind=row["symbol_kind"],
range=(row["start_line"], row["end_line"]),
file=row["file_path"],
)
for row in rows
]
```
---
## 六、实现计划
### Phase 0: 基础设施 (16h)
| 任务 | 工时 | 说明 |
|------|------|------|
| 修复 `search_references` schema | 4h | 兼容两种 schema |
| 新增 `GlobalSymbolIndex.get_file_symbols()` | 4h | 文件符号查询 (见 5.4) |
| 新增 `RelationshipQuery` 类 | 6h | 有向调用查询 |
| 关系类型规范化层 | 2h | `calls``call` |
### Phase 1: API 层 (48h)
| 任务 | 工时 | 复杂度 |
|------|------|--------|
| `find_definition()` | 4h | S |
| `find_references()` | 8h | M |
| `workspace_symbols()` | 4h | S |
| `get_hover()` | 4h | S |
| `file_context()` | 16h | L |
| `semantic_search()` | 12h | M |
### Phase 2: 测试与文档 (16h)
| 任务 | 工时 |
|------|------|
| 单元测试 (≥80%) | 8h |
| API 文档 | 4h |
| 示例代码 | 4h |
### 关键路径
```
Phase 0.1 (schema fix)
Phase 0.2 (file symbols) → Phase 1.5 (file_context)
Phase 1 (其他 API)
Phase 2 (测试)
```
---
## 七、测试策略
### 7.1 单元测试
```python
# test_global_index.py
def test_get_file_symbols():
index = GlobalSymbolIndex(":memory:")
index.update_file_symbols(project_id=1, file_path="test.py", symbols=[...])
results = index.get_file_symbols("test.py")
assert len(results) == 3
# test_relationship_query.py
def test_outgoing_calls():
store = DirIndexStore(":memory:")
calls = store.get_outgoing_calls("src/auth.py", "login")
assert calls[0].relationship == "call" # 已规范化
```
### 7.2 Schema 兼容性测试
```python
def test_search_references_both_schemas():
"""测试两种 schema 的引用搜索"""
# 旧 schema: files(path, ...)
# 新 schema: files(full_path, ...)
```
### 7.3 降级测试
```python
def test_semantic_search_without_vectors():
result = semantic_search(query="auth", mode="fusion")
assert result.vector_score is None
assert result.fusion_score > 0
```
---
## 八、使用示例
```python
from codexlens.api import (
file_context,
find_definition,
find_references,
semantic_search
)
# 1. 获取文件上下文
result = file_context(
project_root="/path/to/project",
file_path="src/auth/login.py",
format="brief"
)
print(result.summary)
# 2. 查找定义
definitions = find_definition(
project_root="/path/to/project",
symbol_name="UserService",
symbol_kind="class"
)
# 3. 语义搜索
results = semantic_search(
project_root="/path/to/project",
query="处理用户登录验证的函数",
mode="fusion"
)
```
---
## 九、CCW 集成
| codexlens API | CCW MCP Tool |
|---------------|--------------|
| `file_context()` | `codexlens_file_context` |
| `find_definition()` | `codexlens_find_definition` |
| `find_references()` | `codexlens_find_references` |
| `workspace_symbols()` | `codexlens_workspace_symbol` |
| `get_hover()` | `codexlens_get_hover` |
| `semantic_search()` | `codexlens_semantic_search` |
---
## 十、分析来源
| 工具 | Session ID | 贡献 |
|------|------------|------|
| Gemini | `1768618654438-gemini` | 架构评审、设计缺陷、融合策略 |
| Codex | `1768618658183-codex` | 组件复用、复杂度估算、任务分解 |
| Gemini | `1768620615744-gemini` | 最终评审、改进建议、APPROVED |
---
## 十一、版本历史
| 版本 | 日期 | 变更 |
|------|------|------|
| 1.0 | 2025-01-17 | 初始版本,合并多文档 |
| 1.1 | 2025-01-17 | 应用 Gemini 评审改进: V1 限制说明、策略映射、消歧义增强、参考实现 |

View File

@@ -97,11 +97,25 @@ encoding = [
"chardet>=5.0",
]
# Clustering for staged hybrid search (HDBSCAN + sklearn)
clustering = [
"hdbscan>=0.8.1",
"scikit-learn>=1.3.0",
]
# Full features including tiktoken for accurate token counting
full = [
"tiktoken>=0.5.0",
]
# Language Server Protocol support
lsp = [
"pygls>=1.3.0",
]
[project.scripts]
codexlens-lsp = "codexlens.lsp:main"
[project.urls]
Homepage = "https://github.com/openai/codex-lens"

View File

@@ -52,5 +52,10 @@ Requires-Dist: transformers>=4.36; extra == "splade-gpu"
Requires-Dist: optimum[onnxruntime-gpu]>=1.16; extra == "splade-gpu"
Provides-Extra: encoding
Requires-Dist: chardet>=5.0; extra == "encoding"
Provides-Extra: clustering
Requires-Dist: hdbscan>=0.8.1; extra == "clustering"
Requires-Dist: scikit-learn>=1.3.0; extra == "clustering"
Provides-Extra: full
Requires-Dist: tiktoken>=0.5.0; extra == "full"
Provides-Extra: lsp
Requires-Dist: pygls>=1.3.0; extra == "lsp"

View File

@@ -2,6 +2,7 @@ pyproject.toml
src/codex_lens.egg-info/PKG-INFO
src/codex_lens.egg-info/SOURCES.txt
src/codex_lens.egg-info/dependency_links.txt
src/codex_lens.egg-info/entry_points.txt
src/codex_lens.egg-info/requires.txt
src/codex_lens.egg-info/top_level.txt
src/codexlens/__init__.py
@@ -18,6 +19,14 @@ src/codexlens/cli/output.py
src/codexlens/indexing/__init__.py
src/codexlens/indexing/embedding.py
src/codexlens/indexing/symbol_extractor.py
src/codexlens/lsp/__init__.py
src/codexlens/lsp/handlers.py
src/codexlens/lsp/providers.py
src/codexlens/lsp/server.py
src/codexlens/mcp/__init__.py
src/codexlens/mcp/hooks.py
src/codexlens/mcp/provider.py
src/codexlens/mcp/schema.py
src/codexlens/parsers/__init__.py
src/codexlens/parsers/encoding.py
src/codexlens/parsers/factory.py
@@ -31,6 +40,13 @@ src/codexlens/search/graph_expander.py
src/codexlens/search/hybrid_search.py
src/codexlens/search/query_parser.py
src/codexlens/search/ranking.py
src/codexlens/search/clustering/__init__.py
src/codexlens/search/clustering/base.py
src/codexlens/search/clustering/dbscan_strategy.py
src/codexlens/search/clustering/factory.py
src/codexlens/search/clustering/frequency_strategy.py
src/codexlens/search/clustering/hdbscan_strategy.py
src/codexlens/search/clustering/noop_strategy.py
src/codexlens/semantic/__init__.py
src/codexlens/semantic/ann_index.py
src/codexlens/semantic/base.py
@@ -84,6 +100,7 @@ tests/test_api_reranker.py
tests/test_chain_search.py
tests/test_cli_hybrid_search.py
tests/test_cli_output.py
tests/test_clustering_strategies.py
tests/test_code_extractor.py
tests/test_config.py
tests/test_dual_fts.py
@@ -122,6 +139,7 @@ tests/test_search_performance.py
tests/test_semantic.py
tests/test_semantic_search.py
tests/test_sqlite_store.py
tests/test_staged_cascade.py
tests/test_storage.py
tests/test_storage_concurrency.py
tests/test_symbol_extractor.py

View File

@@ -0,0 +1,2 @@
[console_scripts]
codexlens-lsp = codexlens.lsp:main

View File

@@ -8,12 +8,19 @@ tree-sitter-typescript>=0.23
pathspec>=0.11
watchdog>=3.0
[clustering]
hdbscan>=0.8.1
scikit-learn>=1.3.0
[encoding]
chardet>=5.0
[full]
tiktoken>=0.5.0
[lsp]
pygls>=1.3.0
[reranker]
optimum>=1.16
onnxruntime>=1.15

View File

@@ -0,0 +1,88 @@
"""Codexlens Public API Layer.
This module exports all public API functions and dataclasses for the
codexlens LSP-like functionality.
Dataclasses (from models.py):
- CallInfo: Call relationship information
- MethodContext: Method context with call relationships
- FileContextResult: File context result with method summaries
- DefinitionResult: Definition lookup result
- ReferenceResult: Reference lookup result
- GroupedReferences: References grouped by definition
- SymbolInfo: Symbol information for workspace search
- HoverInfo: Hover information for a symbol
- SemanticResult: Semantic search result
Utility functions (from utils.py):
- resolve_project: Resolve and validate project root path
- normalize_relationship_type: Normalize relationship type to canonical form
- rank_by_proximity: Rank results by file path proximity
Example:
>>> from codexlens.api import (
... DefinitionResult,
... resolve_project,
... normalize_relationship_type
... )
>>> project = resolve_project("/path/to/project")
>>> rel_type = normalize_relationship_type("calls")
>>> print(rel_type)
'call'
"""
from __future__ import annotations
# Dataclasses
from .models import (
CallInfo,
MethodContext,
FileContextResult,
DefinitionResult,
ReferenceResult,
GroupedReferences,
SymbolInfo,
HoverInfo,
SemanticResult,
)
# Utility functions
from .utils import (
resolve_project,
normalize_relationship_type,
rank_by_proximity,
rank_by_score,
)
# API functions
from .definition import find_definition
from .symbols import workspace_symbols
from .hover import get_hover
from .file_context import file_context
from .references import find_references
from .semantic import semantic_search
__all__ = [
# Dataclasses
"CallInfo",
"MethodContext",
"FileContextResult",
"DefinitionResult",
"ReferenceResult",
"GroupedReferences",
"SymbolInfo",
"HoverInfo",
"SemanticResult",
# Utility functions
"resolve_project",
"normalize_relationship_type",
"rank_by_proximity",
"rank_by_score",
# API functions
"find_definition",
"workspace_symbols",
"get_hover",
"file_context",
"find_references",
"semantic_search",
]

View File

@@ -0,0 +1,126 @@
"""find_definition API implementation.
This module provides the find_definition() function for looking up
symbol definitions with a 3-stage fallback strategy.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import DefinitionResult
from .utils import resolve_project, rank_by_proximity
logger = logging.getLogger(__name__)
def find_definition(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
file_context: Optional[str] = None,
limit: int = 10
) -> List[DefinitionResult]:
"""Find definition locations for a symbol.
Uses a 3-stage fallback strategy:
1. Exact match with kind filter
2. Exact match without kind filter
3. Prefix match
Args:
project_root: Project root directory (for index location)
symbol_name: Name of the symbol to find
symbol_kind: Optional symbol kind filter (class, function, etc.)
file_context: Optional file path for proximity ranking
limit: Maximum number of results to return
Returns:
List of DefinitionResult sorted by proximity if file_context provided
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project_by_source(str(project_path))
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Stage 1: Exact match with kind filter
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
if results:
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
# Stage 2: Exact match without kind (if kind was specified)
if symbol_kind:
results = _search_with_kind(global_index, symbol_name, None, limit)
if results:
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
# Stage 3: Prefix match
results = global_index.search(
name=symbol_name,
kind=None,
limit=limit,
prefix_mode=True
)
if results:
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
return _rank_and_convert(results, file_context)
logger.debug(f"No definitions found for {symbol_name}")
return []
def _search_with_kind(
global_index: GlobalSymbolIndex,
symbol_name: str,
symbol_kind: Optional[str],
limit: int
) -> List[Symbol]:
"""Search for symbols with optional kind filter."""
return global_index.search(
name=symbol_name,
kind=symbol_kind,
limit=limit,
prefix_mode=False
)
def _rank_and_convert(
symbols: List[Symbol],
file_context: Optional[str]
) -> List[DefinitionResult]:
"""Convert symbols to DefinitionResult and rank by proximity."""
results = [
DefinitionResult(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
end_line=sym.range[1] if sym.range else 1,
signature=None, # Could extract from file if needed
container=None, # Could extract from parent symbol
score=1.0
)
for sym in symbols
]
return rank_by_proximity(results, file_context)

View File

@@ -0,0 +1,271 @@
"""file_context API implementation.
This module provides the file_context() function for retrieving
method call graphs from a source file.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.dir_index import DirIndexStore
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import (
FileContextResult,
MethodContext,
CallInfo,
)
from .utils import resolve_project, normalize_relationship_type
logger = logging.getLogger(__name__)
def file_context(
project_root: str,
file_path: str,
include_calls: bool = True,
include_callers: bool = True,
max_depth: int = 1,
format: str = "brief"
) -> FileContextResult:
"""Get method call context for a code file.
Retrieves all methods/functions in the file along with their
outgoing calls and incoming callers.
Args:
project_root: Project root directory (for index location)
file_path: Path to the code file to analyze
include_calls: Whether to include outgoing calls
include_callers: Whether to include incoming callers
max_depth: Call chain depth (V1 only supports 1)
format: Output format (brief | detailed | tree)
Returns:
FileContextResult with method contexts and summary
Raises:
IndexNotFoundError: If project is not indexed
FileNotFoundError: If file does not exist
ValueError: If max_depth > 1 (V1 limitation)
"""
# V1 limitation: only depth=1 supported
if max_depth > 1:
raise ValueError(
f"max_depth > 1 not supported in V1. "
f"Requested: {max_depth}, supported: 1"
)
project_path = resolve_project(project_root)
file_path_resolved = Path(file_path).resolve()
# Validate file exists
if not file_path_resolved.exists():
raise FileNotFoundError(f"File not found: {file_path_resolved}")
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project_by_source(str(project_path))
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Get all symbols in the file
symbols = global_index.get_file_symbols(str(file_path_resolved))
# Filter to functions, methods, and classes
method_symbols = [
s for s in symbols
if s.kind in ("function", "method", "class")
]
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
# Try to find dir_index for relationship queries
dir_index = _find_dir_index(project_info, file_path_resolved)
# Build method contexts
methods: List[MethodContext] = []
outgoing_resolved = True
incoming_resolved = True
targets_resolved = True
for symbol in method_symbols:
calls: List[CallInfo] = []
callers: List[CallInfo] = []
if include_calls and dir_index:
try:
outgoing = dir_index.get_outgoing_calls(
str(file_path_resolved),
symbol.name
)
for target_name, rel_type, line, target_file in outgoing:
calls.append(CallInfo(
symbol_name=target_name,
file_path=target_file,
line=line,
relationship=normalize_relationship_type(rel_type)
))
if target_file is None:
targets_resolved = False
except Exception as e:
logger.debug(f"Failed to get outgoing calls: {e}")
outgoing_resolved = False
if include_callers and dir_index:
try:
incoming = dir_index.get_incoming_calls(symbol.name)
for source_name, rel_type, line, source_file in incoming:
callers.append(CallInfo(
symbol_name=source_name,
file_path=source_file,
line=line,
relationship=normalize_relationship_type(rel_type)
))
except Exception as e:
logger.debug(f"Failed to get incoming calls: {e}")
incoming_resolved = False
methods.append(MethodContext(
name=symbol.name,
kind=symbol.kind,
line_range=symbol.range if symbol.range else (1, 1),
signature=None, # Could extract from source
calls=calls,
callers=callers
))
# Detect language from file extension
language = _detect_language(file_path_resolved)
# Generate summary
summary = _generate_summary(file_path_resolved, methods, format)
return FileContextResult(
file_path=str(file_path_resolved),
language=language,
methods=methods,
summary=summary,
discovery_status={
"outgoing_resolved": outgoing_resolved,
"incoming_resolved": incoming_resolved,
"targets_resolved": targets_resolved
}
)
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
"""Find the dir_index that contains the file.
Args:
project_info: Project information from registry
file_path: Path to the file
Returns:
DirIndexStore if found, None otherwise
"""
try:
# Look for _index.db in file's directory or parent directories
current = file_path.parent
while current != current.parent:
index_db = current / "_index.db"
if index_db.exists():
return DirIndexStore(str(index_db))
# Also check in project's index_root
relative = current.relative_to(project_info.source_root)
index_in_cache = project_info.index_root / relative / "_index.db"
if index_in_cache.exists():
return DirIndexStore(str(index_in_cache))
current = current.parent
except Exception as e:
logger.debug(f"Failed to find dir_index: {e}")
return None
def _detect_language(file_path: Path) -> str:
"""Detect programming language from file extension.
Args:
file_path: Path to the file
Returns:
Language name
"""
ext_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".jsx": "javascript",
".tsx": "typescript",
".go": "go",
".rs": "rust",
".java": "java",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
}
return ext_map.get(file_path.suffix.lower(), "unknown")
def _generate_summary(
file_path: Path,
methods: List[MethodContext],
format: str
) -> str:
"""Generate human-readable summary of file context.
Args:
file_path: Path to the file
methods: List of method contexts
format: Output format (brief | detailed | tree)
Returns:
Markdown-formatted summary
"""
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
for method in methods:
start, end = method.line_range
lines.append(f"### {method.name} (line {start}-{end})")
if method.calls:
calls_str = ", ".join(
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
if format == "detailed"
else c.symbol_name
for c in method.calls
)
lines.append(f"- Calls: {calls_str}")
if method.callers:
callers_str = ", ".join(
f"{c.symbol_name} ({c.file_path}:{c.line})"
if format == "detailed"
else c.symbol_name
for c in method.callers
)
lines.append(f"- Called by: {callers_str}")
if not method.calls and not method.callers:
lines.append("- (no call relationships)")
lines.append("")
return "\n".join(lines)

View File

@@ -0,0 +1,148 @@
"""get_hover API implementation.
This module provides the get_hover() function for retrieving
detailed hover information for symbols.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import HoverInfo
from .utils import resolve_project
logger = logging.getLogger(__name__)
def get_hover(
project_root: str,
symbol_name: str,
file_path: Optional[str] = None
) -> Optional[HoverInfo]:
"""Get detailed hover information for a symbol.
Args:
project_root: Project root directory (for index location)
symbol_name: Name of the symbol to look up
file_path: Optional file path to disambiguate when symbol
appears in multiple files
Returns:
HoverInfo if symbol found, None otherwise
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project_by_source(str(project_path))
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Search for the symbol
results = global_index.search(
name=symbol_name,
kind=None,
limit=50,
prefix_mode=False
)
if not results:
logger.debug(f"No hover info found for {symbol_name}")
return None
# If file_path provided, filter to that file
if file_path:
file_path_resolved = str(Path(file_path).resolve())
matching = [s for s in results if s.file == file_path_resolved]
if matching:
results = matching
# Take the first result
symbol = results[0]
# Build hover info
return HoverInfo(
name=symbol.name,
kind=symbol.kind,
signature=_extract_signature(symbol),
documentation=_extract_documentation(symbol),
file_path=symbol.file or "",
line_range=symbol.range if symbol.range else (1, 1),
type_info=_extract_type_info(symbol)
)
def _extract_signature(symbol: Symbol) -> str:
"""Extract signature from symbol.
For now, generates a basic signature based on kind and name.
In a full implementation, this would parse the actual source code.
Args:
symbol: The symbol to extract signature from
Returns:
Signature string
"""
if symbol.kind == "function":
return f"def {symbol.name}(...)"
elif symbol.kind == "method":
return f"def {symbol.name}(self, ...)"
elif symbol.kind == "class":
return f"class {symbol.name}"
elif symbol.kind == "variable":
return symbol.name
elif symbol.kind == "constant":
return f"{symbol.name} = ..."
else:
return f"{symbol.kind} {symbol.name}"
def _extract_documentation(symbol: Symbol) -> Optional[str]:
"""Extract documentation from symbol.
In a full implementation, this would parse docstrings from source.
For now, returns None.
Args:
symbol: The symbol to extract documentation from
Returns:
Documentation string if available, None otherwise
"""
# Would need to read source file and parse docstring
# For V1, return None
return None
def _extract_type_info(symbol: Symbol) -> Optional[str]:
"""Extract type information from symbol.
In a full implementation, this would parse type annotations.
For now, returns None.
Args:
symbol: The symbol to extract type info from
Returns:
Type info string if available, None otherwise
"""
# Would need to parse type annotations from source
# For V1, return None
return None

View File

@@ -0,0 +1,281 @@
"""API dataclass definitions for codexlens LSP API.
This module defines all result dataclasses used by the public API layer,
following the patterns established in mcp/schema.py.
"""
from __future__ import annotations
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Dict, Tuple
# =============================================================================
# Section 4.2: file_context dataclasses
# =============================================================================
@dataclass
class CallInfo:
"""Call relationship information.
Attributes:
symbol_name: Name of the called/calling symbol
file_path: Target file path (may be None if unresolved)
line: Line number of the call
relationship: Type of relationship (call | import | inheritance)
"""
symbol_name: str
file_path: Optional[str]
line: int
relationship: str # call | import | inheritance
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class MethodContext:
"""Method context with call relationships.
Attributes:
name: Method/function name
kind: Symbol kind (function | method | class)
line_range: Start and end line numbers
signature: Function signature (if available)
calls: List of outgoing calls
callers: List of incoming calls
"""
name: str
kind: str # function | method | class
line_range: Tuple[int, int]
signature: Optional[str]
calls: List[CallInfo] = field(default_factory=list)
callers: List[CallInfo] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
result = {
"name": self.name,
"kind": self.kind,
"line_range": list(self.line_range),
"calls": [c.to_dict() for c in self.calls],
"callers": [c.to_dict() for c in self.callers],
}
if self.signature is not None:
result["signature"] = self.signature
return result
@dataclass
class FileContextResult:
"""File context result with method summaries.
Attributes:
file_path: Path to the analyzed file
language: Programming language
methods: List of method contexts
summary: Human-readable summary
discovery_status: Status flags for call resolution
"""
file_path: str
language: str
methods: List[MethodContext]
summary: str
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
"outgoing_resolved": False,
"incoming_resolved": True,
"targets_resolved": False
})
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"file_path": self.file_path,
"language": self.language,
"methods": [m.to_dict() for m in self.methods],
"summary": self.summary,
"discovery_status": self.discovery_status,
}
# =============================================================================
# Section 4.3: find_definition dataclasses
# =============================================================================
@dataclass
class DefinitionResult:
"""Definition lookup result.
Attributes:
name: Symbol name
kind: Symbol kind (class, function, method, etc.)
file_path: File where symbol is defined
line: Start line number
end_line: End line number
signature: Symbol signature (if available)
container: Containing class/module (if any)
score: Match score for ranking
"""
name: str
kind: str
file_path: str
line: int
end_line: int
signature: Optional[str] = None
container: Optional[str] = None
score: float = 1.0
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
# =============================================================================
# Section 4.4: find_references dataclasses
# =============================================================================
@dataclass
class ReferenceResult:
"""Reference lookup result.
Attributes:
file_path: File containing the reference
line: Line number
column: Column number
context_line: The line of code containing the reference
relationship: Type of reference (call | import | type_annotation | inheritance)
"""
file_path: str
line: int
column: int
context_line: str
relationship: str # call | import | type_annotation | inheritance
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class GroupedReferences:
"""References grouped by definition.
Used when a symbol has multiple definitions (e.g., overloads).
Attributes:
definition: The definition this group refers to
references: List of references to this definition
"""
definition: DefinitionResult
references: List[ReferenceResult] = field(default_factory=list)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"definition": self.definition.to_dict(),
"references": [r.to_dict() for r in self.references],
}
# =============================================================================
# Section 4.5: workspace_symbols dataclasses
# =============================================================================
@dataclass
class SymbolInfo:
"""Symbol information for workspace search.
Attributes:
name: Symbol name
kind: Symbol kind
file_path: File where symbol is defined
line: Line number
container: Containing class/module (if any)
score: Match score for ranking
"""
name: str
kind: str
file_path: str
line: int
container: Optional[str] = None
score: float = 1.0
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}
# =============================================================================
# Section 4.6: get_hover dataclasses
# =============================================================================
@dataclass
class HoverInfo:
"""Hover information for a symbol.
Attributes:
name: Symbol name
kind: Symbol kind
signature: Symbol signature
documentation: Documentation string (if available)
file_path: File where symbol is defined
line_range: Start and end line numbers
type_info: Type information (if available)
"""
name: str
kind: str
signature: str
documentation: Optional[str]
file_path: str
line_range: Tuple[int, int]
type_info: Optional[str] = None
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
result = {
"name": self.name,
"kind": self.kind,
"signature": self.signature,
"file_path": self.file_path,
"line_range": list(self.line_range),
}
if self.documentation is not None:
result["documentation"] = self.documentation
if self.type_info is not None:
result["type_info"] = self.type_info
return result
# =============================================================================
# Section 4.7: semantic_search dataclasses
# =============================================================================
@dataclass
class SemanticResult:
"""Semantic search result.
Attributes:
symbol_name: Name of the matched symbol
kind: Symbol kind
file_path: File where symbol is defined
line: Line number
vector_score: Vector similarity score (None if not available)
structural_score: Structural match score (None if not available)
fusion_score: Combined fusion score
snippet: Code snippet
match_reason: Explanation of why this matched (optional)
"""
symbol_name: str
kind: str
file_path: str
line: int
vector_score: Optional[float]
structural_score: Optional[float]
fusion_score: float
snippet: str
match_reason: Optional[str] = None
def to_dict(self) -> dict:
"""Convert to dictionary, filtering None values."""
return {k: v for k, v in asdict(self).items() if v is not None}

View File

@@ -0,0 +1,345 @@
"""Find references API for codexlens.
This module implements the find_references() function that wraps
ChainSearchEngine.search_references() with grouped result structure
for multi-definition symbols.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional, Dict
from .models import (
DefinitionResult,
ReferenceResult,
GroupedReferences,
)
from .utils import (
resolve_project,
normalize_relationship_type,
)
logger = logging.getLogger(__name__)
def _read_line_from_file(file_path: str, line: int) -> str:
"""Read a specific line from a file.
Args:
file_path: Path to the file
line: Line number (1-based)
Returns:
The line content, stripped of trailing whitespace.
Returns empty string if file cannot be read or line doesn't exist.
"""
try:
path = Path(file_path)
if not path.exists():
return ""
with path.open("r", encoding="utf-8", errors="replace") as f:
for i, content in enumerate(f, 1):
if i == line:
return content.rstrip()
return ""
except Exception as exc:
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
return ""
def _transform_to_reference_result(
raw_ref: "RawReferenceResult",
) -> ReferenceResult:
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
Args:
raw_ref: Raw reference result from ChainSearchEngine
Returns:
API ReferenceResult with context_line and normalized relationship
"""
# Read the actual line from the file
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
# Normalize relationship type
relationship = normalize_relationship_type(raw_ref.relationship_type)
return ReferenceResult(
file_path=raw_ref.file_path,
line=raw_ref.line,
column=raw_ref.column,
context_line=context_line,
relationship=relationship,
)
def find_references(
project_root: str,
symbol_name: str,
symbol_kind: Optional[str] = None,
include_definition: bool = True,
group_by_definition: bool = True,
limit: int = 100,
) -> List[GroupedReferences]:
"""Find all reference locations for a symbol.
Multi-definition case returns grouped results to resolve ambiguity.
This function wraps ChainSearchEngine.search_references() and groups
the results by definition location. Each GroupedReferences contains
a definition and all references that point to it.
Args:
project_root: Project root directory path
symbol_name: Name of the symbol to find references for
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
include_definition: Whether to include the definition location
in the result (default True)
group_by_definition: Whether to group references by definition.
If False, returns a single group with all references.
(default True)
limit: Maximum number of references to return (default 100)
Returns:
List of GroupedReferences. Each group contains:
- definition: The DefinitionResult for this symbol definition
- references: List of ReferenceResult pointing to this definition
Raises:
ValueError: If project_root does not exist or is not a directory
Examples:
>>> refs = find_references("/path/to/project", "authenticate")
>>> for group in refs:
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
... for ref in group.references:
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
Note:
Reference relationship types are normalized:
- 'calls' -> 'call'
- 'imports' -> 'import'
- 'inherits' -> 'inheritance'
"""
# Validate and resolve project root
project_path = resolve_project(project_root)
# Import here to avoid circular imports
from codexlens.config import Config
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.search.chain_search import ChainSearchEngine
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
from codexlens.entities import Symbol
# Initialize infrastructure
config = Config()
registry = RegistryStore(config.registry_db_path)
mapper = PathMapper(config.index_root)
# Create chain search engine
engine = ChainSearchEngine(registry, mapper, config=config)
try:
# Step 1: Find definitions for the symbol
definitions: List[DefinitionResult] = []
if include_definition or group_by_definition:
# Search for symbol definitions
symbols = engine.search_symbols(
name=symbol_name,
source_path=project_path,
kind=symbol_kind,
)
# Convert Symbol to DefinitionResult
for sym in symbols:
# Only include exact name matches for definitions
if sym.name != symbol_name:
continue
# Optionally filter by kind
if symbol_kind and sym.kind != symbol_kind:
continue
definitions.append(DefinitionResult(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
end_line=sym.range[1] if sym.range else 1,
signature=None, # Not available from Symbol
container=None, # Not available from Symbol
score=1.0,
))
# Step 2: Get all references using ChainSearchEngine
raw_references = engine.search_references(
symbol_name=symbol_name,
source_path=project_path,
depth=-1,
limit=limit,
)
# Step 3: Transform raw references to API ReferenceResult
api_references: List[ReferenceResult] = []
for raw_ref in raw_references:
api_ref = _transform_to_reference_result(raw_ref)
api_references.append(api_ref)
# Step 4: Group references by definition
if group_by_definition and definitions:
return _group_references_by_definition(
definitions=definitions,
references=api_references,
include_definition=include_definition,
)
else:
# Return single group with placeholder definition or first definition
if definitions:
definition = definitions[0]
else:
# Create placeholder definition when no definition found
definition = DefinitionResult(
name=symbol_name,
kind=symbol_kind or "unknown",
file_path="",
line=0,
end_line=0,
signature=None,
container=None,
score=0.0,
)
return [GroupedReferences(
definition=definition,
references=api_references,
)]
finally:
engine.close()
def _group_references_by_definition(
definitions: List[DefinitionResult],
references: List[ReferenceResult],
include_definition: bool = True,
) -> List[GroupedReferences]:
"""Group references by their likely definition.
Uses file proximity heuristic to assign references to definitions.
References in the same file or directory as a definition are
assigned to that definition.
Args:
definitions: List of definition locations
references: List of reference locations
include_definition: Whether to include definition in results
Returns:
List of GroupedReferences with references assigned to definitions
"""
import os
if not definitions:
return []
if len(definitions) == 1:
# Single definition - all references belong to it
return [GroupedReferences(
definition=definitions[0],
references=references,
)]
# Multiple definitions - group by proximity
groups: Dict[int, List[ReferenceResult]] = {
i: [] for i in range(len(definitions))
}
for ref in references:
# Find the closest definition by file proximity
best_def_idx = 0
best_score = -1
for i, defn in enumerate(definitions):
score = _proximity_score(ref.file_path, defn.file_path)
if score > best_score:
best_score = score
best_def_idx = i
groups[best_def_idx].append(ref)
# Build result groups
result: List[GroupedReferences] = []
for i, defn in enumerate(definitions):
# Skip definitions with no references if not including definition itself
if not include_definition and not groups[i]:
continue
result.append(GroupedReferences(
definition=defn,
references=groups[i],
))
return result
def _proximity_score(ref_path: str, def_path: str) -> int:
"""Calculate proximity score between two file paths.
Args:
ref_path: Reference file path
def_path: Definition file path
Returns:
Proximity score (higher = closer):
- Same file: 1000
- Same directory: 100
- Otherwise: common path prefix length
"""
import os
if not ref_path or not def_path:
return 0
# Normalize paths
ref_path = os.path.normpath(ref_path)
def_path = os.path.normpath(def_path)
# Same file
if ref_path == def_path:
return 1000
ref_dir = os.path.dirname(ref_path)
def_dir = os.path.dirname(def_path)
# Same directory
if ref_dir == def_dir:
return 100
# Common path prefix
try:
common = os.path.commonpath([ref_path, def_path])
return len(common)
except ValueError:
# No common path (different drives on Windows)
return 0
# Type alias for the raw reference from ChainSearchEngine
class RawReferenceResult:
"""Type stub for ChainSearchEngine.ReferenceResult.
This is only used for type hints and is replaced at runtime
by the actual import.
"""
file_path: str
line: int
column: int
context: str
relationship_type: str

View File

@@ -0,0 +1,471 @@
"""Semantic search API with RRF fusion.
This module provides the semantic_search() function for combining
vector, structural, and keyword search with configurable fusion strategies.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional
from .models import SemanticResult
from .utils import resolve_project
logger = logging.getLogger(__name__)
def semantic_search(
project_root: str,
query: str,
mode: str = "fusion",
vector_weight: float = 0.5,
structural_weight: float = 0.3,
keyword_weight: float = 0.2,
fusion_strategy: str = "rrf",
kind_filter: Optional[List[str]] = None,
limit: int = 20,
include_match_reason: bool = False,
) -> List[SemanticResult]:
"""Semantic search - combining vector and structural search.
This function provides a high-level API for semantic code search,
combining vector similarity, structural (symbol + relationships),
and keyword-based search methods with configurable fusion.
Args:
project_root: Project root directory
query: Natural language query
mode: Search mode
- vector: Vector search only
- structural: Structural search only (symbol + relationships)
- fusion: Fusion search (default)
vector_weight: Vector search weight [0, 1] (default 0.5)
structural_weight: Structural search weight [0, 1] (default 0.3)
keyword_weight: Keyword search weight [0, 1] (default 0.2)
fusion_strategy: Fusion strategy (maps to chain_search.py)
- rrf: Reciprocal Rank Fusion (recommended, default)
- staged: Staged cascade -> staged_cascade_search
- binary: Binary rerank cascade -> binary_cascade_search
- hybrid: Hybrid cascade -> hybrid_cascade_search
kind_filter: Symbol type filter (e.g., ["function", "class"])
limit: Max return count (default 20)
include_match_reason: Generate match reason (heuristic, not LLM)
Returns:
Results sorted by fusion_score
Degradation:
- No vector index: vector_score=None, uses FTS + structural search
- No relationship data: structural_score=None, vector search only
Examples:
>>> results = semantic_search(
... "/path/to/project",
... "authentication handler",
... mode="fusion",
... fusion_strategy="rrf"
... )
>>> for r in results:
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
"""
# Validate and resolve project path
project_path = resolve_project(project_root)
# Normalize weights to sum to 1.0
total_weight = vector_weight + structural_weight + keyword_weight
if total_weight > 0:
vector_weight = vector_weight / total_weight
structural_weight = structural_weight / total_weight
keyword_weight = keyword_weight / total_weight
else:
# Default to equal weights if all zero
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
# Initialize search infrastructure
try:
from codexlens.config import Config
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
except ImportError as exc:
logger.error("Failed to import search dependencies: %s", exc)
return []
# Load config
config = Config.load()
# Get or create registry and mapper
try:
registry = RegistryStore.default()
mapper = PathMapper(registry)
except Exception as exc:
logger.error("Failed to initialize search infrastructure: %s", exc)
return []
# Build search options based on mode
search_options = _build_search_options(
mode=mode,
vector_weight=vector_weight,
structural_weight=structural_weight,
keyword_weight=keyword_weight,
limit=limit,
)
# Execute search based on fusion_strategy
try:
with ChainSearchEngine(registry, mapper, config=config) as engine:
chain_result = _execute_search(
engine=engine,
query=query,
source_path=project_path,
fusion_strategy=fusion_strategy,
options=search_options,
limit=limit,
)
except Exception as exc:
logger.error("Search execution failed: %s", exc)
return []
# Transform results to SemanticResult
semantic_results = _transform_results(
results=chain_result.results,
mode=mode,
vector_weight=vector_weight,
structural_weight=structural_weight,
keyword_weight=keyword_weight,
kind_filter=kind_filter,
include_match_reason=include_match_reason,
query=query,
)
return semantic_results[:limit]
def _build_search_options(
mode: str,
vector_weight: float,
structural_weight: float,
keyword_weight: float,
limit: int,
) -> "SearchOptions":
"""Build SearchOptions based on mode and weights.
Args:
mode: Search mode (vector, structural, fusion)
vector_weight: Vector search weight
structural_weight: Structural search weight
keyword_weight: Keyword search weight
limit: Result limit
Returns:
Configured SearchOptions
"""
from codexlens.search.chain_search import SearchOptions
# Default options
options = SearchOptions(
total_limit=limit * 2, # Fetch extra for filtering
limit_per_dir=limit,
include_symbols=True, # Always include symbols for structural
)
if mode == "vector":
# Pure vector mode
options.hybrid_mode = True
options.enable_vector = True
options.pure_vector = True
options.enable_fuzzy = False
elif mode == "structural":
# Structural only - use FTS + symbols
options.hybrid_mode = True
options.enable_vector = False
options.enable_fuzzy = True
options.include_symbols = True
else:
# Fusion mode (default)
options.hybrid_mode = True
options.enable_vector = vector_weight > 0
options.enable_fuzzy = keyword_weight > 0
options.include_symbols = structural_weight > 0
# Set custom weights for RRF
if options.enable_vector and keyword_weight > 0:
options.hybrid_weights = {
"vector": vector_weight,
"exact": keyword_weight * 0.7,
"fuzzy": keyword_weight * 0.3,
}
return options
def _execute_search(
engine: "ChainSearchEngine",
query: str,
source_path: Path,
fusion_strategy: str,
options: "SearchOptions",
limit: int,
) -> "ChainSearchResult":
"""Execute search using appropriate strategy.
Maps fusion_strategy to ChainSearchEngine methods:
- rrf: Standard hybrid search with RRF fusion
- staged: staged_cascade_search
- binary: binary_cascade_search
- hybrid: hybrid_cascade_search
Args:
engine: ChainSearchEngine instance
query: Search query
source_path: Project root path
fusion_strategy: Strategy name
options: Search options
limit: Result limit
Returns:
ChainSearchResult from the search
"""
from codexlens.search.chain_search import ChainSearchResult
if fusion_strategy == "staged":
# Use staged cascade search (4-stage pipeline)
return engine.staged_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
elif fusion_strategy == "binary":
# Use binary cascade search (binary coarse + dense fine)
return engine.binary_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
elif fusion_strategy == "hybrid":
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
return engine.hybrid_cascade_search(
query=query,
source_path=source_path,
k=limit,
coarse_k=limit * 5,
options=options,
)
else:
# Default: rrf - Standard search with RRF fusion
return engine.search(
query=query,
source_path=source_path,
options=options,
)
def _transform_results(
results: List,
mode: str,
vector_weight: float,
structural_weight: float,
keyword_weight: float,
kind_filter: Optional[List[str]],
include_match_reason: bool,
query: str,
) -> List[SemanticResult]:
"""Transform ChainSearchEngine results to SemanticResult.
Args:
results: List of SearchResult objects
mode: Search mode
vector_weight: Vector weight used
structural_weight: Structural weight used
keyword_weight: Keyword weight used
kind_filter: Optional symbol kind filter
include_match_reason: Whether to generate match reasons
query: Original query (for match reason generation)
Returns:
List of SemanticResult objects
"""
semantic_results = []
for result in results:
# Extract symbol info
symbol_name = getattr(result, "symbol_name", None)
symbol_kind = getattr(result, "symbol_kind", None)
start_line = getattr(result, "start_line", None)
# Use symbol object if available
if hasattr(result, "symbol") and result.symbol:
symbol_name = symbol_name or result.symbol.name
symbol_kind = symbol_kind or result.symbol.kind
if hasattr(result.symbol, "range") and result.symbol.range:
start_line = start_line or result.symbol.range[0]
# Filter by kind if specified
if kind_filter and symbol_kind:
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
continue
# Determine scores based on mode and metadata
metadata = getattr(result, "metadata", {}) or {}
fusion_score = result.score
# Try to extract source scores from metadata
source_scores = metadata.get("source_scores", {})
vector_score: Optional[float] = None
structural_score: Optional[float] = None
if mode == "vector":
# In pure vector mode, the main score is the vector score
vector_score = result.score
structural_score = None
elif mode == "structural":
# In structural mode, no vector score
vector_score = None
structural_score = result.score
else:
# Fusion mode - try to extract individual scores
if "vector" in source_scores:
vector_score = source_scores["vector"]
elif metadata.get("fusion_method") == "simple_weighted":
# From weighted fusion
vector_score = source_scores.get("vector")
# Structural score approximation (from exact/fuzzy FTS)
fts_scores = []
if "exact" in source_scores:
fts_scores.append(source_scores["exact"])
if "fuzzy" in source_scores:
fts_scores.append(source_scores["fuzzy"])
if "splade" in source_scores:
fts_scores.append(source_scores["splade"])
if fts_scores:
structural_score = max(fts_scores)
# Build snippet
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
if len(snippet) > 500:
snippet = snippet[:500] + "..."
# Generate match reason if requested
match_reason = None
if include_match_reason:
match_reason = _generate_match_reason(
query=query,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
snippet=snippet,
vector_score=vector_score,
structural_score=structural_score,
)
semantic_result = SemanticResult(
symbol_name=symbol_name or Path(result.path).stem,
kind=symbol_kind or "unknown",
file_path=result.path,
line=start_line or 1,
vector_score=vector_score,
structural_score=structural_score,
fusion_score=fusion_score,
snippet=snippet,
match_reason=match_reason,
)
semantic_results.append(semantic_result)
# Sort by fusion_score descending
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
return semantic_results
def _generate_match_reason(
query: str,
symbol_name: Optional[str],
symbol_kind: Optional[str],
snippet: str,
vector_score: Optional[float],
structural_score: Optional[float],
) -> str:
"""Generate human-readable match reason heuristically.
This is a simple heuristic-based approach, not LLM-powered.
Args:
query: Original search query
symbol_name: Symbol name if available
symbol_kind: Symbol kind if available
snippet: Code snippet
vector_score: Vector similarity score
structural_score: Structural match score
Returns:
Human-readable explanation string
"""
reasons = []
# Check for direct name match
query_lower = query.lower()
query_words = set(query_lower.split())
if symbol_name:
name_lower = symbol_name.lower()
# Direct substring match
if query_lower in name_lower or name_lower in query_lower:
reasons.append(f"Symbol name '{symbol_name}' matches query")
# Word overlap
name_words = set(_split_camel_case(symbol_name).lower().split())
overlap = query_words & name_words
if overlap and not reasons:
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
# Check snippet for keyword matches
snippet_lower = snippet.lower()
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
if matching_words and len(reasons) < 2:
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
# Add score-based reasoning
if vector_score is not None and vector_score > 0.7:
reasons.append("High semantic similarity")
elif vector_score is not None and vector_score > 0.5:
reasons.append("Moderate semantic similarity")
if structural_score is not None and structural_score > 0.8:
reasons.append("Strong structural match")
# Symbol kind context
if symbol_kind and len(reasons) < 3:
reasons.append(f"Matched {symbol_kind}")
if not reasons:
reasons.append("Partial relevance based on content analysis")
return "; ".join(reasons[:3])
def _split_camel_case(name: str) -> str:
"""Split camelCase and PascalCase to words.
Args:
name: Symbol name in camelCase or PascalCase
Returns:
Space-separated words
"""
import re
# Insert space before uppercase letters
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
# Insert space before uppercase followed by lowercase
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
# Replace underscores with spaces
result = result.replace("_", " ")
return result

View File

@@ -0,0 +1,146 @@
"""workspace_symbols API implementation.
This module provides the workspace_symbols() function for searching
symbols across the entire workspace with prefix matching.
"""
from __future__ import annotations
import fnmatch
import logging
from pathlib import Path
from typing import List, Optional
from ..entities import Symbol
from ..storage.global_index import GlobalSymbolIndex
from ..storage.registry import RegistryStore
from ..errors import IndexNotFoundError
from .models import SymbolInfo
from .utils import resolve_project
logger = logging.getLogger(__name__)
def workspace_symbols(
project_root: str,
query: str,
kind_filter: Optional[List[str]] = None,
file_pattern: Optional[str] = None,
limit: int = 50
) -> List[SymbolInfo]:
"""Search for symbols across the entire workspace.
Uses prefix matching for efficient searching.
Args:
project_root: Project root directory (for index location)
query: Search query (prefix match)
kind_filter: Optional list of symbol kinds to include
(e.g., ["class", "function"])
file_pattern: Optional glob pattern to filter by file path
(e.g., "*.py", "src/**/*.ts")
limit: Maximum number of results to return
Returns:
List of SymbolInfo sorted by score
Raises:
IndexNotFoundError: If project is not indexed
"""
project_path = resolve_project(project_root)
# Get project info from registry
registry = RegistryStore()
project_info = registry.get_project_by_source(str(project_path))
if project_info is None:
raise IndexNotFoundError(f"Project not indexed: {project_path}")
# Open global symbol index
index_db = project_info.index_root / "_global_symbols.db"
if not index_db.exists():
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
# Search with prefix matching
# If kind_filter has multiple kinds, we need to search for each
all_results: List[Symbol] = []
if kind_filter and len(kind_filter) > 0:
# Search for each kind separately
for kind in kind_filter:
results = global_index.search(
name=query,
kind=kind,
limit=limit,
prefix_mode=True
)
all_results.extend(results)
else:
# Search without kind filter
all_results = global_index.search(
name=query,
kind=None,
limit=limit,
prefix_mode=True
)
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
# Apply file pattern filter if specified
if file_pattern:
all_results = [
sym for sym in all_results
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
]
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
# Convert to SymbolInfo and sort by relevance
symbols = [
SymbolInfo(
name=sym.name,
kind=sym.kind,
file_path=sym.file or "",
line=sym.range[0] if sym.range else 1,
container=None, # Could extract from parent
score=_calculate_score(sym.name, query)
)
for sym in all_results
]
# Sort by score (exact matches first)
symbols.sort(key=lambda s: s.score, reverse=True)
return symbols[:limit]
def _calculate_score(symbol_name: str, query: str) -> float:
"""Calculate relevance score for a symbol match.
Scoring:
- Exact match: 1.0
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
- Case-insensitive match: 0.6
Args:
symbol_name: The matched symbol name
query: The search query
Returns:
Score between 0.0 and 1.0
"""
if symbol_name == query:
return 1.0
if symbol_name.lower() == query.lower():
return 0.9
if symbol_name.startswith(query):
ratio = len(query) / len(symbol_name)
return 0.8 + 0.2 * ratio
if symbol_name.lower().startswith(query.lower()):
ratio = len(query) / len(symbol_name)
return 0.6 + 0.2 * ratio
return 0.5

View File

@@ -0,0 +1,153 @@
"""Utility functions for the codexlens API.
This module provides helper functions for:
- Project resolution
- Relationship type normalization
- Result ranking by proximity
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import List, Optional, TypeVar, Callable
from .models import DefinitionResult
# Type variable for generic ranking
T = TypeVar('T')
def resolve_project(project_root: str) -> Path:
"""Resolve and validate project root path.
Args:
project_root: Path to project root (relative or absolute)
Returns:
Resolved absolute Path
Raises:
ValueError: If path does not exist or is not a directory
"""
path = Path(project_root).resolve()
if not path.exists():
raise ValueError(f"Project root does not exist: {path}")
if not path.is_dir():
raise ValueError(f"Project root is not a directory: {path}")
return path
# Relationship type normalization mapping
_RELATIONSHIP_NORMALIZATION = {
# Plural to singular
"calls": "call",
"imports": "import",
"inherits": "inheritance",
"uses": "use",
# Already normalized (passthrough)
"call": "call",
"import": "import",
"inheritance": "inheritance",
"use": "use",
"type_annotation": "type_annotation",
}
def normalize_relationship_type(relationship: str) -> str:
"""Normalize relationship type to canonical form.
Converts plural forms and variations to standard singular forms:
- 'calls' -> 'call'
- 'imports' -> 'import'
- 'inherits' -> 'inheritance'
- 'uses' -> 'use'
Args:
relationship: Raw relationship type string
Returns:
Normalized relationship type
Examples:
>>> normalize_relationship_type('calls')
'call'
>>> normalize_relationship_type('inherits')
'inheritance'
>>> normalize_relationship_type('call')
'call'
"""
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
def rank_by_proximity(
results: List[DefinitionResult],
file_context: Optional[str] = None
) -> List[DefinitionResult]:
"""Rank results by file path proximity to context.
V1 Implementation: Uses path-based proximity scoring.
Scoring algorithm:
1. Same directory: highest score (100)
2. Otherwise: length of common path prefix
Args:
results: List of definition results to rank
file_context: Reference file path for proximity calculation.
If None, returns results unchanged.
Returns:
Results sorted by proximity score (highest first)
Examples:
>>> results = [
... DefinitionResult(name="foo", kind="function",
... file_path="/a/b/c.py", line=1, end_line=10),
... DefinitionResult(name="foo", kind="function",
... file_path="/a/x/y.py", line=1, end_line=10),
... ]
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
>>> ranked[0].file_path
'/a/b/c.py'
"""
if not file_context or not results:
return results
def proximity_score(result: DefinitionResult) -> int:
"""Calculate proximity score for a result."""
result_dir = os.path.dirname(result.file_path)
context_dir = os.path.dirname(file_context)
# Same directory gets highest score
if result_dir == context_dir:
return 100
# Otherwise, score by common path prefix length
try:
common = os.path.commonpath([result.file_path, file_context])
return len(common)
except ValueError:
# No common path (different drives on Windows)
return 0
return sorted(results, key=proximity_score, reverse=True)
def rank_by_score(
results: List[T],
score_fn: Callable[[T], float],
reverse: bool = True
) -> List[T]:
"""Generic ranking function by custom score.
Args:
results: List of items to rank
score_fn: Function to extract score from item
reverse: If True, highest scores first (default)
Returns:
Sorted list
"""
return sorted(results, key=score_fn, reverse=reverse)

View File

@@ -154,6 +154,13 @@ class Config:
cascade_fine_k: int = 10 # Number of final results after reranking
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
# Staged cascade search configuration (4-stage pipeline)
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
# RRF fusion configuration
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
rrf_k: int = 60 # RRF constant (default 60)

View File

@@ -0,0 +1,7 @@
"""codex-lens Language Server Protocol implementation."""
from __future__ import annotations
from codexlens.lsp.server import CodexLensLanguageServer, main
__all__ = ["CodexLensLanguageServer", "main"]

View File

@@ -0,0 +1,551 @@
"""LSP request handlers for codex-lens.
This module contains handlers for LSP requests:
- textDocument/definition
- textDocument/completion
- workspace/symbol
- textDocument/didSave
- textDocument/hover
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import List, Optional, Union
from urllib.parse import quote, unquote
try:
from lsprotocol import types as lsp
except ImportError as exc:
raise ImportError(
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
) from exc
from codexlens.entities import Symbol
from codexlens.lsp.server import server
logger = logging.getLogger(__name__)
# Symbol kind mapping from codex-lens to LSP
SYMBOL_KIND_MAP = {
"class": lsp.SymbolKind.Class,
"function": lsp.SymbolKind.Function,
"method": lsp.SymbolKind.Method,
"variable": lsp.SymbolKind.Variable,
"constant": lsp.SymbolKind.Constant,
"property": lsp.SymbolKind.Property,
"field": lsp.SymbolKind.Field,
"interface": lsp.SymbolKind.Interface,
"module": lsp.SymbolKind.Module,
"namespace": lsp.SymbolKind.Namespace,
"package": lsp.SymbolKind.Package,
"enum": lsp.SymbolKind.Enum,
"enum_member": lsp.SymbolKind.EnumMember,
"struct": lsp.SymbolKind.Struct,
"type": lsp.SymbolKind.TypeParameter,
"type_alias": lsp.SymbolKind.TypeParameter,
}
# Completion kind mapping from codex-lens to LSP
COMPLETION_KIND_MAP = {
"class": lsp.CompletionItemKind.Class,
"function": lsp.CompletionItemKind.Function,
"method": lsp.CompletionItemKind.Method,
"variable": lsp.CompletionItemKind.Variable,
"constant": lsp.CompletionItemKind.Constant,
"property": lsp.CompletionItemKind.Property,
"field": lsp.CompletionItemKind.Field,
"interface": lsp.CompletionItemKind.Interface,
"module": lsp.CompletionItemKind.Module,
"enum": lsp.CompletionItemKind.Enum,
"enum_member": lsp.CompletionItemKind.EnumMember,
"struct": lsp.CompletionItemKind.Struct,
"type": lsp.CompletionItemKind.TypeParameter,
"type_alias": lsp.CompletionItemKind.TypeParameter,
}
def _path_to_uri(path: Union[str, Path]) -> str:
"""Convert a file path to a URI.
Args:
path: File path (string or Path object)
Returns:
File URI string
"""
path_str = str(Path(path).resolve())
# Handle Windows paths
if path_str.startswith("/"):
return f"file://{quote(path_str)}"
else:
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
def _uri_to_path(uri: str) -> Path:
"""Convert a URI to a file path.
Args:
uri: File URI string
Returns:
Path object
"""
path = uri.replace("file:///", "").replace("file://", "")
return Path(unquote(path))
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
"""Extract the word at the given position in the document.
Args:
document_text: Full document text
line: 0-based line number
character: 0-based character position
Returns:
Word at position, or None if no word found
"""
lines = document_text.splitlines()
if line >= len(lines):
return None
line_text = lines[line]
if character > len(line_text):
return None
# Find word boundaries
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
for match in word_pattern.finditer(line_text):
if match.start() <= character <= match.end():
return match.group()
return None
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
"""Extract the incomplete word prefix at the given position.
Args:
document_text: Full document text
line: 0-based line number
character: 0-based character position
Returns:
Prefix string (may be empty)
"""
lines = document_text.splitlines()
if line >= len(lines):
return ""
line_text = lines[line]
if character > len(line_text):
character = len(line_text)
# Extract text before cursor
before_cursor = line_text[:character]
# Find the start of the current word
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
if match:
return match.group()
return ""
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
"""Convert a codex-lens Symbol to an LSP Location.
Args:
symbol: codex-lens Symbol object
Returns:
LSP Location, or None if symbol has no file
"""
if not symbol.file:
return None
# LSP uses 0-based lines, codex-lens uses 1-based
start_line = max(0, symbol.range[0] - 1)
end_line = max(0, symbol.range[1] - 1)
return lsp.Location(
uri=_path_to_uri(symbol.file),
range=lsp.Range(
start=lsp.Position(line=start_line, character=0),
end=lsp.Position(line=end_line, character=0),
),
)
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
"""Map codex-lens symbol kind to LSP SymbolKind.
Args:
kind: codex-lens symbol kind string
Returns:
LSP SymbolKind
"""
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
"""Map codex-lens symbol kind to LSP CompletionItemKind.
Args:
kind: codex-lens symbol kind string
Returns:
LSP CompletionItemKind
"""
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
# -----------------------------------------------------------------------------
# LSP Request Handlers
# -----------------------------------------------------------------------------
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
def lsp_definition(
params: lsp.DefinitionParams,
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
"""Handle textDocument/definition request.
Finds the definition of the symbol at the cursor position.
"""
if not server.global_index:
logger.debug("No global index available for definition lookup")
return None
# Get document
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
# Get word at position
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
logger.debug("No word found at position")
return None
logger.debug("Looking up definition for: %s", word)
# Search for exact symbol match
try:
symbols = server.global_index.search(
name=word,
limit=10,
prefix_mode=False, # Exact match preferred
)
# Filter for exact name match
exact_matches = [s for s in symbols if s.name == word]
if not exact_matches:
# Fall back to prefix search
symbols = server.global_index.search(
name=word,
limit=10,
prefix_mode=True,
)
exact_matches = [s for s in symbols if s.name == word]
if not exact_matches:
logger.debug("No definition found for: %s", word)
return None
# Convert to LSP locations
locations = []
for sym in exact_matches:
loc = symbol_to_location(sym)
if loc:
locations.append(loc)
if len(locations) == 1:
return locations[0]
elif locations:
return locations
else:
return None
except Exception as exc:
logger.error("Error looking up definition: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
"""Handle textDocument/references request.
Finds all references to the symbol at the cursor position using
the code_relationships table for accurate call-site tracking.
Falls back to same-name symbol search if search_engine is unavailable.
"""
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
return None
logger.debug("Finding references for: %s", word)
try:
# Try using search_engine.search_references() for accurate reference tracking
if server.search_engine and server.workspace_root:
references = server.search_engine.search_references(
symbol_name=word,
source_path=server.workspace_root,
limit=200,
)
if references:
locations = []
for ref in references:
locations.append(
lsp.Location(
uri=_path_to_uri(ref.file_path),
range=lsp.Range(
start=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column,
),
end=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column + len(word),
),
),
)
)
return locations if locations else None
# Fallback: search for symbols with same name using global_index
if server.global_index:
symbols = server.global_index.search(
name=word,
limit=100,
prefix_mode=False,
)
# Filter for exact matches
exact_matches = [s for s in symbols if s.name == word]
locations = []
for sym in exact_matches:
loc = symbol_to_location(sym)
if loc:
locations.append(loc)
return locations if locations else None
return None
except Exception as exc:
logger.error("Error finding references: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
"""Handle textDocument/completion request.
Provides code completion suggestions based on indexed symbols.
"""
if not server.global_index:
return None
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
prefix = _get_prefix_at_position(
document.source,
params.position.line,
params.position.character,
)
if not prefix or len(prefix) < 2:
# Require at least 2 characters for completion
return None
logger.debug("Completing prefix: %s", prefix)
try:
symbols = server.global_index.search(
name=prefix,
limit=50,
prefix_mode=True,
)
if not symbols:
return None
# Convert to completion items
items = []
seen_names = set()
for sym in symbols:
if sym.name in seen_names:
continue
seen_names.add(sym.name)
items.append(
lsp.CompletionItem(
label=sym.name,
kind=_symbol_kind_to_completion_kind(sym.kind),
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
sort_text=sym.name.lower(),
)
)
return lsp.CompletionList(
is_incomplete=len(symbols) >= 50,
items=items,
)
except Exception as exc:
logger.error("Error getting completions: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
"""Handle textDocument/hover request.
Provides hover information for the symbol at the cursor position
using HoverProvider for rich symbol information including
signature, documentation, and location.
"""
if not server.global_index:
return None
document = server.workspace.get_text_document(params.text_document.uri)
if not document:
return None
word = _get_word_at_position(
document.source,
params.position.line,
params.position.character,
)
if not word:
return None
logger.debug("Hover for: %s", word)
try:
# Use HoverProvider for rich symbol information
from codexlens.lsp.providers import HoverProvider
provider = HoverProvider(server.global_index, server.registry)
info = provider.get_hover_info(word)
if not info:
return None
# Format as markdown with signature and location
content = provider.format_hover_markdown(info)
return lsp.Hover(
contents=lsp.MarkupContent(
kind=lsp.MarkupKind.Markdown,
value=content,
),
)
except Exception as exc:
logger.error("Error getting hover info: %s", exc)
return None
@server.feature(lsp.WORKSPACE_SYMBOL)
def lsp_workspace_symbol(
params: lsp.WorkspaceSymbolParams,
) -> Optional[List[lsp.SymbolInformation]]:
"""Handle workspace/symbol request.
Searches for symbols across the workspace.
"""
if not server.global_index:
return None
query = params.query
if not query or len(query) < 2:
return None
logger.debug("Workspace symbol search: %s", query)
try:
symbols = server.global_index.search(
name=query,
limit=100,
prefix_mode=True,
)
if not symbols:
return None
result = []
for sym in symbols:
loc = symbol_to_location(sym)
if loc:
result.append(
lsp.SymbolInformation(
name=sym.name,
kind=_symbol_kind_to_lsp(sym.kind),
location=loc,
container_name=Path(sym.file).parent.name if sym.file else None,
)
)
return result if result else None
except Exception as exc:
logger.error("Error searching workspace symbols: %s", exc)
return None
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
"""Handle textDocument/didSave notification.
Triggers incremental re-indexing of the saved file.
Note: Full incremental indexing requires WatcherManager integration,
which is planned for Phase 2.
"""
file_path = _uri_to_path(params.text_document.uri)
logger.info("File saved: %s", file_path)
# Phase 1: Just log the save event
# Phase 2 will integrate with WatcherManager for incremental indexing
# if server.watcher_manager:
# server.watcher_manager.trigger_reindex(file_path)
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
"""Handle textDocument/didOpen notification."""
file_path = _uri_to_path(params.text_document.uri)
logger.debug("File opened: %s", file_path)
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
"""Handle textDocument/didClose notification."""
file_path = _uri_to_path(params.text_document.uri)
logger.debug("File closed: %s", file_path)

View File

@@ -0,0 +1,177 @@
"""LSP feature providers."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.registry import RegistryStore
logger = logging.getLogger(__name__)
@dataclass
class HoverInfo:
"""Hover information for a symbol."""
name: str
kind: str
signature: str
documentation: Optional[str]
file_path: str
line_range: tuple # (start_line, end_line)
class HoverProvider:
"""Provides hover information for symbols."""
def __init__(
self,
global_index: "GlobalSymbolIndex",
registry: Optional["RegistryStore"] = None,
) -> None:
"""Initialize hover provider.
Args:
global_index: Global symbol index for lookups
registry: Optional registry store for index path resolution
"""
self.global_index = global_index
self.registry = registry
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
"""Get hover information for a symbol.
Args:
symbol_name: Name of the symbol to look up
Returns:
HoverInfo or None if symbol not found
"""
# Look up symbol in global index using exact match
symbols = self.global_index.search(
name=symbol_name,
limit=1,
prefix_mode=False,
)
# Filter for exact name match
exact_matches = [s for s in symbols if s.name == symbol_name]
if not exact_matches:
return None
symbol = exact_matches[0]
# Extract signature from source file
signature = self._extract_signature(symbol)
# Symbol uses 'file' attribute and 'range' tuple
file_path = symbol.file or ""
start_line, end_line = symbol.range
return HoverInfo(
name=symbol.name,
kind=symbol.kind,
signature=signature,
documentation=None, # Symbol doesn't have docstring field
file_path=file_path,
line_range=(start_line, end_line),
)
def _extract_signature(self, symbol) -> str:
"""Extract function/class signature from source file.
Args:
symbol: Symbol object with file and range information
Returns:
Extracted signature string or fallback kind + name
"""
try:
file_path = Path(symbol.file) if symbol.file else None
if not file_path or not file_path.exists():
return f"{symbol.kind} {symbol.name}"
content = file_path.read_text(encoding="utf-8", errors="ignore")
lines = content.split("\n")
# Extract signature lines (first line of definition + continuation)
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
if start_line >= len(lines) or start_line < 0:
return f"{symbol.kind} {symbol.name}"
signature_lines = []
first_line = lines[start_line]
signature_lines.append(first_line)
# Continue if multiline signature (no closing paren + colon yet)
# Look for patterns like "def func(", "class Foo(", etc.
i = start_line + 1
max_lines = min(start_line + 5, len(lines))
while i < max_lines:
line = signature_lines[-1]
# Stop if we see closing pattern
if "):" in line or line.rstrip().endswith(":"):
break
signature_lines.append(lines[i])
i += 1
return "\n".join(signature_lines)
except Exception as e:
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
return f"{symbol.kind} {symbol.name}"
def format_hover_markdown(self, info: HoverInfo) -> str:
"""Format hover info as Markdown.
Args:
info: HoverInfo object to format
Returns:
Markdown-formatted hover content
"""
parts = []
# Detect language for code fence based on file extension
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
lang_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".tsx": "typescript",
".jsx": "javascript",
".java": "java",
".go": "go",
".rs": "rust",
".c": "c",
".cpp": "cpp",
".h": "c",
".hpp": "cpp",
".cs": "csharp",
".rb": "ruby",
".php": "php",
}
lang = lang_map.get(ext, "")
# Code block with signature
parts.append(f"```{lang}\n{info.signature}\n```")
# Documentation if available
if info.documentation:
parts.append(f"\n---\n\n{info.documentation}")
# Location info
file_name = Path(info.file_path).name if info.file_path else "unknown"
parts.append(
f"\n---\n\n*{info.kind}* defined in "
f"`{file_name}` "
f"(line {info.line_range[0]})"
)
return "\n".join(parts)

View File

@@ -0,0 +1,263 @@
"""codex-lens LSP Server implementation using pygls.
This module provides the main Language Server class and entry point.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from typing import Optional
try:
from lsprotocol import types as lsp
from pygls.lsp.server import LanguageServer
except ImportError as exc:
raise ImportError(
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
) from exc
from codexlens.config import Config
from codexlens.search.chain_search import ChainSearchEngine
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
logger = logging.getLogger(__name__)
class CodexLensLanguageServer(LanguageServer):
"""Language Server for codex-lens code indexing.
Provides IDE features using codex-lens symbol index:
- Go to Definition
- Find References
- Code Completion
- Hover Information
- Workspace Symbol Search
Attributes:
registry: Global project registry for path lookups
mapper: Path mapper for source/index conversions
global_index: Project-wide symbol index
search_engine: Chain search engine for symbol search
workspace_root: Current workspace root path
"""
def __init__(self) -> None:
super().__init__(name="codexlens-lsp", version="0.1.0")
self.registry: Optional[RegistryStore] = None
self.mapper: Optional[PathMapper] = None
self.global_index: Optional[GlobalSymbolIndex] = None
self.search_engine: Optional[ChainSearchEngine] = None
self.workspace_root: Optional[Path] = None
self._config: Optional[Config] = None
def initialize_components(self, workspace_root: Path) -> bool:
"""Initialize codex-lens components for the workspace.
Args:
workspace_root: Root path of the workspace
Returns:
True if initialization succeeded, False otherwise
"""
self.workspace_root = workspace_root.resolve()
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
try:
# Initialize registry
self.registry = RegistryStore()
self.registry.initialize()
# Initialize path mapper
self.mapper = PathMapper()
# Try to find project in registry
project_info = self.registry.find_by_source_path(str(self.workspace_root))
if project_info:
project_id = int(project_info["id"])
index_root = Path(project_info["index_root"])
# Initialize global symbol index
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
self.global_index = GlobalSymbolIndex(global_db, project_id)
self.global_index.initialize()
# Initialize search engine
self._config = Config()
self.search_engine = ChainSearchEngine(
registry=self.registry,
mapper=self.mapper,
config=self._config,
)
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
return True
else:
logger.warning(
"Workspace not indexed by codex-lens: %s. "
"Run 'codexlens index %s' to index first.",
self.workspace_root,
self.workspace_root,
)
return False
except Exception as exc:
logger.error("Failed to initialize codex-lens: %s", exc)
return False
def shutdown_components(self) -> None:
"""Clean up codex-lens components."""
if self.global_index:
try:
self.global_index.close()
except Exception as exc:
logger.debug("Error closing global index: %s", exc)
self.global_index = None
if self.search_engine:
try:
self.search_engine.close()
except Exception as exc:
logger.debug("Error closing search engine: %s", exc)
self.search_engine = None
if self.registry:
try:
self.registry.close()
except Exception as exc:
logger.debug("Error closing registry: %s", exc)
self.registry = None
# Create server instance
server = CodexLensLanguageServer()
@server.feature(lsp.INITIALIZE)
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
"""Handle LSP initialize request."""
logger.info("LSP initialize request received")
# Get workspace root
workspace_root: Optional[Path] = None
if params.root_uri:
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
elif params.root_path:
workspace_root = Path(params.root_path)
if workspace_root:
server.initialize_components(workspace_root)
# Declare server capabilities
return lsp.InitializeResult(
capabilities=lsp.ServerCapabilities(
text_document_sync=lsp.TextDocumentSyncOptions(
open_close=True,
change=lsp.TextDocumentSyncKind.Incremental,
save=lsp.SaveOptions(include_text=False),
),
definition_provider=True,
references_provider=True,
completion_provider=lsp.CompletionOptions(
trigger_characters=[".", ":"],
resolve_provider=False,
),
hover_provider=True,
workspace_symbol_provider=True,
),
server_info=lsp.ServerInfo(
name="codexlens-lsp",
version="0.1.0",
),
)
@server.feature(lsp.SHUTDOWN)
def lsp_shutdown(params: None) -> None:
"""Handle LSP shutdown request."""
logger.info("LSP shutdown request received")
server.shutdown_components()
def main() -> int:
"""Entry point for codexlens-lsp command.
Returns:
Exit code (0 for success)
"""
# Import handlers to register them with the server
# This must be done before starting the server
import codexlens.lsp.handlers # noqa: F401
parser = argparse.ArgumentParser(
description="codex-lens Language Server",
prog="codexlens-lsp",
)
parser.add_argument(
"--stdio",
action="store_true",
default=True,
help="Use stdio for communication (default)",
)
parser.add_argument(
"--tcp",
action="store_true",
help="Use TCP for communication",
)
parser.add_argument(
"--host",
default="127.0.0.1",
help="TCP host (default: 127.0.0.1)",
)
parser.add_argument(
"--port",
type=int,
default=2087,
help="TCP port (default: 2087)",
)
parser.add_argument(
"--log-level",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Log level (default: INFO)",
)
parser.add_argument(
"--log-file",
help="Log file path (optional)",
)
args = parser.parse_args()
# Configure logging
log_handlers = []
if args.log_file:
log_handlers.append(logging.FileHandler(args.log_file))
else:
log_handlers.append(logging.StreamHandler(sys.stderr))
logging.basicConfig(
level=getattr(logging, args.log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=log_handlers,
)
logger.info("Starting codexlens-lsp server")
if args.tcp:
logger.info("Starting TCP server on %s:%d", args.host, args.port)
server.start_tcp(args.host, args.port)
else:
logger.info("Starting stdio server")
server.start_io()
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,20 @@
"""Model Context Protocol implementation for Claude Code integration."""
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
from codexlens.mcp.provider import MCPProvider
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
__all__ = [
"MCPContext",
"SymbolInfo",
"ReferenceInfo",
"RelatedSymbol",
"MCPProvider",
"HookManager",
"create_context_for_prompt",
]

View File

@@ -0,0 +1,170 @@
"""Hook interfaces for Claude Code integration."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
from codexlens.mcp.schema import MCPContext
if TYPE_CHECKING:
from codexlens.mcp.provider import MCPProvider
logger = logging.getLogger(__name__)
class HookManager:
"""Manages hook registration and execution."""
def __init__(self, mcp_provider: "MCPProvider") -> None:
self.mcp_provider = mcp_provider
self._pre_hooks: Dict[str, Callable] = {}
self._post_hooks: Dict[str, Callable] = {}
# Register default hooks
self._register_default_hooks()
def _register_default_hooks(self) -> None:
"""Register built-in hooks."""
self._pre_hooks["explain"] = self._pre_explain_hook
self._pre_hooks["refactor"] = self._pre_refactor_hook
self._pre_hooks["document"] = self._pre_document_hook
def execute_pre_hook(
self,
action: str,
params: Dict[str, Any],
) -> Optional[MCPContext]:
"""Execute pre-tool hook to gather context.
Args:
action: The action being performed (e.g., "explain", "refactor")
params: Parameters for the action
Returns:
MCPContext to inject into prompt, or None
"""
hook = self._pre_hooks.get(action)
if not hook:
logger.debug(f"No pre-hook for action: {action}")
return None
try:
return hook(params)
except Exception as e:
logger.error(f"Pre-hook failed for {action}: {e}")
return None
def execute_post_hook(
self,
action: str,
result: Any,
) -> None:
"""Execute post-tool hook for proactive caching.
Args:
action: The action that was performed
result: Result of the action
"""
hook = self._post_hooks.get(action)
if not hook:
return
try:
hook(result)
except Exception as e:
logger.error(f"Post-hook failed for {action}: {e}")
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'explain' action."""
symbol_name = params.get("symbol")
if not symbol_name:
return None
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="symbol_explanation",
include_references=True,
include_related=True,
)
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'refactor' action."""
symbol_name = params.get("symbol")
if not symbol_name:
return None
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="refactor_context",
include_references=True,
include_related=True,
max_references=20,
)
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
"""Pre-hook for 'document' action."""
symbol_name = params.get("symbol")
file_path = params.get("file_path")
if symbol_name:
return self.mcp_provider.build_context(
symbol_name=symbol_name,
context_type="documentation_context",
include_references=False,
include_related=True,
)
elif file_path:
return self.mcp_provider.build_context_for_file(
Path(file_path),
context_type="file_documentation",
)
return None
def register_pre_hook(
self,
action: str,
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
) -> None:
"""Register a custom pre-tool hook."""
self._pre_hooks[action] = hook
def register_post_hook(
self,
action: str,
hook: Callable[[Any], None],
) -> None:
"""Register a custom post-tool hook."""
self._post_hooks[action] = hook
def create_context_for_prompt(
mcp_provider: "MCPProvider",
action: str,
params: Dict[str, Any],
) -> str:
"""Create context string for prompt injection.
This is the main entry point for Claude Code hook integration.
Args:
mcp_provider: The MCP provider instance
action: Action being performed
params: Action parameters
Returns:
Formatted context string for prompt injection
"""
manager = HookManager(mcp_provider)
context = manager.execute_pre_hook(action, params)
if context:
return context.to_prompt_injection()
return ""

View File

@@ -0,0 +1,202 @@
"""MCP context provider."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional, List, TYPE_CHECKING
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
if TYPE_CHECKING:
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.registry import RegistryStore
from codexlens.search.chain_search import ChainSearchEngine
logger = logging.getLogger(__name__)
class MCPProvider:
"""Builds MCP context objects from codex-lens data."""
def __init__(
self,
global_index: "GlobalSymbolIndex",
search_engine: "ChainSearchEngine",
registry: "RegistryStore",
) -> None:
self.global_index = global_index
self.search_engine = search_engine
self.registry = registry
def build_context(
self,
symbol_name: str,
context_type: str = "symbol_explanation",
include_references: bool = True,
include_related: bool = True,
max_references: int = 10,
) -> Optional[MCPContext]:
"""Build comprehensive context for a symbol.
Args:
symbol_name: Name of the symbol to contextualize
context_type: Type of context being requested
include_references: Whether to include reference locations
include_related: Whether to include related symbols
max_references: Maximum number of references to include
Returns:
MCPContext object or None if symbol not found
"""
# Look up symbol
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
if not symbols:
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
return None
symbol = symbols[0]
# Build SymbolInfo
symbol_info = SymbolInfo(
name=symbol.name,
kind=symbol.kind,
file_path=symbol.file or "",
line_start=symbol.range[0],
line_end=symbol.range[1],
signature=None, # Symbol entity doesn't have signature
documentation=None, # Symbol entity doesn't have docstring
)
# Extract definition source code
definition = self._extract_definition(symbol)
# Get references
references = []
if include_references:
refs = self.search_engine.search_references(
symbol_name,
limit=max_references,
)
references = [
ReferenceInfo(
file_path=r.file_path,
line=r.line,
column=r.column,
context=r.context,
relationship_type=r.relationship_type,
)
for r in refs
]
# Get related symbols
related_symbols = []
if include_related:
related_symbols = self._get_related_symbols(symbol)
return MCPContext(
context_type=context_type,
symbol=symbol_info,
definition=definition,
references=references,
related_symbols=related_symbols,
metadata={
"source": "codex-lens",
},
)
def _extract_definition(self, symbol) -> Optional[str]:
"""Extract source code for symbol definition."""
try:
file_path = Path(symbol.file) if symbol.file else None
if not file_path or not file_path.exists():
return None
content = file_path.read_text(encoding='utf-8', errors='ignore')
lines = content.split("\n")
start = symbol.range[0] - 1
end = symbol.range[1]
if start >= len(lines):
return None
return "\n".join(lines[start:end])
except Exception as e:
logger.debug(f"Failed to extract definition: {e}")
return None
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
"""Get symbols related to the given symbol."""
related = []
try:
# Search for symbols that might be related by name patterns
# This is a simplified implementation - could be enhanced with relationship data
# Look for imports/callers via reference search
refs = self.search_engine.search_references(symbol.name, limit=20)
seen_names = set()
for ref in refs:
# Extract potential symbol name from context
if ref.relationship_type and ref.relationship_type not in seen_names:
related.append(RelatedSymbol(
name=f"{Path(ref.file_path).stem}",
kind="module",
relationship=ref.relationship_type,
file_path=ref.file_path,
))
seen_names.add(ref.relationship_type)
if len(related) >= 10:
break
except Exception as e:
logger.debug(f"Failed to get related symbols: {e}")
return related
def build_context_for_file(
self,
file_path: Path,
context_type: str = "file_overview",
) -> MCPContext:
"""Build context for an entire file."""
# Try to get symbols by searching with file path
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
symbols = []
# Search for common symbols that might be in this file
# This is a simplified approach - a full implementation would query by file path
try:
# Use the global index to search for symbols from this file
file_str = str(file_path.resolve())
# Get all symbols and filter by file path (not efficient but works)
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
except Exception as e:
logger.debug(f"Failed to get file symbols: {e}")
related = [
RelatedSymbol(
name=s.name,
kind=s.kind,
relationship="defines",
)
for s in symbols
]
return MCPContext(
context_type=context_type,
related_symbols=related,
metadata={
"file_path": str(file_path),
"symbol_count": len(symbols),
},
)

View File

@@ -0,0 +1,113 @@
"""MCP data models."""
from __future__ import annotations
import json
from dataclasses import dataclass, field, asdict
from typing import List, Optional
@dataclass
class SymbolInfo:
"""Information about a code symbol."""
name: str
kind: str
file_path: str
line_start: int
line_end: int
signature: Optional[str] = None
documentation: Optional[str] = None
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class ReferenceInfo:
"""Information about a symbol reference."""
file_path: str
line: int
column: int
context: str
relationship_type: str
def to_dict(self) -> dict:
return asdict(self)
@dataclass
class RelatedSymbol:
"""Related symbol (import, call target, etc.)."""
name: str
kind: str
relationship: str # "imports", "calls", "inherits", "uses"
file_path: Optional[str] = None
def to_dict(self) -> dict:
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class MCPContext:
"""Model Context Protocol context object.
This is the structured context that gets injected into
LLM prompts to provide code understanding.
"""
version: str = "1.0"
context_type: str = "code_context"
symbol: Optional[SymbolInfo] = None
definition: Optional[str] = None
references: List[ReferenceInfo] = field(default_factory=list)
related_symbols: List[RelatedSymbol] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
result = {
"version": self.version,
"context_type": self.context_type,
"metadata": self.metadata,
}
if self.symbol:
result["symbol"] = self.symbol.to_dict()
if self.definition:
result["definition"] = self.definition
if self.references:
result["references"] = [r.to_dict() for r in self.references]
if self.related_symbols:
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
return result
def to_json(self, indent: int = 2) -> str:
"""Serialize to JSON string."""
return json.dumps(self.to_dict(), indent=indent)
def to_prompt_injection(self) -> str:
"""Format for injection into LLM prompt."""
parts = ["<code_context>"]
if self.symbol:
parts.append(f"## Symbol: {self.symbol.name}")
parts.append(f"Type: {self.symbol.kind}")
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
if self.definition:
parts.append("\n## Definition")
parts.append(f"```\n{self.definition}\n```")
if self.references:
parts.append(f"\n## References ({len(self.references)} found)")
for ref in self.references[:5]: # Limit to 5
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
parts.append(f" ```\n {ref.context}\n ```")
if self.related_symbols:
parts.append("\n## Related Symbols")
for sym in self.related_symbols[:10]: # Limit to 10
parts.append(f"- {sym.name} ({sym.relationship})")
parts.append("</code_context>")
return "\n".join(parts)

View File

@@ -6,10 +6,48 @@ from .chain_search import (
quick_search,
)
# Clustering availability flag (lazy import pattern)
CLUSTERING_AVAILABLE = False
_clustering_import_error: str | None = None
try:
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
from .clustering import check_clustering_available
CLUSTERING_AVAILABLE = _clustering_flag
except ImportError as e:
_clustering_import_error = str(e)
def check_clustering_available() -> tuple[bool, str | None]:
"""Fallback when clustering module not loadable."""
return False, _clustering_import_error
# Clustering module exports (conditional)
try:
from .clustering import (
BaseClusteringStrategy,
ClusteringConfig,
ClusteringStrategyFactory,
get_strategy,
)
_clustering_exports = [
"BaseClusteringStrategy",
"ClusteringConfig",
"ClusteringStrategyFactory",
"get_strategy",
]
except ImportError:
_clustering_exports = []
__all__ = [
"ChainSearchEngine",
"SearchOptions",
"SearchStats",
"ChainSearchResult",
"quick_search",
# Clustering
"CLUSTERING_AVAILABLE",
"check_clustering_available",
*_clustering_exports,
]

View File

@@ -116,6 +116,24 @@ class ChainSearchResult:
related_results: List[SearchResult] = field(default_factory=list)
@dataclass
class ReferenceResult:
"""Result from reference search in code_relationships table.
Attributes:
file_path: Path to the file containing the reference
line: Line number where the reference occurs (1-based)
column: Column number where the reference occurs (0-based)
context: Surrounding code snippet for context
relationship_type: Type of relationship (call, import, inheritance, etc.)
"""
file_path: str
line: int
column: int
context: str
relationship_type: str
class ChainSearchEngine:
"""Parallel chain search engine for hierarchical directory indexes.
@@ -810,7 +828,7 @@ class ChainSearchEngine:
k: int = 10,
coarse_k: int = 100,
options: Optional[SearchOptions] = None,
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank"]] = None,
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank", "staged"]] = None,
) -> ChainSearchResult:
"""Unified cascade search entry point with strategy selection.
@@ -819,6 +837,7 @@ class ChainSearchEngine:
- "hybrid": Uses FTS+SPLADE+Vector coarse ranking + cross-encoder reranking (original)
- "binary_rerank": Uses binary vector coarse ranking + cross-encoder reranking (best balance)
- "dense_rerank": Uses dense vector coarse ranking + cross-encoder reranking
- "staged": 4-stage pipeline: binary -> LSP expand -> clustering -> optional rerank
The strategy is determined with the following priority:
1. The `strategy` parameter (e.g., from CLI --cascade-strategy option)
@@ -831,7 +850,7 @@ class ChainSearchEngine:
k: Number of final results to return (default 10)
coarse_k: Number of coarse candidates from first stage (default 100)
options: Search configuration (uses defaults if None)
strategy: Cascade strategy - "binary", "hybrid", or "binary_rerank".
strategy: Cascade strategy - "binary", "hybrid", "binary_rerank", "dense_rerank", or "staged".
Returns:
ChainSearchResult with reranked results and statistics
@@ -844,10 +863,12 @@ class ChainSearchEngine:
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="hybrid")
>>> # Use binary + cross-encoder (best balance of speed and quality)
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="binary_rerank")
>>> # Use 4-stage pipeline (binary + LSP expand + clustering + optional rerank)
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="staged")
"""
# Strategy priority: parameter > config > default
effective_strategy = strategy
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank")
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank", "staged")
if effective_strategy is None:
# Not passed via parameter, check config
if self._config is not None:
@@ -865,9 +886,635 @@ class ChainSearchEngine:
return self.binary_rerank_cascade_search(query, source_path, k, coarse_k, options)
elif effective_strategy == "dense_rerank":
return self.dense_rerank_cascade_search(query, source_path, k, coarse_k, options)
elif effective_strategy == "staged":
return self.staged_cascade_search(query, source_path, k, coarse_k, options)
else:
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
def staged_cascade_search(
self,
query: str,
source_path: Path,
k: int = 10,
coarse_k: int = 100,
options: Optional[SearchOptions] = None,
) -> ChainSearchResult:
"""Execute 4-stage cascade search pipeline with binary, LSP expansion, clustering, and optional reranking.
Staged cascade search process:
1. Stage 1 (Binary Coarse): Fast binary vector search using Hamming distance
to quickly filter to coarse_k candidates (256-bit binary vectors)
2. Stage 2 (LSP Expansion): Expand coarse candidates using GraphExpander to
include related symbols (definitions, references, callers/callees)
3. Stage 3 (Clustering): Use configurable clustering strategy to group similar
results and select representative results from each cluster
4. Stage 4 (Optional Rerank): If config.enable_staged_rerank is True, apply
cross-encoder reranking for final precision
This approach combines the speed of binary search with graph-based context
expansion and diversity-preserving clustering for high-quality results.
Performance characteristics:
- Stage 1: O(N) binary search with SIMD acceleration (~8ms)
- Stage 2: O(k * d) graph traversal where d is expansion depth
- Stage 3: O(n^2) clustering on expanded candidates
- Stage 4: Optional cross-encoder reranking (API call)
Args:
query: Natural language or keyword query string
source_path: Starting directory path
k: Number of final results to return (default 10)
coarse_k: Number of coarse candidates from first stage (default 100)
options: Search configuration (uses defaults if None)
Returns:
ChainSearchResult with per-stage statistics
Examples:
>>> engine = ChainSearchEngine(registry, mapper, config=config)
>>> result = engine.staged_cascade_search(
... "authentication handler",
... Path("D:/project/src"),
... k=10,
... coarse_k=100
... )
>>> for r in result.results:
... print(f"{r.path}: {r.score:.3f}")
"""
if not NUMPY_AVAILABLE:
self.logger.warning(
"NumPy not available, falling back to hybrid cascade search"
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
options = options or SearchOptions()
start_time = time.time()
stats = SearchStats()
# Per-stage timing stats
stage_times: Dict[str, float] = {}
stage_counts: Dict[str, int] = {}
# Use config defaults if available
if self._config is not None:
if hasattr(self._config, "cascade_coarse_k"):
coarse_k = coarse_k or self._config.cascade_coarse_k
if hasattr(self._config, "cascade_fine_k"):
k = k or self._config.cascade_fine_k
# Step 1: Find starting index
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
stats.time_ms = (time.time() - start_time) * 1000
return ChainSearchResult(
query=query,
results=[],
symbols=[],
stats=stats
)
# Step 2: Collect all index paths
index_paths = self._collect_index_paths(start_index, options.depth)
stats.dirs_searched = len(index_paths)
if not index_paths:
self.logger.warning(f"No indexes collected from {start_index}")
stats.time_ms = (time.time() - start_time) * 1000
return ChainSearchResult(
query=query,
results=[],
symbols=[],
stats=stats
)
# ========== Stage 1: Binary Coarse Search ==========
stage1_start = time.time()
coarse_results, index_root = self._stage1_binary_search(
query, index_paths, coarse_k, stats
)
stage_times["stage1_binary_ms"] = (time.time() - stage1_start) * 1000
stage_counts["stage1_candidates"] = len(coarse_results)
self.logger.debug(
"Staged Stage 1: Binary search found %d candidates in %.2fms",
len(coarse_results), stage_times["stage1_binary_ms"]
)
if not coarse_results:
self.logger.debug("No binary candidates found, falling back to hybrid cascade")
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
# ========== Stage 2: LSP Graph Expansion ==========
stage2_start = time.time()
expanded_results = self._stage2_lsp_expand(coarse_results, index_root)
stage_times["stage2_expand_ms"] = (time.time() - stage2_start) * 1000
stage_counts["stage2_expanded"] = len(expanded_results)
self.logger.debug(
"Staged Stage 2: LSP expansion %d -> %d results in %.2fms",
len(coarse_results), len(expanded_results), stage_times["stage2_expand_ms"]
)
# ========== Stage 3: Clustering and Representative Selection ==========
stage3_start = time.time()
clustered_results = self._stage3_cluster_prune(expanded_results, k * 2)
stage_times["stage3_cluster_ms"] = (time.time() - stage3_start) * 1000
stage_counts["stage3_clustered"] = len(clustered_results)
self.logger.debug(
"Staged Stage 3: Clustering %d -> %d representatives in %.2fms",
len(expanded_results), len(clustered_results), stage_times["stage3_cluster_ms"]
)
# ========== Stage 4: Optional Cross-Encoder Reranking ==========
enable_rerank = False
if self._config is not None:
enable_rerank = getattr(self._config, "enable_staged_rerank", False)
if enable_rerank:
stage4_start = time.time()
final_results = self._stage4_optional_rerank(query, clustered_results, k)
stage_times["stage4_rerank_ms"] = (time.time() - stage4_start) * 1000
stage_counts["stage4_reranked"] = len(final_results)
self.logger.debug(
"Staged Stage 4: Reranking %d -> %d results in %.2fms",
len(clustered_results), len(final_results), stage_times["stage4_rerank_ms"]
)
else:
# Skip reranking, just take top-k by score
final_results = sorted(
clustered_results, key=lambda r: r.score, reverse=True
)[:k]
stage_counts["stage4_reranked"] = len(final_results)
# Deduplicate by path (keep highest score)
path_to_result: Dict[str, SearchResult] = {}
for result in final_results:
if result.path not in path_to_result or result.score > path_to_result[result.path].score:
path_to_result[result.path] = result
final_results = list(path_to_result.values())[:k]
# Optional: grouping of similar results
if options.group_results:
from codexlens.search.ranking import group_similar_results
final_results = group_similar_results(
final_results, score_threshold_abs=options.grouping_threshold
)
stats.files_matched = len(final_results)
stats.time_ms = (time.time() - start_time) * 1000
# Add per-stage stats to errors field (as JSON for now, will be proper field later)
stage_stats_json = json.dumps({
"stage_times": stage_times,
"stage_counts": stage_counts,
})
stats.errors.append(f"STAGE_STATS:{stage_stats_json}")
self.logger.debug(
"Staged cascade search complete: %d results in %.2fms "
"(stage1=%.1fms, stage2=%.1fms, stage3=%.1fms)",
len(final_results),
stats.time_ms,
stage_times.get("stage1_binary_ms", 0),
stage_times.get("stage2_expand_ms", 0),
stage_times.get("stage3_cluster_ms", 0),
)
return ChainSearchResult(
query=query,
results=final_results,
symbols=[],
stats=stats,
)
def _stage1_binary_search(
self,
query: str,
index_paths: List[Path],
coarse_k: int,
stats: SearchStats,
) -> Tuple[List[SearchResult], Optional[Path]]:
"""Stage 1: Binary vector coarse search using Hamming distance.
Reuses the binary coarse search logic from binary_cascade_search.
Args:
query: Search query string
index_paths: List of index database paths to search
coarse_k: Number of coarse candidates to retrieve
stats: SearchStats to update with errors
Returns:
Tuple of (list of SearchResult objects, index_root path or None)
"""
# Initialize binary embedding backend
try:
from codexlens.indexing.embedding import BinaryEmbeddingBackend
except ImportError as exc:
self.logger.warning(
"BinaryEmbeddingBackend not available: %s", exc
)
return [], None
# Try centralized BinarySearcher first (preferred for mmap indexes)
index_root = index_paths[0].parent if index_paths else None
coarse_candidates: List[Tuple[int, int, Path]] = [] # (chunk_id, distance, index_path)
used_centralized = False
if index_root:
binary_searcher = self._get_centralized_binary_searcher(index_root)
if binary_searcher is not None:
try:
from codexlens.semantic.embedder import Embedder
embedder = Embedder()
query_dense = embedder.embed_to_numpy([query])[0]
results = binary_searcher.search(query_dense, top_k=coarse_k)
for chunk_id, distance in results:
coarse_candidates.append((chunk_id, distance, index_root))
if coarse_candidates:
used_centralized = True
self.logger.debug(
"Stage 1 centralized binary search: %d candidates", len(results)
)
except Exception as exc:
self.logger.debug(f"Centralized binary search failed: {exc}")
if not used_centralized:
# Fallback to per-directory binary indexes
use_gpu = True
if self._config is not None:
use_gpu = getattr(self._config, "embedding_use_gpu", True)
try:
binary_backend = BinaryEmbeddingBackend(use_gpu=use_gpu)
query_binary = binary_backend.embed_packed([query])[0]
except Exception as exc:
self.logger.warning(f"Failed to generate binary query embedding: {exc}")
return [], index_root
for index_path in index_paths:
try:
binary_index = self._get_or_create_binary_index(index_path)
if binary_index is None or binary_index.count() == 0:
continue
ids, distances = binary_index.search(query_binary, coarse_k)
for chunk_id, dist in zip(ids, distances):
coarse_candidates.append((chunk_id, dist, index_path))
except Exception as exc:
self.logger.debug(
"Binary search failed for %s: %s", index_path, exc
)
if not coarse_candidates:
return [], index_root
# Sort by Hamming distance and take top coarse_k
coarse_candidates.sort(key=lambda x: x[1])
coarse_candidates = coarse_candidates[:coarse_k]
# Build SearchResult objects from candidates
coarse_results: List[SearchResult] = []
# Group candidates by index path for efficient retrieval
candidates_by_index: Dict[Path, List[int]] = {}
for chunk_id, _, idx_path in coarse_candidates:
if idx_path not in candidates_by_index:
candidates_by_index[idx_path] = []
candidates_by_index[idx_path].append(chunk_id)
# Retrieve chunk content
import sqlite3
central_meta_path = index_root / VECTORS_META_DB_NAME if index_root else None
central_meta_store = None
if central_meta_path and central_meta_path.exists():
central_meta_store = VectorMetadataStore(central_meta_path)
for idx_path, chunk_ids in candidates_by_index.items():
try:
chunks_data = []
if central_meta_store:
chunks_data = central_meta_store.get_chunks_by_ids(chunk_ids)
if not chunks_data and used_centralized:
meta_db_path = idx_path / VECTORS_META_DB_NAME
if meta_db_path.exists():
meta_store = VectorMetadataStore(meta_db_path)
chunks_data = meta_store.get_chunks_by_ids(chunk_ids)
if not chunks_data:
try:
conn = sqlite3.connect(str(idx_path))
conn.row_factory = sqlite3.Row
placeholders = ",".join("?" * len(chunk_ids))
cursor = conn.execute(
f"""
SELECT id, file_path, content, metadata, category
FROM semantic_chunks
WHERE id IN ({placeholders})
""",
chunk_ids
)
chunks_data = [
{
"id": row["id"],
"file_path": row["file_path"],
"content": row["content"],
"metadata": row["metadata"],
"category": row["category"],
}
for row in cursor.fetchall()
]
conn.close()
except Exception:
pass
for chunk in chunks_data:
chunk_id = chunk.get("id") or chunk.get("chunk_id")
distance = next(
(d for cid, d, _ in coarse_candidates if cid == chunk_id),
256
)
score = 1.0 - (distance / 256.0)
content = chunk.get("content", "")
# Extract symbol info from metadata if available
metadata = chunk.get("metadata")
symbol_name = None
symbol_kind = None
start_line = None
end_line = None
if metadata:
try:
meta_dict = json.loads(metadata) if isinstance(metadata, str) else metadata
symbol_name = meta_dict.get("symbol_name")
symbol_kind = meta_dict.get("symbol_kind")
start_line = meta_dict.get("start_line")
end_line = meta_dict.get("end_line")
except Exception:
pass
result = SearchResult(
path=chunk.get("file_path", ""),
score=float(score),
excerpt=content[:500] if content else "",
content=content,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
start_line=start_line,
end_line=end_line,
)
coarse_results.append(result)
except Exception as exc:
self.logger.debug(
"Failed to retrieve chunks from %s: %s", idx_path, exc
)
stats.errors.append(f"Stage 1 chunk retrieval failed for {idx_path}: {exc}")
return coarse_results, index_root
def _stage2_lsp_expand(
self,
coarse_results: List[SearchResult],
index_root: Optional[Path],
) -> List[SearchResult]:
"""Stage 2: LSP-based graph expansion using GraphExpander.
Expands coarse results with related symbols (definitions, references,
callers, callees) using precomputed graph neighbors.
Args:
coarse_results: Results from Stage 1 binary search
index_root: Root path of the index (for graph database access)
Returns:
Combined list of original results plus expanded related results
"""
if not coarse_results or index_root is None:
return coarse_results
try:
from codexlens.search.graph_expander import GraphExpander
# Get expansion depth from config
depth = 2
if self._config is not None:
depth = getattr(self._config, "graph_expansion_depth", 2)
expander = GraphExpander(self.mapper, config=self._config)
# Expand top results (limit expansion to avoid explosion)
max_expand = min(10, len(coarse_results))
max_related = 50
related_results = expander.expand(
coarse_results,
depth=depth,
max_expand=max_expand,
max_related=max_related,
)
if related_results:
self.logger.debug(
"Stage 2 expanded %d base results to %d related symbols",
len(coarse_results), len(related_results)
)
# Combine: original results + related results
# Keep original results first (higher relevance)
combined = list(coarse_results)
seen_keys = {(r.path, r.symbol_name, r.start_line) for r in coarse_results}
for related in related_results:
key = (related.path, related.symbol_name, related.start_line)
if key not in seen_keys:
seen_keys.add(key)
combined.append(related)
return combined
except ImportError as exc:
self.logger.debug("GraphExpander not available: %s", exc)
return coarse_results
except Exception as exc:
self.logger.debug("Stage 2 LSP expansion failed: %s", exc)
return coarse_results
def _stage3_cluster_prune(
self,
expanded_results: List[SearchResult],
target_count: int,
) -> List[SearchResult]:
"""Stage 3: Cluster expanded results and select representatives.
Uses the extensible clustering infrastructure from codexlens.search.clustering
to group similar results and select the best representative from each cluster.
Args:
expanded_results: Results from Stage 2 expansion
target_count: Target number of representative results
Returns:
List of representative results (one per cluster)
"""
if not expanded_results:
return []
# If few results, skip clustering
if len(expanded_results) <= target_count:
return expanded_results
try:
from codexlens.search.clustering import (
ClusteringConfig,
get_strategy,
)
# Get clustering config from config
strategy_name = "auto"
min_cluster_size = 3
if self._config is not None:
strategy_name = getattr(self._config, "staged_clustering_strategy", "auto")
min_cluster_size = getattr(self._config, "staged_clustering_min_size", 3)
# Get embeddings for clustering
# Try to get dense embeddings from results' content
embeddings = self._get_embeddings_for_clustering(expanded_results)
if embeddings is None or len(embeddings) == 0:
# No embeddings available, fall back to score-based selection
self.logger.debug("No embeddings for clustering, using score-based selection")
return sorted(
expanded_results, key=lambda r: r.score, reverse=True
)[:target_count]
# Create clustering config
config = ClusteringConfig(
min_cluster_size=min(min_cluster_size, max(2, len(expanded_results) // 5)),
min_samples=2,
metric="cosine",
)
# Get strategy with fallback
strategy = get_strategy(strategy_name, config, fallback=True)
# Cluster and select representatives
representatives = strategy.fit_predict(embeddings, expanded_results)
self.logger.debug(
"Stage 3 clustered %d results into %d representatives using %s",
len(expanded_results), len(representatives), type(strategy).__name__
)
# If clustering returned too few, supplement with top-scored unclustered
if len(representatives) < target_count:
rep_paths = {r.path for r in representatives}
remaining = [r for r in expanded_results if r.path not in rep_paths]
remaining_sorted = sorted(remaining, key=lambda r: r.score, reverse=True)
representatives.extend(remaining_sorted[:target_count - len(representatives)])
return representatives[:target_count]
except ImportError as exc:
self.logger.debug("Clustering not available: %s", exc)
return sorted(
expanded_results, key=lambda r: r.score, reverse=True
)[:target_count]
except Exception as exc:
self.logger.debug("Stage 3 clustering failed: %s", exc)
return sorted(
expanded_results, key=lambda r: r.score, reverse=True
)[:target_count]
def _stage4_optional_rerank(
self,
query: str,
clustered_results: List[SearchResult],
k: int,
) -> List[SearchResult]:
"""Stage 4: Optional cross-encoder reranking.
Applies cross-encoder reranking if enabled in config.
Args:
query: Search query string
clustered_results: Results from Stage 3 clustering
k: Number of final results to return
Returns:
Reranked results sorted by cross-encoder score
"""
if not clustered_results:
return []
# Use existing _cross_encoder_rerank method
return self._cross_encoder_rerank(query, clustered_results, k)
def _get_embeddings_for_clustering(
self,
results: List[SearchResult],
) -> Optional["np.ndarray"]:
"""Get dense embeddings for clustering results.
Tries to generate embeddings from result content for clustering.
Args:
results: List of SearchResult objects
Returns:
NumPy array of embeddings or None if not available
"""
if not NUMPY_AVAILABLE:
return None
if not results:
return None
try:
from codexlens.semantic.factory import get_embedder
# Get embedding settings from config
embedding_backend = "fastembed"
embedding_model = "code"
use_gpu = True
if self._config is not None:
embedding_backend = getattr(self._config, "embedding_backend", "fastembed")
embedding_model = getattr(self._config, "embedding_model", "code")
use_gpu = getattr(self._config, "embedding_use_gpu", True)
# Create embedder
if embedding_backend == "litellm":
embedder = get_embedder(backend="litellm", model=embedding_model)
else:
embedder = get_embedder(backend="fastembed", profile=embedding_model, use_gpu=use_gpu)
# Extract text content from results
texts = []
for result in results:
# Use content if available, otherwise use excerpt
text = result.content or result.excerpt or ""
if not text and result.path:
text = result.path
texts.append(text[:2000]) # Limit text length
# Generate embeddings
embeddings = embedder.embed_to_numpy(texts)
return embeddings
except ImportError as exc:
self.logger.debug("Embedder not available for clustering: %s", exc)
return None
except Exception as exc:
self.logger.debug("Failed to generate embeddings for clustering: %s", exc)
return None
def binary_rerank_cascade_search(
self,
query: str,
@@ -1990,6 +2637,220 @@ class ChainSearchEngine:
index_paths, name, kind, options.total_limit
)
def search_references(
self,
symbol_name: str,
source_path: Optional[Path] = None,
depth: int = -1,
limit: int = 100,
) -> List[ReferenceResult]:
"""Find all references to a symbol across the project.
Searches the code_relationships table in all index databases to find
where the given symbol is referenced (called, imported, inherited, etc.).
Args:
symbol_name: Fully qualified or simple name of the symbol to find references to
source_path: Starting path for search (default: workspace root from registry)
depth: Search depth (-1 = unlimited, 0 = current dir only)
limit: Maximum results to return (default 100)
Returns:
List of ReferenceResult objects sorted by file path and line number
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> refs = engine.search_references("authenticate", Path("D:/project/src"))
>>> for ref in refs[:10]:
... print(f"{ref.file_path}:{ref.line} ({ref.relationship_type})")
"""
import sqlite3
from concurrent.futures import as_completed
# Determine starting path
if source_path is None:
# Try to get workspace root from registry
mappings = self.registry.list_mappings()
if mappings:
source_path = Path(mappings[0].source_path)
else:
self.logger.warning("No source path provided and no mappings in registry")
return []
# Find starting index
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
# Collect all index paths
index_paths = self._collect_index_paths(start_index, depth)
if not index_paths:
self.logger.debug(f"No indexes collected from {start_index}")
return []
self.logger.debug(
"Searching %d indexes for references to '%s'",
len(index_paths), symbol_name
)
# Search in parallel
all_results: List[ReferenceResult] = []
executor = self._get_executor()
def search_single_index(index_path: Path) -> List[ReferenceResult]:
"""Search a single index for references."""
results: List[ReferenceResult] = []
try:
conn = sqlite3.connect(str(index_path), check_same_thread=False)
conn.row_factory = sqlite3.Row
# Query code_relationships for references to this symbol
# Match either target_qualified_name containing the symbol name
# or an exact match on the last component
# Try full_path first (new schema), fallback to path (old schema)
try:
rows = conn.execute(
"""
SELECT DISTINCT
f.full_path as source_file,
cr.source_line,
cr.relationship_type,
f.content
FROM code_relationships cr
JOIN symbols s ON s.id = cr.source_symbol_id
JOIN files f ON f.id = s.file_id
WHERE cr.target_qualified_name LIKE ?
OR cr.target_qualified_name LIKE ?
OR cr.target_qualified_name = ?
ORDER BY f.full_path, cr.source_line
LIMIT ?
""",
(
f"%{symbol_name}", # Ends with symbol name
f"%.{symbol_name}", # Qualified name ending with .symbol_name
symbol_name, # Exact match
limit,
)
).fetchall()
except sqlite3.OperationalError:
# Fallback for old schema with 'path' column
rows = conn.execute(
"""
SELECT DISTINCT
f.path as source_file,
cr.source_line,
cr.relationship_type,
f.content
FROM code_relationships cr
JOIN symbols s ON s.id = cr.source_symbol_id
JOIN files f ON f.id = s.file_id
WHERE cr.target_qualified_name LIKE ?
OR cr.target_qualified_name LIKE ?
OR cr.target_qualified_name = ?
ORDER BY f.path, cr.source_line
LIMIT ?
""",
(
f"%{symbol_name}", # Ends with symbol name
f"%.{symbol_name}", # Qualified name ending with .symbol_name
symbol_name, # Exact match
limit,
)
).fetchall()
for row in rows:
file_path = row["source_file"]
line = row["source_line"] or 1
rel_type = row["relationship_type"]
content = row["content"] or ""
# Extract context (3 lines around reference)
context = self._extract_context(content, line, context_lines=3)
results.append(ReferenceResult(
file_path=file_path,
line=line,
column=0, # Column info not stored in code_relationships
context=context,
relationship_type=rel_type,
))
conn.close()
except sqlite3.DatabaseError as exc:
self.logger.debug(
"Failed to search references in %s: %s", index_path, exc
)
except Exception as exc:
self.logger.debug(
"Unexpected error searching references in %s: %s", index_path, exc
)
return results
# Submit parallel searches
futures = {
executor.submit(search_single_index, idx_path): idx_path
for idx_path in index_paths
}
for future in as_completed(futures):
try:
results = future.result()
all_results.extend(results)
except Exception as exc:
idx_path = futures[future]
self.logger.debug(
"Reference search failed for %s: %s", idx_path, exc
)
# Deduplicate by (file_path, line)
seen: set = set()
unique_results: List[ReferenceResult] = []
for ref in all_results:
key = (ref.file_path, ref.line)
if key not in seen:
seen.add(key)
unique_results.append(ref)
# Sort by file path and line
unique_results.sort(key=lambda r: (r.file_path, r.line))
# Apply limit
return unique_results[:limit]
def _extract_context(
self,
content: str,
line: int,
context_lines: int = 3
) -> str:
"""Extract lines around a given line number from file content.
Args:
content: Full file content
line: Target line number (1-based)
context_lines: Number of lines to include before and after
Returns:
Context snippet as a string
"""
if not content:
return ""
lines = content.splitlines()
total_lines = len(lines)
if line < 1 or line > total_lines:
return ""
# Calculate range (0-indexed internally)
start = max(0, line - 1 - context_lines)
end = min(total_lines, line + context_lines)
context = lines[start:end]
return "\n".join(context)
# === Internal Methods ===
def _find_start_index(self, source_path: Path) -> Optional[Path]:

View File

@@ -0,0 +1,124 @@
"""Clustering strategies for the staged hybrid search pipeline.
This module provides extensible clustering infrastructure for grouping
similar search results and selecting representative results.
Install with: pip install codexlens[clustering]
Example:
>>> from codexlens.search.clustering import (
... CLUSTERING_AVAILABLE,
... ClusteringConfig,
... get_strategy,
... )
>>> config = ClusteringConfig(min_cluster_size=3)
>>> # Auto-select best available strategy with fallback
>>> strategy = get_strategy("auto", config)
>>> representatives = strategy.fit_predict(embeddings, results)
>>>
>>> # Or explicitly use a specific strategy
>>> if CLUSTERING_AVAILABLE:
... from codexlens.search.clustering import HDBSCANStrategy
... strategy = HDBSCANStrategy(config)
... representatives = strategy.fit_predict(embeddings, results)
"""
from __future__ import annotations
# Always export base classes and factory (no heavy dependencies)
from .base import BaseClusteringStrategy, ClusteringConfig
from .factory import (
ClusteringStrategyFactory,
check_clustering_strategy_available,
get_strategy,
)
from .noop_strategy import NoOpStrategy
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
# Feature flag for clustering availability (hdbscan + sklearn)
CLUSTERING_AVAILABLE = False
HDBSCAN_AVAILABLE = False
DBSCAN_AVAILABLE = False
_import_error: str | None = None
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
"""Detect if clustering dependencies are available.
Returns:
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
"""
hdbscan_ok = False
dbscan_ok = False
try:
import hdbscan # noqa: F401
hdbscan_ok = True
except ImportError:
pass
try:
from sklearn.cluster import DBSCAN # noqa: F401
dbscan_ok = True
except ImportError:
pass
all_ok = hdbscan_ok and dbscan_ok
error = None
if not all_ok:
missing = []
if not hdbscan_ok:
missing.append("hdbscan")
if not dbscan_ok:
missing.append("scikit-learn")
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
return all_ok, hdbscan_ok, dbscan_ok, error
# Initialize on module load
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
_detect_clustering_available()
)
def check_clustering_available() -> tuple[bool, str | None]:
"""Check if all clustering dependencies are available.
Returns:
Tuple of (is_available, error_message).
error_message is None if available, otherwise contains install instructions.
"""
return CLUSTERING_AVAILABLE, _import_error
# Conditionally export strategy implementations
__all__ = [
# Feature flags
"CLUSTERING_AVAILABLE",
"HDBSCAN_AVAILABLE",
"DBSCAN_AVAILABLE",
"check_clustering_available",
# Base classes
"BaseClusteringStrategy",
"ClusteringConfig",
# Factory
"ClusteringStrategyFactory",
"get_strategy",
"check_clustering_strategy_available",
# Always-available strategies
"NoOpStrategy",
"FrequencyStrategy",
"FrequencyConfig",
]
# Conditionally add strategy classes to __all__ and module namespace
if HDBSCAN_AVAILABLE:
from .hdbscan_strategy import HDBSCANStrategy
__all__.append("HDBSCANStrategy")
if DBSCAN_AVAILABLE:
from .dbscan_strategy import DBSCANStrategy
__all__.append("DBSCANStrategy")

View File

@@ -0,0 +1,153 @@
"""Base classes for clustering strategies in the hybrid search pipeline.
This module defines the abstract base class for clustering strategies used
in the staged hybrid search pipeline. Strategies cluster search results
based on their embeddings and select representative results from each cluster.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
@dataclass
class ClusteringConfig:
"""Configuration parameters for clustering strategies.
Attributes:
min_cluster_size: Minimum number of results to form a cluster.
HDBSCAN default is 5, but for search results 2-3 is often better.
min_samples: Number of samples in a neighborhood for a point to be
considered a core point. Lower values allow more clusters.
metric: Distance metric for clustering. Common options:
- 'euclidean': Standard L2 distance
- 'cosine': Cosine distance (1 - cosine_similarity)
- 'manhattan': L1 distance
cluster_selection_epsilon: Distance threshold for cluster selection.
Results within this distance may be merged into the same cluster.
allow_single_cluster: If True, allow all results to form one cluster.
Useful when results are very similar.
prediction_data: If True, generate prediction data for new points.
"""
min_cluster_size: int = 3
min_samples: int = 2
metric: str = "cosine"
cluster_selection_epsilon: float = 0.0
allow_single_cluster: bool = True
prediction_data: bool = False
def __post_init__(self) -> None:
"""Validate configuration parameters."""
if self.min_cluster_size < 2:
raise ValueError("min_cluster_size must be >= 2")
if self.min_samples < 1:
raise ValueError("min_samples must be >= 1")
if self.metric not in ("euclidean", "cosine", "manhattan"):
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
if self.cluster_selection_epsilon < 0:
raise ValueError("cluster_selection_epsilon must be >= 0")
class BaseClusteringStrategy(ABC):
"""Abstract base class for clustering strategies.
Clustering strategies are used in the staged hybrid search pipeline to
group similar search results and select representative results from each
cluster, reducing redundancy while maintaining diversity.
Subclasses must implement:
- cluster(): Group results into clusters based on embeddings
- select_representatives(): Choose best result(s) from each cluster
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize the clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
"""
self.config = config or ClusteringConfig()
@abstractmethod
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results based on their embeddings.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Used for additional metadata during clustering.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Results not assigned to any cluster
(noise points) should be returned as single-element clusters.
Example:
>>> strategy = HDBSCANStrategy()
>>> clusters = strategy.cluster(embeddings, results)
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
>>> # Result indices 0, 2, 5 are in cluster 0
>>> # Result indices 1, 3 are in cluster 1
>>> # Result index 4 is a noise point (singleton cluster)
>>> # Result indices 6, 7, 8 are in cluster 2
"""
...
@abstractmethod
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
This method chooses the best result(s) from each cluster to include
in the final search results. The selection can be based on:
- Highest score within cluster
- Closest to cluster centroid
- Custom selection logic
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings array for centroid-based selection.
Returns:
List of representative SearchResult objects, one or more per cluster,
ordered by relevance (highest score first).
Example:
>>> representatives = strategy.select_representatives(clusters, results)
>>> # Returns best result from each cluster
"""
...
def fit_predict(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List["SearchResult"]:
"""Convenience method to cluster and select representatives in one call.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
results: List of SearchResult objects.
Returns:
List of representative SearchResult objects.
"""
clusters = self.cluster(embeddings, results)
return self.select_representatives(clusters, results, embeddings)

View File

@@ -0,0 +1,197 @@
"""DBSCAN-based clustering strategy for search results.
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
is the fallback clustering strategy when HDBSCAN is not available.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class DBSCANStrategy(BaseClusteringStrategy):
"""DBSCAN-based clustering strategy.
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
DBSCAN requires an explicit eps parameter, which is auto-computed from the
distance distribution if not provided.
Example:
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
>>> strategy = DBSCANStrategy(config)
>>> clusters = strategy.cluster(embeddings, results)
>>> representatives = strategy.select_representatives(clusters, results)
"""
# Default eps percentile for auto-computation
DEFAULT_EPS_PERCENTILE: float = 15.0
def __init__(
self,
config: Optional[ClusteringConfig] = None,
eps: Optional[float] = None,
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
) -> None:
"""Initialize DBSCAN clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
from the distance distribution.
eps_percentile: Percentile of pairwise distances to use for
auto-computing eps. Default is 15th percentile.
Raises:
ImportError: If sklearn is not installed.
"""
super().__init__(config)
self.eps = eps
self.eps_percentile = eps_percentile
# Validate sklearn is available
try:
from sklearn.cluster import DBSCAN # noqa: F401
except ImportError as exc:
raise ImportError(
"scikit-learn package is required for DBSCANStrategy. "
"Install with: pip install codexlens[clustering]"
) from exc
def _compute_eps(self, embeddings: "np.ndarray") -> float:
"""Auto-compute eps from pairwise distance distribution.
Uses the specified percentile of pairwise distances as eps,
which typically captures local density well.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
Returns:
Computed eps value.
"""
import numpy as np
from sklearn.metrics import pairwise_distances
# Compute pairwise distances
distances = pairwise_distances(embeddings, metric=self.config.metric)
# Get upper triangle (excluding diagonal)
upper_tri = distances[np.triu_indices_from(distances, k=1)]
if len(upper_tri) == 0:
# Only one point, return a default small eps
return 0.1
# Use percentile of distances as eps
eps = float(np.percentile(upper_tri, self.eps_percentile))
# Ensure eps is positive
return max(eps, 1e-6)
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results using DBSCAN algorithm.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Noise points are returned as singleton clusters.
"""
from sklearn.cluster import DBSCAN
import numpy as np
n_results = len(results)
if n_results == 0:
return []
# Handle edge case: single result
if n_results == 1:
return [[0]]
# Determine eps value
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
# Configure DBSCAN clusterer
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
clusterer = DBSCAN(
eps=eps,
min_samples=self.config.min_samples,
metric=self.config.metric,
)
# Fit and get cluster labels
# Labels: -1 = noise, 0+ = cluster index
labels = clusterer.fit_predict(embeddings)
# Group indices by cluster label
cluster_map: dict[int, list[int]] = {}
for idx, label in enumerate(labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(idx)
# Build result: non-noise clusters first, then noise as singletons
clusters: List[List[int]] = []
# Add proper clusters (label >= 0)
for label in sorted(cluster_map.keys()):
if label >= 0:
clusters.append(cluster_map[label])
# Add noise points as singleton clusters (label == -1)
if -1 in cluster_map:
for idx in cluster_map[-1]:
clusters.append([idx])
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
Selects the result with the highest score from each cluster.
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used in score-based selection).
Returns:
List of representative SearchResult objects, one per cluster,
ordered by score (highest first).
"""
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
# Find the result with the highest score in this cluster
best_idx = max(cluster_indices, key=lambda i: results[i].score)
representatives.append(results[best_idx])
# Sort by score descending
representatives.sort(key=lambda r: r.score, reverse=True)
return representatives

View File

@@ -0,0 +1,202 @@
"""Factory for creating clustering strategies.
Provides a unified interface for instantiating different clustering backends
with automatic fallback chain: hdbscan -> dbscan -> noop.
"""
from __future__ import annotations
from typing import Any, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
from .noop_strategy import NoOpStrategy
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
"""Check whether a specific clustering strategy can be used.
Args:
strategy: Strategy name to check. Options:
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
- "dbscan": DBSCAN clustering (requires sklearn)
- "frequency": Frequency-based clustering (always available)
- "noop": No-op strategy (always available)
Returns:
Tuple of (is_available, error_message).
error_message is None if available, otherwise contains install instructions.
"""
strategy = (strategy or "").strip().lower()
if strategy == "hdbscan":
try:
import hdbscan # noqa: F401
except ImportError:
return False, (
"hdbscan package not available. "
"Install with: pip install codexlens[clustering]"
)
return True, None
if strategy == "dbscan":
try:
from sklearn.cluster import DBSCAN # noqa: F401
except ImportError:
return False, (
"scikit-learn package not available. "
"Install with: pip install codexlens[clustering]"
)
return True, None
if strategy == "frequency":
# Frequency strategy is always available (no external deps)
return True, None
if strategy == "noop":
return True, None
return False, (
f"Invalid clustering strategy: {strategy}. "
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
)
def get_strategy(
strategy: str = "hdbscan",
config: Optional[ClusteringConfig] = None,
*,
fallback: bool = True,
**kwargs: Any,
) -> BaseClusteringStrategy:
"""Factory function to create clustering strategy with fallback chain.
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
Args:
strategy: Clustering strategy to use. Options:
- "hdbscan": HDBSCAN clustering (default, recommended)
- "dbscan": DBSCAN clustering (fallback)
- "frequency": Frequency-based clustering (groups by symbol occurrence)
- "noop": No-op strategy (returns all results ungrouped)
- "auto": Try hdbscan, then dbscan, then noop
config: Clustering configuration. Uses defaults if not provided.
For frequency strategy, pass FrequencyConfig for full control.
fallback: If True (default), automatically fall back to next strategy
in the chain when primary is unavailable. If False, raise ImportError
when requested strategy is unavailable.
**kwargs: Additional strategy-specific arguments.
For DBSCANStrategy: eps, eps_percentile
For FrequencyStrategy: group_by, min_frequency, etc.
Returns:
BaseClusteringStrategy: Configured clustering strategy instance.
Raises:
ValueError: If strategy is not recognized.
ImportError: If required dependencies are not installed and fallback=False.
Example:
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3)
>>> # Auto-select best available strategy
>>> strategy = get_strategy("auto", config)
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
>>> strategy = get_strategy("hdbscan", config)
>>> # Use frequency-based strategy
>>> from codexlens.search.clustering import FrequencyConfig
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
>>> strategy = get_strategy("frequency", freq_config)
"""
strategy = (strategy or "").strip().lower()
# Handle "auto" - try strategies in order
if strategy == "auto":
return _get_best_available_strategy(config, **kwargs)
if strategy == "hdbscan":
ok, err = check_clustering_strategy_available("hdbscan")
if ok:
from .hdbscan_strategy import HDBSCANStrategy
return HDBSCANStrategy(config)
if fallback:
# Try dbscan fallback
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
if ok_dbscan:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
# Final fallback to noop
return NoOpStrategy(config)
raise ImportError(err)
if strategy == "dbscan":
ok, err = check_clustering_strategy_available("dbscan")
if ok:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
if fallback:
# Fallback to noop
return NoOpStrategy(config)
raise ImportError(err)
if strategy == "frequency":
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
if config is None or not isinstance(config, FrequencyConfig):
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
else:
freq_config = config
return FrequencyStrategy(freq_config)
if strategy == "noop":
return NoOpStrategy(config)
raise ValueError(
f"Unknown clustering strategy: {strategy}. "
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
)
def _get_best_available_strategy(
config: Optional[ClusteringConfig] = None,
**kwargs: Any,
) -> BaseClusteringStrategy:
"""Get the best available clustering strategy.
Tries strategies in order: hdbscan -> dbscan -> noop
Args:
config: Clustering configuration.
**kwargs: Additional strategy-specific arguments.
Returns:
Best available clustering strategy instance.
"""
# Try HDBSCAN first
ok, _ = check_clustering_strategy_available("hdbscan")
if ok:
from .hdbscan_strategy import HDBSCANStrategy
return HDBSCANStrategy(config)
# Try DBSCAN second
ok, _ = check_clustering_strategy_available("dbscan")
if ok:
from .dbscan_strategy import DBSCANStrategy
return DBSCANStrategy(config, **kwargs)
# Fallback to NoOp
return NoOpStrategy(config)
# Alias for backward compatibility
ClusteringStrategyFactory = type(
"ClusteringStrategyFactory",
(),
{
"get_strategy": staticmethod(get_strategy),
"check_available": staticmethod(check_clustering_strategy_available),
},
)

View File

@@ -0,0 +1,263 @@
"""Frequency-based clustering strategy for search result deduplication.
This strategy groups search results by symbol/method name and prunes based on
occurrence frequency. High-frequency symbols (frequently referenced methods)
are considered more important and retained, while low-frequency results
(potentially noise) can be filtered out.
Use cases:
- Prioritize commonly called methods/functions
- Filter out one-off results that may be less relevant
- Deduplicate results pointing to the same symbol from different locations
"""
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
@dataclass
class FrequencyConfig(ClusteringConfig):
"""Configuration for frequency-based clustering strategy.
Attributes:
group_by: Field to group results by for frequency counting.
- 'symbol': Group by symbol_name (default, for method/function dedup)
- 'file': Group by file path
- 'symbol_kind': Group by symbol type (function, class, etc.)
min_frequency: Minimum occurrence count to keep a result.
Results appearing less than this are considered noise and pruned.
max_representatives_per_group: Maximum results to keep per symbol group.
frequency_weight: How much to boost score based on frequency.
Final score = original_score * (1 + frequency_weight * log(frequency))
keep_mode: How to handle low-frequency results.
- 'filter': Remove results below min_frequency
- 'demote': Keep but lower their score ranking
"""
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
max_representatives_per_group: int = 3
frequency_weight: float = 0.1 # Boost factor for frequency
keep_mode: Literal["filter", "demote"] = "demote"
def __post_init__(self) -> None:
"""Validate configuration parameters."""
# Skip parent validation since we don't use HDBSCAN params
if self.min_frequency < 1:
raise ValueError("min_frequency must be >= 1")
if self.max_representatives_per_group < 1:
raise ValueError("max_representatives_per_group must be >= 1")
if self.frequency_weight < 0:
raise ValueError("frequency_weight must be >= 0")
if self.group_by not in ("symbol", "file", "symbol_kind"):
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
if self.keep_mode not in ("filter", "demote"):
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
class FrequencyStrategy(BaseClusteringStrategy):
"""Frequency-based clustering strategy for search result deduplication.
This strategy groups search results by symbol name (or file/kind) and:
1. Counts how many times each symbol appears in results
2. Higher frequency = more important (frequently referenced method)
3. Filters or demotes low-frequency results
4. Selects top representatives from each frequency group
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
- Does NOT require embeddings (works with metadata only)
- Is very fast (O(n) complexity)
- Is deterministic (no random initialization)
- Works well for symbol-level deduplication
Example:
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
>>> strategy = FrequencyStrategy(config)
>>> # Results with symbol "authenticate" appearing 5 times
>>> # will be prioritized over "helper_func" appearing once
>>> representatives = strategy.fit_predict(embeddings, results)
"""
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
"""Initialize the frequency strategy.
Args:
config: Frequency configuration. Uses defaults if not provided.
"""
self.config: FrequencyConfig = config or FrequencyConfig()
def _get_group_key(self, result: "SearchResult") -> str:
"""Extract grouping key from a search result.
Args:
result: SearchResult to extract key from.
Returns:
String key for grouping (symbol name, file path, or kind).
"""
if self.config.group_by == "symbol":
# Use symbol_name if available, otherwise fall back to file:line
symbol = getattr(result, "symbol_name", None)
if symbol:
return str(symbol)
# Fallback: use file path + start_line as pseudo-symbol
start_line = getattr(result, "start_line", 0) or 0
return f"{result.path}:{start_line}"
elif self.config.group_by == "file":
return str(result.path)
elif self.config.group_by == "symbol_kind":
kind = getattr(result, "symbol_kind", None)
return str(kind) if kind else "unknown"
return str(result.path) # Default fallback
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Group search results by frequency of occurrence.
Note: This method ignores embeddings and groups by metadata only.
The embeddings parameter is kept for interface compatibility.
Args:
embeddings: Ignored (kept for interface compatibility).
results: List of SearchResult objects to cluster.
Returns:
List of clusters (groups), where each cluster contains indices
of results with the same grouping key. Clusters are ordered by
frequency (highest frequency first).
"""
if not results:
return []
# Group results by key
groups: Dict[str, List[int]] = defaultdict(list)
for idx, result in enumerate(results):
key = self._get_group_key(result)
groups[key].append(idx)
# Sort groups by frequency (descending) then by key (for stability)
sorted_groups = sorted(
groups.items(),
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
)
# Convert to list of clusters
clusters = [indices for _, indices in sorted_groups]
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results based on frequency and score.
For each frequency group:
1. If frequency < min_frequency: filter or demote based on keep_mode
2. Sort by score within group
3. Apply frequency boost to scores
4. Select top N representatives
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (used for tie-breaking if provided).
Returns:
List of representative SearchResult objects, ordered by
frequency-adjusted score (highest first).
"""
import math
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
demoted: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
frequency = len(cluster_indices)
# Get results in this cluster, sorted by score
cluster_results = [results[i] for i in cluster_indices]
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
# Check frequency threshold
if frequency < self.config.min_frequency:
if self.config.keep_mode == "filter":
# Skip low-frequency results entirely
continue
else: # demote mode
# Keep but add to demoted list (lower priority)
for result in cluster_results[: self.config.max_representatives_per_group]:
demoted.append(result)
continue
# Apply frequency boost and select top representatives
for result in cluster_results[: self.config.max_representatives_per_group]:
# Calculate frequency-boosted score
original_score = getattr(result, "score", 0.0)
# log(frequency + 1) to handle frequency=1 case smoothly
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
boosted_score = original_score * frequency_boost
# Create new result with boosted score and frequency metadata
# Note: SearchResult might be immutable, so we preserve original
# and track boosted score in metadata
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
result.metadata["frequency"] = frequency
result.metadata["frequency_boosted_score"] = boosted_score
representatives.append(result)
# Sort representatives by boosted score (or original score as fallback)
def get_sort_score(r: "SearchResult") -> float:
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
return getattr(r, "score", 0.0)
representatives.sort(key=get_sort_score, reverse=True)
# Add demoted results at the end
if demoted:
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
representatives.extend(demoted)
return representatives
def fit_predict(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List["SearchResult"]:
"""Convenience method to cluster and select representatives in one call.
Args:
embeddings: NumPy array (may be ignored for frequency-based clustering).
results: List of SearchResult objects.
Returns:
List of representative SearchResult objects.
"""
clusters = self.cluster(embeddings, results)
return self.select_representatives(clusters, results, embeddings)

View File

@@ -0,0 +1,153 @@
"""HDBSCAN-based clustering strategy for search results.
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
is the primary clustering strategy for grouping similar search results.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class HDBSCANStrategy(BaseClusteringStrategy):
"""HDBSCAN-based clustering strategy.
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
HDBSCAN is preferred over DBSCAN because it:
- Automatically determines the number of clusters
- Handles varying density clusters well
- Identifies noise points (outliers) effectively
Example:
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
>>> strategy = HDBSCANStrategy(config)
>>> clusters = strategy.cluster(embeddings, results)
>>> representatives = strategy.select_representatives(clusters, results)
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize HDBSCAN clustering strategy.
Args:
config: Clustering configuration. Uses defaults if not provided.
Raises:
ImportError: If hdbscan package is not installed.
"""
super().__init__(config)
# Validate hdbscan is available
try:
import hdbscan # noqa: F401
except ImportError as exc:
raise ImportError(
"hdbscan package is required for HDBSCANStrategy. "
"Install with: pip install codexlens[clustering]"
) from exc
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Cluster search results using HDBSCAN algorithm.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim)
containing the embedding vectors for each result.
results: List of SearchResult objects corresponding to embeddings.
Returns:
List of clusters, where each cluster is a list of indices
into the results list. Noise points are returned as singleton clusters.
"""
import hdbscan
import numpy as np
n_results = len(results)
if n_results == 0:
return []
# Handle edge case: fewer results than min_cluster_size
if n_results < self.config.min_cluster_size:
# Return each result as its own singleton cluster
return [[i] for i in range(n_results)]
# Configure HDBSCAN clusterer
clusterer = hdbscan.HDBSCAN(
min_cluster_size=self.config.min_cluster_size,
min_samples=self.config.min_samples,
metric=self.config.metric,
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
allow_single_cluster=self.config.allow_single_cluster,
prediction_data=self.config.prediction_data,
)
# Fit and get cluster labels
# Labels: -1 = noise, 0+ = cluster index
labels = clusterer.fit_predict(embeddings)
# Group indices by cluster label
cluster_map: dict[int, list[int]] = {}
for idx, label in enumerate(labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(idx)
# Build result: non-noise clusters first, then noise as singletons
clusters: List[List[int]] = []
# Add proper clusters (label >= 0)
for label in sorted(cluster_map.keys()):
if label >= 0:
clusters.append(cluster_map[label])
# Add noise points as singleton clusters (label == -1)
if -1 in cluster_map:
for idx in cluster_map[-1]:
clusters.append([idx])
return clusters
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Select representative results from each cluster.
Selects the result with the highest score from each cluster.
Args:
clusters: List of clusters from cluster() method.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used in score-based selection).
Returns:
List of representative SearchResult objects, one per cluster,
ordered by score (highest first).
"""
if not clusters or not results:
return []
representatives: List["SearchResult"] = []
for cluster_indices in clusters:
if not cluster_indices:
continue
# Find the result with the highest score in this cluster
best_idx = max(cluster_indices, key=lambda i: results[i].score)
representatives.append(results[best_idx])
# Sort by score descending
representatives.sort(key=lambda r: r.score, reverse=True)
return representatives

View File

@@ -0,0 +1,83 @@
"""No-op clustering strategy for search results.
NoOpStrategy returns all results ungrouped when clustering dependencies
are not available or clustering is disabled.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from .base import BaseClusteringStrategy, ClusteringConfig
if TYPE_CHECKING:
import numpy as np
from codexlens.entities import SearchResult
class NoOpStrategy(BaseClusteringStrategy):
"""No-op clustering strategy that returns all results ungrouped.
This strategy is used as a final fallback when no clustering dependencies
are available, or when clustering is explicitly disabled. Each result
is treated as its own singleton cluster.
Example:
>>> from codexlens.search.clustering import NoOpStrategy
>>> strategy = NoOpStrategy()
>>> clusters = strategy.cluster(embeddings, results)
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
>>> representatives = strategy.select_representatives(clusters, results)
>>> # Returns all results sorted by score
"""
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
"""Initialize NoOp clustering strategy.
Args:
config: Clustering configuration. Ignored for NoOpStrategy
but accepted for interface compatibility.
"""
super().__init__(config)
def cluster(
self,
embeddings: "np.ndarray",
results: List["SearchResult"],
) -> List[List[int]]:
"""Return each result as its own singleton cluster.
Args:
embeddings: NumPy array of shape (n_results, embedding_dim).
Not used but accepted for interface compatibility.
results: List of SearchResult objects.
Returns:
List of singleton clusters, one per result.
"""
return [[i] for i in range(len(results))]
def select_representatives(
self,
clusters: List[List[int]],
results: List["SearchResult"],
embeddings: Optional["np.ndarray"] = None,
) -> List["SearchResult"]:
"""Return all results sorted by score.
Since each cluster is a singleton, this effectively returns all
results sorted by score descending.
Args:
clusters: List of singleton clusters.
results: Original list of SearchResult objects.
embeddings: Optional embeddings (not used).
Returns:
All SearchResult objects sorted by score (highest first).
"""
if not results:
return []
# Return all results sorted by score
return sorted(results, key=lambda r: r.score, reverse=True)

View File

@@ -1807,6 +1807,178 @@ class DirIndexStore:
for row in rows
]
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
"""Get all symbols in a specific file, sorted by start_line.
Args:
file_path: Full path to the file
Returns:
List of Symbol objects sorted by start_line
"""
file_path_str = str(Path(file_path).resolve())
with self._lock:
conn = self._get_connection()
# First get the file_id
file_row = conn.execute(
"SELECT id FROM files WHERE full_path=?",
(file_path_str,),
).fetchone()
if not file_row:
return []
file_id = int(file_row["id"])
rows = conn.execute(
"""
SELECT s.name, s.kind, s.start_line, s.end_line
FROM symbols s
WHERE s.file_id=?
ORDER BY s.start_line
""",
(file_id,),
).fetchall()
return [
Symbol(
name=row["name"],
kind=row["kind"],
range=(row["start_line"], row["end_line"]),
file=file_path_str,
)
for row in rows
]
def get_outgoing_calls(
self,
file_path: str | Path,
symbol_name: Optional[str] = None,
) -> List[Tuple[str, str, int, Optional[str]]]:
"""Get outgoing calls from symbols in a file.
Queries code_relationships table for calls originating from symbols
in the specified file.
Args:
file_path: Full path to the source file
symbol_name: Optional symbol name to filter by. If None, returns
calls from all symbols in the file.
Returns:
List of tuples: (target_name, relationship_type, source_line, target_file)
- target_name: Qualified name of the call target
- relationship_type: Type of relationship (e.g., "calls", "imports")
- source_line: Line number where the call occurs
- target_file: Target file path (may be None if unknown)
"""
file_path_str = str(Path(file_path).resolve())
with self._lock:
conn = self._get_connection()
# First get the file_id
file_row = conn.execute(
"SELECT id FROM files WHERE full_path=?",
(file_path_str,),
).fetchone()
if not file_row:
return []
file_id = int(file_row["id"])
if symbol_name:
rows = conn.execute(
"""
SELECT cr.target_qualified_name, cr.relationship_type,
cr.source_line, cr.target_file
FROM code_relationships cr
JOIN symbols s ON s.id = cr.source_symbol_id
WHERE s.file_id=? AND s.name=?
ORDER BY cr.source_line
""",
(file_id, symbol_name),
).fetchall()
else:
rows = conn.execute(
"""
SELECT cr.target_qualified_name, cr.relationship_type,
cr.source_line, cr.target_file
FROM code_relationships cr
JOIN symbols s ON s.id = cr.source_symbol_id
WHERE s.file_id=?
ORDER BY cr.source_line
""",
(file_id,),
).fetchall()
return [
(
row["target_qualified_name"],
row["relationship_type"],
int(row["source_line"]),
row["target_file"],
)
for row in rows
]
def get_incoming_calls(
self,
target_name: str,
limit: int = 100,
) -> List[Tuple[str, str, int, str]]:
"""Get incoming calls/references to a target symbol.
Queries code_relationships table for references to the specified
target symbol name.
Args:
target_name: Name of the target symbol to find references for.
Matches against target_qualified_name (exact match,
suffix match, or contains match).
limit: Maximum number of results to return
Returns:
List of tuples: (source_symbol_name, relationship_type, source_line, source_file)
- source_symbol_name: Name of the calling symbol
- relationship_type: Type of relationship (e.g., "calls", "imports")
- source_line: Line number where the call occurs
- source_file: Full path to the source file
"""
with self._lock:
conn = self._get_connection()
rows = conn.execute(
"""
SELECT s.name AS source_name, cr.relationship_type,
cr.source_line, f.full_path AS source_file
FROM code_relationships cr
JOIN symbols s ON s.id = cr.source_symbol_id
JOIN files f ON f.id = s.file_id
WHERE cr.target_qualified_name = ?
OR cr.target_qualified_name LIKE ?
OR cr.target_qualified_name LIKE ?
ORDER BY f.full_path, cr.source_line
LIMIT ?
""",
(
target_name,
f"%.{target_name}",
f"%{target_name}",
limit,
),
).fetchall()
return [
(
row["source_name"],
row["relationship_type"],
int(row["source_line"]),
row["source_file"],
)
for row in rows
]
# === Statistics ===
def stats(self) -> Dict[str, Any]:

View File

@@ -270,6 +270,39 @@ class GlobalSymbolIndex:
symbols = self.search(name=name, kind=kind, limit=limit, prefix_mode=prefix_mode)
return [(s.file or "", s.range) for s in symbols]
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
"""Get all symbols in a specific file, sorted by start_line.
Args:
file_path: Full path to the file
Returns:
List of Symbol objects sorted by start_line
"""
file_path_str = str(Path(file_path).resolve())
with self._lock:
conn = self._get_connection()
rows = conn.execute(
"""
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
FROM global_symbols
WHERE project_id=? AND file_path=?
ORDER BY start_line
""",
(self.project_id, file_path_str),
).fetchall()
return [
Symbol(
name=row["symbol_name"],
kind=row["symbol_kind"],
range=(row["start_line"], row["end_line"]),
file=row["file_path"],
)
for row in rows
]
def _get_existing_index_path(self, file_path_str: str) -> Optional[str]:
with self._lock:
conn = self._get_connection()

View File

@@ -0,0 +1,282 @@
"""Tests for codexlens.api.references module."""
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codexlens.api.references import (
find_references,
_read_line_from_file,
_proximity_score,
_group_references_by_definition,
_transform_to_reference_result,
)
from codexlens.api.models import (
DefinitionResult,
ReferenceResult,
GroupedReferences,
)
class TestReadLineFromFile:
"""Tests for _read_line_from_file helper."""
def test_read_existing_line(self, tmp_path):
"""Test reading an existing line from a file."""
test_file = tmp_path / "test.py"
test_file.write_text("line 1\nline 2\nline 3\n")
assert _read_line_from_file(str(test_file), 1) == "line 1"
assert _read_line_from_file(str(test_file), 2) == "line 2"
assert _read_line_from_file(str(test_file), 3) == "line 3"
def test_read_nonexistent_line(self, tmp_path):
"""Test reading a line that doesn't exist."""
test_file = tmp_path / "test.py"
test_file.write_text("line 1\nline 2\n")
assert _read_line_from_file(str(test_file), 10) == ""
def test_read_nonexistent_file(self):
"""Test reading from a file that doesn't exist."""
assert _read_line_from_file("/nonexistent/path/file.py", 1) == ""
def test_strips_trailing_whitespace(self, tmp_path):
"""Test that trailing whitespace is stripped."""
test_file = tmp_path / "test.py"
test_file.write_text("line with spaces \n")
assert _read_line_from_file(str(test_file), 1) == "line with spaces"
class TestProximityScore:
"""Tests for _proximity_score helper."""
def test_same_file(self):
"""Same file should return highest score."""
score = _proximity_score("/a/b/c.py", "/a/b/c.py")
assert score == 1000
def test_same_directory(self):
"""Same directory should return 100."""
score = _proximity_score("/a/b/x.py", "/a/b/y.py")
assert score == 100
def test_different_directories(self):
"""Different directories should return common prefix length."""
score = _proximity_score("/a/b/c/x.py", "/a/b/d/y.py")
# Common path is /a/b
assert score > 0
def test_empty_paths(self):
"""Empty paths should return 0."""
assert _proximity_score("", "/a/b/c.py") == 0
assert _proximity_score("/a/b/c.py", "") == 0
assert _proximity_score("", "") == 0
class TestGroupReferencesByDefinition:
"""Tests for _group_references_by_definition helper."""
def test_single_definition(self):
"""Single definition should have all references."""
definition = DefinitionResult(
name="foo",
kind="function",
file_path="/a/b/c.py",
line=10,
end_line=20,
)
references = [
ReferenceResult(
file_path="/a/b/d.py",
line=5,
column=0,
context_line="foo()",
relationship="call",
),
ReferenceResult(
file_path="/a/x/y.py",
line=10,
column=0,
context_line="foo()",
relationship="call",
),
]
result = _group_references_by_definition([definition], references)
assert len(result) == 1
assert result[0].definition == definition
assert len(result[0].references) == 2
def test_multiple_definitions(self):
"""Multiple definitions should group by proximity."""
def1 = DefinitionResult(
name="foo",
kind="function",
file_path="/a/b/c.py",
line=10,
end_line=20,
)
def2 = DefinitionResult(
name="foo",
kind="function",
file_path="/x/y/z.py",
line=10,
end_line=20,
)
# Reference closer to def1
ref1 = ReferenceResult(
file_path="/a/b/d.py",
line=5,
column=0,
context_line="foo()",
relationship="call",
)
# Reference closer to def2
ref2 = ReferenceResult(
file_path="/x/y/w.py",
line=10,
column=0,
context_line="foo()",
relationship="call",
)
result = _group_references_by_definition(
[def1, def2], [ref1, ref2], include_definition=True
)
assert len(result) == 2
# Each definition should have the closer reference
def1_refs = [g for g in result if g.definition == def1][0].references
def2_refs = [g for g in result if g.definition == def2][0].references
assert any(r.file_path == "/a/b/d.py" for r in def1_refs)
assert any(r.file_path == "/x/y/w.py" for r in def2_refs)
def test_empty_definitions(self):
"""Empty definitions should return empty result."""
result = _group_references_by_definition([], [])
assert result == []
class TestTransformToReferenceResult:
"""Tests for _transform_to_reference_result helper."""
def test_normalizes_relationship_type(self, tmp_path):
"""Test that relationship type is normalized."""
test_file = tmp_path / "test.py"
test_file.write_text("def foo(): pass\n")
# Create a mock raw reference
raw_ref = MagicMock()
raw_ref.file_path = str(test_file)
raw_ref.line = 1
raw_ref.column = 0
raw_ref.relationship_type = "calls" # Plural form
result = _transform_to_reference_result(raw_ref)
assert result.relationship == "call" # Normalized form
assert result.context_line == "def foo(): pass"
class TestFindReferences:
"""Tests for find_references API function."""
def test_raises_for_invalid_project_root(self):
"""Test that ValueError is raised for invalid project root."""
with pytest.raises(ValueError, match="does not exist"):
find_references("/nonexistent/path", "some_symbol")
@patch("codexlens.search.chain_search.ChainSearchEngine")
@patch("codexlens.storage.registry.RegistryStore")
@patch("codexlens.storage.path_mapper.PathMapper")
@patch("codexlens.config.Config")
def test_returns_grouped_references(
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
):
"""Test that find_references returns GroupedReferences."""
# Setup mocks
mock_engine = MagicMock()
mock_engine_class.return_value = mock_engine
# Mock symbol search (for definitions)
mock_symbol = MagicMock()
mock_symbol.name = "test_func"
mock_symbol.kind = "function"
mock_symbol.file = str(tmp_path / "test.py")
mock_symbol.range = (10, 20)
mock_engine.search_symbols.return_value = [mock_symbol]
# Mock reference search
mock_ref = MagicMock()
mock_ref.file_path = str(tmp_path / "caller.py")
mock_ref.line = 5
mock_ref.column = 0
mock_ref.relationship_type = "call"
mock_engine.search_references.return_value = [mock_ref]
# Create test files
test_file = tmp_path / "test.py"
test_file.write_text("def test_func():\n pass\n")
caller_file = tmp_path / "caller.py"
caller_file.write_text("test_func()\n")
# Call find_references
result = find_references(str(tmp_path), "test_func")
# Verify result structure
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], GroupedReferences)
assert result[0].definition.name == "test_func"
assert len(result[0].references) == 1
@patch("codexlens.search.chain_search.ChainSearchEngine")
@patch("codexlens.storage.registry.RegistryStore")
@patch("codexlens.storage.path_mapper.PathMapper")
@patch("codexlens.config.Config")
def test_respects_include_definition_false(
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
):
"""Test include_definition=False behavior."""
mock_engine = MagicMock()
mock_engine_class.return_value = mock_engine
mock_engine.search_symbols.return_value = []
mock_engine.search_references.return_value = []
result = find_references(
str(tmp_path), "test_func", include_definition=False
)
# Should still return a result with placeholder definition
assert len(result) == 1
assert result[0].definition.name == "test_func"
class TestImports:
"""Tests for module imports and exports."""
def test_find_references_exported_from_api(self):
"""Test that find_references is exported from codexlens.api."""
from codexlens.api import find_references as api_find_references
assert callable(api_find_references)
def test_models_exported_from_api(self):
"""Test that result models are exported from codexlens.api."""
from codexlens.api import (
GroupedReferences,
ReferenceResult,
DefinitionResult,
)
assert GroupedReferences is not None
assert ReferenceResult is not None
assert DefinitionResult is not None

View File

@@ -0,0 +1,528 @@
"""Tests for semantic_search API."""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codexlens.api import SemanticResult
from codexlens.api.semantic import (
semantic_search,
_build_search_options,
_generate_match_reason,
_split_camel_case,
_transform_results,
)
class TestSemanticSearchFunctionSignature:
"""Test that semantic_search has the correct function signature."""
def test_function_accepts_all_parameters(self):
"""Verify function signature matches spec."""
import inspect
sig = inspect.signature(semantic_search)
params = list(sig.parameters.keys())
expected_params = [
"project_root",
"query",
"mode",
"vector_weight",
"structural_weight",
"keyword_weight",
"fusion_strategy",
"kind_filter",
"limit",
"include_match_reason",
]
assert params == expected_params
def test_default_parameter_values(self):
"""Verify default parameter values match spec."""
import inspect
sig = inspect.signature(semantic_search)
assert sig.parameters["mode"].default == "fusion"
assert sig.parameters["vector_weight"].default == 0.5
assert sig.parameters["structural_weight"].default == 0.3
assert sig.parameters["keyword_weight"].default == 0.2
assert sig.parameters["fusion_strategy"].default == "rrf"
assert sig.parameters["kind_filter"].default is None
assert sig.parameters["limit"].default == 20
assert sig.parameters["include_match_reason"].default is False
class TestBuildSearchOptions:
"""Test _build_search_options helper function."""
def test_vector_mode_options(self):
"""Test options for pure vector mode."""
options = _build_search_options(
mode="vector",
vector_weight=1.0,
structural_weight=0.0,
keyword_weight=0.0,
limit=20,
)
assert options.hybrid_mode is True
assert options.enable_vector is True
assert options.pure_vector is True
assert options.enable_fuzzy is False
def test_structural_mode_options(self):
"""Test options for structural mode."""
options = _build_search_options(
mode="structural",
vector_weight=0.0,
structural_weight=1.0,
keyword_weight=0.0,
limit=20,
)
assert options.hybrid_mode is True
assert options.enable_vector is False
assert options.enable_fuzzy is True
assert options.include_symbols is True
def test_fusion_mode_options(self):
"""Test options for fusion mode (default)."""
options = _build_search_options(
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
limit=20,
)
assert options.hybrid_mode is True
assert options.enable_vector is True # vector_weight > 0
assert options.enable_fuzzy is True # keyword_weight > 0
assert options.include_symbols is True # structural_weight > 0
class TestTransformResults:
"""Test _transform_results helper function."""
def test_transforms_basic_result(self):
"""Test basic result transformation."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "def authenticate():"
mock_result.symbol_name = "authenticate"
mock_result.symbol_kind = "function"
mock_result.start_line = 10
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=False,
query="auth",
)
assert len(results) == 1
assert results[0].symbol_name == "authenticate"
assert results[0].kind == "function"
assert results[0].file_path == "/project/src/auth.py"
assert results[0].line == 10
assert results[0].fusion_score == 0.85
def test_kind_filter_excludes_non_matching(self):
"""Test that kind_filter excludes non-matching results."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "AUTH_TOKEN = 'secret'"
mock_result.symbol_name = "AUTH_TOKEN"
mock_result.symbol_kind = "variable"
mock_result.start_line = 5
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=["function", "class"], # Exclude variable
include_match_reason=False,
query="auth",
)
assert len(results) == 0
def test_kind_filter_includes_matching(self):
"""Test that kind_filter includes matching results."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "class AuthManager:"
mock_result.symbol_name = "AuthManager"
mock_result.symbol_kind = "class"
mock_result.start_line = 1
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=["function", "class"], # Include class
include_match_reason=False,
query="auth",
)
assert len(results) == 1
assert results[0].symbol_name == "AuthManager"
def test_include_match_reason_generates_reason(self):
"""Test that include_match_reason generates match reasons."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "def authenticate(user, password):"
mock_result.symbol_name = "authenticate"
mock_result.symbol_kind = "function"
mock_result.start_line = 10
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=True,
query="authenticate",
)
assert len(results) == 1
assert results[0].match_reason is not None
assert "authenticate" in results[0].match_reason.lower()
class TestGenerateMatchReason:
"""Test _generate_match_reason helper function."""
def test_direct_name_match(self):
"""Test match reason for direct name match."""
reason = _generate_match_reason(
query="authenticate",
symbol_name="authenticate",
symbol_kind="function",
snippet="def authenticate(user): pass",
vector_score=0.8,
structural_score=None,
)
assert "authenticate" in reason.lower()
def test_keyword_match(self):
"""Test match reason for keyword match in snippet."""
reason = _generate_match_reason(
query="password validation",
symbol_name="verify_user",
symbol_kind="function",
snippet="def verify_user(password): validate(password)",
vector_score=0.6,
structural_score=None,
)
assert "password" in reason.lower() or "validation" in reason.lower()
def test_high_semantic_similarity(self):
"""Test match reason mentions semantic similarity for high vector score."""
reason = _generate_match_reason(
query="authentication",
symbol_name="login_handler",
symbol_kind="function",
snippet="def login_handler(): pass",
vector_score=0.85,
structural_score=None,
)
assert "semantic" in reason.lower()
def test_returns_string_even_with_no_matches(self):
"""Test that a reason string is always returned."""
reason = _generate_match_reason(
query="xyz123",
symbol_name="abc456",
symbol_kind="function",
snippet="completely unrelated code",
vector_score=0.3,
structural_score=None,
)
assert isinstance(reason, str)
assert len(reason) > 0
class TestSplitCamelCase:
"""Test _split_camel_case helper function."""
def test_camel_case(self):
"""Test splitting camelCase."""
result = _split_camel_case("authenticateUser")
assert "authenticate" in result.lower()
assert "user" in result.lower()
def test_pascal_case(self):
"""Test splitting PascalCase."""
result = _split_camel_case("AuthManager")
assert "auth" in result.lower()
assert "manager" in result.lower()
def test_snake_case(self):
"""Test splitting snake_case."""
result = _split_camel_case("auth_manager")
assert "auth" in result.lower()
assert "manager" in result.lower()
def test_mixed_case(self):
"""Test splitting mixed case."""
result = _split_camel_case("HTTPRequestHandler")
# Should handle acronyms
assert "http" in result.lower() or "request" in result.lower()
class TestSemanticResultDataclass:
"""Test SemanticResult dataclass structure."""
def test_semantic_result_fields(self):
"""Test SemanticResult has all required fields."""
result = SemanticResult(
symbol_name="test",
kind="function",
file_path="/test.py",
line=1,
vector_score=0.8,
structural_score=0.6,
fusion_score=0.7,
snippet="def test(): pass",
match_reason="Test match",
)
assert result.symbol_name == "test"
assert result.kind == "function"
assert result.file_path == "/test.py"
assert result.line == 1
assert result.vector_score == 0.8
assert result.structural_score == 0.6
assert result.fusion_score == 0.7
assert result.snippet == "def test(): pass"
assert result.match_reason == "Test match"
def test_semantic_result_optional_fields(self):
"""Test SemanticResult with optional None fields."""
result = SemanticResult(
symbol_name="test",
kind="function",
file_path="/test.py",
line=1,
vector_score=None, # Degraded - no vector index
structural_score=None, # Degraded - no relationships
fusion_score=0.5,
snippet="def test(): pass",
match_reason=None, # Not requested
)
assert result.vector_score is None
assert result.structural_score is None
assert result.match_reason is None
def test_semantic_result_to_dict(self):
"""Test SemanticResult.to_dict() filters None values."""
result = SemanticResult(
symbol_name="test",
kind="function",
file_path="/test.py",
line=1,
vector_score=None,
structural_score=0.6,
fusion_score=0.7,
snippet="def test(): pass",
match_reason=None,
)
d = result.to_dict()
assert "symbol_name" in d
assert "vector_score" not in d # None values filtered
assert "structural_score" in d
assert "match_reason" not in d # None values filtered
class TestFusionStrategyMapping:
"""Test fusion_strategy parameter mapping via _execute_search."""
def test_rrf_strategy_calls_search(self):
"""Test that rrf strategy maps to standard search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="rrf",
options=mock_options,
limit=20,
)
mock_engine.search.assert_called_once()
def test_staged_strategy_calls_staged_cascade_search(self):
"""Test that staged strategy maps to staged_cascade_search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.staged_cascade_search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="staged",
options=mock_options,
limit=20,
)
mock_engine.staged_cascade_search.assert_called_once()
def test_binary_strategy_calls_binary_cascade_search(self):
"""Test that binary strategy maps to binary_cascade_search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.binary_cascade_search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="binary",
options=mock_options,
limit=20,
)
mock_engine.binary_cascade_search.assert_called_once()
def test_hybrid_strategy_calls_hybrid_cascade_search(self):
"""Test that hybrid strategy maps to hybrid_cascade_search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.hybrid_cascade_search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="hybrid",
options=mock_options,
limit=20,
)
mock_engine.hybrid_cascade_search.assert_called_once()
def test_unknown_strategy_defaults_to_rrf(self):
"""Test that unknown strategy defaults to standard search (rrf)."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="unknown_strategy",
options=mock_options,
limit=20,
)
mock_engine.search.assert_called_once()
class TestGracefulDegradation:
"""Test graceful degradation behavior."""
def test_vector_score_none_when_no_vector_index(self):
"""Test vector_score=None when vector index unavailable."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.5
mock_result.excerpt = "def auth(): pass"
mock_result.symbol_name = "auth"
mock_result.symbol_kind = "function"
mock_result.start_line = 1
mock_result.symbol = None
mock_result.metadata = {} # No vector score in metadata
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=False,
query="auth",
)
assert len(results) == 1
# When no source_scores in metadata, vector_score should be None
assert results[0].vector_score is None
def test_structural_score_extracted_from_fts(self):
"""Test structural_score extracted from FTS scores."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.8
mock_result.excerpt = "def auth(): pass"
mock_result.symbol_name = "auth"
mock_result.symbol_kind = "function"
mock_result.start_line = 1
mock_result.symbol = None
mock_result.metadata = {
"source_scores": {
"exact": 0.9,
"fuzzy": 0.7,
}
}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=False,
query="auth",
)
assert len(results) == 1
assert results[0].structural_score == 0.9 # max of exact/fuzzy

View File

@@ -0,0 +1 @@
"""Tests package for LSP module."""

View File

@@ -0,0 +1,477 @@
"""Tests for hover provider."""
from __future__ import annotations
import pytest
from pathlib import Path
from unittest.mock import Mock, MagicMock
import tempfile
from codexlens.entities import Symbol
class TestHoverInfo:
"""Test HoverInfo dataclass."""
def test_hover_info_import(self):
"""HoverInfo can be imported."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from codexlens.lsp.providers import HoverInfo
assert HoverInfo is not None
def test_hover_info_fields(self):
"""HoverInfo has all required fields."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo
info = HoverInfo(
name="my_function",
kind="function",
signature="def my_function(x: int) -> str:",
documentation="A test function.",
file_path="/test/file.py",
line_range=(10, 15),
)
assert info.name == "my_function"
assert info.kind == "function"
assert info.signature == "def my_function(x: int) -> str:"
assert info.documentation == "A test function."
assert info.file_path == "/test/file.py"
assert info.line_range == (10, 15)
def test_hover_info_optional_documentation(self):
"""Documentation can be None."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation=None,
file_path="/test.py",
line_range=(1, 2),
)
assert info.documentation is None
class TestHoverProvider:
"""Test HoverProvider class."""
def test_provider_import(self):
"""HoverProvider can be imported."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
assert HoverProvider is not None
def test_returns_none_for_unknown_symbol(self):
"""Returns None when symbol not found."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_index = Mock()
mock_index.search.return_value = []
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
result = provider.get_hover_info("unknown_symbol")
assert result is None
mock_index.search.assert_called_once_with(
name="unknown_symbol", limit=1, prefix_mode=False
)
def test_returns_none_for_non_exact_match(self):
"""Returns None when search returns non-exact matches."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
# Return a symbol with different name (prefix match but not exact)
mock_symbol = Mock()
mock_symbol.name = "my_function_extended"
mock_symbol.kind = "function"
mock_symbol.file = "/test/file.py"
mock_symbol.range = (10, 15)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
result = provider.get_hover_info("my_function")
assert result is None
def test_returns_hover_info_for_known_symbol(self):
"""Returns HoverInfo for found symbol."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = None # No file, will use fallback signature
mock_symbol.range = (10, 15)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
result = provider.get_hover_info("my_func")
assert result is not None
assert result.name == "my_func"
assert result.kind == "function"
assert result.line_range == (10, 15)
assert result.signature == "function my_func"
def test_extracts_signature_from_file(self):
"""Extracts signature from actual file content."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
# Create a temporary file with Python content
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, encoding="utf-8"
) as f:
f.write("# comment\n")
f.write("def test_function(x: int, y: str) -> bool:\n")
f.write(" return True\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "test_function"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (2, 3) # Line 2 (1-based)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("test_function")
assert result is not None
assert "def test_function(x: int, y: str) -> bool:" in result.signature
finally:
Path(temp_path).unlink(missing_ok=True)
def test_extracts_multiline_signature(self):
"""Extracts multiline function signature."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
# Create a temporary file with multiline signature
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, encoding="utf-8"
) as f:
f.write("def complex_function(\n")
f.write(" arg1: int,\n")
f.write(" arg2: str,\n")
f.write(") -> bool:\n")
f.write(" return True\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "complex_function"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (1, 5) # Line 1 (1-based)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("complex_function")
assert result is not None
assert "def complex_function(" in result.signature
# Should capture multiline signature
assert "arg1: int" in result.signature
finally:
Path(temp_path).unlink(missing_ok=True)
def test_handles_nonexistent_file_gracefully(self):
"""Returns fallback signature when file doesn't exist."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/nonexistent/path/file.py"
mock_symbol.range = (10, 15)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("my_func")
assert result is not None
assert result.signature == "function my_func"
def test_handles_invalid_line_range(self):
"""Returns fallback signature when line range is invalid."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, encoding="utf-8"
) as f:
f.write("def test():\n")
f.write(" pass\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "test"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (100, 105) # Line beyond file length
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("test")
assert result is not None
assert result.signature == "function test"
finally:
Path(temp_path).unlink(missing_ok=True)
class TestFormatHoverMarkdown:
"""Test markdown formatting."""
def test_format_python_signature(self):
"""Formats Python signature with python code fence."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func(x: int) -> str:",
documentation=None,
file_path="/test/file.py",
line_range=(10, 15),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "```python" in result
assert "def func(x: int) -> str:" in result
assert "function" in result
assert "file.py" in result
assert "line 10" in result
def test_format_javascript_signature(self):
"""Formats JavaScript signature with javascript code fence."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="myFunc",
kind="function",
signature="function myFunc(x) {",
documentation=None,
file_path="/test/file.js",
line_range=(5, 10),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "```javascript" in result
assert "function myFunc(x) {" in result
def test_format_typescript_signature(self):
"""Formats TypeScript signature with typescript code fence."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="myFunc",
kind="function",
signature="function myFunc(x: number): string {",
documentation=None,
file_path="/test/file.ts",
line_range=(5, 10),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "```typescript" in result
def test_format_with_documentation(self):
"""Includes documentation when available."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation="This is a test function.",
file_path="/test/file.py",
line_range=(10, 15),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "This is a test function." in result
assert "---" in result # Separator before docs
def test_format_without_documentation(self):
"""Does not include documentation section when None."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation=None,
file_path="/test/file.py",
line_range=(10, 15),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
# Should have one separator for location, not two
# The result should not have duplicate doc separator
lines = result.split("\n")
separator_count = sum(1 for line in lines if line.strip() == "---")
assert separator_count == 1 # Only location separator
def test_format_unknown_extension(self):
"""Uses empty code fence for unknown file extensions."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="func code here",
documentation=None,
file_path="/test/file.xyz",
line_range=(1, 2),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
# Should have code fence without language specifier
assert "```\n" in result or "```xyz" not in result
def test_format_class_symbol(self):
"""Formats class symbol correctly."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="MyClass",
kind="class",
signature="class MyClass:",
documentation=None,
file_path="/test/file.py",
line_range=(1, 20),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "class MyClass:" in result
assert "*class*" in result
assert "line 1" in result
def test_format_empty_file_path(self):
"""Handles empty file path gracefully."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation=None,
file_path="",
line_range=(1, 2),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "unknown" in result or "```" in result
class TestHoverProviderRegistry:
"""Test HoverProvider with registry integration."""
def test_provider_accepts_none_registry(self):
"""HoverProvider works without registry."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_index = Mock()
mock_index.search.return_value = []
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("test")
assert result is None
assert provider.registry is None
def test_provider_stores_registry(self):
"""HoverProvider stores registry reference."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_index = Mock()
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
assert provider.global_index is mock_index
assert provider.registry is mock_registry

View File

@@ -0,0 +1,497 @@
"""Tests for reference search functionality.
This module tests the ReferenceResult dataclass and search_references method
in ChainSearchEngine, as well as the updated lsp_references handler.
"""
from __future__ import annotations
import pytest
from pathlib import Path
from unittest.mock import Mock, MagicMock, patch
import sqlite3
import tempfile
import os
class TestReferenceResult:
"""Test ReferenceResult dataclass."""
def test_reference_result_fields(self):
"""ReferenceResult has all required fields."""
from codexlens.search.chain_search import ReferenceResult
ref = ReferenceResult(
file_path="/test/file.py",
line=10,
column=5,
context="def foo():",
relationship_type="call",
)
assert ref.file_path == "/test/file.py"
assert ref.line == 10
assert ref.column == 5
assert ref.context == "def foo():"
assert ref.relationship_type == "call"
def test_reference_result_with_empty_context(self):
"""ReferenceResult can have empty context."""
from codexlens.search.chain_search import ReferenceResult
ref = ReferenceResult(
file_path="/test/file.py",
line=1,
column=0,
context="",
relationship_type="import",
)
assert ref.context == ""
def test_reference_result_different_relationship_types(self):
"""ReferenceResult supports different relationship types."""
from codexlens.search.chain_search import ReferenceResult
types = ["call", "import", "inheritance", "implementation", "usage"]
for rel_type in types:
ref = ReferenceResult(
file_path="/test/file.py",
line=1,
column=0,
context="test",
relationship_type=rel_type,
)
assert ref.relationship_type == rel_type
class TestExtractContext:
"""Test the _extract_context helper method."""
def test_extract_context_middle_of_file(self):
"""Extract context from middle of file."""
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
content = "\n".join([
"line 1",
"line 2",
"line 3",
"line 4", # target line
"line 5",
"line 6",
"line 7",
])
# Create minimal mock engine to test _extract_context
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context(content, line=4, context_lines=2)
assert "line 2" in context
assert "line 3" in context
assert "line 4" in context
assert "line 5" in context
assert "line 6" in context
def test_extract_context_start_of_file(self):
"""Extract context at start of file."""
from codexlens.search.chain_search import ChainSearchEngine
content = "\n".join([
"line 1", # target
"line 2",
"line 3",
"line 4",
])
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context(content, line=1, context_lines=2)
assert "line 1" in context
assert "line 2" in context
assert "line 3" in context
def test_extract_context_end_of_file(self):
"""Extract context at end of file."""
from codexlens.search.chain_search import ChainSearchEngine
content = "\n".join([
"line 1",
"line 2",
"line 3",
"line 4", # target
])
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context(content, line=4, context_lines=2)
assert "line 2" in context
assert "line 3" in context
assert "line 4" in context
def test_extract_context_empty_content(self):
"""Extract context from empty content."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context("", line=1, context_lines=3)
assert context == ""
def test_extract_context_invalid_line(self):
"""Extract context with invalid line number."""
from codexlens.search.chain_search import ChainSearchEngine
content = "line 1\nline 2\nline 3"
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
# Line 0 (invalid)
assert engine._extract_context(content, line=0, context_lines=1) == ""
# Line beyond end
assert engine._extract_context(content, line=100, context_lines=1) == ""
class TestSearchReferences:
"""Test search_references method."""
def test_returns_empty_for_no_source_path_and_no_registry(self):
"""Returns empty list when no source path and registry has no mappings."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_registry.list_mappings.return_value = []
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
results = engine.search_references("test_symbol")
assert results == []
def test_returns_empty_for_no_indexes(self):
"""Returns empty list when no indexes found."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
mock_mapper.source_to_index_db.return_value = Path("/nonexistent/_index.db")
engine = ChainSearchEngine(mock_registry, mock_mapper)
with patch.object(engine, "_find_start_index", return_value=None):
results = engine.search_references("test_symbol", Path("/some/path"))
assert results == []
def test_deduplicates_results(self):
"""Removes duplicate file:line references."""
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
# Create a temporary database with duplicate relationships
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'def test(): pass');
INSERT INTO symbols VALUES (1, 1, 'test_func', 'function', 1, 1);
INSERT INTO code_relationships VALUES (1, 1, 'target_func', 'call', 10, NULL);
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 10, NULL);
""")
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target_func", Path(tmpdir))
# Should only have 1 result due to deduplication
assert len(results) == 1
assert results[0].line == 10
def test_sorts_by_file_and_line(self):
"""Results sorted by file path then line number."""
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/b_file.py', 'python', 'content');
INSERT INTO files VALUES (2, '/test/a_file.py', 'python', 'content');
INSERT INTO symbols VALUES (1, 1, 'func1', 'function', 1, 1);
INSERT INTO symbols VALUES (2, 2, 'func2', 'function', 1, 1);
INSERT INTO code_relationships VALUES (1, 1, 'target', 'call', 20, NULL);
INSERT INTO code_relationships VALUES (2, 1, 'target', 'call', 10, NULL);
INSERT INTO code_relationships VALUES (3, 2, 'target', 'call', 5, NULL);
""")
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target", Path(tmpdir))
# Should be sorted: a_file.py:5, b_file.py:10, b_file.py:20
assert len(results) == 3
assert results[0].file_path == "/test/a_file.py"
assert results[0].line == 5
assert results[1].file_path == "/test/b_file.py"
assert results[1].line == 10
assert results[2].file_path == "/test/b_file.py"
assert results[2].line == 20
def test_respects_limit(self):
"""Returns at most limit results."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
INSERT INTO symbols VALUES (1, 1, 'func', 'function', 1, 1);
""")
# Insert many relationships
for i in range(50):
conn.execute(
"INSERT INTO code_relationships VALUES (?, 1, 'target', 'call', ?, NULL)",
(i + 1, i + 1)
)
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target", Path(tmpdir), limit=10)
assert len(results) == 10
def test_matches_qualified_name(self):
"""Matches symbols by qualified name suffix."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
INSERT INTO symbols VALUES (1, 1, 'caller', 'function', 1, 1);
-- Fully qualified name
INSERT INTO code_relationships VALUES (1, 1, 'module.submodule.target_func', 'call', 10, NULL);
-- Simple name
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 20, NULL);
""")
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target_func", Path(tmpdir))
# Should find both references
assert len(results) == 2
class TestLspReferencesHandler:
"""Test the LSP references handler."""
def test_handler_uses_search_engine(self):
"""Handler uses search_engine.search_references when available."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import _path_to_uri
from codexlens.search.chain_search import ReferenceResult
# Create mock references
mock_references = [
ReferenceResult(
file_path="/test/file1.py",
line=10,
column=5,
context="def foo():",
relationship_type="call",
),
ReferenceResult(
file_path="/test/file2.py",
line=20,
column=0,
context="import foo",
relationship_type="import",
),
]
# Verify conversion to LSP Location
locations = []
for ref in mock_references:
locations.append(
lsp.Location(
uri=_path_to_uri(ref.file_path),
range=lsp.Range(
start=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column,
),
end=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column + len("foo"),
),
),
)
)
assert len(locations) == 2
# First reference at line 10 (0-indexed = 9)
assert locations[0].range.start.line == 9
assert locations[0].range.start.character == 5
# Second reference at line 20 (0-indexed = 19)
assert locations[1].range.start.line == 19
assert locations[1].range.start.character == 0
def test_handler_falls_back_to_global_index(self):
"""Handler falls back to global_index when search_engine unavailable."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from codexlens.lsp.handlers import symbol_to_location
from codexlens.entities import Symbol
# Test fallback path converts Symbol to Location
symbol = Symbol(
name="test_func",
kind="function",
range=(10, 15),
file="/test/file.py",
)
location = symbol_to_location(symbol)
assert location is not None
# LSP uses 0-based lines
assert location.range.start.line == 9
assert location.range.end.line == 14

View File

@@ -0,0 +1,210 @@
"""Tests for codex-lens LSP server."""
from __future__ import annotations
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from codexlens.entities import Symbol
class TestCodexLensLanguageServer:
"""Tests for CodexLensLanguageServer."""
def test_server_import(self):
"""Test that server module can be imported."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from codexlens.lsp.server import CodexLensLanguageServer, server
assert CodexLensLanguageServer is not None
assert server is not None
assert server.name == "codexlens-lsp"
def test_server_initialization(self):
"""Test server instance creation."""
pytest.importorskip("pygls")
from codexlens.lsp.server import CodexLensLanguageServer
ls = CodexLensLanguageServer()
assert ls.registry is None
assert ls.mapper is None
assert ls.global_index is None
assert ls.search_engine is None
assert ls.workspace_root is None
class TestDefinitionHandler:
"""Tests for definition handler."""
def test_definition_lookup(self):
"""Test definition lookup returns location for known symbol."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import symbol_to_location
symbol = Symbol(
name="test_function",
kind="function",
range=(10, 15),
file="/path/to/file.py",
)
location = symbol_to_location(symbol)
assert location is not None
assert isinstance(location, lsp.Location)
# LSP uses 0-based lines
assert location.range.start.line == 9
assert location.range.end.line == 14
def test_definition_no_file(self):
"""Test definition lookup returns None for symbol without file."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import symbol_to_location
symbol = Symbol(
name="test_function",
kind="function",
range=(10, 15),
file=None,
)
location = symbol_to_location(symbol)
assert location is None
class TestCompletionHandler:
"""Tests for completion handler."""
def test_get_prefix_at_position(self):
"""Test extracting prefix at cursor position."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _get_prefix_at_position
document_text = "def hello_world():\n print(hel"
# Cursor at end of "hel"
prefix = _get_prefix_at_position(document_text, 1, 14)
assert prefix == "hel"
# Cursor at beginning of line (after whitespace)
prefix = _get_prefix_at_position(document_text, 1, 4)
assert prefix == ""
# Cursor after "he" in "hello_world" - returns text before cursor
prefix = _get_prefix_at_position(document_text, 0, 6)
assert prefix == "he"
# Cursor at end of "hello_world"
prefix = _get_prefix_at_position(document_text, 0, 15)
assert prefix == "hello_world"
def test_get_word_at_position(self):
"""Test extracting word at cursor position."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _get_word_at_position
document_text = "def hello_world():\n print(msg)"
# Cursor on "hello_world"
word = _get_word_at_position(document_text, 0, 6)
assert word == "hello_world"
# Cursor on "print"
word = _get_word_at_position(document_text, 1, 6)
assert word == "print"
# Cursor on "msg"
word = _get_word_at_position(document_text, 1, 11)
assert word == "msg"
def test_symbol_kind_mapping(self):
"""Test symbol kind to completion kind mapping."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import _symbol_kind_to_completion_kind
assert _symbol_kind_to_completion_kind("function") == lsp.CompletionItemKind.Function
assert _symbol_kind_to_completion_kind("class") == lsp.CompletionItemKind.Class
assert _symbol_kind_to_completion_kind("method") == lsp.CompletionItemKind.Method
assert _symbol_kind_to_completion_kind("variable") == lsp.CompletionItemKind.Variable
# Unknown kind should default to Text
assert _symbol_kind_to_completion_kind("unknown") == lsp.CompletionItemKind.Text
class TestWorkspaceSymbolHandler:
"""Tests for workspace symbol handler."""
def test_symbol_kind_to_lsp(self):
"""Test symbol kind to LSP SymbolKind mapping."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import _symbol_kind_to_lsp
assert _symbol_kind_to_lsp("function") == lsp.SymbolKind.Function
assert _symbol_kind_to_lsp("class") == lsp.SymbolKind.Class
assert _symbol_kind_to_lsp("method") == lsp.SymbolKind.Method
assert _symbol_kind_to_lsp("interface") == lsp.SymbolKind.Interface
# Unknown kind should default to Variable
assert _symbol_kind_to_lsp("unknown") == lsp.SymbolKind.Variable
class TestUriConversion:
"""Tests for URI path conversion."""
def test_path_to_uri(self):
"""Test path to URI conversion."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _path_to_uri
# Unix path
uri = _path_to_uri("/home/user/file.py")
assert uri.startswith("file://")
assert "file.py" in uri
def test_uri_to_path(self):
"""Test URI to path conversion."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _uri_to_path
# Basic URI
path = _uri_to_path("file:///home/user/file.py")
assert path.name == "file.py"
class TestMainEntryPoint:
"""Tests for main entry point."""
def test_main_help(self):
"""Test that main shows help without errors."""
pytest.importorskip("pygls")
import sys
from unittest.mock import patch
# Patch sys.argv to show help
with patch.object(sys, 'argv', ['codexlens-lsp', '--help']):
from codexlens.lsp.server import main
with pytest.raises(SystemExit) as exc_info:
main()
# Help exits with 0
assert exc_info.value.code == 0

View File

@@ -0,0 +1 @@
"""Tests for MCP (Model Context Protocol) module."""

View File

@@ -0,0 +1,208 @@
"""Tests for MCP hooks module."""
import pytest
from unittest.mock import Mock, patch
from pathlib import Path
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
from codexlens.mcp.schema import MCPContext, SymbolInfo
class TestHookManager:
"""Test HookManager class."""
@pytest.fixture
def mock_provider(self):
"""Create a mock MCP provider."""
provider = Mock()
provider.build_context.return_value = MCPContext(
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
context_type="symbol_explanation",
)
provider.build_context_for_file.return_value = MCPContext(
context_type="file_overview",
)
return provider
@pytest.fixture
def hook_manager(self, mock_provider):
"""Create a HookManager with mocked provider."""
return HookManager(mock_provider)
def test_default_hooks_registered(self, hook_manager):
"""Default hooks are registered on initialization."""
assert "explain" in hook_manager._pre_hooks
assert "refactor" in hook_manager._pre_hooks
assert "document" in hook_manager._pre_hooks
def test_execute_pre_hook_returns_context(self, hook_manager, mock_provider):
"""execute_pre_hook returns MCPContext for registered hook."""
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
assert result is not None
assert isinstance(result, MCPContext)
mock_provider.build_context.assert_called_once()
def test_execute_pre_hook_returns_none_for_unknown_action(self, hook_manager):
"""execute_pre_hook returns None for unregistered action."""
result = hook_manager.execute_pre_hook("unknown_action", {"symbol": "test"})
assert result is None
def test_execute_pre_hook_handles_exception(self, hook_manager, mock_provider):
"""execute_pre_hook handles provider exceptions gracefully."""
mock_provider.build_context.side_effect = Exception("Provider failed")
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
assert result is None
def test_execute_post_hook_no_error_for_unregistered(self, hook_manager):
"""execute_post_hook doesn't error for unregistered action."""
# Should not raise
hook_manager.execute_post_hook("unknown", {"result": "data"})
def test_pre_explain_hook_calls_build_context(self, hook_manager, mock_provider):
"""_pre_explain_hook calls build_context correctly."""
hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
mock_provider.build_context.assert_called_with(
symbol_name="my_func",
context_type="symbol_explanation",
include_references=True,
include_related=True,
)
def test_pre_explain_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
"""_pre_explain_hook returns None when symbol param missing."""
result = hook_manager.execute_pre_hook("explain", {})
assert result is None
mock_provider.build_context.assert_not_called()
def test_pre_refactor_hook_calls_build_context(self, hook_manager, mock_provider):
"""_pre_refactor_hook calls build_context with refactor settings."""
hook_manager.execute_pre_hook("refactor", {"symbol": "my_class"})
mock_provider.build_context.assert_called_with(
symbol_name="my_class",
context_type="refactor_context",
include_references=True,
include_related=True,
max_references=20,
)
def test_pre_refactor_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
"""_pre_refactor_hook returns None when symbol param missing."""
result = hook_manager.execute_pre_hook("refactor", {})
assert result is None
mock_provider.build_context.assert_not_called()
def test_pre_document_hook_with_symbol(self, hook_manager, mock_provider):
"""_pre_document_hook uses build_context when symbol provided."""
hook_manager.execute_pre_hook("document", {"symbol": "my_func"})
mock_provider.build_context.assert_called_with(
symbol_name="my_func",
context_type="documentation_context",
include_references=False,
include_related=True,
)
def test_pre_document_hook_with_file_path(self, hook_manager, mock_provider):
"""_pre_document_hook uses build_context_for_file when file_path provided."""
hook_manager.execute_pre_hook("document", {"file_path": "/src/module.py"})
mock_provider.build_context_for_file.assert_called_once()
call_args = mock_provider.build_context_for_file.call_args
assert call_args[0][0] == Path("/src/module.py")
assert call_args[1].get("context_type") == "file_documentation"
def test_pre_document_hook_prefers_symbol_over_file(self, hook_manager, mock_provider):
"""_pre_document_hook prefers symbol when both provided."""
hook_manager.execute_pre_hook(
"document", {"symbol": "my_func", "file_path": "/src/module.py"}
)
mock_provider.build_context.assert_called_once()
mock_provider.build_context_for_file.assert_not_called()
def test_pre_document_hook_returns_none_without_params(self, hook_manager, mock_provider):
"""_pre_document_hook returns None when neither symbol nor file_path provided."""
result = hook_manager.execute_pre_hook("document", {})
assert result is None
mock_provider.build_context.assert_not_called()
mock_provider.build_context_for_file.assert_not_called()
def test_register_pre_hook(self, hook_manager):
"""register_pre_hook adds custom hook."""
custom_hook = Mock(return_value=MCPContext())
hook_manager.register_pre_hook("custom_action", custom_hook)
assert "custom_action" in hook_manager._pre_hooks
hook_manager.execute_pre_hook("custom_action", {"data": "value"})
custom_hook.assert_called_once_with({"data": "value"})
def test_register_post_hook(self, hook_manager):
"""register_post_hook adds custom hook."""
custom_hook = Mock()
hook_manager.register_post_hook("custom_action", custom_hook)
assert "custom_action" in hook_manager._post_hooks
hook_manager.execute_post_hook("custom_action", {"result": "data"})
custom_hook.assert_called_once_with({"result": "data"})
def test_execute_post_hook_handles_exception(self, hook_manager):
"""execute_post_hook handles hook exceptions gracefully."""
failing_hook = Mock(side_effect=Exception("Hook failed"))
hook_manager.register_post_hook("failing", failing_hook)
# Should not raise
hook_manager.execute_post_hook("failing", {"data": "value"})
class TestCreateContextForPrompt:
"""Test create_context_for_prompt function."""
def test_returns_prompt_injection_string(self):
"""create_context_for_prompt returns formatted string."""
mock_provider = Mock()
mock_provider.build_context.return_value = MCPContext(
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
definition="def test_func(): pass",
)
result = create_context_for_prompt(
mock_provider, "explain", {"symbol": "test_func"}
)
assert isinstance(result, str)
assert "<code_context>" in result
assert "test_func" in result
assert "</code_context>" in result
def test_returns_empty_string_when_no_context(self):
"""create_context_for_prompt returns empty string when no context built."""
mock_provider = Mock()
mock_provider.build_context.return_value = None
result = create_context_for_prompt(
mock_provider, "explain", {"symbol": "nonexistent"}
)
assert result == ""
def test_returns_empty_string_for_unknown_action(self):
"""create_context_for_prompt returns empty string for unregistered action."""
mock_provider = Mock()
result = create_context_for_prompt(
mock_provider, "unknown_action", {"data": "value"}
)
assert result == ""
mock_provider.build_context.assert_not_called()

View File

@@ -0,0 +1,383 @@
"""Tests for MCP provider."""
import pytest
from unittest.mock import Mock, MagicMock, patch
from pathlib import Path
import tempfile
import os
from codexlens.mcp.provider import MCPProvider
from codexlens.mcp.schema import MCPContext, SymbolInfo, ReferenceInfo
class TestMCPProvider:
"""Test MCPProvider class."""
@pytest.fixture
def mock_global_index(self):
"""Create a mock global index."""
return Mock()
@pytest.fixture
def mock_search_engine(self):
"""Create a mock search engine."""
return Mock()
@pytest.fixture
def mock_registry(self):
"""Create a mock registry."""
return Mock()
@pytest.fixture
def provider(self, mock_global_index, mock_search_engine, mock_registry):
"""Create an MCPProvider with mocked dependencies."""
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
def test_build_context_returns_none_for_unknown_symbol(self, provider, mock_global_index):
"""build_context returns None when symbol is not found."""
mock_global_index.search.return_value = []
result = provider.build_context("unknown_symbol")
assert result is None
mock_global_index.search.assert_called_once_with(
"unknown_symbol", prefix_mode=False, limit=1
)
def test_build_context_returns_mcp_context(
self, provider, mock_global_index, mock_search_engine
):
"""build_context returns MCPContext for known symbol."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
result = provider.build_context("my_func")
assert result is not None
assert isinstance(result, MCPContext)
assert result.symbol is not None
assert result.symbol.name == "my_func"
assert result.symbol.kind == "function"
assert result.context_type == "symbol_explanation"
def test_build_context_with_custom_context_type(
self, provider, mock_global_index, mock_search_engine
):
"""build_context respects custom context_type."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
result = provider.build_context("my_func", context_type="refactor_context")
assert result is not None
assert result.context_type == "refactor_context"
def test_build_context_includes_references(
self, provider, mock_global_index, mock_search_engine
):
"""build_context includes references when include_references=True."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_ref = Mock()
mock_ref.file_path = "/caller.py"
mock_ref.line = 25
mock_ref.column = 4
mock_ref.context = "result = my_func()"
mock_ref.relationship_type = "call"
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = [mock_ref]
result = provider.build_context("my_func", include_references=True)
assert result is not None
assert len(result.references) == 1
assert result.references[0].file_path == "/caller.py"
assert result.references[0].line == 25
assert result.references[0].relationship_type == "call"
def test_build_context_excludes_references_when_disabled(
self, provider, mock_global_index, mock_search_engine
):
"""build_context excludes references when include_references=False."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
# Disable both references and related to avoid any search_references calls
result = provider.build_context(
"my_func", include_references=False, include_related=False
)
assert result is not None
assert len(result.references) == 0
mock_search_engine.search_references.assert_not_called()
def test_build_context_respects_max_references(
self, provider, mock_global_index, mock_search_engine
):
"""build_context passes max_references to search engine."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
# Disable include_related to test only the references call
provider.build_context("my_func", max_references=5, include_related=False)
mock_search_engine.search_references.assert_called_once_with(
"my_func", limit=5
)
def test_build_context_includes_metadata(
self, provider, mock_global_index, mock_search_engine
):
"""build_context includes source metadata."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
result = provider.build_context("my_func")
assert result is not None
assert result.metadata.get("source") == "codex-lens"
def test_extract_definition_with_valid_file(self, provider):
"""_extract_definition reads file content correctly."""
# Create a temporary file with some content
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write("# Line 1\n")
f.write("# Line 2\n")
f.write("def my_func():\n") # Line 3
f.write(" pass\n") # Line 4
f.write("# Line 5\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.file = temp_path
mock_symbol.range = (3, 4) # 1-based line numbers
definition = provider._extract_definition(mock_symbol)
assert definition is not None
assert "def my_func():" in definition
assert "pass" in definition
finally:
os.unlink(temp_path)
def test_extract_definition_returns_none_for_missing_file(self, provider):
"""_extract_definition returns None for non-existent file."""
mock_symbol = Mock()
mock_symbol.file = "/nonexistent/path/file.py"
mock_symbol.range = (1, 5)
definition = provider._extract_definition(mock_symbol)
assert definition is None
def test_extract_definition_returns_none_for_none_file(self, provider):
"""_extract_definition returns None when symbol.file is None."""
mock_symbol = Mock()
mock_symbol.file = None
mock_symbol.range = (1, 5)
definition = provider._extract_definition(mock_symbol)
assert definition is None
def test_build_context_for_file_returns_context(
self, provider, mock_global_index
):
"""build_context_for_file returns MCPContext."""
mock_global_index.search.return_value = []
result = provider.build_context_for_file(
Path("/test/file.py"),
context_type="file_overview",
)
assert result is not None
assert isinstance(result, MCPContext)
assert result.context_type == "file_overview"
assert result.metadata.get("file_path") == str(Path("/test/file.py"))
def test_build_context_for_file_includes_symbols(
self, provider, mock_global_index
):
"""build_context_for_file includes symbols from the file."""
# Create temp file to get resolved path
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write("def func(): pass\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "func"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (1, 1)
mock_global_index.search.return_value = [mock_symbol]
result = provider.build_context_for_file(Path(temp_path))
assert result is not None
# Symbols from this file should be in related_symbols
assert len(result.related_symbols) >= 0 # May be 0 if filtering doesn't match
finally:
os.unlink(temp_path)
class TestMCPProviderRelatedSymbols:
"""Test related symbols functionality."""
@pytest.fixture
def provider(self):
"""Create provider with mocks."""
mock_global_index = Mock()
mock_search_engine = Mock()
mock_registry = Mock()
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
def test_get_related_symbols_from_references(self, provider):
"""_get_related_symbols extracts symbols from references."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.file = "/test.py"
mock_ref1 = Mock()
mock_ref1.file_path = "/caller1.py"
mock_ref1.relationship_type = "call"
mock_ref2 = Mock()
mock_ref2.file_path = "/caller2.py"
mock_ref2.relationship_type = "import"
provider.search_engine.search_references.return_value = [mock_ref1, mock_ref2]
related = provider._get_related_symbols(mock_symbol)
assert len(related) == 2
assert related[0].relationship == "call"
assert related[1].relationship == "import"
def test_get_related_symbols_limits_results(self, provider):
"""_get_related_symbols limits to 10 unique relationship types."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.file = "/test.py"
# Create 15 references with unique relationship types
refs = []
for i in range(15):
ref = Mock()
ref.file_path = f"/file{i}.py"
ref.relationship_type = f"type{i}"
refs.append(ref)
provider.search_engine.search_references.return_value = refs
related = provider._get_related_symbols(mock_symbol)
assert len(related) <= 10
def test_get_related_symbols_handles_exception(self, provider):
"""_get_related_symbols handles exceptions gracefully."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.file = "/test.py"
provider.search_engine.search_references.side_effect = Exception("Search failed")
related = provider._get_related_symbols(mock_symbol)
assert related == []
class TestMCPProviderIntegration:
"""Integration-style tests for MCPProvider."""
def test_full_context_workflow(self):
"""Test complete context building workflow."""
# Create temp file
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write("def my_function(arg1, arg2):\n")
f.write(" '''This is my function.'''\n")
f.write(" return arg1 + arg2\n")
temp_path = f.name
try:
# Setup mocks
mock_global_index = Mock()
mock_search_engine = Mock()
mock_registry = Mock()
mock_symbol = Mock()
mock_symbol.name = "my_function"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (1, 3)
mock_ref = Mock()
mock_ref.file_path = "/user.py"
mock_ref.line = 10
mock_ref.column = 4
mock_ref.context = "result = my_function(1, 2)"
mock_ref.relationship_type = "call"
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = [mock_ref]
provider = MCPProvider(mock_global_index, mock_search_engine, mock_registry)
context = provider.build_context("my_function")
assert context is not None
assert context.symbol.name == "my_function"
assert context.definition is not None
assert "def my_function" in context.definition
assert len(context.references) == 1
assert context.references[0].relationship_type == "call"
# Test serialization
json_str = context.to_json()
assert "my_function" in json_str
# Test prompt injection
prompt = context.to_prompt_injection()
assert "<code_context>" in prompt
assert "my_function" in prompt
assert "</code_context>" in prompt
finally:
os.unlink(temp_path)

View File

@@ -0,0 +1,288 @@
"""Tests for MCP schema."""
import pytest
import json
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
class TestSymbolInfo:
"""Test SymbolInfo dataclass."""
def test_to_dict_includes_all_fields(self):
"""SymbolInfo.to_dict() includes all non-None fields."""
info = SymbolInfo(
name="func",
kind="function",
file_path="/test.py",
line_start=10,
line_end=20,
signature="def func():",
documentation="Test doc",
)
d = info.to_dict()
assert d["name"] == "func"
assert d["kind"] == "function"
assert d["file_path"] == "/test.py"
assert d["line_start"] == 10
assert d["line_end"] == 20
assert d["signature"] == "def func():"
assert d["documentation"] == "Test doc"
def test_to_dict_excludes_none(self):
"""SymbolInfo.to_dict() excludes None fields."""
info = SymbolInfo(
name="func",
kind="function",
file_path="/test.py",
line_start=10,
line_end=20,
)
d = info.to_dict()
assert "signature" not in d
assert "documentation" not in d
assert "name" in d
assert "kind" in d
def test_basic_creation(self):
"""SymbolInfo can be created with required fields only."""
info = SymbolInfo(
name="MyClass",
kind="class",
file_path="/src/module.py",
line_start=1,
line_end=50,
)
assert info.name == "MyClass"
assert info.kind == "class"
assert info.signature is None
assert info.documentation is None
class TestReferenceInfo:
"""Test ReferenceInfo dataclass."""
def test_to_dict(self):
"""ReferenceInfo.to_dict() returns all fields."""
ref = ReferenceInfo(
file_path="/src/main.py",
line=25,
column=4,
context="result = func()",
relationship_type="call",
)
d = ref.to_dict()
assert d["file_path"] == "/src/main.py"
assert d["line"] == 25
assert d["column"] == 4
assert d["context"] == "result = func()"
assert d["relationship_type"] == "call"
def test_all_fields_required(self):
"""ReferenceInfo requires all fields."""
ref = ReferenceInfo(
file_path="/test.py",
line=10,
column=0,
context="import module",
relationship_type="import",
)
assert ref.file_path == "/test.py"
assert ref.relationship_type == "import"
class TestRelatedSymbol:
"""Test RelatedSymbol dataclass."""
def test_to_dict_includes_all_fields(self):
"""RelatedSymbol.to_dict() includes all non-None fields."""
sym = RelatedSymbol(
name="BaseClass",
kind="class",
relationship="inherits",
file_path="/src/base.py",
)
d = sym.to_dict()
assert d["name"] == "BaseClass"
assert d["kind"] == "class"
assert d["relationship"] == "inherits"
assert d["file_path"] == "/src/base.py"
def test_to_dict_excludes_none(self):
"""RelatedSymbol.to_dict() excludes None file_path."""
sym = RelatedSymbol(
name="helper",
kind="function",
relationship="calls",
)
d = sym.to_dict()
assert "file_path" not in d
assert d["name"] == "helper"
assert d["relationship"] == "calls"
class TestMCPContext:
"""Test MCPContext dataclass."""
def test_to_dict_basic(self):
"""MCPContext.to_dict() returns basic structure."""
ctx = MCPContext(context_type="test")
d = ctx.to_dict()
assert d["version"] == "1.0"
assert d["context_type"] == "test"
assert d["metadata"] == {}
def test_to_dict_with_symbol(self):
"""MCPContext.to_dict() includes symbol when present."""
ctx = MCPContext(
context_type="test",
symbol=SymbolInfo("f", "function", "/t.py", 1, 2),
)
d = ctx.to_dict()
assert "symbol" in d
assert d["symbol"]["name"] == "f"
assert d["symbol"]["kind"] == "function"
def test_to_dict_with_references(self):
"""MCPContext.to_dict() includes references when present."""
ctx = MCPContext(
context_type="test",
references=[
ReferenceInfo("/a.py", 10, 0, "call()", "call"),
ReferenceInfo("/b.py", 20, 5, "import x", "import"),
],
)
d = ctx.to_dict()
assert "references" in d
assert len(d["references"]) == 2
assert d["references"][0]["line"] == 10
def test_to_dict_with_related_symbols(self):
"""MCPContext.to_dict() includes related_symbols when present."""
ctx = MCPContext(
context_type="test",
related_symbols=[
RelatedSymbol("Base", "class", "inherits"),
RelatedSymbol("helper", "function", "calls"),
],
)
d = ctx.to_dict()
assert "related_symbols" in d
assert len(d["related_symbols"]) == 2
def test_to_json(self):
"""MCPContext.to_json() returns valid JSON."""
ctx = MCPContext(context_type="test")
j = ctx.to_json()
parsed = json.loads(j)
assert parsed["version"] == "1.0"
assert parsed["context_type"] == "test"
def test_to_json_with_indent(self):
"""MCPContext.to_json() respects indent parameter."""
ctx = MCPContext(context_type="test")
j = ctx.to_json(indent=4)
# Check it's properly indented
assert " " in j
def test_to_prompt_injection_basic(self):
"""MCPContext.to_prompt_injection() returns formatted string."""
ctx = MCPContext(
symbol=SymbolInfo("my_func", "function", "/test.py", 10, 20),
definition="def my_func(): pass",
)
prompt = ctx.to_prompt_injection()
assert "<code_context>" in prompt
assert "my_func" in prompt
assert "def my_func()" in prompt
assert "</code_context>" in prompt
def test_to_prompt_injection_with_references(self):
"""MCPContext.to_prompt_injection() includes references."""
ctx = MCPContext(
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
references=[
ReferenceInfo("/a.py", 10, 0, "func()", "call"),
ReferenceInfo("/b.py", 20, 0, "from x import func", "import"),
],
)
prompt = ctx.to_prompt_injection()
assert "References (2 found)" in prompt
assert "/a.py:10" in prompt
assert "call" in prompt
def test_to_prompt_injection_limits_references(self):
"""MCPContext.to_prompt_injection() limits references to 5."""
refs = [
ReferenceInfo(f"/file{i}.py", i, 0, f"ref{i}", "call")
for i in range(10)
]
ctx = MCPContext(
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
references=refs,
)
prompt = ctx.to_prompt_injection()
# Should show "10 found" but only include 5
assert "References (10 found)" in prompt
assert "/file0.py" in prompt
assert "/file4.py" in prompt
assert "/file5.py" not in prompt
def test_to_prompt_injection_with_related_symbols(self):
"""MCPContext.to_prompt_injection() includes related symbols."""
ctx = MCPContext(
symbol=SymbolInfo("MyClass", "class", "/test.py", 1, 50),
related_symbols=[
RelatedSymbol("BaseClass", "class", "inherits"),
RelatedSymbol("helper", "function", "calls"),
],
)
prompt = ctx.to_prompt_injection()
assert "Related Symbols" in prompt
assert "BaseClass (inherits)" in prompt
assert "helper (calls)" in prompt
def test_to_prompt_injection_limits_related_symbols(self):
"""MCPContext.to_prompt_injection() limits related symbols to 10."""
related = [
RelatedSymbol(f"sym{i}", "function", "calls")
for i in range(15)
]
ctx = MCPContext(
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
related_symbols=related,
)
prompt = ctx.to_prompt_injection()
assert "sym0 (calls)" in prompt
assert "sym9 (calls)" in prompt
assert "sym10 (calls)" not in prompt
def test_empty_context(self):
"""MCPContext works with minimal data."""
ctx = MCPContext()
d = ctx.to_dict()
assert d["version"] == "1.0"
assert d["context_type"] == "code_context"
prompt = ctx.to_prompt_injection()
assert "<code_context>" in prompt
assert "</code_context>" in prompt
def test_metadata_preserved(self):
"""MCPContext preserves custom metadata."""
ctx = MCPContext(
context_type="custom",
metadata={
"source": "codex-lens",
"indexed_at": "2024-01-01",
"custom_key": "custom_value",
},
)
d = ctx.to_dict()
assert d["metadata"]["source"] == "codex-lens"
assert d["metadata"]["custom_key"] == "custom_value"

View File

@@ -79,3 +79,87 @@ def test_symbol_filtering_handles_path_failures(monkeypatch: pytest.MonkeyPatch,
if os.name == "nt":
assert "CrossDrive" in caplog.text
def test_cascade_search_strategy_routing(temp_paths: Path) -> None:
"""Test cascade_search() routes to correct strategy implementation."""
from unittest.mock import patch
from codexlens.search.chain_search import ChainSearchResult, SearchStats
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
mapper = PathMapper(index_root=temp_paths / "indexes")
config = Config(data_dir=temp_paths / "data")
engine = ChainSearchEngine(registry, mapper, config=config)
source_path = temp_paths / "src"
# Test strategy='staged' routing
with patch.object(engine, "staged_cascade_search") as mock_staged:
mock_staged.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="staged")
mock_staged.assert_called_once()
# Test strategy='binary' routing
with patch.object(engine, "binary_cascade_search") as mock_binary:
mock_binary.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="binary")
mock_binary.assert_called_once()
# Test strategy='hybrid' routing
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
mock_hybrid.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="hybrid")
mock_hybrid.assert_called_once()
# Test strategy='binary_rerank' routing
with patch.object(engine, "binary_rerank_cascade_search") as mock_br:
mock_br.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="binary_rerank")
mock_br.assert_called_once()
# Test strategy='dense_rerank' routing
with patch.object(engine, "dense_rerank_cascade_search") as mock_dr:
mock_dr.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="dense_rerank")
mock_dr.assert_called_once()
# Test default routing (no strategy specified) - defaults to binary
with patch.object(engine, "binary_cascade_search") as mock_default:
mock_default.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path)
mock_default.assert_called_once()
def test_cascade_search_invalid_strategy(temp_paths: Path) -> None:
"""Test cascade_search() defaults to 'binary' for invalid strategy."""
from unittest.mock import patch
from codexlens.search.chain_search import ChainSearchResult, SearchStats
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
mapper = PathMapper(index_root=temp_paths / "indexes")
config = Config(data_dir=temp_paths / "data")
engine = ChainSearchEngine(registry, mapper, config=config)
source_path = temp_paths / "src"
# Invalid strategy should default to binary
with patch.object(engine, "binary_cascade_search") as mock_binary:
mock_binary.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="invalid_strategy")
mock_binary.assert_called_once()

View File

@@ -0,0 +1,766 @@
"""Unit tests for clustering strategies in the hybrid search pipeline.
Tests cover:
1. HDBSCANStrategy - Primary HDBSCAN clustering
2. DBSCANStrategy - Fallback DBSCAN clustering
3. NoOpStrategy - No-op fallback when clustering unavailable
4. ClusteringStrategyFactory - Factory with fallback chain
"""
from __future__ import annotations
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from codexlens.entities import SearchResult
from codexlens.search.clustering import (
BaseClusteringStrategy,
ClusteringConfig,
ClusteringStrategyFactory,
NoOpStrategy,
check_clustering_strategy_available,
get_strategy,
)
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def sample_results() -> List[SearchResult]:
"""Create sample search results for testing."""
return [
SearchResult(path="a.py", score=0.9, excerpt="def foo(): pass"),
SearchResult(path="b.py", score=0.8, excerpt="def foo(): pass"),
SearchResult(path="c.py", score=0.7, excerpt="def bar(): pass"),
SearchResult(path="d.py", score=0.6, excerpt="def bar(): pass"),
SearchResult(path="e.py", score=0.5, excerpt="def baz(): pass"),
]
@pytest.fixture
def mock_embeddings():
"""Create mock embeddings for 5 results.
Creates embeddings that should form 2 clusters:
- Results 0, 1 (similar to each other)
- Results 2, 3 (similar to each other)
- Result 4 (noise/singleton)
"""
import numpy as np
# Create embeddings in 3D for simplicity
return np.array(
[
[1.0, 0.0, 0.0], # Result 0 - cluster A
[0.9, 0.1, 0.0], # Result 1 - cluster A
[0.0, 1.0, 0.0], # Result 2 - cluster B
[0.1, 0.9, 0.0], # Result 3 - cluster B
[0.0, 0.0, 1.0], # Result 4 - noise/singleton
],
dtype=np.float32,
)
@pytest.fixture
def default_config() -> ClusteringConfig:
"""Create default clustering configuration."""
return ClusteringConfig(
min_cluster_size=2,
min_samples=1,
metric="euclidean",
)
# =============================================================================
# Test ClusteringConfig
# =============================================================================
class TestClusteringConfig:
"""Tests for ClusteringConfig validation."""
def test_default_values(self):
"""Test default configuration values."""
config = ClusteringConfig()
assert config.min_cluster_size == 3
assert config.min_samples == 2
assert config.metric == "cosine"
assert config.cluster_selection_epsilon == 0.0
assert config.allow_single_cluster is True
assert config.prediction_data is False
def test_custom_values(self):
"""Test custom configuration values."""
config = ClusteringConfig(
min_cluster_size=5,
min_samples=3,
metric="euclidean",
cluster_selection_epsilon=0.1,
allow_single_cluster=False,
prediction_data=True,
)
assert config.min_cluster_size == 5
assert config.min_samples == 3
assert config.metric == "euclidean"
def test_invalid_min_cluster_size(self):
"""Test validation rejects min_cluster_size < 2."""
with pytest.raises(ValueError, match="min_cluster_size must be >= 2"):
ClusteringConfig(min_cluster_size=1)
def test_invalid_min_samples(self):
"""Test validation rejects min_samples < 1."""
with pytest.raises(ValueError, match="min_samples must be >= 1"):
ClusteringConfig(min_samples=0)
def test_invalid_metric(self):
"""Test validation rejects invalid metric."""
with pytest.raises(ValueError, match="metric must be one of"):
ClusteringConfig(metric="invalid")
def test_invalid_epsilon(self):
"""Test validation rejects negative epsilon."""
with pytest.raises(ValueError, match="cluster_selection_epsilon must be >= 0"):
ClusteringConfig(cluster_selection_epsilon=-0.1)
# =============================================================================
# Test NoOpStrategy
# =============================================================================
class TestNoOpStrategy:
"""Tests for NoOpStrategy - always available."""
def test_cluster_returns_singleton_clusters(
self, sample_results: List[SearchResult], mock_embeddings
):
"""Test cluster() returns each result as singleton cluster."""
strategy = NoOpStrategy()
clusters = strategy.cluster(mock_embeddings, sample_results)
assert len(clusters) == 5
for i, cluster in enumerate(clusters):
assert cluster == [i]
def test_cluster_empty_results(self):
"""Test cluster() with empty results."""
import numpy as np
strategy = NoOpStrategy()
clusters = strategy.cluster(np.array([]), [])
assert clusters == []
def test_select_representatives_returns_all_sorted(
self, sample_results: List[SearchResult]
):
"""Test select_representatives() returns all results sorted by score."""
strategy = NoOpStrategy()
clusters = [[i] for i in range(len(sample_results))]
representatives = strategy.select_representatives(clusters, sample_results)
assert len(representatives) == 5
# Check sorted by score descending
scores = [r.score for r in representatives]
assert scores == sorted(scores, reverse=True)
def test_select_representatives_empty(self):
"""Test select_representatives() with empty input."""
strategy = NoOpStrategy()
representatives = strategy.select_representatives([], [])
assert representatives == []
def test_fit_predict_convenience_method(
self, sample_results: List[SearchResult], mock_embeddings
):
"""Test fit_predict() convenience method."""
strategy = NoOpStrategy()
representatives = strategy.fit_predict(mock_embeddings, sample_results)
assert len(representatives) == 5
# All results returned, sorted by score
assert representatives[0].score >= representatives[-1].score
# =============================================================================
# Test HDBSCANStrategy
# =============================================================================
class TestHDBSCANStrategy:
"""Tests for HDBSCANStrategy - requires hdbscan package."""
@pytest.fixture
def hdbscan_strategy(self, default_config):
"""Create HDBSCANStrategy if available."""
try:
from codexlens.search.clustering import HDBSCANStrategy
return HDBSCANStrategy(default_config)
except ImportError:
pytest.skip("hdbscan not installed")
def test_cluster_returns_list_of_lists(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test cluster() returns List[List[int]]."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
assert isinstance(clusters, list)
for cluster in clusters:
assert isinstance(cluster, list)
for idx in cluster:
assert isinstance(idx, int)
assert 0 <= idx < len(sample_results)
def test_cluster_covers_all_results(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test all result indices appear in clusters."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
all_indices = set()
for cluster in clusters:
all_indices.update(cluster)
assert all_indices == set(range(len(sample_results)))
def test_cluster_empty_results(self, hdbscan_strategy):
"""Test cluster() with empty results."""
import numpy as np
clusters = hdbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
assert clusters == []
def test_cluster_single_result(self, hdbscan_strategy):
"""Test cluster() with single result."""
import numpy as np
result = SearchResult(path="a.py", score=0.9, excerpt="test")
embeddings = np.array([[1.0, 0.0, 0.0]])
clusters = hdbscan_strategy.cluster(embeddings, [result])
assert len(clusters) == 1
assert clusters[0] == [0]
def test_cluster_fewer_than_min_cluster_size(self, hdbscan_strategy):
"""Test cluster() with fewer results than min_cluster_size."""
import numpy as np
# Strategy has min_cluster_size=2, so 1 result returns singleton
result = SearchResult(path="a.py", score=0.9, excerpt="test")
embeddings = np.array([[1.0, 0.0, 0.0]])
clusters = hdbscan_strategy.cluster(embeddings, [result])
assert len(clusters) == 1
assert clusters[0] == [0]
def test_select_representatives_picks_highest_score(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test select_representatives() picks highest score per cluster."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = hdbscan_strategy.select_representatives(
clusters, sample_results
)
# Each representative should be the highest-scored in its cluster
for rep in representatives:
# Find the cluster containing this representative
rep_idx = next(
i for i, r in enumerate(sample_results) if r.path == rep.path
)
for cluster in clusters:
if rep_idx in cluster:
cluster_scores = [sample_results[i].score for i in cluster]
assert rep.score == max(cluster_scores)
break
def test_select_representatives_sorted_by_score(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test representatives are sorted by score descending."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = hdbscan_strategy.select_representatives(
clusters, sample_results
)
scores = [r.score for r in representatives]
assert scores == sorted(scores, reverse=True)
def test_fit_predict_end_to_end(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test fit_predict() end-to-end clustering."""
representatives = hdbscan_strategy.fit_predict(mock_embeddings, sample_results)
# Should have fewer or equal representatives than input
assert len(representatives) <= len(sample_results)
# All representatives should be from original results
rep_paths = {r.path for r in representatives}
original_paths = {r.path for r in sample_results}
assert rep_paths.issubset(original_paths)
# =============================================================================
# Test DBSCANStrategy
# =============================================================================
class TestDBSCANStrategy:
"""Tests for DBSCANStrategy - requires sklearn."""
@pytest.fixture
def dbscan_strategy(self, default_config):
"""Create DBSCANStrategy if available."""
try:
from codexlens.search.clustering import DBSCANStrategy
return DBSCANStrategy(default_config)
except ImportError:
pytest.skip("sklearn not installed")
def test_cluster_returns_list_of_lists(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test cluster() returns List[List[int]]."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
assert isinstance(clusters, list)
for cluster in clusters:
assert isinstance(cluster, list)
for idx in cluster:
assert isinstance(idx, int)
assert 0 <= idx < len(sample_results)
def test_cluster_covers_all_results(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test all result indices appear in clusters."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
all_indices = set()
for cluster in clusters:
all_indices.update(cluster)
assert all_indices == set(range(len(sample_results)))
def test_cluster_empty_results(self, dbscan_strategy):
"""Test cluster() with empty results."""
import numpy as np
clusters = dbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
assert clusters == []
def test_cluster_single_result(self, dbscan_strategy):
"""Test cluster() with single result."""
import numpy as np
result = SearchResult(path="a.py", score=0.9, excerpt="test")
embeddings = np.array([[1.0, 0.0, 0.0]])
clusters = dbscan_strategy.cluster(embeddings, [result])
assert len(clusters) == 1
assert clusters[0] == [0]
def test_cluster_with_explicit_eps(self, default_config):
"""Test cluster() with explicit eps parameter."""
try:
from codexlens.search.clustering import DBSCANStrategy
except ImportError:
pytest.skip("sklearn not installed")
import numpy as np
strategy = DBSCANStrategy(default_config, eps=0.5)
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(3)]
embeddings = np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0]])
clusters = strategy.cluster(embeddings, results)
# With eps=0.5, first two should cluster, third should be separate
assert len(clusters) >= 2
def test_auto_compute_eps(self, dbscan_strategy, mock_embeddings):
"""Test eps auto-computation from distance distribution."""
# Should not raise - eps is computed automatically
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(5)]
clusters = dbscan_strategy.cluster(mock_embeddings, results)
assert len(clusters) > 0
def test_select_representatives_picks_highest_score(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test select_representatives() picks highest score per cluster."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = dbscan_strategy.select_representatives(
clusters, sample_results
)
# Each representative should be the highest-scored in its cluster
for rep in representatives:
rep_idx = next(
i for i, r in enumerate(sample_results) if r.path == rep.path
)
for cluster in clusters:
if rep_idx in cluster:
cluster_scores = [sample_results[i].score for i in cluster]
assert rep.score == max(cluster_scores)
break
def test_select_representatives_sorted_by_score(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test representatives are sorted by score descending."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = dbscan_strategy.select_representatives(
clusters, sample_results
)
scores = [r.score for r in representatives]
assert scores == sorted(scores, reverse=True)
# =============================================================================
# Test ClusteringStrategyFactory
# =============================================================================
class TestClusteringStrategyFactory:
"""Tests for ClusteringStrategyFactory."""
def test_check_noop_always_available(self):
"""Test noop strategy is always available."""
ok, err = check_clustering_strategy_available("noop")
assert ok is True
assert err is None
def test_check_invalid_strategy(self):
"""Test invalid strategy name returns error."""
ok, err = check_clustering_strategy_available("invalid")
assert ok is False
assert "Invalid clustering strategy" in err
def test_get_strategy_noop(self, default_config):
"""Test get_strategy('noop') returns NoOpStrategy."""
strategy = get_strategy("noop", default_config)
assert isinstance(strategy, NoOpStrategy)
def test_get_strategy_auto_returns_something(self, default_config):
"""Test get_strategy('auto') returns a strategy."""
strategy = get_strategy("auto", default_config)
assert isinstance(strategy, BaseClusteringStrategy)
def test_get_strategy_with_fallback_enabled(self, default_config):
"""Test fallback when primary strategy unavailable."""
# Mock hdbscan unavailable
with patch.dict("sys.modules", {"hdbscan": None}):
# Should fall back to dbscan or noop
strategy = get_strategy("hdbscan", default_config, fallback=True)
assert isinstance(strategy, BaseClusteringStrategy)
def test_get_strategy_fallback_disabled_raises(self, default_config):
"""Test ImportError when fallback disabled and strategy unavailable."""
with patch(
"codexlens.search.clustering.factory.check_clustering_strategy_available"
) as mock_check:
mock_check.return_value = (False, "Test error")
with pytest.raises(ImportError, match="Test error"):
get_strategy("hdbscan", default_config, fallback=False)
def test_get_strategy_invalid_raises(self, default_config):
"""Test ValueError for invalid strategy name."""
with pytest.raises(ValueError, match="Unknown clustering strategy"):
get_strategy("invalid", default_config)
def test_factory_class_interface(self, default_config):
"""Test ClusteringStrategyFactory class interface."""
strategy = ClusteringStrategyFactory.get_strategy("noop", default_config)
assert isinstance(strategy, NoOpStrategy)
ok, err = ClusteringStrategyFactory.check_available("noop")
assert ok is True
@pytest.mark.skipif(
not check_clustering_strategy_available("hdbscan")[0],
reason="hdbscan not installed",
)
def test_get_strategy_hdbscan(self, default_config):
"""Test get_strategy('hdbscan') returns HDBSCANStrategy."""
from codexlens.search.clustering import HDBSCANStrategy
strategy = get_strategy("hdbscan", default_config)
assert isinstance(strategy, HDBSCANStrategy)
@pytest.mark.skipif(
not check_clustering_strategy_available("dbscan")[0],
reason="sklearn not installed",
)
def test_get_strategy_dbscan(self, default_config):
"""Test get_strategy('dbscan') returns DBSCANStrategy."""
from codexlens.search.clustering import DBSCANStrategy
strategy = get_strategy("dbscan", default_config)
assert isinstance(strategy, DBSCANStrategy)
@pytest.mark.skipif(
not check_clustering_strategy_available("dbscan")[0],
reason="sklearn not installed",
)
def test_get_strategy_dbscan_with_kwargs(self, default_config):
"""Test DBSCANStrategy kwargs passed through factory."""
strategy = get_strategy("dbscan", default_config, eps=0.3, eps_percentile=20.0)
assert strategy.eps == 0.3
assert strategy.eps_percentile == 20.0
# =============================================================================
# Integration Tests
# =============================================================================
class TestClusteringIntegration:
"""Integration tests for clustering strategies."""
def test_all_strategies_same_interface(
self, sample_results: List[SearchResult], mock_embeddings, default_config
):
"""Test all strategies have consistent interface."""
strategies = [NoOpStrategy(default_config)]
# Add available strategies
try:
from codexlens.search.clustering import HDBSCANStrategy
strategies.append(HDBSCANStrategy(default_config))
except ImportError:
pass
try:
from codexlens.search.clustering import DBSCANStrategy
strategies.append(DBSCANStrategy(default_config))
except ImportError:
pass
for strategy in strategies:
# All should implement cluster()
clusters = strategy.cluster(mock_embeddings, sample_results)
assert isinstance(clusters, list)
# All should implement select_representatives()
reps = strategy.select_representatives(clusters, sample_results)
assert isinstance(reps, list)
assert all(isinstance(r, SearchResult) for r in reps)
# All should implement fit_predict()
reps = strategy.fit_predict(mock_embeddings, sample_results)
assert isinstance(reps, list)
def test_clustering_reduces_redundancy(
self, default_config
):
"""Test clustering reduces redundant similar results."""
import numpy as np
# Create results with very similar embeddings
results = [
SearchResult(path=f"{i}.py", score=0.9 - i * 0.01, excerpt="def foo(): pass")
for i in range(10)
]
# Very similar embeddings - should cluster together
embeddings = np.array(
[[1.0 + i * 0.01, 0.0, 0.0] for i in range(10)], dtype=np.float32
)
strategy = get_strategy("auto", default_config)
representatives = strategy.fit_predict(embeddings, results)
# Should have fewer representatives than input (clustering reduced redundancy)
# NoOp returns all, but HDBSCAN/DBSCAN should reduce
assert len(representatives) <= len(results)
# =============================================================================
# Test FrequencyStrategy
# =============================================================================
class TestFrequencyStrategy:
"""Tests for FrequencyStrategy - frequency-based clustering."""
@pytest.fixture
def frequency_config(self):
"""Create FrequencyConfig for testing."""
from codexlens.search.clustering import FrequencyConfig
return FrequencyConfig(min_frequency=1, max_representatives_per_group=3)
@pytest.fixture
def frequency_strategy(self, frequency_config):
"""Create FrequencyStrategy instance."""
from codexlens.search.clustering import FrequencyStrategy
return FrequencyStrategy(frequency_config)
@pytest.fixture
def symbol_results(self) -> List[SearchResult]:
"""Create sample results with symbol names for frequency testing."""
return [
SearchResult(path="auth.py", score=0.9, excerpt="authenticate user", symbol_name="authenticate"),
SearchResult(path="login.py", score=0.85, excerpt="authenticate login", symbol_name="authenticate"),
SearchResult(path="session.py", score=0.8, excerpt="authenticate session", symbol_name="authenticate"),
SearchResult(path="utils.py", score=0.7, excerpt="helper function", symbol_name="helper_func"),
SearchResult(path="validate.py", score=0.6, excerpt="validate input", symbol_name="validate"),
SearchResult(path="check.py", score=0.55, excerpt="validate data", symbol_name="validate"),
]
def test_frequency_strategy_available(self):
"""Test FrequencyStrategy is always available (no deps)."""
ok, err = check_clustering_strategy_available("frequency")
assert ok is True
assert err is None
def test_get_strategy_frequency(self):
"""Test get_strategy('frequency') returns FrequencyStrategy."""
from codexlens.search.clustering import FrequencyStrategy
strategy = get_strategy("frequency")
assert isinstance(strategy, FrequencyStrategy)
def test_cluster_groups_by_symbol(self, frequency_strategy, symbol_results):
"""Test cluster() groups results by symbol name."""
import numpy as np
embeddings = np.random.rand(len(symbol_results), 128)
clusters = frequency_strategy.cluster(embeddings, symbol_results)
# Should have 3 groups: authenticate(3), validate(2), helper_func(1)
assert len(clusters) == 3
# First cluster should be authenticate (highest frequency)
first_cluster_symbols = [symbol_results[i].symbol_name for i in clusters[0]]
assert all(s == "authenticate" for s in first_cluster_symbols)
assert len(clusters[0]) == 3
def test_cluster_orders_by_frequency(self, frequency_strategy, symbol_results):
"""Test clusters are ordered by frequency (descending)."""
import numpy as np
embeddings = np.random.rand(len(symbol_results), 128)
clusters = frequency_strategy.cluster(embeddings, symbol_results)
# Verify frequency ordering
frequencies = [len(c) for c in clusters]
assert frequencies == sorted(frequencies, reverse=True)
def test_select_representatives_adds_frequency_metadata(self, frequency_strategy, symbol_results):
"""Test representatives have frequency metadata."""
import numpy as np
embeddings = np.random.rand(len(symbol_results), 128)
clusters = frequency_strategy.cluster(embeddings, symbol_results)
reps = frequency_strategy.select_representatives(clusters, symbol_results, embeddings)
# Check frequency metadata
for rep in reps:
assert "frequency" in rep.metadata
assert rep.metadata["frequency"] >= 1
def test_min_frequency_filter_mode(self, symbol_results):
"""Test min_frequency with filter mode removes low-frequency results."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(min_frequency=2, keep_mode="filter")
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# helper_func (freq=1) should be filtered out
rep_symbols = [r.symbol_name for r in reps]
assert "helper_func" not in rep_symbols
assert "authenticate" in rep_symbols
assert "validate" in rep_symbols
def test_min_frequency_demote_mode(self, symbol_results):
"""Test min_frequency with demote mode keeps but deprioritizes low-frequency."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(min_frequency=2, keep_mode="demote")
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# helper_func should still be present but at the end
rep_symbols = [r.symbol_name for r in reps]
assert "helper_func" in rep_symbols
# Should be demoted to end
helper_idx = rep_symbols.index("helper_func")
assert helper_idx == len(rep_symbols) - 1
def test_group_by_file(self, symbol_results):
"""Test grouping by file path instead of symbol."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(group_by="file")
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
clusters = strategy.cluster(embeddings, symbol_results)
# Each file should be its own group (all unique paths)
assert len(clusters) == 6
def test_max_representatives_per_group(self, symbol_results):
"""Test max_representatives_per_group limits output per symbol."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(max_representatives_per_group=1)
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# Should have at most 1 per group = 3 groups = 3 reps
assert len(reps) == 3
def test_frequency_boost_score(self, symbol_results):
"""Test frequency_weight boosts high-frequency results."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(frequency_weight=0.5) # Strong boost
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# High-frequency results should have boosted scores in metadata
for rep in reps:
if rep.metadata.get("frequency", 1) > 1:
assert rep.metadata.get("frequency_boosted_score", 0) > rep.score
def test_empty_results(self, frequency_strategy):
"""Test handling of empty results."""
import numpy as np
clusters = frequency_strategy.cluster(np.array([]).reshape(0, 128), [])
assert clusters == []
reps = frequency_strategy.select_representatives([], [], None)
assert reps == []
def test_factory_with_kwargs(self):
"""Test factory passes kwargs to FrequencyConfig."""
strategy = get_strategy("frequency", min_frequency=3, group_by="file")
assert strategy.config.min_frequency == 3
assert strategy.config.group_by == "file"

View File

@@ -0,0 +1,698 @@
"""Integration tests for staged cascade search pipeline.
Tests the 4-stage pipeline:
1. Stage 1: Binary coarse search
2. Stage 2: LSP graph expansion
3. Stage 3: Clustering and representative selection
4. Stage 4: Optional cross-encoder reranking
"""
from __future__ import annotations
import json
import tempfile
from pathlib import Path
from typing import List
from unittest.mock import MagicMock, Mock, patch
import pytest
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def temp_paths():
"""Create temporary directory structure."""
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
root = Path(tmpdir.name)
yield root
try:
tmpdir.cleanup()
except (PermissionError, OSError):
pass
@pytest.fixture
def mock_registry(temp_paths: Path):
"""Create mock registry store."""
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
return registry
@pytest.fixture
def mock_mapper(temp_paths: Path):
"""Create path mapper."""
return PathMapper(index_root=temp_paths / "indexes")
@pytest.fixture
def mock_config():
"""Create mock config with staged cascade settings."""
config = MagicMock(spec=Config)
config.cascade_coarse_k = 100
config.cascade_fine_k = 10
config.enable_staged_rerank = False
config.staged_clustering_strategy = "auto"
config.staged_clustering_min_size = 3
config.graph_expansion_depth = 2
return config
@pytest.fixture
def sample_binary_results() -> List[SearchResult]:
"""Create sample binary search results for testing."""
return [
SearchResult(
path="a.py",
score=0.95,
excerpt="def authenticate_user(username, password):",
symbol_name="authenticate_user",
symbol_kind="function",
start_line=10,
end_line=15,
),
SearchResult(
path="b.py",
score=0.85,
excerpt="class AuthManager:",
symbol_name="AuthManager",
symbol_kind="class",
start_line=5,
end_line=20,
),
SearchResult(
path="c.py",
score=0.75,
excerpt="def check_credentials(user, pwd):",
symbol_name="check_credentials",
symbol_kind="function",
start_line=30,
end_line=35,
),
]
@pytest.fixture
def sample_expanded_results() -> List[SearchResult]:
"""Create sample expanded results (after LSP expansion)."""
return [
SearchResult(
path="a.py",
score=0.95,
excerpt="def authenticate_user(username, password):",
symbol_name="authenticate_user",
symbol_kind="function",
),
SearchResult(
path="a.py",
score=0.90,
excerpt="def verify_password(pwd):",
symbol_name="verify_password",
symbol_kind="function",
),
SearchResult(
path="b.py",
score=0.85,
excerpt="class AuthManager:",
symbol_name="AuthManager",
symbol_kind="class",
),
SearchResult(
path="b.py",
score=0.80,
excerpt="def login(self, user):",
symbol_name="login",
symbol_kind="function",
),
SearchResult(
path="c.py",
score=0.75,
excerpt="def check_credentials(user, pwd):",
symbol_name="check_credentials",
symbol_kind="function",
),
SearchResult(
path="d.py",
score=0.70,
excerpt="class UserModel:",
symbol_name="UserModel",
symbol_kind="class",
),
]
# =============================================================================
# Test Stage Methods
# =============================================================================
class TestStage1BinarySearch:
"""Tests for Stage 1: Binary coarse search."""
def test_stage1_returns_results_with_index_root(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage1_binary_search returns results and index_root."""
from codexlens.search.chain_search import SearchStats
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Mock the binary embedding backend (import is inside the method)
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
mock_index = MagicMock()
mock_index.count.return_value = 10
mock_index.search.return_value = ([1, 2, 3], [10, 20, 30])
mock_binary_idx.return_value = mock_index
index_paths = [Path("/fake/index1/_index.db")]
stats = SearchStats()
results, index_root = engine._stage1_binary_search(
"query", index_paths, coarse_k=10, stats=stats
)
assert isinstance(results, list)
assert isinstance(index_root, (Path, type(None)))
def test_stage1_handles_empty_index_paths(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage1_binary_search handles empty index paths."""
from codexlens.search.chain_search import SearchStats
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
index_paths = []
stats = SearchStats()
results, index_root = engine._stage1_binary_search(
"query", index_paths, coarse_k=10, stats=stats
)
assert results == []
assert index_root is None
def test_stage1_aggregates_results_from_multiple_indexes(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage1_binary_search aggregates results from multiple indexes."""
from codexlens.search.chain_search import SearchStats
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
mock_index = MagicMock()
mock_index.count.return_value = 10
# Return different results for different calls
mock_index.search.side_effect = [
([1, 2], [10, 20]),
([3, 4], [15, 25]),
]
mock_binary_idx.return_value = mock_index
index_paths = [
Path("/fake/index1/_index.db"),
Path("/fake/index2/_index.db"),
]
stats = SearchStats()
results, _ = engine._stage1_binary_search(
"query", index_paths, coarse_k=10, stats=stats
)
# Should aggregate candidates from both indexes
assert isinstance(results, list)
class TestStage2LSPExpand:
"""Tests for Stage 2: LSP graph expansion."""
def test_stage2_returns_expanded_results(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test _stage2_lsp_expand returns expanded results."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Import is inside the method, so we need to patch it there
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
mock_expander = MagicMock()
mock_expander.expand.return_value = [
SearchResult(path="related.py", score=0.7, excerpt="related")
]
mock_expander_cls.return_value = mock_expander
expanded = engine._stage2_lsp_expand(
sample_binary_results, index_root=Path("/fake/index")
)
assert isinstance(expanded, list)
# Should include original results
assert len(expanded) >= len(sample_binary_results)
def test_stage2_handles_no_index_root(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test _stage2_lsp_expand handles missing index_root."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
expanded = engine._stage2_lsp_expand(sample_binary_results, index_root=None)
# Should return original results unchanged
assert expanded == sample_binary_results
def test_stage2_handles_empty_results(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage2_lsp_expand handles empty input."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
expanded = engine._stage2_lsp_expand([], index_root=Path("/fake"))
assert expanded == []
def test_stage2_deduplicates_results(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test _stage2_lsp_expand deduplicates by (path, symbol_name, start_line)."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Mock expander to return duplicate of first result
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
mock_expander = MagicMock()
duplicate = SearchResult(
path=sample_binary_results[0].path,
score=0.5,
excerpt="duplicate",
symbol_name=sample_binary_results[0].symbol_name,
start_line=sample_binary_results[0].start_line,
)
mock_expander.expand.return_value = [duplicate]
mock_expander_cls.return_value = mock_expander
expanded = engine._stage2_lsp_expand(
sample_binary_results, index_root=Path("/fake")
)
# Should not include duplicate
assert len(expanded) == len(sample_binary_results)
class TestStage3ClusterPrune:
"""Tests for Stage 3: Clustering and representative selection."""
def test_stage3_returns_representatives(
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
):
"""Test _stage3_cluster_prune returns representative results."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
import numpy as np
# Mock embeddings
mock_embed.return_value = np.random.rand(
len(sample_expanded_results), 128
).astype(np.float32)
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
assert isinstance(clustered, list)
assert len(clustered) <= len(sample_expanded_results)
assert all(isinstance(r, SearchResult) for r in clustered)
def test_stage3_handles_few_results(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage3_cluster_prune skips clustering for few results."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
few_results = [
SearchResult(path="a.py", score=0.9, excerpt="a"),
SearchResult(path="b.py", score=0.8, excerpt="b"),
]
clustered = engine._stage3_cluster_prune(few_results, target_count=5)
# Should return all results unchanged
assert clustered == few_results
def test_stage3_handles_no_embeddings(
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
):
"""Test _stage3_cluster_prune falls back to score-based selection without embeddings."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
mock_embed.return_value = None
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
# Should return top-scored results
assert len(clustered) <= 3
# Should be sorted by score descending
scores = [r.score for r in clustered]
assert scores == sorted(scores, reverse=True)
def test_stage3_uses_config_clustering_strategy(
self, mock_registry, mock_mapper, sample_expanded_results
):
"""Test _stage3_cluster_prune uses config clustering strategy."""
config = MagicMock(spec=Config)
config.staged_clustering_strategy = "auto"
config.staged_clustering_min_size = 2
engine = ChainSearchEngine(mock_registry, PathMapper(), config=config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
import numpy as np
mock_embed.return_value = np.random.rand(
len(sample_expanded_results), 128
).astype(np.float32)
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
# Should use clustering (auto will pick best available)
# Result should be a list of SearchResult objects
assert isinstance(clustered, list)
assert all(isinstance(r, SearchResult) for r in clustered)
class TestStage4OptionalRerank:
"""Tests for Stage 4: Optional cross-encoder reranking."""
def test_stage4_reranks_with_reranker(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage4_optional_rerank uses _cross_encoder_rerank."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
results = [
SearchResult(path="a.py", score=0.9, excerpt="a"),
SearchResult(path="b.py", score=0.8, excerpt="b"),
SearchResult(path="c.py", score=0.7, excerpt="c"),
]
# Mock the _cross_encoder_rerank method that _stage4 calls
with patch.object(engine, "_cross_encoder_rerank") as mock_rerank:
mock_rerank.return_value = [
SearchResult(path="c.py", score=0.95, excerpt="c"),
SearchResult(path="a.py", score=0.85, excerpt="a"),
]
reranked = engine._stage4_optional_rerank("query", results, k=2)
mock_rerank.assert_called_once_with("query", results, 2)
assert len(reranked) <= 2
# First result should be reranked winner
assert reranked[0].path == "c.py"
def test_stage4_handles_empty_results(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage4_optional_rerank handles empty input."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
reranked = engine._stage4_optional_rerank("query", [], k=2)
# Should return empty list
assert reranked == []
# =============================================================================
# Integration Tests
# =============================================================================
class TestStagedCascadeIntegration:
"""Integration tests for staged_cascade_search() end-to-end."""
def test_staged_cascade_returns_chain_result(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search returns ChainSearchResult."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Mock all stages
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
result = engine.staged_cascade_search(
"query", temp_paths / "src", k=10, coarse_k=100
)
from codexlens.search.chain_search import ChainSearchResult
assert isinstance(result, ChainSearchResult)
assert result.query == "query"
assert len(result.results) <= 10
def test_staged_cascade_includes_stage_stats(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search includes per-stage timing stats."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
result = engine.staged_cascade_search(
"query", temp_paths / "src"
)
# Check for stage stats in errors field
stage_stats = None
for err in result.stats.errors:
if err.startswith("STAGE_STATS:"):
stage_stats = json.loads(err.replace("STAGE_STATS:", ""))
break
assert stage_stats is not None
assert "stage_times" in stage_stats
assert "stage_counts" in stage_stats
assert "stage1_binary_ms" in stage_stats["stage_times"]
assert "stage1_candidates" in stage_stats["stage_counts"]
def test_staged_cascade_with_rerank_enabled(
self, mock_registry, mock_mapper, temp_paths
):
"""Test staged_cascade_search with reranking enabled."""
config = MagicMock(spec=Config)
config.cascade_coarse_k = 100
config.cascade_fine_k = 10
config.enable_staged_rerank = True
config.staged_clustering_strategy = "auto"
config.graph_expansion_depth = 2
engine = ChainSearchEngine(mock_registry, mock_mapper, config=config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage4_optional_rerank") as mock_stage4:
mock_stage4.return_value = [
SearchResult(path="a.py", score=0.95, excerpt="a")
]
result = engine.staged_cascade_search(
"query", temp_paths / "src"
)
# Verify stage 4 was called
mock_stage4.assert_called_once()
def test_staged_cascade_fallback_to_hybrid(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search falls back to hybrid when numpy unavailable."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch("codexlens.search.chain_search.NUMPY_AVAILABLE", False):
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
mock_hybrid.return_value = MagicMock()
engine.staged_cascade_search("query", temp_paths / "src")
# Should fall back to hybrid cascade
mock_hybrid.assert_called_once()
def test_staged_cascade_deduplicates_final_results(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search deduplicates results by path."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
# Return duplicates with different scores
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a"),
SearchResult(path="a.py", score=0.8, excerpt="a duplicate"),
SearchResult(path="b.py", score=0.7, excerpt="b"),
]
result = engine.staged_cascade_search(
"query", temp_paths / "src", k=10
)
# Should deduplicate a.py (keep higher score)
paths = [r.path for r in result.results]
assert len(paths) == len(set(paths))
# a.py should have score 0.9
a_result = next(r for r in result.results if r.path == "a.py")
assert a_result.score == 0.9
# =============================================================================
# Graceful Degradation Tests
# =============================================================================
class TestStagedCascadeGracefulDegradation:
"""Tests for graceful degradation when dependencies unavailable."""
def test_falls_back_when_clustering_unavailable(
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
):
"""Test clustering stage falls back gracefully when clustering unavailable."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
mock_embed.return_value = None
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
# Should fall back to score-based selection
assert len(clustered) <= 3
def test_falls_back_when_graph_expander_unavailable(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test LSP expansion falls back when GraphExpander unavailable."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Patch the import inside the method
with patch("codexlens.search.graph_expander.GraphExpander", side_effect=ImportError):
expanded = engine._stage2_lsp_expand(
sample_binary_results, index_root=Path("/fake")
)
# Should return original results
assert expanded == sample_binary_results
def test_handles_stage_failures_gracefully(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged pipeline handles stage failures gracefully."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
# Stage 1 returns no results
mock_stage1.return_value = ([], None)
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
mock_hybrid.return_value = MagicMock()
engine.staged_cascade_search("query", temp_paths / "src")
# Should fall back to hybrid when stage 1 fails
mock_hybrid.assert_called_once()