mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat(cli): 添加 --rule 选项支持模板自动发现
重构 ccw cli 模板系统: - 新增 template-discovery.ts 模块,支持扁平化模板自动发现 - 添加 --rule <template> 选项,自动加载 protocol 和 template - 模板目录从嵌套结构 (prompts/category/file.txt) 迁移到扁平结构 (prompts/category-function.txt) - 更新所有 agent/command 文件,使用 $PROTO $TMPL 环境变量替代 $(cat ...) 模式 - 支持模糊匹配:--rule 02-review-architecture 可匹配 analysis-review-architecture.txt 其他更新: - Dashboard: 添加 Claude Manager 和 Issue Manager 页面 - Codex-lens: 增强 chain_search 和 clustering 模块 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
676
codex-lens/docs/CODEXLENS_LSP_API_SPEC.md
Normal file
676
codex-lens/docs/CODEXLENS_LSP_API_SPEC.md
Normal file
@@ -0,0 +1,676 @@
|
||||
# Codexlens LSP API 规范
|
||||
|
||||
**版本**: 1.1
|
||||
**状态**: ✅ APPROVED (Gemini Review)
|
||||
**架构**: codexlens 提供 Python API,CCW 实现 MCP 端点
|
||||
**分析来源**: Gemini (架构评审) + Codex (实现评审)
|
||||
**最后更新**: 2025-01-17
|
||||
|
||||
---
|
||||
|
||||
## 一、概述
|
||||
|
||||
### 1.1 背景
|
||||
|
||||
基于 cclsp MCP 服务器实现的分析,设计 codexlens 的 LSP 搜索方法接口,为 AI 提供代码智能能力。
|
||||
|
||||
### 1.2 架构决策
|
||||
|
||||
**MCP 端点由 CCW 实现,codexlens 只提供 Python API**
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Claude Code │
|
||||
│ ┌───────────────────────────────────────────────────────┐ │
|
||||
│ │ MCP Client │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────────────────────────────────────────────────┐ │
|
||||
│ │ CCW MCP Server │ │
|
||||
│ │ ┌─────────────────────────────────────────────────┐ │ │
|
||||
│ │ │ MCP Tool Handlers │ │ │
|
||||
│ │ │ • codexlens_file_context │ │ │
|
||||
│ │ │ • codexlens_find_definition │ │ │
|
||||
│ │ │ • codexlens_find_references │ │ │
|
||||
│ │ │ • codexlens_semantic_search │ │ │
|
||||
│ │ └──────────────────────┬──────────────────────────┘ │ │
|
||||
│ └─────────────────────────┼─────────────────────────────┘ │
|
||||
└────────────────────────────┼────────────────────────────────┘
|
||||
│ Python API 调用
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ codexlens │
|
||||
│ ┌───────────────────────────────────────────────────────┐ │
|
||||
│ │ Public API Layer │ │
|
||||
│ │ codexlens.api.file_context() │ │
|
||||
│ │ codexlens.api.find_definition() │ │
|
||||
│ │ codexlens.api.find_references() │ │
|
||||
│ │ codexlens.api.semantic_search() │ │
|
||||
│ └──────────────────────┬────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────▼────────────────────────────────┐ │
|
||||
│ │ Core Components │ │
|
||||
│ │ GlobalSymbolIndex | ChainSearchEngine | HoverProvider │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────────────────────▼────────────────────────────────┐ │
|
||||
│ │ SQLite Index Databases │ │
|
||||
│ │ global_symbols.db | *.index.db (per-directory) │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 1.3 职责分离
|
||||
|
||||
| 组件 | 职责 |
|
||||
|------|------|
|
||||
| **codexlens** | Python API、索引查询、搜索算法、结果聚合、降级处理 |
|
||||
| **CCW** | MCP 协议、参数校验、结果序列化、错误处理、project_root 推断 |
|
||||
|
||||
### 1.4 codexlens vs cclsp 对比
|
||||
|
||||
| 特性 | cclsp | codexlens |
|
||||
|------|-------|-----------|
|
||||
| 数据源 | 实时 LSP 服务器 | 预建 SQLite 索引 |
|
||||
| 启动时间 | 200-3000ms | <50ms |
|
||||
| 响应时间 | 50-500ms | <5ms |
|
||||
| 跨语言 | 每语言需要 LSP 服务器 | 统一 Python/TS/JS/Go 索引 |
|
||||
| 依赖 | 需要语言服务器 | 无外部依赖 |
|
||||
| 准确度 | 100% (编译器级) | 95%+ (tree-sitter) |
|
||||
| 重命名支持 | 是 | 否 (只读索引) |
|
||||
| 实时诊断 | 是 | 通过 IDE MCP |
|
||||
|
||||
**推荐**: codexlens 用于快速搜索,cclsp 用于精确重构
|
||||
|
||||
---
|
||||
|
||||
## 二、cclsp 设计模式 (参考)
|
||||
|
||||
### 2.1 MCP 工具接口设计
|
||||
|
||||
| 模式 | 说明 | 代码位置 |
|
||||
|------|------|----------|
|
||||
| **基于名称** | 接受 `symbol_name` 而非文件坐标 | `index.ts:70` |
|
||||
| **安全消歧义** | `rename_symbol` → `rename_symbol_strict` 两步 | `index.ts:133, 172` |
|
||||
| **复杂性抽象** | 隐藏 LSP 协议细节 | `index.ts:211` |
|
||||
| **优雅失败** | 返回有用的文本响应 | 全局 |
|
||||
|
||||
### 2.2 符号解析算法
|
||||
|
||||
```
|
||||
1. getDocumentSymbols (lsp-client.ts:1406)
|
||||
└─ 获取文件所有符号
|
||||
|
||||
2. 处理两种格式:
|
||||
├─ DocumentSymbol[] → 扁平化
|
||||
└─ SymbolInformation[] → 二次定位
|
||||
|
||||
3. 过滤: symbol.name === symbolName && symbol.kind
|
||||
|
||||
4. 回退: 无结果时移除 kind 约束重试
|
||||
|
||||
5. 聚合: 遍历所有匹配,聚合定义位置
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 三、需求规格
|
||||
|
||||
### 需求 1: 文件上下文查询 (`file_context`)
|
||||
|
||||
**用途**: 读取代码文件,返回文件中所有方法的调用关系摘要
|
||||
|
||||
**输出示例**:
|
||||
```markdown
|
||||
## src/auth/login.py (3 methods)
|
||||
|
||||
### login_user (line 15-45)
|
||||
- Calls: validate_password (auth/utils.py:23), create_session (session/manager.py:89)
|
||||
- Called by: handle_login_request (api/routes.py:156), test_login (tests/test_auth.py:34)
|
||||
|
||||
### validate_token (line 47-62)
|
||||
- Calls: decode_jwt (auth/jwt.py:12)
|
||||
- Called by: auth_middleware (middleware/auth.py:28)
|
||||
```
|
||||
|
||||
### 需求 2: 通用 LSP 搜索 (cclsp 兼容)
|
||||
|
||||
| 端点 | 用途 |
|
||||
|------|------|
|
||||
| `find_definition` | 根据符号名查找定义位置 |
|
||||
| `find_references` | 查找符号的所有引用 |
|
||||
| `workspace_symbols` | 工作区符号搜索 |
|
||||
| `get_hover` | 获取符号悬停信息 |
|
||||
|
||||
### 需求 3: 向量 + LSP 融合搜索
|
||||
|
||||
**用途**: 结合向量语义搜索和结构化 LSP 搜索
|
||||
|
||||
**融合策略**:
|
||||
- **RRF** (首选): 简单、不需要分数归一化、鲁棒
|
||||
- **Cascade**: 特定场景,先向量后 LSP
|
||||
- **Adaptive**: 长期目标,按查询类型自动选择
|
||||
|
||||
---
|
||||
|
||||
## 四、API 规范
|
||||
|
||||
### 4.1 模块结构
|
||||
|
||||
```
|
||||
src/codexlens/
|
||||
├─ api/ [新增] 公开 API 层
|
||||
│ ├─ __init__.py 导出所有 API
|
||||
│ ├─ file_context.py 文件上下文
|
||||
│ ├─ definition.py 定义查找
|
||||
│ ├─ references.py 引用查找
|
||||
│ ├─ symbols.py 符号搜索
|
||||
│ ├─ hover.py 悬停信息
|
||||
│ └─ semantic.py 语义搜索
|
||||
│
|
||||
├─ storage/
|
||||
│ ├─ global_index.py [扩展] get_file_symbols()
|
||||
│ └─ relationship_query.py [新增] 有向调用查询
|
||||
│
|
||||
└─ search/
|
||||
└─ chain_search.py [修复] schema 兼容
|
||||
```
|
||||
|
||||
### 4.2 `codexlens.api.file_context()`
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
@dataclass
|
||||
class CallInfo:
|
||||
"""调用关系信息"""
|
||||
symbol_name: str
|
||||
file_path: Optional[str] # 目标文件 (可能为 None)
|
||||
line: int
|
||||
relationship: str # call | import | inheritance
|
||||
|
||||
@dataclass
|
||||
class MethodContext:
|
||||
"""方法上下文"""
|
||||
name: str
|
||||
kind: str # function | method | class
|
||||
line_range: Tuple[int, int]
|
||||
signature: Optional[str]
|
||||
calls: List[CallInfo] # 出向调用
|
||||
callers: List[CallInfo] # 入向调用
|
||||
|
||||
@dataclass
|
||||
class FileContextResult:
|
||||
"""文件上下文结果"""
|
||||
file_path: str
|
||||
language: str
|
||||
methods: List[MethodContext]
|
||||
summary: str # 人类可读摘要
|
||||
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
|
||||
"outgoing_resolved": False,
|
||||
"incoming_resolved": True,
|
||||
"targets_resolved": False
|
||||
})
|
||||
|
||||
def file_context(
|
||||
project_root: str,
|
||||
file_path: str,
|
||||
include_calls: bool = True,
|
||||
include_callers: bool = True,
|
||||
max_depth: int = 1,
|
||||
format: str = "brief" # brief | detailed | tree
|
||||
) -> FileContextResult:
|
||||
"""
|
||||
获取代码文件的方法调用上下文。
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录 (用于定位索引)
|
||||
file_path: 代码文件路径
|
||||
include_calls: 是否包含出向调用
|
||||
include_callers: 是否包含入向调用
|
||||
max_depth: 调用链深度 (1=直接调用)
|
||||
⚠️ V1 限制: 当前版本仅支持 max_depth=1
|
||||
深度调用链分析将在 V2 实现
|
||||
format: 输出格式
|
||||
|
||||
Returns:
|
||||
FileContextResult
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: 项目未索引
|
||||
FileNotFoundError: 文件不存在
|
||||
|
||||
Note:
|
||||
V1 实现限制:
|
||||
- max_depth 仅支持 1 (直接调用)
|
||||
- 出向调用目标文件可能为 None (未解析)
|
||||
- 深度调用链分析作为 V2 特性规划
|
||||
"""
|
||||
```
|
||||
|
||||
### 4.3 `codexlens.api.find_definition()`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class DefinitionResult:
|
||||
"""定义查找结果"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
end_line: int
|
||||
signature: Optional[str]
|
||||
container: Optional[str] # 所属类/模块
|
||||
score: float
|
||||
|
||||
def find_definition(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
file_context: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[DefinitionResult]:
|
||||
"""
|
||||
根据符号名称查找定义位置。
|
||||
|
||||
Fallback 策略:
|
||||
1. 精确匹配 + kind 过滤
|
||||
2. 精确匹配 (移除 kind)
|
||||
3. 前缀匹配
|
||||
"""
|
||||
```
|
||||
|
||||
### 4.4 `codexlens.api.find_references()`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ReferenceResult:
|
||||
"""引用结果"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context_line: str
|
||||
relationship: str # call | import | type_annotation | inheritance
|
||||
|
||||
@dataclass
|
||||
class GroupedReferences:
|
||||
"""按定义分组的引用"""
|
||||
definition: DefinitionResult
|
||||
references: List[ReferenceResult]
|
||||
|
||||
def find_references(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
include_definition: bool = True,
|
||||
group_by_definition: bool = True,
|
||||
limit: int = 100
|
||||
) -> List[GroupedReferences]:
|
||||
"""
|
||||
查找符号的所有引用位置。
|
||||
|
||||
多定义时分组返回,解决引用混淆问题。
|
||||
"""
|
||||
```
|
||||
|
||||
### 4.5 `codexlens.api.workspace_symbols()`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""符号信息"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
container: Optional[str]
|
||||
score: float
|
||||
|
||||
def workspace_symbols(
|
||||
project_root: str,
|
||||
query: str,
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
file_pattern: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[SymbolInfo]:
|
||||
"""在整个工作区搜索符号 (前缀匹配)。"""
|
||||
```
|
||||
|
||||
### 4.6 `codexlens.api.get_hover()`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""悬停信息"""
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: Tuple[int, int]
|
||||
type_info: Optional[str]
|
||||
|
||||
def get_hover(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
file_path: Optional[str] = None
|
||||
) -> Optional[HoverInfo]:
|
||||
"""获取符号的详细悬停信息。"""
|
||||
```
|
||||
|
||||
### 4.7 `codexlens.api.semantic_search()`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class SemanticResult:
|
||||
"""语义搜索结果"""
|
||||
symbol_name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
vector_score: Optional[float]
|
||||
structural_score: Optional[float]
|
||||
fusion_score: float
|
||||
snippet: str
|
||||
match_reason: Optional[str]
|
||||
|
||||
def semantic_search(
|
||||
project_root: str,
|
||||
query: str,
|
||||
mode: str = "fusion", # vector | structural | fusion
|
||||
vector_weight: float = 0.5,
|
||||
structural_weight: float = 0.3,
|
||||
keyword_weight: float = 0.2,
|
||||
fusion_strategy: str = "rrf", # rrf | staged | binary | hybrid
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
include_match_reason: bool = False
|
||||
) -> List[SemanticResult]:
|
||||
"""
|
||||
语义搜索 - 结合向量和结构化搜索。
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录
|
||||
query: 自然语言查询
|
||||
mode: 搜索模式
|
||||
- vector: 仅向量搜索
|
||||
- structural: 仅结构搜索 (符号 + 关系)
|
||||
- fusion: 融合搜索 (默认)
|
||||
vector_weight: 向量搜索权重 [0, 1]
|
||||
structural_weight: 结构搜索权重 [0, 1]
|
||||
keyword_weight: 关键词搜索权重 [0, 1]
|
||||
fusion_strategy: 融合策略 (映射到 chain_search.py)
|
||||
- rrf: Reciprocal Rank Fusion (推荐,默认)
|
||||
- staged: 分阶段级联 → staged_cascade_search
|
||||
- binary: 二分重排级联 → binary_rerank_cascade_search
|
||||
- hybrid: 混合级联 → hybrid_cascade_search
|
||||
kind_filter: 符号类型过滤
|
||||
limit: 最大返回数量
|
||||
include_match_reason: 是否生成匹配原因 (启发式,非 LLM)
|
||||
|
||||
Returns:
|
||||
按 fusion_score 排序的结果列表
|
||||
|
||||
降级行为:
|
||||
- 无向量索引: vector_score=None, 使用 FTS + 结构搜索
|
||||
- 无关系数据: structural_score=None, 仅向量搜索
|
||||
"""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 五、已知问题与解决方案
|
||||
|
||||
### 5.1 P0 阻塞项
|
||||
|
||||
| 问题 | 位置 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| **索引 Schema 不匹配** | `chain_search.py:313-324` vs `dir_index.py:304-312` | 兼容 `full_path` 和 `path` |
|
||||
| **文件符号查询缺失** | `global_index.py:214-260` | 新增 `get_file_symbols()` |
|
||||
| **出向调用查询缺失** | `dir_index.py:333-342` | 新增 `RelationshipQuery` |
|
||||
| **关系类型不一致** | `entities.py:74-79` | 规范化 `calls` → `call` |
|
||||
|
||||
### 5.2 设计缺陷 (Gemini 发现)
|
||||
|
||||
| 缺陷 | 影响 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| **调用图不完整** | `file_context` 缺少出向调用 | 新增有向调用 API |
|
||||
| **消歧义未定义** | 多定义时无法区分 | 实现 `rank_by_proximity()` |
|
||||
| **AI 特性成本过高** | `explanation` 需要 LLM | 设为可选,默认关闭 |
|
||||
| **融合参数不一致** | 3 分支但只有 2 权重 | 补充 `keyword_weight` |
|
||||
|
||||
### 5.3 消歧义算法
|
||||
|
||||
**V1 实现** (基于文件路径接近度):
|
||||
|
||||
```python
|
||||
def rank_by_proximity(
|
||||
results: List[DefinitionResult],
|
||||
file_context: str
|
||||
) -> List[DefinitionResult]:
|
||||
"""按文件接近度排序 (V1: 路径接近度)"""
|
||||
def proximity_score(result):
|
||||
# 1. 同目录最高分
|
||||
if os.path.dirname(result.file_path) == os.path.dirname(file_context):
|
||||
return 100
|
||||
# 2. 共同路径前缀长度
|
||||
common = os.path.commonpath([result.file_path, file_context])
|
||||
return len(common)
|
||||
|
||||
return sorted(results, key=proximity_score, reverse=True)
|
||||
```
|
||||
|
||||
**V2 增强计划** (基于 import graph 距离):
|
||||
|
||||
```python
|
||||
def rank_by_import_distance(
|
||||
results: List[DefinitionResult],
|
||||
file_context: str,
|
||||
import_graph: Dict[str, Set[str]]
|
||||
) -> List[DefinitionResult]:
|
||||
"""按 import graph 距离排序 (V2)"""
|
||||
def import_distance(result):
|
||||
# BFS 计算最短 import 路径
|
||||
return bfs_shortest_path(
|
||||
import_graph,
|
||||
file_context,
|
||||
result.file_path
|
||||
)
|
||||
|
||||
# 组合: 0.6 * import_distance + 0.4 * path_proximity
|
||||
return sorted(results, key=lambda r: (
|
||||
0.6 * import_distance(r) +
|
||||
0.4 * (100 - proximity_score(r))
|
||||
))
|
||||
```
|
||||
|
||||
### 5.4 参考实现: `get_file_symbols()`
|
||||
|
||||
**位置**: `src/codexlens/storage/global_index.py`
|
||||
|
||||
```python
|
||||
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
|
||||
"""
|
||||
获取指定文件中定义的所有符号。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径 (相对或绝对)
|
||||
|
||||
Returns:
|
||||
按行号排序的符号列表
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||
FROM global_symbols
|
||||
WHERE project_id = ? AND file_path = ?
|
||||
ORDER BY start_line
|
||||
""",
|
||||
(self.project_id, file_path_str),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
Symbol(
|
||||
name=row["symbol_name"],
|
||||
kind=row["symbol_kind"],
|
||||
range=(row["start_line"], row["end_line"]),
|
||||
file=row["file_path"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 六、实现计划
|
||||
|
||||
### Phase 0: 基础设施 (16h)
|
||||
|
||||
| 任务 | 工时 | 说明 |
|
||||
|------|------|------|
|
||||
| 修复 `search_references` schema | 4h | 兼容两种 schema |
|
||||
| 新增 `GlobalSymbolIndex.get_file_symbols()` | 4h | 文件符号查询 (见 5.4) |
|
||||
| 新增 `RelationshipQuery` 类 | 6h | 有向调用查询 |
|
||||
| 关系类型规范化层 | 2h | `calls` → `call` |
|
||||
|
||||
### Phase 1: API 层 (48h)
|
||||
|
||||
| 任务 | 工时 | 复杂度 |
|
||||
|------|------|--------|
|
||||
| `find_definition()` | 4h | S |
|
||||
| `find_references()` | 8h | M |
|
||||
| `workspace_symbols()` | 4h | S |
|
||||
| `get_hover()` | 4h | S |
|
||||
| `file_context()` | 16h | L |
|
||||
| `semantic_search()` | 12h | M |
|
||||
|
||||
### Phase 2: 测试与文档 (16h)
|
||||
|
||||
| 任务 | 工时 |
|
||||
|------|------|
|
||||
| 单元测试 (≥80%) | 8h |
|
||||
| API 文档 | 4h |
|
||||
| 示例代码 | 4h |
|
||||
|
||||
### 关键路径
|
||||
|
||||
```
|
||||
Phase 0.1 (schema fix)
|
||||
↓
|
||||
Phase 0.2 (file symbols) → Phase 1.5 (file_context)
|
||||
↓
|
||||
Phase 1 (其他 API)
|
||||
↓
|
||||
Phase 2 (测试)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、测试策略
|
||||
|
||||
### 7.1 单元测试
|
||||
|
||||
```python
|
||||
# test_global_index.py
|
||||
def test_get_file_symbols():
|
||||
index = GlobalSymbolIndex(":memory:")
|
||||
index.update_file_symbols(project_id=1, file_path="test.py", symbols=[...])
|
||||
results = index.get_file_symbols("test.py")
|
||||
assert len(results) == 3
|
||||
|
||||
# test_relationship_query.py
|
||||
def test_outgoing_calls():
|
||||
store = DirIndexStore(":memory:")
|
||||
calls = store.get_outgoing_calls("src/auth.py", "login")
|
||||
assert calls[0].relationship == "call" # 已规范化
|
||||
```
|
||||
|
||||
### 7.2 Schema 兼容性测试
|
||||
|
||||
```python
|
||||
def test_search_references_both_schemas():
|
||||
"""测试两种 schema 的引用搜索"""
|
||||
# 旧 schema: files(path, ...)
|
||||
# 新 schema: files(full_path, ...)
|
||||
```
|
||||
|
||||
### 7.3 降级测试
|
||||
|
||||
```python
|
||||
def test_semantic_search_without_vectors():
|
||||
result = semantic_search(query="auth", mode="fusion")
|
||||
assert result.vector_score is None
|
||||
assert result.fusion_score > 0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 八、使用示例
|
||||
|
||||
```python
|
||||
from codexlens.api import (
|
||||
file_context,
|
||||
find_definition,
|
||||
find_references,
|
||||
semantic_search
|
||||
)
|
||||
|
||||
# 1. 获取文件上下文
|
||||
result = file_context(
|
||||
project_root="/path/to/project",
|
||||
file_path="src/auth/login.py",
|
||||
format="brief"
|
||||
)
|
||||
print(result.summary)
|
||||
|
||||
# 2. 查找定义
|
||||
definitions = find_definition(
|
||||
project_root="/path/to/project",
|
||||
symbol_name="UserService",
|
||||
symbol_kind="class"
|
||||
)
|
||||
|
||||
# 3. 语义搜索
|
||||
results = semantic_search(
|
||||
project_root="/path/to/project",
|
||||
query="处理用户登录验证的函数",
|
||||
mode="fusion"
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 九、CCW 集成
|
||||
|
||||
| codexlens API | CCW MCP Tool |
|
||||
|---------------|--------------|
|
||||
| `file_context()` | `codexlens_file_context` |
|
||||
| `find_definition()` | `codexlens_find_definition` |
|
||||
| `find_references()` | `codexlens_find_references` |
|
||||
| `workspace_symbols()` | `codexlens_workspace_symbol` |
|
||||
| `get_hover()` | `codexlens_get_hover` |
|
||||
| `semantic_search()` | `codexlens_semantic_search` |
|
||||
|
||||
---
|
||||
|
||||
## 十、分析来源
|
||||
|
||||
| 工具 | Session ID | 贡献 |
|
||||
|------|------------|------|
|
||||
| Gemini | `1768618654438-gemini` | 架构评审、设计缺陷、融合策略 |
|
||||
| Codex | `1768618658183-codex` | 组件复用、复杂度估算、任务分解 |
|
||||
| Gemini | `1768620615744-gemini` | 最终评审、改进建议、APPROVED |
|
||||
|
||||
---
|
||||
|
||||
## 十一、版本历史
|
||||
|
||||
| 版本 | 日期 | 变更 |
|
||||
|------|------|------|
|
||||
| 1.0 | 2025-01-17 | 初始版本,合并多文档 |
|
||||
| 1.1 | 2025-01-17 | 应用 Gemini 评审改进: V1 限制说明、策略映射、消歧义增强、参考实现 |
|
||||
@@ -97,11 +97,25 @@ encoding = [
|
||||
"chardet>=5.0",
|
||||
]
|
||||
|
||||
# Clustering for staged hybrid search (HDBSCAN + sklearn)
|
||||
clustering = [
|
||||
"hdbscan>=0.8.1",
|
||||
"scikit-learn>=1.3.0",
|
||||
]
|
||||
|
||||
# Full features including tiktoken for accurate token counting
|
||||
full = [
|
||||
"tiktoken>=0.5.0",
|
||||
]
|
||||
|
||||
# Language Server Protocol support
|
||||
lsp = [
|
||||
"pygls>=1.3.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
codexlens-lsp = "codexlens.lsp:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/openai/codex-lens"
|
||||
|
||||
|
||||
@@ -52,5 +52,10 @@ Requires-Dist: transformers>=4.36; extra == "splade-gpu"
|
||||
Requires-Dist: optimum[onnxruntime-gpu]>=1.16; extra == "splade-gpu"
|
||||
Provides-Extra: encoding
|
||||
Requires-Dist: chardet>=5.0; extra == "encoding"
|
||||
Provides-Extra: clustering
|
||||
Requires-Dist: hdbscan>=0.8.1; extra == "clustering"
|
||||
Requires-Dist: scikit-learn>=1.3.0; extra == "clustering"
|
||||
Provides-Extra: full
|
||||
Requires-Dist: tiktoken>=0.5.0; extra == "full"
|
||||
Provides-Extra: lsp
|
||||
Requires-Dist: pygls>=1.3.0; extra == "lsp"
|
||||
|
||||
@@ -2,6 +2,7 @@ pyproject.toml
|
||||
src/codex_lens.egg-info/PKG-INFO
|
||||
src/codex_lens.egg-info/SOURCES.txt
|
||||
src/codex_lens.egg-info/dependency_links.txt
|
||||
src/codex_lens.egg-info/entry_points.txt
|
||||
src/codex_lens.egg-info/requires.txt
|
||||
src/codex_lens.egg-info/top_level.txt
|
||||
src/codexlens/__init__.py
|
||||
@@ -18,6 +19,14 @@ src/codexlens/cli/output.py
|
||||
src/codexlens/indexing/__init__.py
|
||||
src/codexlens/indexing/embedding.py
|
||||
src/codexlens/indexing/symbol_extractor.py
|
||||
src/codexlens/lsp/__init__.py
|
||||
src/codexlens/lsp/handlers.py
|
||||
src/codexlens/lsp/providers.py
|
||||
src/codexlens/lsp/server.py
|
||||
src/codexlens/mcp/__init__.py
|
||||
src/codexlens/mcp/hooks.py
|
||||
src/codexlens/mcp/provider.py
|
||||
src/codexlens/mcp/schema.py
|
||||
src/codexlens/parsers/__init__.py
|
||||
src/codexlens/parsers/encoding.py
|
||||
src/codexlens/parsers/factory.py
|
||||
@@ -31,6 +40,13 @@ src/codexlens/search/graph_expander.py
|
||||
src/codexlens/search/hybrid_search.py
|
||||
src/codexlens/search/query_parser.py
|
||||
src/codexlens/search/ranking.py
|
||||
src/codexlens/search/clustering/__init__.py
|
||||
src/codexlens/search/clustering/base.py
|
||||
src/codexlens/search/clustering/dbscan_strategy.py
|
||||
src/codexlens/search/clustering/factory.py
|
||||
src/codexlens/search/clustering/frequency_strategy.py
|
||||
src/codexlens/search/clustering/hdbscan_strategy.py
|
||||
src/codexlens/search/clustering/noop_strategy.py
|
||||
src/codexlens/semantic/__init__.py
|
||||
src/codexlens/semantic/ann_index.py
|
||||
src/codexlens/semantic/base.py
|
||||
@@ -84,6 +100,7 @@ tests/test_api_reranker.py
|
||||
tests/test_chain_search.py
|
||||
tests/test_cli_hybrid_search.py
|
||||
tests/test_cli_output.py
|
||||
tests/test_clustering_strategies.py
|
||||
tests/test_code_extractor.py
|
||||
tests/test_config.py
|
||||
tests/test_dual_fts.py
|
||||
@@ -122,6 +139,7 @@ tests/test_search_performance.py
|
||||
tests/test_semantic.py
|
||||
tests/test_semantic_search.py
|
||||
tests/test_sqlite_store.py
|
||||
tests/test_staged_cascade.py
|
||||
tests/test_storage.py
|
||||
tests/test_storage_concurrency.py
|
||||
tests/test_symbol_extractor.py
|
||||
|
||||
2
codex-lens/src/codex_lens.egg-info/entry_points.txt
Normal file
2
codex-lens/src/codex_lens.egg-info/entry_points.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
[console_scripts]
|
||||
codexlens-lsp = codexlens.lsp:main
|
||||
@@ -8,12 +8,19 @@ tree-sitter-typescript>=0.23
|
||||
pathspec>=0.11
|
||||
watchdog>=3.0
|
||||
|
||||
[clustering]
|
||||
hdbscan>=0.8.1
|
||||
scikit-learn>=1.3.0
|
||||
|
||||
[encoding]
|
||||
chardet>=5.0
|
||||
|
||||
[full]
|
||||
tiktoken>=0.5.0
|
||||
|
||||
[lsp]
|
||||
pygls>=1.3.0
|
||||
|
||||
[reranker]
|
||||
optimum>=1.16
|
||||
onnxruntime>=1.15
|
||||
|
||||
88
codex-lens/src/codexlens/api/__init__.py
Normal file
88
codex-lens/src/codexlens/api/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Codexlens Public API Layer.
|
||||
|
||||
This module exports all public API functions and dataclasses for the
|
||||
codexlens LSP-like functionality.
|
||||
|
||||
Dataclasses (from models.py):
|
||||
- CallInfo: Call relationship information
|
||||
- MethodContext: Method context with call relationships
|
||||
- FileContextResult: File context result with method summaries
|
||||
- DefinitionResult: Definition lookup result
|
||||
- ReferenceResult: Reference lookup result
|
||||
- GroupedReferences: References grouped by definition
|
||||
- SymbolInfo: Symbol information for workspace search
|
||||
- HoverInfo: Hover information for a symbol
|
||||
- SemanticResult: Semantic search result
|
||||
|
||||
Utility functions (from utils.py):
|
||||
- resolve_project: Resolve and validate project root path
|
||||
- normalize_relationship_type: Normalize relationship type to canonical form
|
||||
- rank_by_proximity: Rank results by file path proximity
|
||||
|
||||
Example:
|
||||
>>> from codexlens.api import (
|
||||
... DefinitionResult,
|
||||
... resolve_project,
|
||||
... normalize_relationship_type
|
||||
... )
|
||||
>>> project = resolve_project("/path/to/project")
|
||||
>>> rel_type = normalize_relationship_type("calls")
|
||||
>>> print(rel_type)
|
||||
'call'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Dataclasses
|
||||
from .models import (
|
||||
CallInfo,
|
||||
MethodContext,
|
||||
FileContextResult,
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
SymbolInfo,
|
||||
HoverInfo,
|
||||
SemanticResult,
|
||||
)
|
||||
|
||||
# Utility functions
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
rank_by_proximity,
|
||||
rank_by_score,
|
||||
)
|
||||
|
||||
# API functions
|
||||
from .definition import find_definition
|
||||
from .symbols import workspace_symbols
|
||||
from .hover import get_hover
|
||||
from .file_context import file_context
|
||||
from .references import find_references
|
||||
from .semantic import semantic_search
|
||||
|
||||
__all__ = [
|
||||
# Dataclasses
|
||||
"CallInfo",
|
||||
"MethodContext",
|
||||
"FileContextResult",
|
||||
"DefinitionResult",
|
||||
"ReferenceResult",
|
||||
"GroupedReferences",
|
||||
"SymbolInfo",
|
||||
"HoverInfo",
|
||||
"SemanticResult",
|
||||
# Utility functions
|
||||
"resolve_project",
|
||||
"normalize_relationship_type",
|
||||
"rank_by_proximity",
|
||||
"rank_by_score",
|
||||
# API functions
|
||||
"find_definition",
|
||||
"workspace_symbols",
|
||||
"get_hover",
|
||||
"file_context",
|
||||
"find_references",
|
||||
"semantic_search",
|
||||
]
|
||||
126
codex-lens/src/codexlens/api/definition.py
Normal file
126
codex-lens/src/codexlens/api/definition.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""find_definition API implementation.
|
||||
|
||||
This module provides the find_definition() function for looking up
|
||||
symbol definitions with a 3-stage fallback strategy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import DefinitionResult
|
||||
from .utils import resolve_project, rank_by_proximity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_definition(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
file_context: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[DefinitionResult]:
|
||||
"""Find definition locations for a symbol.
|
||||
|
||||
Uses a 3-stage fallback strategy:
|
||||
1. Exact match with kind filter
|
||||
2. Exact match without kind filter
|
||||
3. Prefix match
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to find
|
||||
symbol_kind: Optional symbol kind filter (class, function, etc.)
|
||||
file_context: Optional file path for proximity ranking
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of DefinitionResult sorted by proximity if file_context provided
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Stage 1: Exact match with kind filter
|
||||
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 2: Exact match without kind (if kind was specified)
|
||||
if symbol_kind:
|
||||
results = _search_with_kind(global_index, symbol_name, None, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 3: Prefix match
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
if results:
|
||||
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
logger.debug(f"No definitions found for {symbol_name}")
|
||||
return []
|
||||
|
||||
|
||||
def _search_with_kind(
|
||||
global_index: GlobalSymbolIndex,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str],
|
||||
limit: int
|
||||
) -> List[Symbol]:
|
||||
"""Search for symbols with optional kind filter."""
|
||||
return global_index.search(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind,
|
||||
limit=limit,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
|
||||
def _rank_and_convert(
|
||||
symbols: List[Symbol],
|
||||
file_context: Optional[str]
|
||||
) -> List[DefinitionResult]:
|
||||
"""Convert symbols to DefinitionResult and rank by proximity."""
|
||||
results = [
|
||||
DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Could extract from file if needed
|
||||
container=None, # Could extract from parent symbol
|
||||
score=1.0
|
||||
)
|
||||
for sym in symbols
|
||||
]
|
||||
return rank_by_proximity(results, file_context)
|
||||
271
codex-lens/src/codexlens/api/file_context.py
Normal file
271
codex-lens/src/codexlens/api/file_context.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""file_context API implementation.
|
||||
|
||||
This module provides the file_context() function for retrieving
|
||||
method call graphs from a source file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.dir_index import DirIndexStore
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import (
|
||||
FileContextResult,
|
||||
MethodContext,
|
||||
CallInfo,
|
||||
)
|
||||
from .utils import resolve_project, normalize_relationship_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def file_context(
|
||||
project_root: str,
|
||||
file_path: str,
|
||||
include_calls: bool = True,
|
||||
include_callers: bool = True,
|
||||
max_depth: int = 1,
|
||||
format: str = "brief"
|
||||
) -> FileContextResult:
|
||||
"""Get method call context for a code file.
|
||||
|
||||
Retrieves all methods/functions in the file along with their
|
||||
outgoing calls and incoming callers.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
file_path: Path to the code file to analyze
|
||||
include_calls: Whether to include outgoing calls
|
||||
include_callers: Whether to include incoming callers
|
||||
max_depth: Call chain depth (V1 only supports 1)
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
FileContextResult with method contexts and summary
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
FileNotFoundError: If file does not exist
|
||||
ValueError: If max_depth > 1 (V1 limitation)
|
||||
"""
|
||||
# V1 limitation: only depth=1 supported
|
||||
if max_depth > 1:
|
||||
raise ValueError(
|
||||
f"max_depth > 1 not supported in V1. "
|
||||
f"Requested: {max_depth}, supported: 1"
|
||||
)
|
||||
|
||||
project_path = resolve_project(project_root)
|
||||
file_path_resolved = Path(file_path).resolve()
|
||||
|
||||
# Validate file exists
|
||||
if not file_path_resolved.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path_resolved}")
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Get all symbols in the file
|
||||
symbols = global_index.get_file_symbols(str(file_path_resolved))
|
||||
|
||||
# Filter to functions, methods, and classes
|
||||
method_symbols = [
|
||||
s for s in symbols
|
||||
if s.kind in ("function", "method", "class")
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
|
||||
|
||||
# Try to find dir_index for relationship queries
|
||||
dir_index = _find_dir_index(project_info, file_path_resolved)
|
||||
|
||||
# Build method contexts
|
||||
methods: List[MethodContext] = []
|
||||
outgoing_resolved = True
|
||||
incoming_resolved = True
|
||||
targets_resolved = True
|
||||
|
||||
for symbol in method_symbols:
|
||||
calls: List[CallInfo] = []
|
||||
callers: List[CallInfo] = []
|
||||
|
||||
if include_calls and dir_index:
|
||||
try:
|
||||
outgoing = dir_index.get_outgoing_calls(
|
||||
str(file_path_resolved),
|
||||
symbol.name
|
||||
)
|
||||
for target_name, rel_type, line, target_file in outgoing:
|
||||
calls.append(CallInfo(
|
||||
symbol_name=target_name,
|
||||
file_path=target_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
if target_file is None:
|
||||
targets_resolved = False
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get outgoing calls: {e}")
|
||||
outgoing_resolved = False
|
||||
|
||||
if include_callers and dir_index:
|
||||
try:
|
||||
incoming = dir_index.get_incoming_calls(symbol.name)
|
||||
for source_name, rel_type, line, source_file in incoming:
|
||||
callers.append(CallInfo(
|
||||
symbol_name=source_name,
|
||||
file_path=source_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get incoming calls: {e}")
|
||||
incoming_resolved = False
|
||||
|
||||
methods.append(MethodContext(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
signature=None, # Could extract from source
|
||||
calls=calls,
|
||||
callers=callers
|
||||
))
|
||||
|
||||
# Detect language from file extension
|
||||
language = _detect_language(file_path_resolved)
|
||||
|
||||
# Generate summary
|
||||
summary = _generate_summary(file_path_resolved, methods, format)
|
||||
|
||||
return FileContextResult(
|
||||
file_path=str(file_path_resolved),
|
||||
language=language,
|
||||
methods=methods,
|
||||
summary=summary,
|
||||
discovery_status={
|
||||
"outgoing_resolved": outgoing_resolved,
|
||||
"incoming_resolved": incoming_resolved,
|
||||
"targets_resolved": targets_resolved
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
|
||||
"""Find the dir_index that contains the file.
|
||||
|
||||
Args:
|
||||
project_info: Project information from registry
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
DirIndexStore if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Look for _index.db in file's directory or parent directories
|
||||
current = file_path.parent
|
||||
while current != current.parent:
|
||||
index_db = current / "_index.db"
|
||||
if index_db.exists():
|
||||
return DirIndexStore(str(index_db))
|
||||
|
||||
# Also check in project's index_root
|
||||
relative = current.relative_to(project_info.source_root)
|
||||
index_in_cache = project_info.index_root / relative / "_index.db"
|
||||
if index_in_cache.exists():
|
||||
return DirIndexStore(str(index_in_cache))
|
||||
|
||||
current = current.parent
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to find dir_index: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _detect_language(file_path: Path) -> str:
|
||||
"""Detect programming language from file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Language name
|
||||
"""
|
||||
ext_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".jsx": "javascript",
|
||||
".tsx": "typescript",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".java": "java",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
}
|
||||
return ext_map.get(file_path.suffix.lower(), "unknown")
|
||||
|
||||
|
||||
def _generate_summary(
|
||||
file_path: Path,
|
||||
methods: List[MethodContext],
|
||||
format: str
|
||||
) -> str:
|
||||
"""Generate human-readable summary of file context.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
methods: List of method contexts
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
Markdown-formatted summary
|
||||
"""
|
||||
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
|
||||
|
||||
for method in methods:
|
||||
start, end = method.line_range
|
||||
lines.append(f"### {method.name} (line {start}-{end})")
|
||||
|
||||
if method.calls:
|
||||
calls_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.calls
|
||||
)
|
||||
lines.append(f"- Calls: {calls_str}")
|
||||
|
||||
if method.callers:
|
||||
callers_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.callers
|
||||
)
|
||||
lines.append(f"- Called by: {callers_str}")
|
||||
|
||||
if not method.calls and not method.callers:
|
||||
lines.append("- (no call relationships)")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
148
codex-lens/src/codexlens/api/hover.py
Normal file
148
codex-lens/src/codexlens/api/hover.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""get_hover API implementation.
|
||||
|
||||
This module provides the get_hover() function for retrieving
|
||||
detailed hover information for symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import HoverInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hover(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
file_path: Optional[str] = None
|
||||
) -> Optional[HoverInfo]:
|
||||
"""Get detailed hover information for a symbol.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to look up
|
||||
file_path: Optional file path to disambiguate when symbol
|
||||
appears in multiple files
|
||||
|
||||
Returns:
|
||||
HoverInfo if symbol found, None otherwise
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search for the symbol
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=50,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
if not results:
|
||||
logger.debug(f"No hover info found for {symbol_name}")
|
||||
return None
|
||||
|
||||
# If file_path provided, filter to that file
|
||||
if file_path:
|
||||
file_path_resolved = str(Path(file_path).resolve())
|
||||
matching = [s for s in results if s.file == file_path_resolved]
|
||||
if matching:
|
||||
results = matching
|
||||
|
||||
# Take the first result
|
||||
symbol = results[0]
|
||||
|
||||
# Build hover info
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=_extract_signature(symbol),
|
||||
documentation=_extract_documentation(symbol),
|
||||
file_path=symbol.file or "",
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
type_info=_extract_type_info(symbol)
|
||||
)
|
||||
|
||||
|
||||
def _extract_signature(symbol: Symbol) -> str:
|
||||
"""Extract signature from symbol.
|
||||
|
||||
For now, generates a basic signature based on kind and name.
|
||||
In a full implementation, this would parse the actual source code.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract signature from
|
||||
|
||||
Returns:
|
||||
Signature string
|
||||
"""
|
||||
if symbol.kind == "function":
|
||||
return f"def {symbol.name}(...)"
|
||||
elif symbol.kind == "method":
|
||||
return f"def {symbol.name}(self, ...)"
|
||||
elif symbol.kind == "class":
|
||||
return f"class {symbol.name}"
|
||||
elif symbol.kind == "variable":
|
||||
return symbol.name
|
||||
elif symbol.kind == "constant":
|
||||
return f"{symbol.name} = ..."
|
||||
else:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
|
||||
def _extract_documentation(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract documentation from symbol.
|
||||
|
||||
In a full implementation, this would parse docstrings from source.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract documentation from
|
||||
|
||||
Returns:
|
||||
Documentation string if available, None otherwise
|
||||
"""
|
||||
# Would need to read source file and parse docstring
|
||||
# For V1, return None
|
||||
return None
|
||||
|
||||
|
||||
def _extract_type_info(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract type information from symbol.
|
||||
|
||||
In a full implementation, this would parse type annotations.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract type info from
|
||||
|
||||
Returns:
|
||||
Type info string if available, None otherwise
|
||||
"""
|
||||
# Would need to parse type annotations from source
|
||||
# For V1, return None
|
||||
return None
|
||||
281
codex-lens/src/codexlens/api/models.py
Normal file
281
codex-lens/src/codexlens/api/models.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""API dataclass definitions for codexlens LSP API.
|
||||
|
||||
This module defines all result dataclasses used by the public API layer,
|
||||
following the patterns established in mcp/schema.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.2: file_context dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class CallInfo:
|
||||
"""Call relationship information.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the called/calling symbol
|
||||
file_path: Target file path (may be None if unresolved)
|
||||
line: Line number of the call
|
||||
relationship: Type of relationship (call | import | inheritance)
|
||||
"""
|
||||
symbol_name: str
|
||||
file_path: Optional[str]
|
||||
line: int
|
||||
relationship: str # call | import | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodContext:
|
||||
"""Method context with call relationships.
|
||||
|
||||
Attributes:
|
||||
name: Method/function name
|
||||
kind: Symbol kind (function | method | class)
|
||||
line_range: Start and end line numbers
|
||||
signature: Function signature (if available)
|
||||
calls: List of outgoing calls
|
||||
callers: List of incoming calls
|
||||
"""
|
||||
name: str
|
||||
kind: str # function | method | class
|
||||
line_range: Tuple[int, int]
|
||||
signature: Optional[str]
|
||||
calls: List[CallInfo] = field(default_factory=list)
|
||||
callers: List[CallInfo] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"line_range": list(self.line_range),
|
||||
"calls": [c.to_dict() for c in self.calls],
|
||||
"callers": [c.to_dict() for c in self.callers],
|
||||
}
|
||||
if self.signature is not None:
|
||||
result["signature"] = self.signature
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileContextResult:
|
||||
"""File context result with method summaries.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the analyzed file
|
||||
language: Programming language
|
||||
methods: List of method contexts
|
||||
summary: Human-readable summary
|
||||
discovery_status: Status flags for call resolution
|
||||
"""
|
||||
file_path: str
|
||||
language: str
|
||||
methods: List[MethodContext]
|
||||
summary: str
|
||||
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
|
||||
"outgoing_resolved": False,
|
||||
"incoming_resolved": True,
|
||||
"targets_resolved": False
|
||||
})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"language": self.language,
|
||||
"methods": [m.to_dict() for m in self.methods],
|
||||
"summary": self.summary,
|
||||
"discovery_status": self.discovery_status,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.3: find_definition dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class DefinitionResult:
|
||||
"""Definition lookup result.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind (class, function, method, etc.)
|
||||
file_path: File where symbol is defined
|
||||
line: Start line number
|
||||
end_line: End line number
|
||||
signature: Symbol signature (if available)
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
end_line: int
|
||||
signature: Optional[str] = None
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.4: find_references dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class ReferenceResult:
|
||||
"""Reference lookup result.
|
||||
|
||||
Attributes:
|
||||
file_path: File containing the reference
|
||||
line: Line number
|
||||
column: Column number
|
||||
context_line: The line of code containing the reference
|
||||
relationship: Type of reference (call | import | type_annotation | inheritance)
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context_line: str
|
||||
relationship: str # call | import | type_annotation | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupedReferences:
|
||||
"""References grouped by definition.
|
||||
|
||||
Used when a symbol has multiple definitions (e.g., overloads).
|
||||
|
||||
Attributes:
|
||||
definition: The definition this group refers to
|
||||
references: List of references to this definition
|
||||
"""
|
||||
definition: DefinitionResult
|
||||
references: List[ReferenceResult] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"definition": self.definition.to_dict(),
|
||||
"references": [r.to_dict() for r in self.references],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.5: workspace_symbols dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Symbol information for workspace search.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.6: get_hover dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
signature: Symbol signature
|
||||
documentation: Documentation string (if available)
|
||||
file_path: File where symbol is defined
|
||||
line_range: Start and end line numbers
|
||||
type_info: Type information (if available)
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: Tuple[int, int]
|
||||
type_info: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"signature": self.signature,
|
||||
"file_path": self.file_path,
|
||||
"line_range": list(self.line_range),
|
||||
}
|
||||
if self.documentation is not None:
|
||||
result["documentation"] = self.documentation
|
||||
if self.type_info is not None:
|
||||
result["type_info"] = self.type_info
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.7: semantic_search dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SemanticResult:
|
||||
"""Semantic search result.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the matched symbol
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
vector_score: Vector similarity score (None if not available)
|
||||
structural_score: Structural match score (None if not available)
|
||||
fusion_score: Combined fusion score
|
||||
snippet: Code snippet
|
||||
match_reason: Explanation of why this matched (optional)
|
||||
"""
|
||||
symbol_name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
vector_score: Optional[float]
|
||||
structural_score: Optional[float]
|
||||
fusion_score: float
|
||||
snippet: str
|
||||
match_reason: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
345
codex-lens/src/codexlens/api/references.py
Normal file
345
codex-lens/src/codexlens/api/references.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Find references API for codexlens.
|
||||
|
||||
This module implements the find_references() function that wraps
|
||||
ChainSearchEngine.search_references() with grouped result structure
|
||||
for multi-definition symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from .models import (
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
)
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _read_line_from_file(file_path: str, line: int) -> str:
|
||||
"""Read a specific line from a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
line: Line number (1-based)
|
||||
|
||||
Returns:
|
||||
The line content, stripped of trailing whitespace.
|
||||
Returns empty string if file cannot be read or line doesn't exist.
|
||||
"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return ""
|
||||
|
||||
with path.open("r", encoding="utf-8", errors="replace") as f:
|
||||
for i, content in enumerate(f, 1):
|
||||
if i == line:
|
||||
return content.rstrip()
|
||||
return ""
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _transform_to_reference_result(
|
||||
raw_ref: "RawReferenceResult",
|
||||
) -> ReferenceResult:
|
||||
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
|
||||
|
||||
Args:
|
||||
raw_ref: Raw reference result from ChainSearchEngine
|
||||
|
||||
Returns:
|
||||
API ReferenceResult with context_line and normalized relationship
|
||||
"""
|
||||
# Read the actual line from the file
|
||||
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
|
||||
|
||||
# Normalize relationship type
|
||||
relationship = normalize_relationship_type(raw_ref.relationship_type)
|
||||
|
||||
return ReferenceResult(
|
||||
file_path=raw_ref.file_path,
|
||||
line=raw_ref.line,
|
||||
column=raw_ref.column,
|
||||
context_line=context_line,
|
||||
relationship=relationship,
|
||||
)
|
||||
|
||||
|
||||
def find_references(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
include_definition: bool = True,
|
||||
group_by_definition: bool = True,
|
||||
limit: int = 100,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Find all reference locations for a symbol.
|
||||
|
||||
Multi-definition case returns grouped results to resolve ambiguity.
|
||||
|
||||
This function wraps ChainSearchEngine.search_references() and groups
|
||||
the results by definition location. Each GroupedReferences contains
|
||||
a definition and all references that point to it.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory path
|
||||
symbol_name: Name of the symbol to find references for
|
||||
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
|
||||
include_definition: Whether to include the definition location
|
||||
in the result (default True)
|
||||
group_by_definition: Whether to group references by definition.
|
||||
If False, returns a single group with all references.
|
||||
(default True)
|
||||
limit: Maximum number of references to return (default 100)
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences. Each group contains:
|
||||
- definition: The DefinitionResult for this symbol definition
|
||||
- references: List of ReferenceResult pointing to this definition
|
||||
|
||||
Raises:
|
||||
ValueError: If project_root does not exist or is not a directory
|
||||
|
||||
Examples:
|
||||
>>> refs = find_references("/path/to/project", "authenticate")
|
||||
>>> for group in refs:
|
||||
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
|
||||
... for ref in group.references:
|
||||
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
|
||||
|
||||
Note:
|
||||
Reference relationship types are normalized:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
"""
|
||||
# Validate and resolve project root
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
# Initialize infrastructure
|
||||
config = Config()
|
||||
registry = RegistryStore(config.registry_db_path)
|
||||
mapper = PathMapper(config.index_root)
|
||||
|
||||
# Create chain search engine
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
|
||||
try:
|
||||
# Step 1: Find definitions for the symbol
|
||||
definitions: List[DefinitionResult] = []
|
||||
|
||||
if include_definition or group_by_definition:
|
||||
# Search for symbol definitions
|
||||
symbols = engine.search_symbols(
|
||||
name=symbol_name,
|
||||
source_path=project_path,
|
||||
kind=symbol_kind,
|
||||
)
|
||||
|
||||
# Convert Symbol to DefinitionResult
|
||||
for sym in symbols:
|
||||
# Only include exact name matches for definitions
|
||||
if sym.name != symbol_name:
|
||||
continue
|
||||
|
||||
# Optionally filter by kind
|
||||
if symbol_kind and sym.kind != symbol_kind:
|
||||
continue
|
||||
|
||||
definitions.append(DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Not available from Symbol
|
||||
container=None, # Not available from Symbol
|
||||
score=1.0,
|
||||
))
|
||||
|
||||
# Step 2: Get all references using ChainSearchEngine
|
||||
raw_references = engine.search_references(
|
||||
symbol_name=symbol_name,
|
||||
source_path=project_path,
|
||||
depth=-1,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Step 3: Transform raw references to API ReferenceResult
|
||||
api_references: List[ReferenceResult] = []
|
||||
for raw_ref in raw_references:
|
||||
api_ref = _transform_to_reference_result(raw_ref)
|
||||
api_references.append(api_ref)
|
||||
|
||||
# Step 4: Group references by definition
|
||||
if group_by_definition and definitions:
|
||||
return _group_references_by_definition(
|
||||
definitions=definitions,
|
||||
references=api_references,
|
||||
include_definition=include_definition,
|
||||
)
|
||||
else:
|
||||
# Return single group with placeholder definition or first definition
|
||||
if definitions:
|
||||
definition = definitions[0]
|
||||
else:
|
||||
# Create placeholder definition when no definition found
|
||||
definition = DefinitionResult(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path="",
|
||||
line=0,
|
||||
end_line=0,
|
||||
signature=None,
|
||||
container=None,
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
return [GroupedReferences(
|
||||
definition=definition,
|
||||
references=api_references,
|
||||
)]
|
||||
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
|
||||
def _group_references_by_definition(
|
||||
definitions: List[DefinitionResult],
|
||||
references: List[ReferenceResult],
|
||||
include_definition: bool = True,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Group references by their likely definition.
|
||||
|
||||
Uses file proximity heuristic to assign references to definitions.
|
||||
References in the same file or directory as a definition are
|
||||
assigned to that definition.
|
||||
|
||||
Args:
|
||||
definitions: List of definition locations
|
||||
references: List of reference locations
|
||||
include_definition: Whether to include definition in results
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences with references assigned to definitions
|
||||
"""
|
||||
import os
|
||||
|
||||
if not definitions:
|
||||
return []
|
||||
|
||||
if len(definitions) == 1:
|
||||
# Single definition - all references belong to it
|
||||
return [GroupedReferences(
|
||||
definition=definitions[0],
|
||||
references=references,
|
||||
)]
|
||||
|
||||
# Multiple definitions - group by proximity
|
||||
groups: Dict[int, List[ReferenceResult]] = {
|
||||
i: [] for i in range(len(definitions))
|
||||
}
|
||||
|
||||
for ref in references:
|
||||
# Find the closest definition by file proximity
|
||||
best_def_idx = 0
|
||||
best_score = -1
|
||||
|
||||
for i, defn in enumerate(definitions):
|
||||
score = _proximity_score(ref.file_path, defn.file_path)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_def_idx = i
|
||||
|
||||
groups[best_def_idx].append(ref)
|
||||
|
||||
# Build result groups
|
||||
result: List[GroupedReferences] = []
|
||||
for i, defn in enumerate(definitions):
|
||||
# Skip definitions with no references if not including definition itself
|
||||
if not include_definition and not groups[i]:
|
||||
continue
|
||||
|
||||
result.append(GroupedReferences(
|
||||
definition=defn,
|
||||
references=groups[i],
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _proximity_score(ref_path: str, def_path: str) -> int:
|
||||
"""Calculate proximity score between two file paths.
|
||||
|
||||
Args:
|
||||
ref_path: Reference file path
|
||||
def_path: Definition file path
|
||||
|
||||
Returns:
|
||||
Proximity score (higher = closer):
|
||||
- Same file: 1000
|
||||
- Same directory: 100
|
||||
- Otherwise: common path prefix length
|
||||
"""
|
||||
import os
|
||||
|
||||
if not ref_path or not def_path:
|
||||
return 0
|
||||
|
||||
# Normalize paths
|
||||
ref_path = os.path.normpath(ref_path)
|
||||
def_path = os.path.normpath(def_path)
|
||||
|
||||
# Same file
|
||||
if ref_path == def_path:
|
||||
return 1000
|
||||
|
||||
ref_dir = os.path.dirname(ref_path)
|
||||
def_dir = os.path.dirname(def_path)
|
||||
|
||||
# Same directory
|
||||
if ref_dir == def_dir:
|
||||
return 100
|
||||
|
||||
# Common path prefix
|
||||
try:
|
||||
common = os.path.commonpath([ref_path, def_path])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
|
||||
# Type alias for the raw reference from ChainSearchEngine
|
||||
class RawReferenceResult:
|
||||
"""Type stub for ChainSearchEngine.ReferenceResult.
|
||||
|
||||
This is only used for type hints and is replaced at runtime
|
||||
by the actual import.
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
471
codex-lens/src/codexlens/api/semantic.py
Normal file
471
codex-lens/src/codexlens/api/semantic.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Semantic search API with RRF fusion.
|
||||
|
||||
This module provides the semantic_search() function for combining
|
||||
vector, structural, and keyword search with configurable fusion strategies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from .models import SemanticResult
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def semantic_search(
|
||||
project_root: str,
|
||||
query: str,
|
||||
mode: str = "fusion",
|
||||
vector_weight: float = 0.5,
|
||||
structural_weight: float = 0.3,
|
||||
keyword_weight: float = 0.2,
|
||||
fusion_strategy: str = "rrf",
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
include_match_reason: bool = False,
|
||||
) -> List[SemanticResult]:
|
||||
"""Semantic search - combining vector and structural search.
|
||||
|
||||
This function provides a high-level API for semantic code search,
|
||||
combining vector similarity, structural (symbol + relationships),
|
||||
and keyword-based search methods with configurable fusion.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory
|
||||
query: Natural language query
|
||||
mode: Search mode
|
||||
- vector: Vector search only
|
||||
- structural: Structural search only (symbol + relationships)
|
||||
- fusion: Fusion search (default)
|
||||
vector_weight: Vector search weight [0, 1] (default 0.5)
|
||||
structural_weight: Structural search weight [0, 1] (default 0.3)
|
||||
keyword_weight: Keyword search weight [0, 1] (default 0.2)
|
||||
fusion_strategy: Fusion strategy (maps to chain_search.py)
|
||||
- rrf: Reciprocal Rank Fusion (recommended, default)
|
||||
- staged: Staged cascade -> staged_cascade_search
|
||||
- binary: Binary rerank cascade -> binary_cascade_search
|
||||
- hybrid: Hybrid cascade -> hybrid_cascade_search
|
||||
kind_filter: Symbol type filter (e.g., ["function", "class"])
|
||||
limit: Max return count (default 20)
|
||||
include_match_reason: Generate match reason (heuristic, not LLM)
|
||||
|
||||
Returns:
|
||||
Results sorted by fusion_score
|
||||
|
||||
Degradation:
|
||||
- No vector index: vector_score=None, uses FTS + structural search
|
||||
- No relationship data: structural_score=None, vector search only
|
||||
|
||||
Examples:
|
||||
>>> results = semantic_search(
|
||||
... "/path/to/project",
|
||||
... "authentication handler",
|
||||
... mode="fusion",
|
||||
... fusion_strategy="rrf"
|
||||
... )
|
||||
>>> for r in results:
|
||||
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
|
||||
"""
|
||||
# Validate and resolve project path
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Normalize weights to sum to 1.0
|
||||
total_weight = vector_weight + structural_weight + keyword_weight
|
||||
if total_weight > 0:
|
||||
vector_weight = vector_weight / total_weight
|
||||
structural_weight = structural_weight / total_weight
|
||||
keyword_weight = keyword_weight / total_weight
|
||||
else:
|
||||
# Default to equal weights if all zero
|
||||
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
|
||||
|
||||
# Initialize search infrastructure
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
except ImportError as exc:
|
||||
logger.error("Failed to import search dependencies: %s", exc)
|
||||
return []
|
||||
|
||||
# Load config
|
||||
config = Config.load()
|
||||
|
||||
# Get or create registry and mapper
|
||||
try:
|
||||
registry = RegistryStore.default()
|
||||
mapper = PathMapper(registry)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize search infrastructure: %s", exc)
|
||||
return []
|
||||
|
||||
# Build search options based on mode
|
||||
search_options = _build_search_options(
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Execute search based on fusion_strategy
|
||||
try:
|
||||
with ChainSearchEngine(registry, mapper, config=config) as engine:
|
||||
chain_result = _execute_search(
|
||||
engine=engine,
|
||||
query=query,
|
||||
source_path=project_path,
|
||||
fusion_strategy=fusion_strategy,
|
||||
options=search_options,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Search execution failed: %s", exc)
|
||||
return []
|
||||
|
||||
# Transform results to SemanticResult
|
||||
semantic_results = _transform_results(
|
||||
results=chain_result.results,
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
kind_filter=kind_filter,
|
||||
include_match_reason=include_match_reason,
|
||||
query=query,
|
||||
)
|
||||
|
||||
return semantic_results[:limit]
|
||||
|
||||
|
||||
def _build_search_options(
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
limit: int,
|
||||
) -> "SearchOptions":
|
||||
"""Build SearchOptions based on mode and weights.
|
||||
|
||||
Args:
|
||||
mode: Search mode (vector, structural, fusion)
|
||||
vector_weight: Vector search weight
|
||||
structural_weight: Structural search weight
|
||||
keyword_weight: Keyword search weight
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
Configured SearchOptions
|
||||
"""
|
||||
from codexlens.search.chain_search import SearchOptions
|
||||
|
||||
# Default options
|
||||
options = SearchOptions(
|
||||
total_limit=limit * 2, # Fetch extra for filtering
|
||||
limit_per_dir=limit,
|
||||
include_symbols=True, # Always include symbols for structural
|
||||
)
|
||||
|
||||
if mode == "vector":
|
||||
# Pure vector mode
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = True
|
||||
options.pure_vector = True
|
||||
options.enable_fuzzy = False
|
||||
elif mode == "structural":
|
||||
# Structural only - use FTS + symbols
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = False
|
||||
options.enable_fuzzy = True
|
||||
options.include_symbols = True
|
||||
else:
|
||||
# Fusion mode (default)
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = vector_weight > 0
|
||||
options.enable_fuzzy = keyword_weight > 0
|
||||
options.include_symbols = structural_weight > 0
|
||||
|
||||
# Set custom weights for RRF
|
||||
if options.enable_vector and keyword_weight > 0:
|
||||
options.hybrid_weights = {
|
||||
"vector": vector_weight,
|
||||
"exact": keyword_weight * 0.7,
|
||||
"fuzzy": keyword_weight * 0.3,
|
||||
}
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _execute_search(
|
||||
engine: "ChainSearchEngine",
|
||||
query: str,
|
||||
source_path: Path,
|
||||
fusion_strategy: str,
|
||||
options: "SearchOptions",
|
||||
limit: int,
|
||||
) -> "ChainSearchResult":
|
||||
"""Execute search using appropriate strategy.
|
||||
|
||||
Maps fusion_strategy to ChainSearchEngine methods:
|
||||
- rrf: Standard hybrid search with RRF fusion
|
||||
- staged: staged_cascade_search
|
||||
- binary: binary_cascade_search
|
||||
- hybrid: hybrid_cascade_search
|
||||
|
||||
Args:
|
||||
engine: ChainSearchEngine instance
|
||||
query: Search query
|
||||
source_path: Project root path
|
||||
fusion_strategy: Strategy name
|
||||
options: Search options
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
ChainSearchResult from the search
|
||||
"""
|
||||
from codexlens.search.chain_search import ChainSearchResult
|
||||
|
||||
if fusion_strategy == "staged":
|
||||
# Use staged cascade search (4-stage pipeline)
|
||||
return engine.staged_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "binary":
|
||||
# Use binary cascade search (binary coarse + dense fine)
|
||||
return engine.binary_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "hybrid":
|
||||
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
|
||||
return engine.hybrid_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
# Default: rrf - Standard search with RRF fusion
|
||||
return engine.search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
def _transform_results(
|
||||
results: List,
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
kind_filter: Optional[List[str]],
|
||||
include_match_reason: bool,
|
||||
query: str,
|
||||
) -> List[SemanticResult]:
|
||||
"""Transform ChainSearchEngine results to SemanticResult.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
mode: Search mode
|
||||
vector_weight: Vector weight used
|
||||
structural_weight: Structural weight used
|
||||
keyword_weight: Keyword weight used
|
||||
kind_filter: Optional symbol kind filter
|
||||
include_match_reason: Whether to generate match reasons
|
||||
query: Original query (for match reason generation)
|
||||
|
||||
Returns:
|
||||
List of SemanticResult objects
|
||||
"""
|
||||
semantic_results = []
|
||||
|
||||
for result in results:
|
||||
# Extract symbol info
|
||||
symbol_name = getattr(result, "symbol_name", None)
|
||||
symbol_kind = getattr(result, "symbol_kind", None)
|
||||
start_line = getattr(result, "start_line", None)
|
||||
|
||||
# Use symbol object if available
|
||||
if hasattr(result, "symbol") and result.symbol:
|
||||
symbol_name = symbol_name or result.symbol.name
|
||||
symbol_kind = symbol_kind or result.symbol.kind
|
||||
if hasattr(result.symbol, "range") and result.symbol.range:
|
||||
start_line = start_line or result.symbol.range[0]
|
||||
|
||||
# Filter by kind if specified
|
||||
if kind_filter and symbol_kind:
|
||||
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
|
||||
continue
|
||||
|
||||
# Determine scores based on mode and metadata
|
||||
metadata = getattr(result, "metadata", {}) or {}
|
||||
fusion_score = result.score
|
||||
|
||||
# Try to extract source scores from metadata
|
||||
source_scores = metadata.get("source_scores", {})
|
||||
vector_score: Optional[float] = None
|
||||
structural_score: Optional[float] = None
|
||||
|
||||
if mode == "vector":
|
||||
# In pure vector mode, the main score is the vector score
|
||||
vector_score = result.score
|
||||
structural_score = None
|
||||
elif mode == "structural":
|
||||
# In structural mode, no vector score
|
||||
vector_score = None
|
||||
structural_score = result.score
|
||||
else:
|
||||
# Fusion mode - try to extract individual scores
|
||||
if "vector" in source_scores:
|
||||
vector_score = source_scores["vector"]
|
||||
elif metadata.get("fusion_method") == "simple_weighted":
|
||||
# From weighted fusion
|
||||
vector_score = source_scores.get("vector")
|
||||
|
||||
# Structural score approximation (from exact/fuzzy FTS)
|
||||
fts_scores = []
|
||||
if "exact" in source_scores:
|
||||
fts_scores.append(source_scores["exact"])
|
||||
if "fuzzy" in source_scores:
|
||||
fts_scores.append(source_scores["fuzzy"])
|
||||
if "splade" in source_scores:
|
||||
fts_scores.append(source_scores["splade"])
|
||||
|
||||
if fts_scores:
|
||||
structural_score = max(fts_scores)
|
||||
|
||||
# Build snippet
|
||||
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
|
||||
if len(snippet) > 500:
|
||||
snippet = snippet[:500] + "..."
|
||||
|
||||
# Generate match reason if requested
|
||||
match_reason = None
|
||||
if include_match_reason:
|
||||
match_reason = _generate_match_reason(
|
||||
query=query,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
snippet=snippet,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
)
|
||||
|
||||
semantic_result = SemanticResult(
|
||||
symbol_name=symbol_name or Path(result.path).stem,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path=result.path,
|
||||
line=start_line or 1,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
fusion_score=fusion_score,
|
||||
snippet=snippet,
|
||||
match_reason=match_reason,
|
||||
)
|
||||
|
||||
semantic_results.append(semantic_result)
|
||||
|
||||
# Sort by fusion_score descending
|
||||
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
|
||||
|
||||
return semantic_results
|
||||
|
||||
|
||||
def _generate_match_reason(
|
||||
query: str,
|
||||
symbol_name: Optional[str],
|
||||
symbol_kind: Optional[str],
|
||||
snippet: str,
|
||||
vector_score: Optional[float],
|
||||
structural_score: Optional[float],
|
||||
) -> str:
|
||||
"""Generate human-readable match reason heuristically.
|
||||
|
||||
This is a simple heuristic-based approach, not LLM-powered.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
symbol_name: Symbol name if available
|
||||
symbol_kind: Symbol kind if available
|
||||
snippet: Code snippet
|
||||
vector_score: Vector similarity score
|
||||
structural_score: Structural match score
|
||||
|
||||
Returns:
|
||||
Human-readable explanation string
|
||||
"""
|
||||
reasons = []
|
||||
|
||||
# Check for direct name match
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
if symbol_name:
|
||||
name_lower = symbol_name.lower()
|
||||
# Direct substring match
|
||||
if query_lower in name_lower or name_lower in query_lower:
|
||||
reasons.append(f"Symbol name '{symbol_name}' matches query")
|
||||
# Word overlap
|
||||
name_words = set(_split_camel_case(symbol_name).lower().split())
|
||||
overlap = query_words & name_words
|
||||
if overlap and not reasons:
|
||||
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
|
||||
|
||||
# Check snippet for keyword matches
|
||||
snippet_lower = snippet.lower()
|
||||
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
|
||||
if matching_words and len(reasons) < 2:
|
||||
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
|
||||
|
||||
# Add score-based reasoning
|
||||
if vector_score is not None and vector_score > 0.7:
|
||||
reasons.append("High semantic similarity")
|
||||
elif vector_score is not None and vector_score > 0.5:
|
||||
reasons.append("Moderate semantic similarity")
|
||||
|
||||
if structural_score is not None and structural_score > 0.8:
|
||||
reasons.append("Strong structural match")
|
||||
|
||||
# Symbol kind context
|
||||
if symbol_kind and len(reasons) < 3:
|
||||
reasons.append(f"Matched {symbol_kind}")
|
||||
|
||||
if not reasons:
|
||||
reasons.append("Partial relevance based on content analysis")
|
||||
|
||||
return "; ".join(reasons[:3])
|
||||
|
||||
|
||||
def _split_camel_case(name: str) -> str:
|
||||
"""Split camelCase and PascalCase to words.
|
||||
|
||||
Args:
|
||||
name: Symbol name in camelCase or PascalCase
|
||||
|
||||
Returns:
|
||||
Space-separated words
|
||||
"""
|
||||
import re
|
||||
|
||||
# Insert space before uppercase letters
|
||||
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
# Insert space before uppercase followed by lowercase
|
||||
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
|
||||
# Replace underscores with spaces
|
||||
result = result.replace("_", " ")
|
||||
|
||||
return result
|
||||
146
codex-lens/src/codexlens/api/symbols.py
Normal file
146
codex-lens/src/codexlens/api/symbols.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""workspace_symbols API implementation.
|
||||
|
||||
This module provides the workspace_symbols() function for searching
|
||||
symbols across the entire workspace with prefix matching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import SymbolInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def workspace_symbols(
|
||||
project_root: str,
|
||||
query: str,
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
file_pattern: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[SymbolInfo]:
|
||||
"""Search for symbols across the entire workspace.
|
||||
|
||||
Uses prefix matching for efficient searching.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
query: Search query (prefix match)
|
||||
kind_filter: Optional list of symbol kinds to include
|
||||
(e.g., ["class", "function"])
|
||||
file_pattern: Optional glob pattern to filter by file path
|
||||
(e.g., "*.py", "src/**/*.ts")
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of SymbolInfo sorted by score
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search with prefix matching
|
||||
# If kind_filter has multiple kinds, we need to search for each
|
||||
all_results: List[Symbol] = []
|
||||
|
||||
if kind_filter and len(kind_filter) > 0:
|
||||
# Search for each kind separately
|
||||
for kind in kind_filter:
|
||||
results = global_index.search(
|
||||
name=query,
|
||||
kind=kind,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
all_results.extend(results)
|
||||
else:
|
||||
# Search without kind filter
|
||||
all_results = global_index.search(
|
||||
name=query,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
|
||||
|
||||
# Apply file pattern filter if specified
|
||||
if file_pattern:
|
||||
all_results = [
|
||||
sym for sym in all_results
|
||||
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
|
||||
]
|
||||
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
|
||||
|
||||
# Convert to SymbolInfo and sort by relevance
|
||||
symbols = [
|
||||
SymbolInfo(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
container=None, # Could extract from parent
|
||||
score=_calculate_score(sym.name, query)
|
||||
)
|
||||
for sym in all_results
|
||||
]
|
||||
|
||||
# Sort by score (exact matches first)
|
||||
symbols.sort(key=lambda s: s.score, reverse=True)
|
||||
|
||||
return symbols[:limit]
|
||||
|
||||
|
||||
def _calculate_score(symbol_name: str, query: str) -> float:
|
||||
"""Calculate relevance score for a symbol match.
|
||||
|
||||
Scoring:
|
||||
- Exact match: 1.0
|
||||
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
|
||||
- Case-insensitive match: 0.6
|
||||
|
||||
Args:
|
||||
symbol_name: The matched symbol name
|
||||
query: The search query
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
if symbol_name == query:
|
||||
return 1.0
|
||||
|
||||
if symbol_name.lower() == query.lower():
|
||||
return 0.9
|
||||
|
||||
if symbol_name.startswith(query):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.8 + 0.2 * ratio
|
||||
|
||||
if symbol_name.lower().startswith(query.lower()):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.6 + 0.2 * ratio
|
||||
|
||||
return 0.5
|
||||
153
codex-lens/src/codexlens/api/utils.py
Normal file
153
codex-lens/src/codexlens/api/utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Utility functions for the codexlens API.
|
||||
|
||||
This module provides helper functions for:
|
||||
- Project resolution
|
||||
- Relationship type normalization
|
||||
- Result ranking by proximity
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, TypeVar, Callable
|
||||
|
||||
from .models import DefinitionResult
|
||||
|
||||
|
||||
# Type variable for generic ranking
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def resolve_project(project_root: str) -> Path:
|
||||
"""Resolve and validate project root path.
|
||||
|
||||
Args:
|
||||
project_root: Path to project root (relative or absolute)
|
||||
|
||||
Returns:
|
||||
Resolved absolute Path
|
||||
|
||||
Raises:
|
||||
ValueError: If path does not exist or is not a directory
|
||||
"""
|
||||
path = Path(project_root).resolve()
|
||||
if not path.exists():
|
||||
raise ValueError(f"Project root does not exist: {path}")
|
||||
if not path.is_dir():
|
||||
raise ValueError(f"Project root is not a directory: {path}")
|
||||
return path
|
||||
|
||||
|
||||
# Relationship type normalization mapping
|
||||
_RELATIONSHIP_NORMALIZATION = {
|
||||
# Plural to singular
|
||||
"calls": "call",
|
||||
"imports": "import",
|
||||
"inherits": "inheritance",
|
||||
"uses": "use",
|
||||
# Already normalized (passthrough)
|
||||
"call": "call",
|
||||
"import": "import",
|
||||
"inheritance": "inheritance",
|
||||
"use": "use",
|
||||
"type_annotation": "type_annotation",
|
||||
}
|
||||
|
||||
|
||||
def normalize_relationship_type(relationship: str) -> str:
|
||||
"""Normalize relationship type to canonical form.
|
||||
|
||||
Converts plural forms and variations to standard singular forms:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
- 'uses' -> 'use'
|
||||
|
||||
Args:
|
||||
relationship: Raw relationship type string
|
||||
|
||||
Returns:
|
||||
Normalized relationship type
|
||||
|
||||
Examples:
|
||||
>>> normalize_relationship_type('calls')
|
||||
'call'
|
||||
>>> normalize_relationship_type('inherits')
|
||||
'inheritance'
|
||||
>>> normalize_relationship_type('call')
|
||||
'call'
|
||||
"""
|
||||
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
|
||||
|
||||
|
||||
def rank_by_proximity(
|
||||
results: List[DefinitionResult],
|
||||
file_context: Optional[str] = None
|
||||
) -> List[DefinitionResult]:
|
||||
"""Rank results by file path proximity to context.
|
||||
|
||||
V1 Implementation: Uses path-based proximity scoring.
|
||||
|
||||
Scoring algorithm:
|
||||
1. Same directory: highest score (100)
|
||||
2. Otherwise: length of common path prefix
|
||||
|
||||
Args:
|
||||
results: List of definition results to rank
|
||||
file_context: Reference file path for proximity calculation.
|
||||
If None, returns results unchanged.
|
||||
|
||||
Returns:
|
||||
Results sorted by proximity score (highest first)
|
||||
|
||||
Examples:
|
||||
>>> results = [
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/b/c.py", line=1, end_line=10),
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/x/y.py", line=1, end_line=10),
|
||||
... ]
|
||||
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
|
||||
>>> ranked[0].file_path
|
||||
'/a/b/c.py'
|
||||
"""
|
||||
if not file_context or not results:
|
||||
return results
|
||||
|
||||
def proximity_score(result: DefinitionResult) -> int:
|
||||
"""Calculate proximity score for a result."""
|
||||
result_dir = os.path.dirname(result.file_path)
|
||||
context_dir = os.path.dirname(file_context)
|
||||
|
||||
# Same directory gets highest score
|
||||
if result_dir == context_dir:
|
||||
return 100
|
||||
|
||||
# Otherwise, score by common path prefix length
|
||||
try:
|
||||
common = os.path.commonpath([result.file_path, file_context])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
return sorted(results, key=proximity_score, reverse=True)
|
||||
|
||||
|
||||
def rank_by_score(
|
||||
results: List[T],
|
||||
score_fn: Callable[[T], float],
|
||||
reverse: bool = True
|
||||
) -> List[T]:
|
||||
"""Generic ranking function by custom score.
|
||||
|
||||
Args:
|
||||
results: List of items to rank
|
||||
score_fn: Function to extract score from item
|
||||
reverse: If True, highest scores first (default)
|
||||
|
||||
Returns:
|
||||
Sorted list
|
||||
"""
|
||||
return sorted(results, key=score_fn, reverse=reverse)
|
||||
@@ -154,6 +154,13 @@ class Config:
|
||||
cascade_fine_k: int = 10 # Number of final results after reranking
|
||||
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
|
||||
|
||||
# Staged cascade search configuration (4-stage pipeline)
|
||||
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
|
||||
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
|
||||
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
|
||||
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
|
||||
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
|
||||
|
||||
# RRF fusion configuration
|
||||
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
|
||||
rrf_k: int = 60 # RRF constant (default 60)
|
||||
|
||||
7
codex-lens/src/codexlens/lsp/__init__.py
Normal file
7
codex-lens/src/codexlens/lsp/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""codex-lens Language Server Protocol implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codexlens.lsp.server import CodexLensLanguageServer, main
|
||||
|
||||
__all__ = ["CodexLensLanguageServer", "main"]
|
||||
551
codex-lens/src/codexlens/lsp/handlers.py
Normal file
551
codex-lens/src/codexlens/lsp/handlers.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""LSP request handlers for codex-lens.
|
||||
|
||||
This module contains handlers for LSP requests:
|
||||
- textDocument/definition
|
||||
- textDocument/completion
|
||||
- workspace/symbol
|
||||
- textDocument/didSave
|
||||
- textDocument/hover
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
from codexlens.lsp.server import server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Symbol kind mapping from codex-lens to LSP
|
||||
SYMBOL_KIND_MAP = {
|
||||
"class": lsp.SymbolKind.Class,
|
||||
"function": lsp.SymbolKind.Function,
|
||||
"method": lsp.SymbolKind.Method,
|
||||
"variable": lsp.SymbolKind.Variable,
|
||||
"constant": lsp.SymbolKind.Constant,
|
||||
"property": lsp.SymbolKind.Property,
|
||||
"field": lsp.SymbolKind.Field,
|
||||
"interface": lsp.SymbolKind.Interface,
|
||||
"module": lsp.SymbolKind.Module,
|
||||
"namespace": lsp.SymbolKind.Namespace,
|
||||
"package": lsp.SymbolKind.Package,
|
||||
"enum": lsp.SymbolKind.Enum,
|
||||
"enum_member": lsp.SymbolKind.EnumMember,
|
||||
"struct": lsp.SymbolKind.Struct,
|
||||
"type": lsp.SymbolKind.TypeParameter,
|
||||
"type_alias": lsp.SymbolKind.TypeParameter,
|
||||
}
|
||||
|
||||
# Completion kind mapping from codex-lens to LSP
|
||||
COMPLETION_KIND_MAP = {
|
||||
"class": lsp.CompletionItemKind.Class,
|
||||
"function": lsp.CompletionItemKind.Function,
|
||||
"method": lsp.CompletionItemKind.Method,
|
||||
"variable": lsp.CompletionItemKind.Variable,
|
||||
"constant": lsp.CompletionItemKind.Constant,
|
||||
"property": lsp.CompletionItemKind.Property,
|
||||
"field": lsp.CompletionItemKind.Field,
|
||||
"interface": lsp.CompletionItemKind.Interface,
|
||||
"module": lsp.CompletionItemKind.Module,
|
||||
"enum": lsp.CompletionItemKind.Enum,
|
||||
"enum_member": lsp.CompletionItemKind.EnumMember,
|
||||
"struct": lsp.CompletionItemKind.Struct,
|
||||
"type": lsp.CompletionItemKind.TypeParameter,
|
||||
"type_alias": lsp.CompletionItemKind.TypeParameter,
|
||||
}
|
||||
|
||||
|
||||
def _path_to_uri(path: Union[str, Path]) -> str:
|
||||
"""Convert a file path to a URI.
|
||||
|
||||
Args:
|
||||
path: File path (string or Path object)
|
||||
|
||||
Returns:
|
||||
File URI string
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
# Handle Windows paths
|
||||
if path_str.startswith("/"):
|
||||
return f"file://{quote(path_str)}"
|
||||
else:
|
||||
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
|
||||
|
||||
|
||||
def _uri_to_path(uri: str) -> Path:
|
||||
"""Convert a URI to a file path.
|
||||
|
||||
Args:
|
||||
uri: File URI string
|
||||
|
||||
Returns:
|
||||
Path object
|
||||
"""
|
||||
path = uri.replace("file:///", "").replace("file://", "")
|
||||
return Path(unquote(path))
|
||||
|
||||
|
||||
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
|
||||
"""Extract the word at the given position in the document.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Word at position, or None if no word found
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return None
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
return None
|
||||
|
||||
# Find word boundaries
|
||||
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
|
||||
for match in word_pattern.finditer(line_text):
|
||||
if match.start() <= character <= match.end():
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
|
||||
"""Extract the incomplete word prefix at the given position.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Prefix string (may be empty)
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return ""
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
character = len(line_text)
|
||||
|
||||
# Extract text before cursor
|
||||
before_cursor = line_text[:character]
|
||||
|
||||
# Find the start of the current word
|
||||
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
|
||||
"""Convert a codex-lens Symbol to an LSP Location.
|
||||
|
||||
Args:
|
||||
symbol: codex-lens Symbol object
|
||||
|
||||
Returns:
|
||||
LSP Location, or None if symbol has no file
|
||||
"""
|
||||
if not symbol.file:
|
||||
return None
|
||||
|
||||
# LSP uses 0-based lines, codex-lens uses 1-based
|
||||
start_line = max(0, symbol.range[0] - 1)
|
||||
end_line = max(0, symbol.range[1] - 1)
|
||||
|
||||
return lsp.Location(
|
||||
uri=_path_to_uri(symbol.file),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(line=start_line, character=0),
|
||||
end=lsp.Position(line=end_line, character=0),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
|
||||
"""Map codex-lens symbol kind to LSP SymbolKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP SymbolKind
|
||||
"""
|
||||
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
|
||||
|
||||
|
||||
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
|
||||
"""Map codex-lens symbol kind to LSP CompletionItemKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP CompletionItemKind
|
||||
"""
|
||||
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LSP Request Handlers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
|
||||
def lsp_definition(
|
||||
params: lsp.DefinitionParams,
|
||||
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
|
||||
"""Handle textDocument/definition request.
|
||||
|
||||
Finds the definition of the symbol at the cursor position.
|
||||
"""
|
||||
if not server.global_index:
|
||||
logger.debug("No global index available for definition lookup")
|
||||
return None
|
||||
|
||||
# Get document
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
# Get word at position
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
logger.debug("No word found at position")
|
||||
return None
|
||||
|
||||
logger.debug("Looking up definition for: %s", word)
|
||||
|
||||
# Search for exact symbol match
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=False, # Exact match preferred
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
if not exact_matches:
|
||||
# Fall back to prefix search
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=True,
|
||||
)
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
if not exact_matches:
|
||||
logger.debug("No definition found for: %s", word)
|
||||
return None
|
||||
|
||||
# Convert to LSP locations
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
if len(locations) == 1:
|
||||
return locations[0]
|
||||
elif locations:
|
||||
return locations
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error looking up definition: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
|
||||
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
|
||||
"""Handle textDocument/references request.
|
||||
|
||||
Finds all references to the symbol at the cursor position using
|
||||
the code_relationships table for accurate call-site tracking.
|
||||
Falls back to same-name symbol search if search_engine is unavailable.
|
||||
"""
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Finding references for: %s", word)
|
||||
|
||||
try:
|
||||
# Try using search_engine.search_references() for accurate reference tracking
|
||||
if server.search_engine and server.workspace_root:
|
||||
references = server.search_engine.search_references(
|
||||
symbol_name=word,
|
||||
source_path=server.workspace_root,
|
||||
limit=200,
|
||||
)
|
||||
|
||||
if references:
|
||||
locations = []
|
||||
for ref in references:
|
||||
locations.append(
|
||||
lsp.Location(
|
||||
uri=_path_to_uri(ref.file_path),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column,
|
||||
),
|
||||
end=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column + len(word),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return locations if locations else None
|
||||
|
||||
# Fallback: search for symbols with same name using global_index
|
||||
if server.global_index:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=100,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact matches
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
return locations if locations else None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error finding references: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
|
||||
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
|
||||
"""Handle textDocument/completion request.
|
||||
|
||||
Provides code completion suggestions based on indexed symbols.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
prefix = _get_prefix_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not prefix or len(prefix) < 2:
|
||||
# Require at least 2 characters for completion
|
||||
return None
|
||||
|
||||
logger.debug("Completing prefix: %s", prefix)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=prefix,
|
||||
limit=50,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
# Convert to completion items
|
||||
items = []
|
||||
seen_names = set()
|
||||
|
||||
for sym in symbols:
|
||||
if sym.name in seen_names:
|
||||
continue
|
||||
seen_names.add(sym.name)
|
||||
|
||||
items.append(
|
||||
lsp.CompletionItem(
|
||||
label=sym.name,
|
||||
kind=_symbol_kind_to_completion_kind(sym.kind),
|
||||
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
|
||||
sort_text=sym.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return lsp.CompletionList(
|
||||
is_incomplete=len(symbols) >= 50,
|
||||
items=items,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting completions: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
|
||||
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
|
||||
"""Handle textDocument/hover request.
|
||||
|
||||
Provides hover information for the symbol at the cursor position
|
||||
using HoverProvider for rich symbol information including
|
||||
signature, documentation, and location.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Hover for: %s", word)
|
||||
|
||||
try:
|
||||
# Use HoverProvider for rich symbol information
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
provider = HoverProvider(server.global_index, server.registry)
|
||||
info = provider.get_hover_info(word)
|
||||
|
||||
if not info:
|
||||
return None
|
||||
|
||||
# Format as markdown with signature and location
|
||||
content = provider.format_hover_markdown(info)
|
||||
|
||||
return lsp.Hover(
|
||||
contents=lsp.MarkupContent(
|
||||
kind=lsp.MarkupKind.Markdown,
|
||||
value=content,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting hover info: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.WORKSPACE_SYMBOL)
|
||||
def lsp_workspace_symbol(
|
||||
params: lsp.WorkspaceSymbolParams,
|
||||
) -> Optional[List[lsp.SymbolInformation]]:
|
||||
"""Handle workspace/symbol request.
|
||||
|
||||
Searches for symbols across the workspace.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
query = params.query
|
||||
if not query or len(query) < 2:
|
||||
return None
|
||||
|
||||
logger.debug("Workspace symbol search: %s", query)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=query,
|
||||
limit=100,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
result = []
|
||||
for sym in symbols:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
result.append(
|
||||
lsp.SymbolInformation(
|
||||
name=sym.name,
|
||||
kind=_symbol_kind_to_lsp(sym.kind),
|
||||
location=loc,
|
||||
container_name=Path(sym.file).parent.name if sym.file else None,
|
||||
)
|
||||
)
|
||||
|
||||
return result if result else None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error searching workspace symbols: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
|
||||
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didSave notification.
|
||||
|
||||
Triggers incremental re-indexing of the saved file.
|
||||
Note: Full incremental indexing requires WatcherManager integration,
|
||||
which is planned for Phase 2.
|
||||
"""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.info("File saved: %s", file_path)
|
||||
|
||||
# Phase 1: Just log the save event
|
||||
# Phase 2 will integrate with WatcherManager for incremental indexing
|
||||
# if server.watcher_manager:
|
||||
# server.watcher_manager.trigger_reindex(file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
|
||||
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didOpen notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File opened: %s", file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
|
||||
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didClose notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File closed: %s", file_path)
|
||||
177
codex-lens/src/codexlens/lsp/providers.py
Normal file
177
codex-lens/src/codexlens/lsp/providers.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""LSP feature providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol."""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: tuple # (start_line, end_line)
|
||||
|
||||
|
||||
class HoverProvider:
|
||||
"""Provides hover information for symbols."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
registry: Optional["RegistryStore"] = None,
|
||||
) -> None:
|
||||
"""Initialize hover provider.
|
||||
|
||||
Args:
|
||||
global_index: Global symbol index for lookups
|
||||
registry: Optional registry store for index path resolution
|
||||
"""
|
||||
self.global_index = global_index
|
||||
self.registry = registry
|
||||
|
||||
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
|
||||
"""Get hover information for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to look up
|
||||
|
||||
Returns:
|
||||
HoverInfo or None if symbol not found
|
||||
"""
|
||||
# Look up symbol in global index using exact match
|
||||
symbols = self.global_index.search(
|
||||
name=symbol_name,
|
||||
limit=1,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == symbol_name]
|
||||
|
||||
if not exact_matches:
|
||||
return None
|
||||
|
||||
symbol = exact_matches[0]
|
||||
|
||||
# Extract signature from source file
|
||||
signature = self._extract_signature(symbol)
|
||||
|
||||
# Symbol uses 'file' attribute and 'range' tuple
|
||||
file_path = symbol.file or ""
|
||||
start_line, end_line = symbol.range
|
||||
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=signature,
|
||||
documentation=None, # Symbol doesn't have docstring field
|
||||
file_path=file_path,
|
||||
line_range=(start_line, end_line),
|
||||
)
|
||||
|
||||
def _extract_signature(self, symbol) -> str:
|
||||
"""Extract function/class signature from source file.
|
||||
|
||||
Args:
|
||||
symbol: Symbol object with file and range information
|
||||
|
||||
Returns:
|
||||
Extracted signature string or fallback kind + name
|
||||
"""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Extract signature lines (first line of definition + continuation)
|
||||
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
|
||||
if start_line >= len(lines) or start_line < 0:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
signature_lines = []
|
||||
first_line = lines[start_line]
|
||||
signature_lines.append(first_line)
|
||||
|
||||
# Continue if multiline signature (no closing paren + colon yet)
|
||||
# Look for patterns like "def func(", "class Foo(", etc.
|
||||
i = start_line + 1
|
||||
max_lines = min(start_line + 5, len(lines))
|
||||
while i < max_lines:
|
||||
line = signature_lines[-1]
|
||||
# Stop if we see closing pattern
|
||||
if "):" in line or line.rstrip().endswith(":"):
|
||||
break
|
||||
signature_lines.append(lines[i])
|
||||
i += 1
|
||||
|
||||
return "\n".join(signature_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
def format_hover_markdown(self, info: HoverInfo) -> str:
|
||||
"""Format hover info as Markdown.
|
||||
|
||||
Args:
|
||||
info: HoverInfo object to format
|
||||
|
||||
Returns:
|
||||
Markdown-formatted hover content
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Detect language for code fence based on file extension
|
||||
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
|
||||
lang_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".jsx": "javascript",
|
||||
".java": "java",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".cs": "csharp",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
}
|
||||
lang = lang_map.get(ext, "")
|
||||
|
||||
# Code block with signature
|
||||
parts.append(f"```{lang}\n{info.signature}\n```")
|
||||
|
||||
# Documentation if available
|
||||
if info.documentation:
|
||||
parts.append(f"\n---\n\n{info.documentation}")
|
||||
|
||||
# Location info
|
||||
file_name = Path(info.file_path).name if info.file_path else "unknown"
|
||||
parts.append(
|
||||
f"\n---\n\n*{info.kind}* defined in "
|
||||
f"`{file_name}` "
|
||||
f"(line {info.line_range[0]})"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
263
codex-lens/src/codexlens/lsp/server.py
Normal file
263
codex-lens/src/codexlens/lsp/server.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""codex-lens LSP Server implementation using pygls.
|
||||
|
||||
This module provides the main Language Server class and entry point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
from pygls.lsp.server import LanguageServer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodexLensLanguageServer(LanguageServer):
|
||||
"""Language Server for codex-lens code indexing.
|
||||
|
||||
Provides IDE features using codex-lens symbol index:
|
||||
- Go to Definition
|
||||
- Find References
|
||||
- Code Completion
|
||||
- Hover Information
|
||||
- Workspace Symbol Search
|
||||
|
||||
Attributes:
|
||||
registry: Global project registry for path lookups
|
||||
mapper: Path mapper for source/index conversions
|
||||
global_index: Project-wide symbol index
|
||||
search_engine: Chain search engine for symbol search
|
||||
workspace_root: Current workspace root path
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="codexlens-lsp", version="0.1.0")
|
||||
|
||||
self.registry: Optional[RegistryStore] = None
|
||||
self.mapper: Optional[PathMapper] = None
|
||||
self.global_index: Optional[GlobalSymbolIndex] = None
|
||||
self.search_engine: Optional[ChainSearchEngine] = None
|
||||
self.workspace_root: Optional[Path] = None
|
||||
self._config: Optional[Config] = None
|
||||
|
||||
def initialize_components(self, workspace_root: Path) -> bool:
|
||||
"""Initialize codex-lens components for the workspace.
|
||||
|
||||
Args:
|
||||
workspace_root: Root path of the workspace
|
||||
|
||||
Returns:
|
||||
True if initialization succeeded, False otherwise
|
||||
"""
|
||||
self.workspace_root = workspace_root.resolve()
|
||||
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
|
||||
|
||||
try:
|
||||
# Initialize registry
|
||||
self.registry = RegistryStore()
|
||||
self.registry.initialize()
|
||||
|
||||
# Initialize path mapper
|
||||
self.mapper = PathMapper()
|
||||
|
||||
# Try to find project in registry
|
||||
project_info = self.registry.find_by_source_path(str(self.workspace_root))
|
||||
|
||||
if project_info:
|
||||
project_id = int(project_info["id"])
|
||||
index_root = Path(project_info["index_root"])
|
||||
|
||||
# Initialize global symbol index
|
||||
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||
self.global_index = GlobalSymbolIndex(global_db, project_id)
|
||||
self.global_index.initialize()
|
||||
|
||||
# Initialize search engine
|
||||
self._config = Config()
|
||||
self.search_engine = ChainSearchEngine(
|
||||
registry=self.registry,
|
||||
mapper=self.mapper,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Workspace not indexed by codex-lens: %s. "
|
||||
"Run 'codexlens index %s' to index first.",
|
||||
self.workspace_root,
|
||||
self.workspace_root,
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize codex-lens: %s", exc)
|
||||
return False
|
||||
|
||||
def shutdown_components(self) -> None:
|
||||
"""Clean up codex-lens components."""
|
||||
if self.global_index:
|
||||
try:
|
||||
self.global_index.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing global index: %s", exc)
|
||||
self.global_index = None
|
||||
|
||||
if self.search_engine:
|
||||
try:
|
||||
self.search_engine.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing search engine: %s", exc)
|
||||
self.search_engine = None
|
||||
|
||||
if self.registry:
|
||||
try:
|
||||
self.registry.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing registry: %s", exc)
|
||||
self.registry = None
|
||||
|
||||
|
||||
# Create server instance
|
||||
server = CodexLensLanguageServer()
|
||||
|
||||
|
||||
@server.feature(lsp.INITIALIZE)
|
||||
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
|
||||
"""Handle LSP initialize request."""
|
||||
logger.info("LSP initialize request received")
|
||||
|
||||
# Get workspace root
|
||||
workspace_root: Optional[Path] = None
|
||||
if params.root_uri:
|
||||
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
|
||||
elif params.root_path:
|
||||
workspace_root = Path(params.root_path)
|
||||
|
||||
if workspace_root:
|
||||
server.initialize_components(workspace_root)
|
||||
|
||||
# Declare server capabilities
|
||||
return lsp.InitializeResult(
|
||||
capabilities=lsp.ServerCapabilities(
|
||||
text_document_sync=lsp.TextDocumentSyncOptions(
|
||||
open_close=True,
|
||||
change=lsp.TextDocumentSyncKind.Incremental,
|
||||
save=lsp.SaveOptions(include_text=False),
|
||||
),
|
||||
definition_provider=True,
|
||||
references_provider=True,
|
||||
completion_provider=lsp.CompletionOptions(
|
||||
trigger_characters=[".", ":"],
|
||||
resolve_provider=False,
|
||||
),
|
||||
hover_provider=True,
|
||||
workspace_symbol_provider=True,
|
||||
),
|
||||
server_info=lsp.ServerInfo(
|
||||
name="codexlens-lsp",
|
||||
version="0.1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@server.feature(lsp.SHUTDOWN)
|
||||
def lsp_shutdown(params: None) -> None:
|
||||
"""Handle LSP shutdown request."""
|
||||
logger.info("LSP shutdown request received")
|
||||
server.shutdown_components()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Entry point for codexlens-lsp command.
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success)
|
||||
"""
|
||||
# Import handlers to register them with the server
|
||||
# This must be done before starting the server
|
||||
import codexlens.lsp.handlers # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="codex-lens Language Server",
|
||||
prog="codexlens-lsp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stdio",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use stdio for communication (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tcp",
|
||||
action="store_true",
|
||||
help="Use TCP for communication",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default="127.0.0.1",
|
||||
help="TCP host (default: 127.0.0.1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=2087,
|
||||
help="TCP port (default: 2087)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Log level (default: INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
help="Log file path (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
log_handlers = []
|
||||
if args.log_file:
|
||||
log_handlers.append(logging.FileHandler(args.log_file))
|
||||
else:
|
||||
log_handlers.append(logging.StreamHandler(sys.stderr))
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
logger.info("Starting codexlens-lsp server")
|
||||
|
||||
if args.tcp:
|
||||
logger.info("Starting TCP server on %s:%d", args.host, args.port)
|
||||
server.start_tcp(args.host, args.port)
|
||||
else:
|
||||
logger.info("Starting stdio server")
|
||||
server.start_io()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
20
codex-lens/src/codexlens/mcp/__init__.py
Normal file
20
codex-lens/src/codexlens/mcp/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Model Context Protocol implementation for Claude Code integration."""
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"MCPContext",
|
||||
"SymbolInfo",
|
||||
"ReferenceInfo",
|
||||
"RelatedSymbol",
|
||||
"MCPProvider",
|
||||
"HookManager",
|
||||
"create_context_for_prompt",
|
||||
]
|
||||
170
codex-lens/src/codexlens/mcp/hooks.py
Normal file
170
codex-lens/src/codexlens/mcp/hooks.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Hook interfaces for Claude Code integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import MCPContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Manages hook registration and execution."""
|
||||
|
||||
def __init__(self, mcp_provider: "MCPProvider") -> None:
|
||||
self.mcp_provider = mcp_provider
|
||||
self._pre_hooks: Dict[str, Callable] = {}
|
||||
self._post_hooks: Dict[str, Callable] = {}
|
||||
|
||||
# Register default hooks
|
||||
self._register_default_hooks()
|
||||
|
||||
def _register_default_hooks(self) -> None:
|
||||
"""Register built-in hooks."""
|
||||
self._pre_hooks["explain"] = self._pre_explain_hook
|
||||
self._pre_hooks["refactor"] = self._pre_refactor_hook
|
||||
self._pre_hooks["document"] = self._pre_document_hook
|
||||
|
||||
def execute_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> Optional[MCPContext]:
|
||||
"""Execute pre-tool hook to gather context.
|
||||
|
||||
Args:
|
||||
action: The action being performed (e.g., "explain", "refactor")
|
||||
params: Parameters for the action
|
||||
|
||||
Returns:
|
||||
MCPContext to inject into prompt, or None
|
||||
"""
|
||||
hook = self._pre_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
logger.debug(f"No pre-hook for action: {action}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return hook(params)
|
||||
except Exception as e:
|
||||
logger.error(f"Pre-hook failed for {action}: {e}")
|
||||
return None
|
||||
|
||||
def execute_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
result: Any,
|
||||
) -> None:
|
||||
"""Execute post-tool hook for proactive caching.
|
||||
|
||||
Args:
|
||||
action: The action that was performed
|
||||
result: Result of the action
|
||||
"""
|
||||
hook = self._post_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
return
|
||||
|
||||
try:
|
||||
hook(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Post-hook failed for {action}: {e}")
|
||||
|
||||
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'explain' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="symbol_explanation",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'refactor' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="refactor_context",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
max_references=20,
|
||||
)
|
||||
|
||||
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'document' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
file_path = params.get("file_path")
|
||||
|
||||
if symbol_name:
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="documentation_context",
|
||||
include_references=False,
|
||||
include_related=True,
|
||||
)
|
||||
elif file_path:
|
||||
return self.mcp_provider.build_context_for_file(
|
||||
Path(file_path),
|
||||
context_type="file_documentation",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def register_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
|
||||
) -> None:
|
||||
"""Register a custom pre-tool hook."""
|
||||
self._pre_hooks[action] = hook
|
||||
|
||||
def register_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Any], None],
|
||||
) -> None:
|
||||
"""Register a custom post-tool hook."""
|
||||
self._post_hooks[action] = hook
|
||||
|
||||
|
||||
def create_context_for_prompt(
|
||||
mcp_provider: "MCPProvider",
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Create context string for prompt injection.
|
||||
|
||||
This is the main entry point for Claude Code hook integration.
|
||||
|
||||
Args:
|
||||
mcp_provider: The MCP provider instance
|
||||
action: Action being performed
|
||||
params: Action parameters
|
||||
|
||||
Returns:
|
||||
Formatted context string for prompt injection
|
||||
"""
|
||||
manager = HookManager(mcp_provider)
|
||||
context = manager.execute_pre_hook(action, params)
|
||||
|
||||
if context:
|
||||
return context.to_prompt_injection()
|
||||
|
||||
return ""
|
||||
202
codex-lens/src/codexlens/mcp/provider.py
Normal file
202
codex-lens/src/codexlens/mcp/provider.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""MCP context provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPProvider:
|
||||
"""Builds MCP context objects from codex-lens data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
search_engine: "ChainSearchEngine",
|
||||
registry: "RegistryStore",
|
||||
) -> None:
|
||||
self.global_index = global_index
|
||||
self.search_engine = search_engine
|
||||
self.registry = registry
|
||||
|
||||
def build_context(
|
||||
self,
|
||||
symbol_name: str,
|
||||
context_type: str = "symbol_explanation",
|
||||
include_references: bool = True,
|
||||
include_related: bool = True,
|
||||
max_references: int = 10,
|
||||
) -> Optional[MCPContext]:
|
||||
"""Build comprehensive context for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to contextualize
|
||||
context_type: Type of context being requested
|
||||
include_references: Whether to include reference locations
|
||||
include_related: Whether to include related symbols
|
||||
max_references: Maximum number of references to include
|
||||
|
||||
Returns:
|
||||
MCPContext object or None if symbol not found
|
||||
"""
|
||||
# Look up symbol
|
||||
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
|
||||
|
||||
if not symbols:
|
||||
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
|
||||
return None
|
||||
|
||||
symbol = symbols[0]
|
||||
|
||||
# Build SymbolInfo
|
||||
symbol_info = SymbolInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
file_path=symbol.file or "",
|
||||
line_start=symbol.range[0],
|
||||
line_end=symbol.range[1],
|
||||
signature=None, # Symbol entity doesn't have signature
|
||||
documentation=None, # Symbol entity doesn't have docstring
|
||||
)
|
||||
|
||||
# Extract definition source code
|
||||
definition = self._extract_definition(symbol)
|
||||
|
||||
# Get references
|
||||
references = []
|
||||
if include_references:
|
||||
refs = self.search_engine.search_references(
|
||||
symbol_name,
|
||||
limit=max_references,
|
||||
)
|
||||
references = [
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
context=r.context,
|
||||
relationship_type=r.relationship_type,
|
||||
)
|
||||
for r in refs
|
||||
]
|
||||
|
||||
# Get related symbols
|
||||
related_symbols = []
|
||||
if include_related:
|
||||
related_symbols = self._get_related_symbols(symbol)
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
symbol=symbol_info,
|
||||
definition=definition,
|
||||
references=references,
|
||||
related_symbols=related_symbols,
|
||||
metadata={
|
||||
"source": "codex-lens",
|
||||
},
|
||||
)
|
||||
|
||||
def _extract_definition(self, symbol) -> Optional[str]:
|
||||
"""Extract source code for symbol definition."""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return None
|
||||
|
||||
content = file_path.read_text(encoding='utf-8', errors='ignore')
|
||||
lines = content.split("\n")
|
||||
|
||||
start = symbol.range[0] - 1
|
||||
end = symbol.range[1]
|
||||
|
||||
if start >= len(lines):
|
||||
return None
|
||||
|
||||
return "\n".join(lines[start:end])
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract definition: {e}")
|
||||
return None
|
||||
|
||||
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
|
||||
"""Get symbols related to the given symbol."""
|
||||
related = []
|
||||
|
||||
try:
|
||||
# Search for symbols that might be related by name patterns
|
||||
# This is a simplified implementation - could be enhanced with relationship data
|
||||
|
||||
# Look for imports/callers via reference search
|
||||
refs = self.search_engine.search_references(symbol.name, limit=20)
|
||||
|
||||
seen_names = set()
|
||||
for ref in refs:
|
||||
# Extract potential symbol name from context
|
||||
if ref.relationship_type and ref.relationship_type not in seen_names:
|
||||
related.append(RelatedSymbol(
|
||||
name=f"{Path(ref.file_path).stem}",
|
||||
kind="module",
|
||||
relationship=ref.relationship_type,
|
||||
file_path=ref.file_path,
|
||||
))
|
||||
seen_names.add(ref.relationship_type)
|
||||
if len(related) >= 10:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get related symbols: {e}")
|
||||
|
||||
return related
|
||||
|
||||
def build_context_for_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
context_type: str = "file_overview",
|
||||
) -> MCPContext:
|
||||
"""Build context for an entire file."""
|
||||
# Try to get symbols by searching with file path
|
||||
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
|
||||
symbols = []
|
||||
|
||||
# Search for common symbols that might be in this file
|
||||
# This is a simplified approach - a full implementation would query by file path
|
||||
try:
|
||||
# Use the global index to search for symbols from this file
|
||||
file_str = str(file_path.resolve())
|
||||
# Get all symbols and filter by file path (not efficient but works)
|
||||
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
|
||||
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get file symbols: {e}")
|
||||
|
||||
related = [
|
||||
RelatedSymbol(
|
||||
name=s.name,
|
||||
kind=s.kind,
|
||||
relationship="defines",
|
||||
)
|
||||
for s in symbols
|
||||
]
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
related_symbols=related,
|
||||
metadata={
|
||||
"file_path": str(file_path),
|
||||
"symbol_count": len(symbols),
|
||||
},
|
||||
)
|
||||
113
codex-lens/src/codexlens/mcp/schema.py
Normal file
113
codex-lens/src/codexlens/mcp/schema.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""MCP data models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Information about a code symbol."""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line_start: int
|
||||
line_end: int
|
||||
signature: Optional[str] = None
|
||||
documentation: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceInfo:
|
||||
"""Information about a symbol reference."""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelatedSymbol:
|
||||
"""Related symbol (import, call target, etc.)."""
|
||||
name: str
|
||||
kind: str
|
||||
relationship: str # "imports", "calls", "inherits", "uses"
|
||||
file_path: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPContext:
|
||||
"""Model Context Protocol context object.
|
||||
|
||||
This is the structured context that gets injected into
|
||||
LLM prompts to provide code understanding.
|
||||
"""
|
||||
version: str = "1.0"
|
||||
context_type: str = "code_context"
|
||||
symbol: Optional[SymbolInfo] = None
|
||||
definition: Optional[str] = None
|
||||
references: List[ReferenceInfo] = field(default_factory=list)
|
||||
related_symbols: List[RelatedSymbol] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result = {
|
||||
"version": self.version,
|
||||
"context_type": self.context_type,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
if self.symbol:
|
||||
result["symbol"] = self.symbol.to_dict()
|
||||
if self.definition:
|
||||
result["definition"] = self.definition
|
||||
if self.references:
|
||||
result["references"] = [r.to_dict() for r in self.references]
|
||||
if self.related_symbols:
|
||||
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
|
||||
|
||||
return result
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""Serialize to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
def to_prompt_injection(self) -> str:
|
||||
"""Format for injection into LLM prompt."""
|
||||
parts = ["<code_context>"]
|
||||
|
||||
if self.symbol:
|
||||
parts.append(f"## Symbol: {self.symbol.name}")
|
||||
parts.append(f"Type: {self.symbol.kind}")
|
||||
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
|
||||
|
||||
if self.definition:
|
||||
parts.append("\n## Definition")
|
||||
parts.append(f"```\n{self.definition}\n```")
|
||||
|
||||
if self.references:
|
||||
parts.append(f"\n## References ({len(self.references)} found)")
|
||||
for ref in self.references[:5]: # Limit to 5
|
||||
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
|
||||
parts.append(f" ```\n {ref.context}\n ```")
|
||||
|
||||
if self.related_symbols:
|
||||
parts.append("\n## Related Symbols")
|
||||
for sym in self.related_symbols[:10]: # Limit to 10
|
||||
parts.append(f"- {sym.name} ({sym.relationship})")
|
||||
|
||||
parts.append("</code_context>")
|
||||
return "\n".join(parts)
|
||||
@@ -6,10 +6,48 @@ from .chain_search import (
|
||||
quick_search,
|
||||
)
|
||||
|
||||
# Clustering availability flag (lazy import pattern)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
_clustering_import_error: str | None = None
|
||||
|
||||
try:
|
||||
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
|
||||
from .clustering import check_clustering_available
|
||||
CLUSTERING_AVAILABLE = _clustering_flag
|
||||
except ImportError as e:
|
||||
_clustering_import_error = str(e)
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Fallback when clustering module not loadable."""
|
||||
return False, _clustering_import_error
|
||||
|
||||
|
||||
# Clustering module exports (conditional)
|
||||
try:
|
||||
from .clustering import (
|
||||
BaseClusteringStrategy,
|
||||
ClusteringConfig,
|
||||
ClusteringStrategyFactory,
|
||||
get_strategy,
|
||||
)
|
||||
_clustering_exports = [
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
]
|
||||
except ImportError:
|
||||
_clustering_exports = []
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChainSearchEngine",
|
||||
"SearchOptions",
|
||||
"SearchStats",
|
||||
"ChainSearchResult",
|
||||
"quick_search",
|
||||
# Clustering
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
*_clustering_exports,
|
||||
]
|
||||
|
||||
@@ -116,6 +116,24 @@ class ChainSearchResult:
|
||||
related_results: List[SearchResult] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceResult:
|
||||
"""Result from reference search in code_relationships table.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the file containing the reference
|
||||
line: Line number where the reference occurs (1-based)
|
||||
column: Column number where the reference occurs (0-based)
|
||||
context: Surrounding code snippet for context
|
||||
relationship_type: Type of relationship (call, import, inheritance, etc.)
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
|
||||
|
||||
class ChainSearchEngine:
|
||||
"""Parallel chain search engine for hierarchical directory indexes.
|
||||
|
||||
@@ -810,7 +828,7 @@ class ChainSearchEngine:
|
||||
k: int = 10,
|
||||
coarse_k: int = 100,
|
||||
options: Optional[SearchOptions] = None,
|
||||
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank"]] = None,
|
||||
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank", "staged"]] = None,
|
||||
) -> ChainSearchResult:
|
||||
"""Unified cascade search entry point with strategy selection.
|
||||
|
||||
@@ -819,6 +837,7 @@ class ChainSearchEngine:
|
||||
- "hybrid": Uses FTS+SPLADE+Vector coarse ranking + cross-encoder reranking (original)
|
||||
- "binary_rerank": Uses binary vector coarse ranking + cross-encoder reranking (best balance)
|
||||
- "dense_rerank": Uses dense vector coarse ranking + cross-encoder reranking
|
||||
- "staged": 4-stage pipeline: binary -> LSP expand -> clustering -> optional rerank
|
||||
|
||||
The strategy is determined with the following priority:
|
||||
1. The `strategy` parameter (e.g., from CLI --cascade-strategy option)
|
||||
@@ -831,7 +850,7 @@ class ChainSearchEngine:
|
||||
k: Number of final results to return (default 10)
|
||||
coarse_k: Number of coarse candidates from first stage (default 100)
|
||||
options: Search configuration (uses defaults if None)
|
||||
strategy: Cascade strategy - "binary", "hybrid", or "binary_rerank".
|
||||
strategy: Cascade strategy - "binary", "hybrid", "binary_rerank", "dense_rerank", or "staged".
|
||||
|
||||
Returns:
|
||||
ChainSearchResult with reranked results and statistics
|
||||
@@ -844,10 +863,12 @@ class ChainSearchEngine:
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="hybrid")
|
||||
>>> # Use binary + cross-encoder (best balance of speed and quality)
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="binary_rerank")
|
||||
>>> # Use 4-stage pipeline (binary + LSP expand + clustering + optional rerank)
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="staged")
|
||||
"""
|
||||
# Strategy priority: parameter > config > default
|
||||
effective_strategy = strategy
|
||||
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank")
|
||||
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank", "staged")
|
||||
if effective_strategy is None:
|
||||
# Not passed via parameter, check config
|
||||
if self._config is not None:
|
||||
@@ -865,9 +886,635 @@ class ChainSearchEngine:
|
||||
return self.binary_rerank_cascade_search(query, source_path, k, coarse_k, options)
|
||||
elif effective_strategy == "dense_rerank":
|
||||
return self.dense_rerank_cascade_search(query, source_path, k, coarse_k, options)
|
||||
elif effective_strategy == "staged":
|
||||
return self.staged_cascade_search(query, source_path, k, coarse_k, options)
|
||||
else:
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
|
||||
def staged_cascade_search(
|
||||
self,
|
||||
query: str,
|
||||
source_path: Path,
|
||||
k: int = 10,
|
||||
coarse_k: int = 100,
|
||||
options: Optional[SearchOptions] = None,
|
||||
) -> ChainSearchResult:
|
||||
"""Execute 4-stage cascade search pipeline with binary, LSP expansion, clustering, and optional reranking.
|
||||
|
||||
Staged cascade search process:
|
||||
1. Stage 1 (Binary Coarse): Fast binary vector search using Hamming distance
|
||||
to quickly filter to coarse_k candidates (256-bit binary vectors)
|
||||
2. Stage 2 (LSP Expansion): Expand coarse candidates using GraphExpander to
|
||||
include related symbols (definitions, references, callers/callees)
|
||||
3. Stage 3 (Clustering): Use configurable clustering strategy to group similar
|
||||
results and select representative results from each cluster
|
||||
4. Stage 4 (Optional Rerank): If config.enable_staged_rerank is True, apply
|
||||
cross-encoder reranking for final precision
|
||||
|
||||
This approach combines the speed of binary search with graph-based context
|
||||
expansion and diversity-preserving clustering for high-quality results.
|
||||
|
||||
Performance characteristics:
|
||||
- Stage 1: O(N) binary search with SIMD acceleration (~8ms)
|
||||
- Stage 2: O(k * d) graph traversal where d is expansion depth
|
||||
- Stage 3: O(n^2) clustering on expanded candidates
|
||||
- Stage 4: Optional cross-encoder reranking (API call)
|
||||
|
||||
Args:
|
||||
query: Natural language or keyword query string
|
||||
source_path: Starting directory path
|
||||
k: Number of final results to return (default 10)
|
||||
coarse_k: Number of coarse candidates from first stage (default 100)
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
ChainSearchResult with per-stage statistics
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
>>> result = engine.staged_cascade_search(
|
||||
... "authentication handler",
|
||||
... Path("D:/project/src"),
|
||||
... k=10,
|
||||
... coarse_k=100
|
||||
... )
|
||||
>>> for r in result.results:
|
||||
... print(f"{r.path}: {r.score:.3f}")
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
self.logger.warning(
|
||||
"NumPy not available, falling back to hybrid cascade search"
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
|
||||
options = options or SearchOptions()
|
||||
start_time = time.time()
|
||||
stats = SearchStats()
|
||||
|
||||
# Per-stage timing stats
|
||||
stage_times: Dict[str, float] = {}
|
||||
stage_counts: Dict[str, int] = {}
|
||||
|
||||
# Use config defaults if available
|
||||
if self._config is not None:
|
||||
if hasattr(self._config, "cascade_coarse_k"):
|
||||
coarse_k = coarse_k or self._config.cascade_coarse_k
|
||||
if hasattr(self._config, "cascade_fine_k"):
|
||||
k = k or self._config.cascade_fine_k
|
||||
|
||||
# Step 1: Find starting index
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=[],
|
||||
symbols=[],
|
||||
stats=stats
|
||||
)
|
||||
|
||||
# Step 2: Collect all index paths
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
stats.dirs_searched = len(index_paths)
|
||||
|
||||
if not index_paths:
|
||||
self.logger.warning(f"No indexes collected from {start_index}")
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=[],
|
||||
symbols=[],
|
||||
stats=stats
|
||||
)
|
||||
|
||||
# ========== Stage 1: Binary Coarse Search ==========
|
||||
stage1_start = time.time()
|
||||
coarse_results, index_root = self._stage1_binary_search(
|
||||
query, index_paths, coarse_k, stats
|
||||
)
|
||||
stage_times["stage1_binary_ms"] = (time.time() - stage1_start) * 1000
|
||||
stage_counts["stage1_candidates"] = len(coarse_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 1: Binary search found %d candidates in %.2fms",
|
||||
len(coarse_results), stage_times["stage1_binary_ms"]
|
||||
)
|
||||
|
||||
if not coarse_results:
|
||||
self.logger.debug("No binary candidates found, falling back to hybrid cascade")
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
|
||||
# ========== Stage 2: LSP Graph Expansion ==========
|
||||
stage2_start = time.time()
|
||||
expanded_results = self._stage2_lsp_expand(coarse_results, index_root)
|
||||
stage_times["stage2_expand_ms"] = (time.time() - stage2_start) * 1000
|
||||
stage_counts["stage2_expanded"] = len(expanded_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 2: LSP expansion %d -> %d results in %.2fms",
|
||||
len(coarse_results), len(expanded_results), stage_times["stage2_expand_ms"]
|
||||
)
|
||||
|
||||
# ========== Stage 3: Clustering and Representative Selection ==========
|
||||
stage3_start = time.time()
|
||||
clustered_results = self._stage3_cluster_prune(expanded_results, k * 2)
|
||||
stage_times["stage3_cluster_ms"] = (time.time() - stage3_start) * 1000
|
||||
stage_counts["stage3_clustered"] = len(clustered_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 3: Clustering %d -> %d representatives in %.2fms",
|
||||
len(expanded_results), len(clustered_results), stage_times["stage3_cluster_ms"]
|
||||
)
|
||||
|
||||
# ========== Stage 4: Optional Cross-Encoder Reranking ==========
|
||||
enable_rerank = False
|
||||
if self._config is not None:
|
||||
enable_rerank = getattr(self._config, "enable_staged_rerank", False)
|
||||
|
||||
if enable_rerank:
|
||||
stage4_start = time.time()
|
||||
final_results = self._stage4_optional_rerank(query, clustered_results, k)
|
||||
stage_times["stage4_rerank_ms"] = (time.time() - stage4_start) * 1000
|
||||
stage_counts["stage4_reranked"] = len(final_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 4: Reranking %d -> %d results in %.2fms",
|
||||
len(clustered_results), len(final_results), stage_times["stage4_rerank_ms"]
|
||||
)
|
||||
else:
|
||||
# Skip reranking, just take top-k by score
|
||||
final_results = sorted(
|
||||
clustered_results, key=lambda r: r.score, reverse=True
|
||||
)[:k]
|
||||
stage_counts["stage4_reranked"] = len(final_results)
|
||||
|
||||
# Deduplicate by path (keep highest score)
|
||||
path_to_result: Dict[str, SearchResult] = {}
|
||||
for result in final_results:
|
||||
if result.path not in path_to_result or result.score > path_to_result[result.path].score:
|
||||
path_to_result[result.path] = result
|
||||
|
||||
final_results = list(path_to_result.values())[:k]
|
||||
|
||||
# Optional: grouping of similar results
|
||||
if options.group_results:
|
||||
from codexlens.search.ranking import group_similar_results
|
||||
final_results = group_similar_results(
|
||||
final_results, score_threshold_abs=options.grouping_threshold
|
||||
)
|
||||
|
||||
stats.files_matched = len(final_results)
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Add per-stage stats to errors field (as JSON for now, will be proper field later)
|
||||
stage_stats_json = json.dumps({
|
||||
"stage_times": stage_times,
|
||||
"stage_counts": stage_counts,
|
||||
})
|
||||
stats.errors.append(f"STAGE_STATS:{stage_stats_json}")
|
||||
|
||||
self.logger.debug(
|
||||
"Staged cascade search complete: %d results in %.2fms "
|
||||
"(stage1=%.1fms, stage2=%.1fms, stage3=%.1fms)",
|
||||
len(final_results),
|
||||
stats.time_ms,
|
||||
stage_times.get("stage1_binary_ms", 0),
|
||||
stage_times.get("stage2_expand_ms", 0),
|
||||
stage_times.get("stage3_cluster_ms", 0),
|
||||
)
|
||||
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=final_results,
|
||||
symbols=[],
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def _stage1_binary_search(
|
||||
self,
|
||||
query: str,
|
||||
index_paths: List[Path],
|
||||
coarse_k: int,
|
||||
stats: SearchStats,
|
||||
) -> Tuple[List[SearchResult], Optional[Path]]:
|
||||
"""Stage 1: Binary vector coarse search using Hamming distance.
|
||||
|
||||
Reuses the binary coarse search logic from binary_cascade_search.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
index_paths: List of index database paths to search
|
||||
coarse_k: Number of coarse candidates to retrieve
|
||||
stats: SearchStats to update with errors
|
||||
|
||||
Returns:
|
||||
Tuple of (list of SearchResult objects, index_root path or None)
|
||||
"""
|
||||
# Initialize binary embedding backend
|
||||
try:
|
||||
from codexlens.indexing.embedding import BinaryEmbeddingBackend
|
||||
except ImportError as exc:
|
||||
self.logger.warning(
|
||||
"BinaryEmbeddingBackend not available: %s", exc
|
||||
)
|
||||
return [], None
|
||||
|
||||
# Try centralized BinarySearcher first (preferred for mmap indexes)
|
||||
index_root = index_paths[0].parent if index_paths else None
|
||||
coarse_candidates: List[Tuple[int, int, Path]] = [] # (chunk_id, distance, index_path)
|
||||
used_centralized = False
|
||||
|
||||
if index_root:
|
||||
binary_searcher = self._get_centralized_binary_searcher(index_root)
|
||||
if binary_searcher is not None:
|
||||
try:
|
||||
from codexlens.semantic.embedder import Embedder
|
||||
embedder = Embedder()
|
||||
query_dense = embedder.embed_to_numpy([query])[0]
|
||||
|
||||
results = binary_searcher.search(query_dense, top_k=coarse_k)
|
||||
for chunk_id, distance in results:
|
||||
coarse_candidates.append((chunk_id, distance, index_root))
|
||||
if coarse_candidates:
|
||||
used_centralized = True
|
||||
self.logger.debug(
|
||||
"Stage 1 centralized binary search: %d candidates", len(results)
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Centralized binary search failed: {exc}")
|
||||
|
||||
if not used_centralized:
|
||||
# Fallback to per-directory binary indexes
|
||||
use_gpu = True
|
||||
if self._config is not None:
|
||||
use_gpu = getattr(self._config, "embedding_use_gpu", True)
|
||||
|
||||
try:
|
||||
binary_backend = BinaryEmbeddingBackend(use_gpu=use_gpu)
|
||||
query_binary = binary_backend.embed_packed([query])[0]
|
||||
except Exception as exc:
|
||||
self.logger.warning(f"Failed to generate binary query embedding: {exc}")
|
||||
return [], index_root
|
||||
|
||||
for index_path in index_paths:
|
||||
try:
|
||||
binary_index = self._get_or_create_binary_index(index_path)
|
||||
if binary_index is None or binary_index.count() == 0:
|
||||
continue
|
||||
ids, distances = binary_index.search(query_binary, coarse_k)
|
||||
for chunk_id, dist in zip(ids, distances):
|
||||
coarse_candidates.append((chunk_id, dist, index_path))
|
||||
except Exception as exc:
|
||||
self.logger.debug(
|
||||
"Binary search failed for %s: %s", index_path, exc
|
||||
)
|
||||
|
||||
if not coarse_candidates:
|
||||
return [], index_root
|
||||
|
||||
# Sort by Hamming distance and take top coarse_k
|
||||
coarse_candidates.sort(key=lambda x: x[1])
|
||||
coarse_candidates = coarse_candidates[:coarse_k]
|
||||
|
||||
# Build SearchResult objects from candidates
|
||||
coarse_results: List[SearchResult] = []
|
||||
|
||||
# Group candidates by index path for efficient retrieval
|
||||
candidates_by_index: Dict[Path, List[int]] = {}
|
||||
for chunk_id, _, idx_path in coarse_candidates:
|
||||
if idx_path not in candidates_by_index:
|
||||
candidates_by_index[idx_path] = []
|
||||
candidates_by_index[idx_path].append(chunk_id)
|
||||
|
||||
# Retrieve chunk content
|
||||
import sqlite3
|
||||
central_meta_path = index_root / VECTORS_META_DB_NAME if index_root else None
|
||||
central_meta_store = None
|
||||
if central_meta_path and central_meta_path.exists():
|
||||
central_meta_store = VectorMetadataStore(central_meta_path)
|
||||
|
||||
for idx_path, chunk_ids in candidates_by_index.items():
|
||||
try:
|
||||
chunks_data = []
|
||||
if central_meta_store:
|
||||
chunks_data = central_meta_store.get_chunks_by_ids(chunk_ids)
|
||||
|
||||
if not chunks_data and used_centralized:
|
||||
meta_db_path = idx_path / VECTORS_META_DB_NAME
|
||||
if meta_db_path.exists():
|
||||
meta_store = VectorMetadataStore(meta_db_path)
|
||||
chunks_data = meta_store.get_chunks_by_ids(chunk_ids)
|
||||
|
||||
if not chunks_data:
|
||||
try:
|
||||
conn = sqlite3.connect(str(idx_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
cursor = conn.execute(
|
||||
f"""
|
||||
SELECT id, file_path, content, metadata, category
|
||||
FROM semantic_chunks
|
||||
WHERE id IN ({placeholders})
|
||||
""",
|
||||
chunk_ids
|
||||
)
|
||||
chunks_data = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"file_path": row["file_path"],
|
||||
"content": row["content"],
|
||||
"metadata": row["metadata"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for chunk in chunks_data:
|
||||
chunk_id = chunk.get("id") or chunk.get("chunk_id")
|
||||
distance = next(
|
||||
(d for cid, d, _ in coarse_candidates if cid == chunk_id),
|
||||
256
|
||||
)
|
||||
score = 1.0 - (distance / 256.0)
|
||||
|
||||
content = chunk.get("content", "")
|
||||
|
||||
# Extract symbol info from metadata if available
|
||||
metadata = chunk.get("metadata")
|
||||
symbol_name = None
|
||||
symbol_kind = None
|
||||
start_line = None
|
||||
end_line = None
|
||||
if metadata:
|
||||
try:
|
||||
meta_dict = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
symbol_name = meta_dict.get("symbol_name")
|
||||
symbol_kind = meta_dict.get("symbol_kind")
|
||||
start_line = meta_dict.get("start_line")
|
||||
end_line = meta_dict.get("end_line")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result = SearchResult(
|
||||
path=chunk.get("file_path", ""),
|
||||
score=float(score),
|
||||
excerpt=content[:500] if content else "",
|
||||
content=content,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
)
|
||||
coarse_results.append(result)
|
||||
except Exception as exc:
|
||||
self.logger.debug(
|
||||
"Failed to retrieve chunks from %s: %s", idx_path, exc
|
||||
)
|
||||
stats.errors.append(f"Stage 1 chunk retrieval failed for {idx_path}: {exc}")
|
||||
|
||||
return coarse_results, index_root
|
||||
|
||||
def _stage2_lsp_expand(
|
||||
self,
|
||||
coarse_results: List[SearchResult],
|
||||
index_root: Optional[Path],
|
||||
) -> List[SearchResult]:
|
||||
"""Stage 2: LSP-based graph expansion using GraphExpander.
|
||||
|
||||
Expands coarse results with related symbols (definitions, references,
|
||||
callers, callees) using precomputed graph neighbors.
|
||||
|
||||
Args:
|
||||
coarse_results: Results from Stage 1 binary search
|
||||
index_root: Root path of the index (for graph database access)
|
||||
|
||||
Returns:
|
||||
Combined list of original results plus expanded related results
|
||||
"""
|
||||
if not coarse_results or index_root is None:
|
||||
return coarse_results
|
||||
|
||||
try:
|
||||
from codexlens.search.graph_expander import GraphExpander
|
||||
|
||||
# Get expansion depth from config
|
||||
depth = 2
|
||||
if self._config is not None:
|
||||
depth = getattr(self._config, "graph_expansion_depth", 2)
|
||||
|
||||
expander = GraphExpander(self.mapper, config=self._config)
|
||||
|
||||
# Expand top results (limit expansion to avoid explosion)
|
||||
max_expand = min(10, len(coarse_results))
|
||||
max_related = 50
|
||||
|
||||
related_results = expander.expand(
|
||||
coarse_results,
|
||||
depth=depth,
|
||||
max_expand=max_expand,
|
||||
max_related=max_related,
|
||||
)
|
||||
|
||||
if related_results:
|
||||
self.logger.debug(
|
||||
"Stage 2 expanded %d base results to %d related symbols",
|
||||
len(coarse_results), len(related_results)
|
||||
)
|
||||
|
||||
# Combine: original results + related results
|
||||
# Keep original results first (higher relevance)
|
||||
combined = list(coarse_results)
|
||||
seen_keys = {(r.path, r.symbol_name, r.start_line) for r in coarse_results}
|
||||
|
||||
for related in related_results:
|
||||
key = (related.path, related.symbol_name, related.start_line)
|
||||
if key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
combined.append(related)
|
||||
|
||||
return combined
|
||||
|
||||
except ImportError as exc:
|
||||
self.logger.debug("GraphExpander not available: %s", exc)
|
||||
return coarse_results
|
||||
except Exception as exc:
|
||||
self.logger.debug("Stage 2 LSP expansion failed: %s", exc)
|
||||
return coarse_results
|
||||
|
||||
def _stage3_cluster_prune(
|
||||
self,
|
||||
expanded_results: List[SearchResult],
|
||||
target_count: int,
|
||||
) -> List[SearchResult]:
|
||||
"""Stage 3: Cluster expanded results and select representatives.
|
||||
|
||||
Uses the extensible clustering infrastructure from codexlens.search.clustering
|
||||
to group similar results and select the best representative from each cluster.
|
||||
|
||||
Args:
|
||||
expanded_results: Results from Stage 2 expansion
|
||||
target_count: Target number of representative results
|
||||
|
||||
Returns:
|
||||
List of representative results (one per cluster)
|
||||
"""
|
||||
if not expanded_results:
|
||||
return []
|
||||
|
||||
# If few results, skip clustering
|
||||
if len(expanded_results) <= target_count:
|
||||
return expanded_results
|
||||
|
||||
try:
|
||||
from codexlens.search.clustering import (
|
||||
ClusteringConfig,
|
||||
get_strategy,
|
||||
)
|
||||
|
||||
# Get clustering config from config
|
||||
strategy_name = "auto"
|
||||
min_cluster_size = 3
|
||||
|
||||
if self._config is not None:
|
||||
strategy_name = getattr(self._config, "staged_clustering_strategy", "auto")
|
||||
min_cluster_size = getattr(self._config, "staged_clustering_min_size", 3)
|
||||
|
||||
# Get embeddings for clustering
|
||||
# Try to get dense embeddings from results' content
|
||||
embeddings = self._get_embeddings_for_clustering(expanded_results)
|
||||
|
||||
if embeddings is None or len(embeddings) == 0:
|
||||
# No embeddings available, fall back to score-based selection
|
||||
self.logger.debug("No embeddings for clustering, using score-based selection")
|
||||
return sorted(
|
||||
expanded_results, key=lambda r: r.score, reverse=True
|
||||
)[:target_count]
|
||||
|
||||
# Create clustering config
|
||||
config = ClusteringConfig(
|
||||
min_cluster_size=min(min_cluster_size, max(2, len(expanded_results) // 5)),
|
||||
min_samples=2,
|
||||
metric="cosine",
|
||||
)
|
||||
|
||||
# Get strategy with fallback
|
||||
strategy = get_strategy(strategy_name, config, fallback=True)
|
||||
|
||||
# Cluster and select representatives
|
||||
representatives = strategy.fit_predict(embeddings, expanded_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Stage 3 clustered %d results into %d representatives using %s",
|
||||
len(expanded_results), len(representatives), type(strategy).__name__
|
||||
)
|
||||
|
||||
# If clustering returned too few, supplement with top-scored unclustered
|
||||
if len(representatives) < target_count:
|
||||
rep_paths = {r.path for r in representatives}
|
||||
remaining = [r for r in expanded_results if r.path not in rep_paths]
|
||||
remaining_sorted = sorted(remaining, key=lambda r: r.score, reverse=True)
|
||||
representatives.extend(remaining_sorted[:target_count - len(representatives)])
|
||||
|
||||
return representatives[:target_count]
|
||||
|
||||
except ImportError as exc:
|
||||
self.logger.debug("Clustering not available: %s", exc)
|
||||
return sorted(
|
||||
expanded_results, key=lambda r: r.score, reverse=True
|
||||
)[:target_count]
|
||||
except Exception as exc:
|
||||
self.logger.debug("Stage 3 clustering failed: %s", exc)
|
||||
return sorted(
|
||||
expanded_results, key=lambda r: r.score, reverse=True
|
||||
)[:target_count]
|
||||
|
||||
def _stage4_optional_rerank(
|
||||
self,
|
||||
query: str,
|
||||
clustered_results: List[SearchResult],
|
||||
k: int,
|
||||
) -> List[SearchResult]:
|
||||
"""Stage 4: Optional cross-encoder reranking.
|
||||
|
||||
Applies cross-encoder reranking if enabled in config.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
clustered_results: Results from Stage 3 clustering
|
||||
k: Number of final results to return
|
||||
|
||||
Returns:
|
||||
Reranked results sorted by cross-encoder score
|
||||
"""
|
||||
if not clustered_results:
|
||||
return []
|
||||
|
||||
# Use existing _cross_encoder_rerank method
|
||||
return self._cross_encoder_rerank(query, clustered_results, k)
|
||||
|
||||
def _get_embeddings_for_clustering(
|
||||
self,
|
||||
results: List[SearchResult],
|
||||
) -> Optional["np.ndarray"]:
|
||||
"""Get dense embeddings for clustering results.
|
||||
|
||||
Tries to generate embeddings from result content for clustering.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
|
||||
Returns:
|
||||
NumPy array of embeddings or None if not available
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
return None
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
try:
|
||||
from codexlens.semantic.factory import get_embedder
|
||||
|
||||
# Get embedding settings from config
|
||||
embedding_backend = "fastembed"
|
||||
embedding_model = "code"
|
||||
use_gpu = True
|
||||
|
||||
if self._config is not None:
|
||||
embedding_backend = getattr(self._config, "embedding_backend", "fastembed")
|
||||
embedding_model = getattr(self._config, "embedding_model", "code")
|
||||
use_gpu = getattr(self._config, "embedding_use_gpu", True)
|
||||
|
||||
# Create embedder
|
||||
if embedding_backend == "litellm":
|
||||
embedder = get_embedder(backend="litellm", model=embedding_model)
|
||||
else:
|
||||
embedder = get_embedder(backend="fastembed", profile=embedding_model, use_gpu=use_gpu)
|
||||
|
||||
# Extract text content from results
|
||||
texts = []
|
||||
for result in results:
|
||||
# Use content if available, otherwise use excerpt
|
||||
text = result.content or result.excerpt or ""
|
||||
if not text and result.path:
|
||||
text = result.path
|
||||
texts.append(text[:2000]) # Limit text length
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = embedder.embed_to_numpy(texts)
|
||||
return embeddings
|
||||
|
||||
except ImportError as exc:
|
||||
self.logger.debug("Embedder not available for clustering: %s", exc)
|
||||
return None
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to generate embeddings for clustering: %s", exc)
|
||||
return None
|
||||
|
||||
def binary_rerank_cascade_search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -1990,6 +2637,220 @@ class ChainSearchEngine:
|
||||
index_paths, name, kind, options.total_limit
|
||||
)
|
||||
|
||||
def search_references(
|
||||
self,
|
||||
symbol_name: str,
|
||||
source_path: Optional[Path] = None,
|
||||
depth: int = -1,
|
||||
limit: int = 100,
|
||||
) -> List[ReferenceResult]:
|
||||
"""Find all references to a symbol across the project.
|
||||
|
||||
Searches the code_relationships table in all index databases to find
|
||||
where the given symbol is referenced (called, imported, inherited, etc.).
|
||||
|
||||
Args:
|
||||
symbol_name: Fully qualified or simple name of the symbol to find references to
|
||||
source_path: Starting path for search (default: workspace root from registry)
|
||||
depth: Search depth (-1 = unlimited, 0 = current dir only)
|
||||
limit: Maximum results to return (default 100)
|
||||
|
||||
Returns:
|
||||
List of ReferenceResult objects sorted by file path and line number
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> refs = engine.search_references("authenticate", Path("D:/project/src"))
|
||||
>>> for ref in refs[:10]:
|
||||
... print(f"{ref.file_path}:{ref.line} ({ref.relationship_type})")
|
||||
"""
|
||||
import sqlite3
|
||||
from concurrent.futures import as_completed
|
||||
|
||||
# Determine starting path
|
||||
if source_path is None:
|
||||
# Try to get workspace root from registry
|
||||
mappings = self.registry.list_mappings()
|
||||
if mappings:
|
||||
source_path = Path(mappings[0].source_path)
|
||||
else:
|
||||
self.logger.warning("No source path provided and no mappings in registry")
|
||||
return []
|
||||
|
||||
# Find starting index
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
# Collect all index paths
|
||||
index_paths = self._collect_index_paths(start_index, depth)
|
||||
if not index_paths:
|
||||
self.logger.debug(f"No indexes collected from {start_index}")
|
||||
return []
|
||||
|
||||
self.logger.debug(
|
||||
"Searching %d indexes for references to '%s'",
|
||||
len(index_paths), symbol_name
|
||||
)
|
||||
|
||||
# Search in parallel
|
||||
all_results: List[ReferenceResult] = []
|
||||
executor = self._get_executor()
|
||||
|
||||
def search_single_index(index_path: Path) -> List[ReferenceResult]:
|
||||
"""Search a single index for references."""
|
||||
results: List[ReferenceResult] = []
|
||||
try:
|
||||
conn = sqlite3.connect(str(index_path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Query code_relationships for references to this symbol
|
||||
# Match either target_qualified_name containing the symbol name
|
||||
# or an exact match on the last component
|
||||
# Try full_path first (new schema), fallback to path (old schema)
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT
|
||||
f.full_path as source_file,
|
||||
cr.source_line,
|
||||
cr.relationship_type,
|
||||
f.content
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name = ?
|
||||
ORDER BY f.full_path, cr.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(
|
||||
f"%{symbol_name}", # Ends with symbol name
|
||||
f"%.{symbol_name}", # Qualified name ending with .symbol_name
|
||||
symbol_name, # Exact match
|
||||
limit,
|
||||
)
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# Fallback for old schema with 'path' column
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT
|
||||
f.path as source_file,
|
||||
cr.source_line,
|
||||
cr.relationship_type,
|
||||
f.content
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name = ?
|
||||
ORDER BY f.path, cr.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(
|
||||
f"%{symbol_name}", # Ends with symbol name
|
||||
f"%.{symbol_name}", # Qualified name ending with .symbol_name
|
||||
symbol_name, # Exact match
|
||||
limit,
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row in rows:
|
||||
file_path = row["source_file"]
|
||||
line = row["source_line"] or 1
|
||||
rel_type = row["relationship_type"]
|
||||
content = row["content"] or ""
|
||||
|
||||
# Extract context (3 lines around reference)
|
||||
context = self._extract_context(content, line, context_lines=3)
|
||||
|
||||
results.append(ReferenceResult(
|
||||
file_path=file_path,
|
||||
line=line,
|
||||
column=0, # Column info not stored in code_relationships
|
||||
context=context,
|
||||
relationship_type=rel_type,
|
||||
))
|
||||
|
||||
conn.close()
|
||||
except sqlite3.DatabaseError as exc:
|
||||
self.logger.debug(
|
||||
"Failed to search references in %s: %s", index_path, exc
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.debug(
|
||||
"Unexpected error searching references in %s: %s", index_path, exc
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
# Submit parallel searches
|
||||
futures = {
|
||||
executor.submit(search_single_index, idx_path): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
results = future.result()
|
||||
all_results.extend(results)
|
||||
except Exception as exc:
|
||||
idx_path = futures[future]
|
||||
self.logger.debug(
|
||||
"Reference search failed for %s: %s", idx_path, exc
|
||||
)
|
||||
|
||||
# Deduplicate by (file_path, line)
|
||||
seen: set = set()
|
||||
unique_results: List[ReferenceResult] = []
|
||||
for ref in all_results:
|
||||
key = (ref.file_path, ref.line)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_results.append(ref)
|
||||
|
||||
# Sort by file path and line
|
||||
unique_results.sort(key=lambda r: (r.file_path, r.line))
|
||||
|
||||
# Apply limit
|
||||
return unique_results[:limit]
|
||||
|
||||
def _extract_context(
|
||||
self,
|
||||
content: str,
|
||||
line: int,
|
||||
context_lines: int = 3
|
||||
) -> str:
|
||||
"""Extract lines around a given line number from file content.
|
||||
|
||||
Args:
|
||||
content: Full file content
|
||||
line: Target line number (1-based)
|
||||
context_lines: Number of lines to include before and after
|
||||
|
||||
Returns:
|
||||
Context snippet as a string
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
lines = content.splitlines()
|
||||
total_lines = len(lines)
|
||||
|
||||
if line < 1 or line > total_lines:
|
||||
return ""
|
||||
|
||||
# Calculate range (0-indexed internally)
|
||||
start = max(0, line - 1 - context_lines)
|
||||
end = min(total_lines, line + context_lines)
|
||||
|
||||
context = lines[start:end]
|
||||
return "\n".join(context)
|
||||
|
||||
# === Internal Methods ===
|
||||
|
||||
def _find_start_index(self, source_path: Path) -> Optional[Path]:
|
||||
|
||||
124
codex-lens/src/codexlens/search/clustering/__init__.py
Normal file
124
codex-lens/src/codexlens/search/clustering/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Clustering strategies for the staged hybrid search pipeline.
|
||||
|
||||
This module provides extensible clustering infrastructure for grouping
|
||||
similar search results and selecting representative results.
|
||||
|
||||
Install with: pip install codexlens[clustering]
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import (
|
||||
... CLUSTERING_AVAILABLE,
|
||||
... ClusteringConfig,
|
||||
... get_strategy,
|
||||
... )
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy with fallback
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
>>>
|
||||
>>> # Or explicitly use a specific strategy
|
||||
>>> if CLUSTERING_AVAILABLE:
|
||||
... from codexlens.search.clustering import HDBSCANStrategy
|
||||
... strategy = HDBSCANStrategy(config)
|
||||
... representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Always export base classes and factory (no heavy dependencies)
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .factory import (
|
||||
ClusteringStrategyFactory,
|
||||
check_clustering_strategy_available,
|
||||
get_strategy,
|
||||
)
|
||||
from .noop_strategy import NoOpStrategy
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
|
||||
# Feature flag for clustering availability (hdbscan + sklearn)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
HDBSCAN_AVAILABLE = False
|
||||
DBSCAN_AVAILABLE = False
|
||||
_import_error: str | None = None
|
||||
|
||||
|
||||
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
|
||||
"""Detect if clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
|
||||
"""
|
||||
hdbscan_ok = False
|
||||
dbscan_ok = False
|
||||
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
hdbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
dbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
all_ok = hdbscan_ok and dbscan_ok
|
||||
error = None
|
||||
if not all_ok:
|
||||
missing = []
|
||||
if not hdbscan_ok:
|
||||
missing.append("hdbscan")
|
||||
if not dbscan_ok:
|
||||
missing.append("scikit-learn")
|
||||
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
|
||||
|
||||
return all_ok, hdbscan_ok, dbscan_ok, error
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
|
||||
_detect_clustering_available()
|
||||
)
|
||||
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Check if all clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
return CLUSTERING_AVAILABLE, _import_error
|
||||
|
||||
|
||||
# Conditionally export strategy implementations
|
||||
__all__ = [
|
||||
# Feature flags
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"HDBSCAN_AVAILABLE",
|
||||
"DBSCAN_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
# Base classes
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
# Factory
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
"check_clustering_strategy_available",
|
||||
# Always-available strategies
|
||||
"NoOpStrategy",
|
||||
"FrequencyStrategy",
|
||||
"FrequencyConfig",
|
||||
]
|
||||
|
||||
# Conditionally add strategy classes to __all__ and module namespace
|
||||
if HDBSCAN_AVAILABLE:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
|
||||
__all__.append("HDBSCANStrategy")
|
||||
|
||||
if DBSCAN_AVAILABLE:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
|
||||
__all__.append("DBSCANStrategy")
|
||||
153
codex-lens/src/codexlens/search/clustering/base.py
Normal file
153
codex-lens/src/codexlens/search/clustering/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Base classes for clustering strategies in the hybrid search pipeline.
|
||||
|
||||
This module defines the abstract base class for clustering strategies used
|
||||
in the staged hybrid search pipeline. Strategies cluster search results
|
||||
based on their embeddings and select representative results from each cluster.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusteringConfig:
|
||||
"""Configuration parameters for clustering strategies.
|
||||
|
||||
Attributes:
|
||||
min_cluster_size: Minimum number of results to form a cluster.
|
||||
HDBSCAN default is 5, but for search results 2-3 is often better.
|
||||
min_samples: Number of samples in a neighborhood for a point to be
|
||||
considered a core point. Lower values allow more clusters.
|
||||
metric: Distance metric for clustering. Common options:
|
||||
- 'euclidean': Standard L2 distance
|
||||
- 'cosine': Cosine distance (1 - cosine_similarity)
|
||||
- 'manhattan': L1 distance
|
||||
cluster_selection_epsilon: Distance threshold for cluster selection.
|
||||
Results within this distance may be merged into the same cluster.
|
||||
allow_single_cluster: If True, allow all results to form one cluster.
|
||||
Useful when results are very similar.
|
||||
prediction_data: If True, generate prediction data for new points.
|
||||
"""
|
||||
|
||||
min_cluster_size: int = 3
|
||||
min_samples: int = 2
|
||||
metric: str = "cosine"
|
||||
cluster_selection_epsilon: float = 0.0
|
||||
allow_single_cluster: bool = True
|
||||
prediction_data: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
if self.min_cluster_size < 2:
|
||||
raise ValueError("min_cluster_size must be >= 2")
|
||||
if self.min_samples < 1:
|
||||
raise ValueError("min_samples must be >= 1")
|
||||
if self.metric not in ("euclidean", "cosine", "manhattan"):
|
||||
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
|
||||
if self.cluster_selection_epsilon < 0:
|
||||
raise ValueError("cluster_selection_epsilon must be >= 0")
|
||||
|
||||
|
||||
class BaseClusteringStrategy(ABC):
|
||||
"""Abstract base class for clustering strategies.
|
||||
|
||||
Clustering strategies are used in the staged hybrid search pipeline to
|
||||
group similar search results and select representative results from each
|
||||
cluster, reducing redundancy while maintaining diversity.
|
||||
|
||||
Subclasses must implement:
|
||||
- cluster(): Group results into clusters based on embeddings
|
||||
- select_representatives(): Choose best result(s) from each cluster
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize the clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config = config or ClusteringConfig()
|
||||
|
||||
@abstractmethod
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results based on their embeddings.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
Used for additional metadata during clustering.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Results not assigned to any cluster
|
||||
(noise points) should be returned as single-element clusters.
|
||||
|
||||
Example:
|
||||
>>> strategy = HDBSCANStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
|
||||
>>> # Result indices 0, 2, 5 are in cluster 0
|
||||
>>> # Result indices 1, 3 are in cluster 1
|
||||
>>> # Result index 4 is a noise point (singleton cluster)
|
||||
>>> # Result indices 6, 7, 8 are in cluster 2
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
This method chooses the best result(s) from each cluster to include
|
||||
in the final search results. The selection can be based on:
|
||||
- Highest score within cluster
|
||||
- Closest to cluster centroid
|
||||
- Custom selection logic
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings array for centroid-based selection.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one or more per cluster,
|
||||
ordered by relevance (highest score first).
|
||||
|
||||
Example:
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns best result from each cluster
|
||||
"""
|
||||
...
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
197
codex-lens/src/codexlens/search/clustering/dbscan_strategy.py
Normal file
197
codex-lens/src/codexlens/search/clustering/dbscan_strategy.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""DBSCAN-based clustering strategy for search results.
|
||||
|
||||
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the fallback clustering strategy when HDBSCAN is not available.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class DBSCANStrategy(BaseClusteringStrategy):
|
||||
"""DBSCAN-based clustering strategy.
|
||||
|
||||
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
|
||||
DBSCAN requires an explicit eps parameter, which is auto-computed from the
|
||||
distance distribution if not provided.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = DBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
# Default eps percentile for auto-computation
|
||||
DEFAULT_EPS_PERCENTILE: float = 15.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
eps: Optional[float] = None,
|
||||
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
|
||||
) -> None:
|
||||
"""Initialize DBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
|
||||
from the distance distribution.
|
||||
eps_percentile: Percentile of pairwise distances to use for
|
||||
auto-computing eps. Default is 15th percentile.
|
||||
|
||||
Raises:
|
||||
ImportError: If sklearn is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.eps = eps
|
||||
self.eps_percentile = eps_percentile
|
||||
|
||||
# Validate sklearn is available
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"scikit-learn package is required for DBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def _compute_eps(self, embeddings: "np.ndarray") -> float:
|
||||
"""Auto-compute eps from pairwise distance distribution.
|
||||
|
||||
Uses the specified percentile of pairwise distances as eps,
|
||||
which typically captures local density well.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
|
||||
Returns:
|
||||
Computed eps value.
|
||||
"""
|
||||
import numpy as np
|
||||
from sklearn.metrics import pairwise_distances
|
||||
|
||||
# Compute pairwise distances
|
||||
distances = pairwise_distances(embeddings, metric=self.config.metric)
|
||||
|
||||
# Get upper triangle (excluding diagonal)
|
||||
upper_tri = distances[np.triu_indices_from(distances, k=1)]
|
||||
|
||||
if len(upper_tri) == 0:
|
||||
# Only one point, return a default small eps
|
||||
return 0.1
|
||||
|
||||
# Use percentile of distances as eps
|
||||
eps = float(np.percentile(upper_tri, self.eps_percentile))
|
||||
|
||||
# Ensure eps is positive
|
||||
return max(eps, 1e-6)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using DBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
from sklearn.cluster import DBSCAN
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: single result
|
||||
if n_results == 1:
|
||||
return [[0]]
|
||||
|
||||
# Determine eps value
|
||||
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
|
||||
|
||||
# Configure DBSCAN clusterer
|
||||
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
|
||||
clusterer = DBSCAN(
|
||||
eps=eps,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
202
codex-lens/src/codexlens/search/clustering/factory.py
Normal file
202
codex-lens/src/codexlens/search/clustering/factory.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Factory for creating clustering strategies.
|
||||
|
||||
Provides a unified interface for instantiating different clustering backends
|
||||
with automatic fallback chain: hdbscan -> dbscan -> noop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .noop_strategy import NoOpStrategy
|
||||
|
||||
|
||||
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific clustering strategy can be used.
|
||||
|
||||
Args:
|
||||
strategy: Strategy name to check. Options:
|
||||
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
|
||||
- "dbscan": DBSCAN clustering (requires sklearn)
|
||||
- "frequency": Frequency-based clustering (always available)
|
||||
- "noop": No-op strategy (always available)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
if strategy == "hdbscan":
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"hdbscan package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "dbscan":
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"scikit-learn package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "frequency":
|
||||
# Frequency strategy is always available (no external deps)
|
||||
return True, None
|
||||
|
||||
if strategy == "noop":
|
||||
return True, None
|
||||
|
||||
return False, (
|
||||
f"Invalid clustering strategy: {strategy}. "
|
||||
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
|
||||
)
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy: str = "hdbscan",
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
*,
|
||||
fallback: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Factory function to create clustering strategy with fallback chain.
|
||||
|
||||
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
|
||||
|
||||
Args:
|
||||
strategy: Clustering strategy to use. Options:
|
||||
- "hdbscan": HDBSCAN clustering (default, recommended)
|
||||
- "dbscan": DBSCAN clustering (fallback)
|
||||
- "frequency": Frequency-based clustering (groups by symbol occurrence)
|
||||
- "noop": No-op strategy (returns all results ungrouped)
|
||||
- "auto": Try hdbscan, then dbscan, then noop
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
For frequency strategy, pass FrequencyConfig for full control.
|
||||
fallback: If True (default), automatically fall back to next strategy
|
||||
in the chain when primary is unavailable. If False, raise ImportError
|
||||
when requested strategy is unavailable.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
For DBSCANStrategy: eps, eps_percentile
|
||||
For FrequencyStrategy: group_by, min_frequency, etc.
|
||||
|
||||
Returns:
|
||||
BaseClusteringStrategy: Configured clustering strategy instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If strategy is not recognized.
|
||||
ImportError: If required dependencies are not installed and fallback=False.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
|
||||
>>> strategy = get_strategy("hdbscan", config)
|
||||
>>> # Use frequency-based strategy
|
||||
>>> from codexlens.search.clustering import FrequencyConfig
|
||||
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = get_strategy("frequency", freq_config)
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
# Handle "auto" - try strategies in order
|
||||
if strategy == "auto":
|
||||
return _get_best_available_strategy(config, **kwargs)
|
||||
|
||||
if strategy == "hdbscan":
|
||||
ok, err = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
if fallback:
|
||||
# Try dbscan fallback
|
||||
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok_dbscan:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
# Final fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "dbscan":
|
||||
ok, err = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
if fallback:
|
||||
# Fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "frequency":
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
|
||||
if config is None or not isinstance(config, FrequencyConfig):
|
||||
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
|
||||
else:
|
||||
freq_config = config
|
||||
return FrequencyStrategy(freq_config)
|
||||
|
||||
if strategy == "noop":
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown clustering strategy: {strategy}. "
|
||||
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
|
||||
)
|
||||
|
||||
|
||||
def _get_best_available_strategy(
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Get the best available clustering strategy.
|
||||
|
||||
Tries strategies in order: hdbscan -> dbscan -> noop
|
||||
|
||||
Args:
|
||||
config: Clustering configuration.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
|
||||
Returns:
|
||||
Best available clustering strategy instance.
|
||||
"""
|
||||
# Try HDBSCAN first
|
||||
ok, _ = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
# Try DBSCAN second
|
||||
ok, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
# Fallback to NoOp
|
||||
return NoOpStrategy(config)
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
ClusteringStrategyFactory = type(
|
||||
"ClusteringStrategyFactory",
|
||||
(),
|
||||
{
|
||||
"get_strategy": staticmethod(get_strategy),
|
||||
"check_available": staticmethod(check_clustering_strategy_available),
|
||||
},
|
||||
)
|
||||
263
codex-lens/src/codexlens/search/clustering/frequency_strategy.py
Normal file
263
codex-lens/src/codexlens/search/clustering/frequency_strategy.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol/method name and prunes based on
|
||||
occurrence frequency. High-frequency symbols (frequently referenced methods)
|
||||
are considered more important and retained, while low-frequency results
|
||||
(potentially noise) can be filtered out.
|
||||
|
||||
Use cases:
|
||||
- Prioritize commonly called methods/functions
|
||||
- Filter out one-off results that may be less relevant
|
||||
- Deduplicate results pointing to the same symbol from different locations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrequencyConfig(ClusteringConfig):
|
||||
"""Configuration for frequency-based clustering strategy.
|
||||
|
||||
Attributes:
|
||||
group_by: Field to group results by for frequency counting.
|
||||
- 'symbol': Group by symbol_name (default, for method/function dedup)
|
||||
- 'file': Group by file path
|
||||
- 'symbol_kind': Group by symbol type (function, class, etc.)
|
||||
min_frequency: Minimum occurrence count to keep a result.
|
||||
Results appearing less than this are considered noise and pruned.
|
||||
max_representatives_per_group: Maximum results to keep per symbol group.
|
||||
frequency_weight: How much to boost score based on frequency.
|
||||
Final score = original_score * (1 + frequency_weight * log(frequency))
|
||||
keep_mode: How to handle low-frequency results.
|
||||
- 'filter': Remove results below min_frequency
|
||||
- 'demote': Keep but lower their score ranking
|
||||
"""
|
||||
|
||||
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
|
||||
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
|
||||
max_representatives_per_group: int = 3
|
||||
frequency_weight: float = 0.1 # Boost factor for frequency
|
||||
keep_mode: Literal["filter", "demote"] = "demote"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
# Skip parent validation since we don't use HDBSCAN params
|
||||
if self.min_frequency < 1:
|
||||
raise ValueError("min_frequency must be >= 1")
|
||||
if self.max_representatives_per_group < 1:
|
||||
raise ValueError("max_representatives_per_group must be >= 1")
|
||||
if self.frequency_weight < 0:
|
||||
raise ValueError("frequency_weight must be >= 0")
|
||||
if self.group_by not in ("symbol", "file", "symbol_kind"):
|
||||
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
|
||||
if self.keep_mode not in ("filter", "demote"):
|
||||
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
|
||||
|
||||
|
||||
class FrequencyStrategy(BaseClusteringStrategy):
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol name (or file/kind) and:
|
||||
1. Counts how many times each symbol appears in results
|
||||
2. Higher frequency = more important (frequently referenced method)
|
||||
3. Filters or demotes low-frequency results
|
||||
4. Selects top representatives from each frequency group
|
||||
|
||||
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
|
||||
- Does NOT require embeddings (works with metadata only)
|
||||
- Is very fast (O(n) complexity)
|
||||
- Is deterministic (no random initialization)
|
||||
- Works well for symbol-level deduplication
|
||||
|
||||
Example:
|
||||
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = FrequencyStrategy(config)
|
||||
>>> # Results with symbol "authenticate" appearing 5 times
|
||||
>>> # will be prioritized over "helper_func" appearing once
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
|
||||
"""Initialize the frequency strategy.
|
||||
|
||||
Args:
|
||||
config: Frequency configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config: FrequencyConfig = config or FrequencyConfig()
|
||||
|
||||
def _get_group_key(self, result: "SearchResult") -> str:
|
||||
"""Extract grouping key from a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult to extract key from.
|
||||
|
||||
Returns:
|
||||
String key for grouping (symbol name, file path, or kind).
|
||||
"""
|
||||
if self.config.group_by == "symbol":
|
||||
# Use symbol_name if available, otherwise fall back to file:line
|
||||
symbol = getattr(result, "symbol_name", None)
|
||||
if symbol:
|
||||
return str(symbol)
|
||||
# Fallback: use file path + start_line as pseudo-symbol
|
||||
start_line = getattr(result, "start_line", 0) or 0
|
||||
return f"{result.path}:{start_line}"
|
||||
|
||||
elif self.config.group_by == "file":
|
||||
return str(result.path)
|
||||
|
||||
elif self.config.group_by == "symbol_kind":
|
||||
kind = getattr(result, "symbol_kind", None)
|
||||
return str(kind) if kind else "unknown"
|
||||
|
||||
return str(result.path) # Default fallback
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Group search results by frequency of occurrence.
|
||||
|
||||
Note: This method ignores embeddings and groups by metadata only.
|
||||
The embeddings parameter is kept for interface compatibility.
|
||||
|
||||
Args:
|
||||
embeddings: Ignored (kept for interface compatibility).
|
||||
results: List of SearchResult objects to cluster.
|
||||
|
||||
Returns:
|
||||
List of clusters (groups), where each cluster contains indices
|
||||
of results with the same grouping key. Clusters are ordered by
|
||||
frequency (highest frequency first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Group results by key
|
||||
groups: Dict[str, List[int]] = defaultdict(list)
|
||||
for idx, result in enumerate(results):
|
||||
key = self._get_group_key(result)
|
||||
groups[key].append(idx)
|
||||
|
||||
# Sort groups by frequency (descending) then by key (for stability)
|
||||
sorted_groups = sorted(
|
||||
groups.items(),
|
||||
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
|
||||
)
|
||||
|
||||
# Convert to list of clusters
|
||||
clusters = [indices for _, indices in sorted_groups]
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results based on frequency and score.
|
||||
|
||||
For each frequency group:
|
||||
1. If frequency < min_frequency: filter or demote based on keep_mode
|
||||
2. Sort by score within group
|
||||
3. Apply frequency boost to scores
|
||||
4. Select top N representatives
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (used for tie-breaking if provided).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, ordered by
|
||||
frequency-adjusted score (highest first).
|
||||
"""
|
||||
import math
|
||||
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
demoted: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
frequency = len(cluster_indices)
|
||||
|
||||
# Get results in this cluster, sorted by score
|
||||
cluster_results = [results[i] for i in cluster_indices]
|
||||
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
|
||||
# Check frequency threshold
|
||||
if frequency < self.config.min_frequency:
|
||||
if self.config.keep_mode == "filter":
|
||||
# Skip low-frequency results entirely
|
||||
continue
|
||||
else: # demote mode
|
||||
# Keep but add to demoted list (lower priority)
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
demoted.append(result)
|
||||
continue
|
||||
|
||||
# Apply frequency boost and select top representatives
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
# Calculate frequency-boosted score
|
||||
original_score = getattr(result, "score", 0.0)
|
||||
# log(frequency + 1) to handle frequency=1 case smoothly
|
||||
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
|
||||
boosted_score = original_score * frequency_boost
|
||||
|
||||
# Create new result with boosted score and frequency metadata
|
||||
# Note: SearchResult might be immutable, so we preserve original
|
||||
# and track boosted score in metadata
|
||||
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
|
||||
result.metadata["frequency"] = frequency
|
||||
result.metadata["frequency_boosted_score"] = boosted_score
|
||||
|
||||
representatives.append(result)
|
||||
|
||||
# Sort representatives by boosted score (or original score as fallback)
|
||||
def get_sort_score(r: "SearchResult") -> float:
|
||||
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
|
||||
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
|
||||
return getattr(r, "score", 0.0)
|
||||
|
||||
representatives.sort(key=get_sort_score, reverse=True)
|
||||
|
||||
# Add demoted results at the end
|
||||
if demoted:
|
||||
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
representatives.extend(demoted)
|
||||
|
||||
return representatives
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array (may be ignored for frequency-based clustering).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
153
codex-lens/src/codexlens/search/clustering/hdbscan_strategy.py
Normal file
153
codex-lens/src/codexlens/search/clustering/hdbscan_strategy.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""HDBSCAN-based clustering strategy for search results.
|
||||
|
||||
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the primary clustering strategy for grouping similar search results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class HDBSCANStrategy(BaseClusteringStrategy):
|
||||
"""HDBSCAN-based clustering strategy.
|
||||
|
||||
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
|
||||
HDBSCAN is preferred over DBSCAN because it:
|
||||
- Automatically determines the number of clusters
|
||||
- Handles varying density clusters well
|
||||
- Identifies noise points (outliers) effectively
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = HDBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize HDBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
|
||||
Raises:
|
||||
ImportError: If hdbscan package is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
# Validate hdbscan is available
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"hdbscan package is required for HDBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using HDBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
import hdbscan
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: fewer results than min_cluster_size
|
||||
if n_results < self.config.min_cluster_size:
|
||||
# Return each result as its own singleton cluster
|
||||
return [[i] for i in range(n_results)]
|
||||
|
||||
# Configure HDBSCAN clusterer
|
||||
clusterer = hdbscan.HDBSCAN(
|
||||
min_cluster_size=self.config.min_cluster_size,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
|
||||
allow_single_cluster=self.config.allow_single_cluster,
|
||||
prediction_data=self.config.prediction_data,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
83
codex-lens/src/codexlens/search/clustering/noop_strategy.py
Normal file
83
codex-lens/src/codexlens/search/clustering/noop_strategy.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""No-op clustering strategy for search results.
|
||||
|
||||
NoOpStrategy returns all results ungrouped when clustering dependencies
|
||||
are not available or clustering is disabled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class NoOpStrategy(BaseClusteringStrategy):
|
||||
"""No-op clustering strategy that returns all results ungrouped.
|
||||
|
||||
This strategy is used as a final fallback when no clustering dependencies
|
||||
are available, or when clustering is explicitly disabled. Each result
|
||||
is treated as its own singleton cluster.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import NoOpStrategy
|
||||
>>> strategy = NoOpStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns all results sorted by score
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize NoOp clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Ignored for NoOpStrategy
|
||||
but accepted for interface compatibility.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Return each result as its own singleton cluster.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
Not used but accepted for interface compatibility.
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of singleton clusters, one per result.
|
||||
"""
|
||||
return [[i] for i in range(len(results))]
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Return all results sorted by score.
|
||||
|
||||
Since each cluster is a singleton, this effectively returns all
|
||||
results sorted by score descending.
|
||||
|
||||
Args:
|
||||
clusters: List of singleton clusters.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used).
|
||||
|
||||
Returns:
|
||||
All SearchResult objects sorted by score (highest first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Return all results sorted by score
|
||||
return sorted(results, key=lambda r: r.score, reverse=True)
|
||||
@@ -1807,6 +1807,178 @@ class DirIndexStore:
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
|
||||
"""Get all symbols in a specific file, sorted by start_line.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the file
|
||||
|
||||
Returns:
|
||||
List of Symbol objects sorted by start_line
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
# First get the file_id
|
||||
file_row = conn.execute(
|
||||
"SELECT id FROM files WHERE full_path=?",
|
||||
(file_path_str,),
|
||||
).fetchone()
|
||||
|
||||
if not file_row:
|
||||
return []
|
||||
|
||||
file_id = int(file_row["id"])
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.name, s.kind, s.start_line, s.end_line
|
||||
FROM symbols s
|
||||
WHERE s.file_id=?
|
||||
ORDER BY s.start_line
|
||||
""",
|
||||
(file_id,),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
Symbol(
|
||||
name=row["name"],
|
||||
kind=row["kind"],
|
||||
range=(row["start_line"], row["end_line"]),
|
||||
file=file_path_str,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_outgoing_calls(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
symbol_name: Optional[str] = None,
|
||||
) -> List[Tuple[str, str, int, Optional[str]]]:
|
||||
"""Get outgoing calls from symbols in a file.
|
||||
|
||||
Queries code_relationships table for calls originating from symbols
|
||||
in the specified file.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the source file
|
||||
symbol_name: Optional symbol name to filter by. If None, returns
|
||||
calls from all symbols in the file.
|
||||
|
||||
Returns:
|
||||
List of tuples: (target_name, relationship_type, source_line, target_file)
|
||||
- target_name: Qualified name of the call target
|
||||
- relationship_type: Type of relationship (e.g., "calls", "imports")
|
||||
- source_line: Line number where the call occurs
|
||||
- target_file: Target file path (may be None if unknown)
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
# First get the file_id
|
||||
file_row = conn.execute(
|
||||
"SELECT id FROM files WHERE full_path=?",
|
||||
(file_path_str,),
|
||||
).fetchone()
|
||||
|
||||
if not file_row:
|
||||
return []
|
||||
|
||||
file_id = int(file_row["id"])
|
||||
|
||||
if symbol_name:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT cr.target_qualified_name, cr.relationship_type,
|
||||
cr.source_line, cr.target_file
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
WHERE s.file_id=? AND s.name=?
|
||||
ORDER BY cr.source_line
|
||||
""",
|
||||
(file_id, symbol_name),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT cr.target_qualified_name, cr.relationship_type,
|
||||
cr.source_line, cr.target_file
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
WHERE s.file_id=?
|
||||
ORDER BY cr.source_line
|
||||
""",
|
||||
(file_id,),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
(
|
||||
row["target_qualified_name"],
|
||||
row["relationship_type"],
|
||||
int(row["source_line"]),
|
||||
row["target_file"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_incoming_calls(
|
||||
self,
|
||||
target_name: str,
|
||||
limit: int = 100,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
"""Get incoming calls/references to a target symbol.
|
||||
|
||||
Queries code_relationships table for references to the specified
|
||||
target symbol name.
|
||||
|
||||
Args:
|
||||
target_name: Name of the target symbol to find references for.
|
||||
Matches against target_qualified_name (exact match,
|
||||
suffix match, or contains match).
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of tuples: (source_symbol_name, relationship_type, source_line, source_file)
|
||||
- source_symbol_name: Name of the calling symbol
|
||||
- relationship_type: Type of relationship (e.g., "calls", "imports")
|
||||
- source_line: Line number where the call occurs
|
||||
- source_file: Full path to the source file
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.name AS source_name, cr.relationship_type,
|
||||
cr.source_line, f.full_path AS source_file
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE cr.target_qualified_name = ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
ORDER BY f.full_path, cr.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(
|
||||
target_name,
|
||||
f"%.{target_name}",
|
||||
f"%{target_name}",
|
||||
limit,
|
||||
),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
(
|
||||
row["source_name"],
|
||||
row["relationship_type"],
|
||||
int(row["source_line"]),
|
||||
row["source_file"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
# === Statistics ===
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -270,6 +270,39 @@ class GlobalSymbolIndex:
|
||||
symbols = self.search(name=name, kind=kind, limit=limit, prefix_mode=prefix_mode)
|
||||
return [(s.file or "", s.range) for s in symbols]
|
||||
|
||||
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
|
||||
"""Get all symbols in a specific file, sorted by start_line.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the file
|
||||
|
||||
Returns:
|
||||
List of Symbol objects sorted by start_line
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND file_path=?
|
||||
ORDER BY start_line
|
||||
""",
|
||||
(self.project_id, file_path_str),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
Symbol(
|
||||
name=row["symbol_name"],
|
||||
kind=row["symbol_kind"],
|
||||
range=(row["start_line"], row["end_line"]),
|
||||
file=row["file_path"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _get_existing_index_path(self, file_path_str: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
|
||||
282
codex-lens/tests/api/test_references.py
Normal file
282
codex-lens/tests/api/test_references.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Tests for codexlens.api.references module."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.api.references import (
|
||||
find_references,
|
||||
_read_line_from_file,
|
||||
_proximity_score,
|
||||
_group_references_by_definition,
|
||||
_transform_to_reference_result,
|
||||
)
|
||||
from codexlens.api.models import (
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
)
|
||||
|
||||
|
||||
class TestReadLineFromFile:
|
||||
"""Tests for _read_line_from_file helper."""
|
||||
|
||||
def test_read_existing_line(self, tmp_path):
|
||||
"""Test reading an existing line from a file."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("line 1\nline 2\nline 3\n")
|
||||
|
||||
assert _read_line_from_file(str(test_file), 1) == "line 1"
|
||||
assert _read_line_from_file(str(test_file), 2) == "line 2"
|
||||
assert _read_line_from_file(str(test_file), 3) == "line 3"
|
||||
|
||||
def test_read_nonexistent_line(self, tmp_path):
|
||||
"""Test reading a line that doesn't exist."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("line 1\nline 2\n")
|
||||
|
||||
assert _read_line_from_file(str(test_file), 10) == ""
|
||||
|
||||
def test_read_nonexistent_file(self):
|
||||
"""Test reading from a file that doesn't exist."""
|
||||
assert _read_line_from_file("/nonexistent/path/file.py", 1) == ""
|
||||
|
||||
def test_strips_trailing_whitespace(self, tmp_path):
|
||||
"""Test that trailing whitespace is stripped."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("line with spaces \n")
|
||||
|
||||
assert _read_line_from_file(str(test_file), 1) == "line with spaces"
|
||||
|
||||
|
||||
class TestProximityScore:
|
||||
"""Tests for _proximity_score helper."""
|
||||
|
||||
def test_same_file(self):
|
||||
"""Same file should return highest score."""
|
||||
score = _proximity_score("/a/b/c.py", "/a/b/c.py")
|
||||
assert score == 1000
|
||||
|
||||
def test_same_directory(self):
|
||||
"""Same directory should return 100."""
|
||||
score = _proximity_score("/a/b/x.py", "/a/b/y.py")
|
||||
assert score == 100
|
||||
|
||||
def test_different_directories(self):
|
||||
"""Different directories should return common prefix length."""
|
||||
score = _proximity_score("/a/b/c/x.py", "/a/b/d/y.py")
|
||||
# Common path is /a/b
|
||||
assert score > 0
|
||||
|
||||
def test_empty_paths(self):
|
||||
"""Empty paths should return 0."""
|
||||
assert _proximity_score("", "/a/b/c.py") == 0
|
||||
assert _proximity_score("/a/b/c.py", "") == 0
|
||||
assert _proximity_score("", "") == 0
|
||||
|
||||
|
||||
class TestGroupReferencesByDefinition:
|
||||
"""Tests for _group_references_by_definition helper."""
|
||||
|
||||
def test_single_definition(self):
|
||||
"""Single definition should have all references."""
|
||||
definition = DefinitionResult(
|
||||
name="foo",
|
||||
kind="function",
|
||||
file_path="/a/b/c.py",
|
||||
line=10,
|
||||
end_line=20,
|
||||
)
|
||||
references = [
|
||||
ReferenceResult(
|
||||
file_path="/a/b/d.py",
|
||||
line=5,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
),
|
||||
ReferenceResult(
|
||||
file_path="/a/x/y.py",
|
||||
line=10,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
),
|
||||
]
|
||||
|
||||
result = _group_references_by_definition([definition], references)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].definition == definition
|
||||
assert len(result[0].references) == 2
|
||||
|
||||
def test_multiple_definitions(self):
|
||||
"""Multiple definitions should group by proximity."""
|
||||
def1 = DefinitionResult(
|
||||
name="foo",
|
||||
kind="function",
|
||||
file_path="/a/b/c.py",
|
||||
line=10,
|
||||
end_line=20,
|
||||
)
|
||||
def2 = DefinitionResult(
|
||||
name="foo",
|
||||
kind="function",
|
||||
file_path="/x/y/z.py",
|
||||
line=10,
|
||||
end_line=20,
|
||||
)
|
||||
|
||||
# Reference closer to def1
|
||||
ref1 = ReferenceResult(
|
||||
file_path="/a/b/d.py",
|
||||
line=5,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
)
|
||||
# Reference closer to def2
|
||||
ref2 = ReferenceResult(
|
||||
file_path="/x/y/w.py",
|
||||
line=10,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
)
|
||||
|
||||
result = _group_references_by_definition(
|
||||
[def1, def2], [ref1, ref2], include_definition=True
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Each definition should have the closer reference
|
||||
def1_refs = [g for g in result if g.definition == def1][0].references
|
||||
def2_refs = [g for g in result if g.definition == def2][0].references
|
||||
|
||||
assert any(r.file_path == "/a/b/d.py" for r in def1_refs)
|
||||
assert any(r.file_path == "/x/y/w.py" for r in def2_refs)
|
||||
|
||||
def test_empty_definitions(self):
|
||||
"""Empty definitions should return empty result."""
|
||||
result = _group_references_by_definition([], [])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestTransformToReferenceResult:
|
||||
"""Tests for _transform_to_reference_result helper."""
|
||||
|
||||
def test_normalizes_relationship_type(self, tmp_path):
|
||||
"""Test that relationship type is normalized."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("def foo(): pass\n")
|
||||
|
||||
# Create a mock raw reference
|
||||
raw_ref = MagicMock()
|
||||
raw_ref.file_path = str(test_file)
|
||||
raw_ref.line = 1
|
||||
raw_ref.column = 0
|
||||
raw_ref.relationship_type = "calls" # Plural form
|
||||
|
||||
result = _transform_to_reference_result(raw_ref)
|
||||
|
||||
assert result.relationship == "call" # Normalized form
|
||||
assert result.context_line == "def foo(): pass"
|
||||
|
||||
|
||||
class TestFindReferences:
|
||||
"""Tests for find_references API function."""
|
||||
|
||||
def test_raises_for_invalid_project_root(self):
|
||||
"""Test that ValueError is raised for invalid project root."""
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
find_references("/nonexistent/path", "some_symbol")
|
||||
|
||||
@patch("codexlens.search.chain_search.ChainSearchEngine")
|
||||
@patch("codexlens.storage.registry.RegistryStore")
|
||||
@patch("codexlens.storage.path_mapper.PathMapper")
|
||||
@patch("codexlens.config.Config")
|
||||
def test_returns_grouped_references(
|
||||
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
|
||||
):
|
||||
"""Test that find_references returns GroupedReferences."""
|
||||
# Setup mocks
|
||||
mock_engine = MagicMock()
|
||||
mock_engine_class.return_value = mock_engine
|
||||
|
||||
# Mock symbol search (for definitions)
|
||||
mock_symbol = MagicMock()
|
||||
mock_symbol.name = "test_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = str(tmp_path / "test.py")
|
||||
mock_symbol.range = (10, 20)
|
||||
mock_engine.search_symbols.return_value = [mock_symbol]
|
||||
|
||||
# Mock reference search
|
||||
mock_ref = MagicMock()
|
||||
mock_ref.file_path = str(tmp_path / "caller.py")
|
||||
mock_ref.line = 5
|
||||
mock_ref.column = 0
|
||||
mock_ref.relationship_type = "call"
|
||||
mock_engine.search_references.return_value = [mock_ref]
|
||||
|
||||
# Create test files
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("def test_func():\n pass\n")
|
||||
caller_file = tmp_path / "caller.py"
|
||||
caller_file.write_text("test_func()\n")
|
||||
|
||||
# Call find_references
|
||||
result = find_references(str(tmp_path), "test_func")
|
||||
|
||||
# Verify result structure
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], GroupedReferences)
|
||||
assert result[0].definition.name == "test_func"
|
||||
assert len(result[0].references) == 1
|
||||
|
||||
@patch("codexlens.search.chain_search.ChainSearchEngine")
|
||||
@patch("codexlens.storage.registry.RegistryStore")
|
||||
@patch("codexlens.storage.path_mapper.PathMapper")
|
||||
@patch("codexlens.config.Config")
|
||||
def test_respects_include_definition_false(
|
||||
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
|
||||
):
|
||||
"""Test include_definition=False behavior."""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine_class.return_value = mock_engine
|
||||
mock_engine.search_symbols.return_value = []
|
||||
mock_engine.search_references.return_value = []
|
||||
|
||||
result = find_references(
|
||||
str(tmp_path), "test_func", include_definition=False
|
||||
)
|
||||
|
||||
# Should still return a result with placeholder definition
|
||||
assert len(result) == 1
|
||||
assert result[0].definition.name == "test_func"
|
||||
|
||||
|
||||
class TestImports:
|
||||
"""Tests for module imports and exports."""
|
||||
|
||||
def test_find_references_exported_from_api(self):
|
||||
"""Test that find_references is exported from codexlens.api."""
|
||||
from codexlens.api import find_references as api_find_references
|
||||
|
||||
assert callable(api_find_references)
|
||||
|
||||
def test_models_exported_from_api(self):
|
||||
"""Test that result models are exported from codexlens.api."""
|
||||
from codexlens.api import (
|
||||
GroupedReferences,
|
||||
ReferenceResult,
|
||||
DefinitionResult,
|
||||
)
|
||||
|
||||
assert GroupedReferences is not None
|
||||
assert ReferenceResult is not None
|
||||
assert DefinitionResult is not None
|
||||
528
codex-lens/tests/api/test_semantic_search.py
Normal file
528
codex-lens/tests/api/test_semantic_search.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Tests for semantic_search API."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.api import SemanticResult
|
||||
from codexlens.api.semantic import (
|
||||
semantic_search,
|
||||
_build_search_options,
|
||||
_generate_match_reason,
|
||||
_split_camel_case,
|
||||
_transform_results,
|
||||
)
|
||||
|
||||
|
||||
class TestSemanticSearchFunctionSignature:
|
||||
"""Test that semantic_search has the correct function signature."""
|
||||
|
||||
def test_function_accepts_all_parameters(self):
|
||||
"""Verify function signature matches spec."""
|
||||
import inspect
|
||||
sig = inspect.signature(semantic_search)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
expected_params = [
|
||||
"project_root",
|
||||
"query",
|
||||
"mode",
|
||||
"vector_weight",
|
||||
"structural_weight",
|
||||
"keyword_weight",
|
||||
"fusion_strategy",
|
||||
"kind_filter",
|
||||
"limit",
|
||||
"include_match_reason",
|
||||
]
|
||||
|
||||
assert params == expected_params
|
||||
|
||||
def test_default_parameter_values(self):
|
||||
"""Verify default parameter values match spec."""
|
||||
import inspect
|
||||
sig = inspect.signature(semantic_search)
|
||||
|
||||
assert sig.parameters["mode"].default == "fusion"
|
||||
assert sig.parameters["vector_weight"].default == 0.5
|
||||
assert sig.parameters["structural_weight"].default == 0.3
|
||||
assert sig.parameters["keyword_weight"].default == 0.2
|
||||
assert sig.parameters["fusion_strategy"].default == "rrf"
|
||||
assert sig.parameters["kind_filter"].default is None
|
||||
assert sig.parameters["limit"].default == 20
|
||||
assert sig.parameters["include_match_reason"].default is False
|
||||
|
||||
|
||||
class TestBuildSearchOptions:
|
||||
"""Test _build_search_options helper function."""
|
||||
|
||||
def test_vector_mode_options(self):
|
||||
"""Test options for pure vector mode."""
|
||||
options = _build_search_options(
|
||||
mode="vector",
|
||||
vector_weight=1.0,
|
||||
structural_weight=0.0,
|
||||
keyword_weight=0.0,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert options.hybrid_mode is True
|
||||
assert options.enable_vector is True
|
||||
assert options.pure_vector is True
|
||||
assert options.enable_fuzzy is False
|
||||
|
||||
def test_structural_mode_options(self):
|
||||
"""Test options for structural mode."""
|
||||
options = _build_search_options(
|
||||
mode="structural",
|
||||
vector_weight=0.0,
|
||||
structural_weight=1.0,
|
||||
keyword_weight=0.0,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert options.hybrid_mode is True
|
||||
assert options.enable_vector is False
|
||||
assert options.enable_fuzzy is True
|
||||
assert options.include_symbols is True
|
||||
|
||||
def test_fusion_mode_options(self):
|
||||
"""Test options for fusion mode (default)."""
|
||||
options = _build_search_options(
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert options.hybrid_mode is True
|
||||
assert options.enable_vector is True # vector_weight > 0
|
||||
assert options.enable_fuzzy is True # keyword_weight > 0
|
||||
assert options.include_symbols is True # structural_weight > 0
|
||||
|
||||
|
||||
class TestTransformResults:
|
||||
"""Test _transform_results helper function."""
|
||||
|
||||
def test_transforms_basic_result(self):
|
||||
"""Test basic result transformation."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "def authenticate():"
|
||||
mock_result.symbol_name = "authenticate"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 10
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].symbol_name == "authenticate"
|
||||
assert results[0].kind == "function"
|
||||
assert results[0].file_path == "/project/src/auth.py"
|
||||
assert results[0].line == 10
|
||||
assert results[0].fusion_score == 0.85
|
||||
|
||||
def test_kind_filter_excludes_non_matching(self):
|
||||
"""Test that kind_filter excludes non-matching results."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "AUTH_TOKEN = 'secret'"
|
||||
mock_result.symbol_name = "AUTH_TOKEN"
|
||||
mock_result.symbol_kind = "variable"
|
||||
mock_result.start_line = 5
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=["function", "class"], # Exclude variable
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
def test_kind_filter_includes_matching(self):
|
||||
"""Test that kind_filter includes matching results."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "class AuthManager:"
|
||||
mock_result.symbol_name = "AuthManager"
|
||||
mock_result.symbol_kind = "class"
|
||||
mock_result.start_line = 1
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=["function", "class"], # Include class
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].symbol_name == "AuthManager"
|
||||
|
||||
def test_include_match_reason_generates_reason(self):
|
||||
"""Test that include_match_reason generates match reasons."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "def authenticate(user, password):"
|
||||
mock_result.symbol_name = "authenticate"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 10
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=True,
|
||||
query="authenticate",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].match_reason is not None
|
||||
assert "authenticate" in results[0].match_reason.lower()
|
||||
|
||||
|
||||
class TestGenerateMatchReason:
|
||||
"""Test _generate_match_reason helper function."""
|
||||
|
||||
def test_direct_name_match(self):
|
||||
"""Test match reason for direct name match."""
|
||||
reason = _generate_match_reason(
|
||||
query="authenticate",
|
||||
symbol_name="authenticate",
|
||||
symbol_kind="function",
|
||||
snippet="def authenticate(user): pass",
|
||||
vector_score=0.8,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert "authenticate" in reason.lower()
|
||||
|
||||
def test_keyword_match(self):
|
||||
"""Test match reason for keyword match in snippet."""
|
||||
reason = _generate_match_reason(
|
||||
query="password validation",
|
||||
symbol_name="verify_user",
|
||||
symbol_kind="function",
|
||||
snippet="def verify_user(password): validate(password)",
|
||||
vector_score=0.6,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert "password" in reason.lower() or "validation" in reason.lower()
|
||||
|
||||
def test_high_semantic_similarity(self):
|
||||
"""Test match reason mentions semantic similarity for high vector score."""
|
||||
reason = _generate_match_reason(
|
||||
query="authentication",
|
||||
symbol_name="login_handler",
|
||||
symbol_kind="function",
|
||||
snippet="def login_handler(): pass",
|
||||
vector_score=0.85,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert "semantic" in reason.lower()
|
||||
|
||||
def test_returns_string_even_with_no_matches(self):
|
||||
"""Test that a reason string is always returned."""
|
||||
reason = _generate_match_reason(
|
||||
query="xyz123",
|
||||
symbol_name="abc456",
|
||||
symbol_kind="function",
|
||||
snippet="completely unrelated code",
|
||||
vector_score=0.3,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert isinstance(reason, str)
|
||||
assert len(reason) > 0
|
||||
|
||||
|
||||
class TestSplitCamelCase:
|
||||
"""Test _split_camel_case helper function."""
|
||||
|
||||
def test_camel_case(self):
|
||||
"""Test splitting camelCase."""
|
||||
result = _split_camel_case("authenticateUser")
|
||||
assert "authenticate" in result.lower()
|
||||
assert "user" in result.lower()
|
||||
|
||||
def test_pascal_case(self):
|
||||
"""Test splitting PascalCase."""
|
||||
result = _split_camel_case("AuthManager")
|
||||
assert "auth" in result.lower()
|
||||
assert "manager" in result.lower()
|
||||
|
||||
def test_snake_case(self):
|
||||
"""Test splitting snake_case."""
|
||||
result = _split_camel_case("auth_manager")
|
||||
assert "auth" in result.lower()
|
||||
assert "manager" in result.lower()
|
||||
|
||||
def test_mixed_case(self):
|
||||
"""Test splitting mixed case."""
|
||||
result = _split_camel_case("HTTPRequestHandler")
|
||||
# Should handle acronyms
|
||||
assert "http" in result.lower() or "request" in result.lower()
|
||||
|
||||
|
||||
class TestSemanticResultDataclass:
|
||||
"""Test SemanticResult dataclass structure."""
|
||||
|
||||
def test_semantic_result_fields(self):
|
||||
"""Test SemanticResult has all required fields."""
|
||||
result = SemanticResult(
|
||||
symbol_name="test",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line=1,
|
||||
vector_score=0.8,
|
||||
structural_score=0.6,
|
||||
fusion_score=0.7,
|
||||
snippet="def test(): pass",
|
||||
match_reason="Test match",
|
||||
)
|
||||
|
||||
assert result.symbol_name == "test"
|
||||
assert result.kind == "function"
|
||||
assert result.file_path == "/test.py"
|
||||
assert result.line == 1
|
||||
assert result.vector_score == 0.8
|
||||
assert result.structural_score == 0.6
|
||||
assert result.fusion_score == 0.7
|
||||
assert result.snippet == "def test(): pass"
|
||||
assert result.match_reason == "Test match"
|
||||
|
||||
def test_semantic_result_optional_fields(self):
|
||||
"""Test SemanticResult with optional None fields."""
|
||||
result = SemanticResult(
|
||||
symbol_name="test",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line=1,
|
||||
vector_score=None, # Degraded - no vector index
|
||||
structural_score=None, # Degraded - no relationships
|
||||
fusion_score=0.5,
|
||||
snippet="def test(): pass",
|
||||
match_reason=None, # Not requested
|
||||
)
|
||||
|
||||
assert result.vector_score is None
|
||||
assert result.structural_score is None
|
||||
assert result.match_reason is None
|
||||
|
||||
def test_semantic_result_to_dict(self):
|
||||
"""Test SemanticResult.to_dict() filters None values."""
|
||||
result = SemanticResult(
|
||||
symbol_name="test",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line=1,
|
||||
vector_score=None,
|
||||
structural_score=0.6,
|
||||
fusion_score=0.7,
|
||||
snippet="def test(): pass",
|
||||
match_reason=None,
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert "symbol_name" in d
|
||||
assert "vector_score" not in d # None values filtered
|
||||
assert "structural_score" in d
|
||||
assert "match_reason" not in d # None values filtered
|
||||
|
||||
|
||||
class TestFusionStrategyMapping:
|
||||
"""Test fusion_strategy parameter mapping via _execute_search."""
|
||||
|
||||
def test_rrf_strategy_calls_search(self):
|
||||
"""Test that rrf strategy maps to standard search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="rrf",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.search.assert_called_once()
|
||||
|
||||
def test_staged_strategy_calls_staged_cascade_search(self):
|
||||
"""Test that staged strategy maps to staged_cascade_search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.staged_cascade_search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="staged",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.staged_cascade_search.assert_called_once()
|
||||
|
||||
def test_binary_strategy_calls_binary_cascade_search(self):
|
||||
"""Test that binary strategy maps to binary_cascade_search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.binary_cascade_search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="binary",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.binary_cascade_search.assert_called_once()
|
||||
|
||||
def test_hybrid_strategy_calls_hybrid_cascade_search(self):
|
||||
"""Test that hybrid strategy maps to hybrid_cascade_search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.hybrid_cascade_search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="hybrid",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.hybrid_cascade_search.assert_called_once()
|
||||
|
||||
def test_unknown_strategy_defaults_to_rrf(self):
|
||||
"""Test that unknown strategy defaults to standard search (rrf)."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="unknown_strategy",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.search.assert_called_once()
|
||||
|
||||
|
||||
class TestGracefulDegradation:
|
||||
"""Test graceful degradation behavior."""
|
||||
|
||||
def test_vector_score_none_when_no_vector_index(self):
|
||||
"""Test vector_score=None when vector index unavailable."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.5
|
||||
mock_result.excerpt = "def auth(): pass"
|
||||
mock_result.symbol_name = "auth"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 1
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {} # No vector score in metadata
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# When no source_scores in metadata, vector_score should be None
|
||||
assert results[0].vector_score is None
|
||||
|
||||
def test_structural_score_extracted_from_fts(self):
|
||||
"""Test structural_score extracted from FTS scores."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.8
|
||||
mock_result.excerpt = "def auth(): pass"
|
||||
mock_result.symbol_name = "auth"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 1
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {
|
||||
"source_scores": {
|
||||
"exact": 0.9,
|
||||
"fuzzy": 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].structural_score == 0.9 # max of exact/fuzzy
|
||||
1
codex-lens/tests/lsp/__init__.py
Normal file
1
codex-lens/tests/lsp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests package for LSP module."""
|
||||
477
codex-lens/tests/lsp/test_hover.py
Normal file
477
codex-lens/tests/lsp/test_hover.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""Tests for hover provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock
|
||||
import tempfile
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
|
||||
class TestHoverInfo:
|
||||
"""Test HoverInfo dataclass."""
|
||||
|
||||
def test_hover_info_import(self):
|
||||
"""HoverInfo can be imported."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo
|
||||
|
||||
assert HoverInfo is not None
|
||||
|
||||
def test_hover_info_fields(self):
|
||||
"""HoverInfo has all required fields."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo
|
||||
|
||||
info = HoverInfo(
|
||||
name="my_function",
|
||||
kind="function",
|
||||
signature="def my_function(x: int) -> str:",
|
||||
documentation="A test function.",
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
assert info.name == "my_function"
|
||||
assert info.kind == "function"
|
||||
assert info.signature == "def my_function(x: int) -> str:"
|
||||
assert info.documentation == "A test function."
|
||||
assert info.file_path == "/test/file.py"
|
||||
assert info.line_range == (10, 15)
|
||||
|
||||
def test_hover_info_optional_documentation(self):
|
||||
"""Documentation can be None."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation=None,
|
||||
file_path="/test.py",
|
||||
line_range=(1, 2),
|
||||
)
|
||||
assert info.documentation is None
|
||||
|
||||
|
||||
class TestHoverProvider:
|
||||
"""Test HoverProvider class."""
|
||||
|
||||
def test_provider_import(self):
|
||||
"""HoverProvider can be imported."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
assert HoverProvider is not None
|
||||
|
||||
def test_returns_none_for_unknown_symbol(self):
|
||||
"""Returns None when symbol not found."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = []
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
result = provider.get_hover_info("unknown_symbol")
|
||||
|
||||
assert result is None
|
||||
mock_index.search.assert_called_once_with(
|
||||
name="unknown_symbol", limit=1, prefix_mode=False
|
||||
)
|
||||
|
||||
def test_returns_none_for_non_exact_match(self):
|
||||
"""Returns None when search returns non-exact matches."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
# Return a symbol with different name (prefix match but not exact)
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_function_extended"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test/file.py"
|
||||
mock_symbol.range = (10, 15)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
result = provider.get_hover_info("my_function")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_hover_info_for_known_symbol(self):
|
||||
"""Returns HoverInfo for found symbol."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = None # No file, will use fallback signature
|
||||
mock_symbol.range = (10, 15)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
result = provider.get_hover_info("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "my_func"
|
||||
assert result.kind == "function"
|
||||
assert result.line_range == (10, 15)
|
||||
assert result.signature == "function my_func"
|
||||
|
||||
def test_extracts_signature_from_file(self):
|
||||
"""Extracts signature from actual file content."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
# Create a temporary file with Python content
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write("# comment\n")
|
||||
f.write("def test_function(x: int, y: str) -> bool:\n")
|
||||
f.write(" return True\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "test_function"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (2, 3) # Line 2 (1-based)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("test_function")
|
||||
|
||||
assert result is not None
|
||||
assert "def test_function(x: int, y: str) -> bool:" in result.signature
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
def test_extracts_multiline_signature(self):
|
||||
"""Extracts multiline function signature."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
# Create a temporary file with multiline signature
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write("def complex_function(\n")
|
||||
f.write(" arg1: int,\n")
|
||||
f.write(" arg2: str,\n")
|
||||
f.write(") -> bool:\n")
|
||||
f.write(" return True\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "complex_function"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (1, 5) # Line 1 (1-based)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("complex_function")
|
||||
|
||||
assert result is not None
|
||||
assert "def complex_function(" in result.signature
|
||||
# Should capture multiline signature
|
||||
assert "arg1: int" in result.signature
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
def test_handles_nonexistent_file_gracefully(self):
|
||||
"""Returns fallback signature when file doesn't exist."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/nonexistent/path/file.py"
|
||||
mock_symbol.range = (10, 15)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert result.signature == "function my_func"
|
||||
|
||||
def test_handles_invalid_line_range(self):
|
||||
"""Returns fallback signature when line range is invalid."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write("def test():\n")
|
||||
f.write(" pass\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "test"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (100, 105) # Line beyond file length
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("test")
|
||||
|
||||
assert result is not None
|
||||
assert result.signature == "function test"
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestFormatHoverMarkdown:
|
||||
"""Test markdown formatting."""
|
||||
|
||||
def test_format_python_signature(self):
|
||||
"""Formats Python signature with python code fence."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func(x: int) -> str:",
|
||||
documentation=None,
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "```python" in result
|
||||
assert "def func(x: int) -> str:" in result
|
||||
assert "function" in result
|
||||
assert "file.py" in result
|
||||
assert "line 10" in result
|
||||
|
||||
def test_format_javascript_signature(self):
|
||||
"""Formats JavaScript signature with javascript code fence."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="myFunc",
|
||||
kind="function",
|
||||
signature="function myFunc(x) {",
|
||||
documentation=None,
|
||||
file_path="/test/file.js",
|
||||
line_range=(5, 10),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "```javascript" in result
|
||||
assert "function myFunc(x) {" in result
|
||||
|
||||
def test_format_typescript_signature(self):
|
||||
"""Formats TypeScript signature with typescript code fence."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="myFunc",
|
||||
kind="function",
|
||||
signature="function myFunc(x: number): string {",
|
||||
documentation=None,
|
||||
file_path="/test/file.ts",
|
||||
line_range=(5, 10),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "```typescript" in result
|
||||
|
||||
def test_format_with_documentation(self):
|
||||
"""Includes documentation when available."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation="This is a test function.",
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "This is a test function." in result
|
||||
assert "---" in result # Separator before docs
|
||||
|
||||
def test_format_without_documentation(self):
|
||||
"""Does not include documentation section when None."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation=None,
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
# Should have one separator for location, not two
|
||||
# The result should not have duplicate doc separator
|
||||
lines = result.split("\n")
|
||||
separator_count = sum(1 for line in lines if line.strip() == "---")
|
||||
assert separator_count == 1 # Only location separator
|
||||
|
||||
def test_format_unknown_extension(self):
|
||||
"""Uses empty code fence for unknown file extensions."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="func code here",
|
||||
documentation=None,
|
||||
file_path="/test/file.xyz",
|
||||
line_range=(1, 2),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
# Should have code fence without language specifier
|
||||
assert "```\n" in result or "```xyz" not in result
|
||||
|
||||
def test_format_class_symbol(self):
|
||||
"""Formats class symbol correctly."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="MyClass",
|
||||
kind="class",
|
||||
signature="class MyClass:",
|
||||
documentation=None,
|
||||
file_path="/test/file.py",
|
||||
line_range=(1, 20),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "class MyClass:" in result
|
||||
assert "*class*" in result
|
||||
assert "line 1" in result
|
||||
|
||||
def test_format_empty_file_path(self):
|
||||
"""Handles empty file path gracefully."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation=None,
|
||||
file_path="",
|
||||
line_range=(1, 2),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "unknown" in result or "```" in result
|
||||
|
||||
|
||||
class TestHoverProviderRegistry:
|
||||
"""Test HoverProvider with registry integration."""
|
||||
|
||||
def test_provider_accepts_none_registry(self):
|
||||
"""HoverProvider works without registry."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = []
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("test")
|
||||
|
||||
assert result is None
|
||||
assert provider.registry is None
|
||||
|
||||
def test_provider_stores_registry(self):
|
||||
"""HoverProvider stores registry reference."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_index = Mock()
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
|
||||
assert provider.global_index is mock_index
|
||||
assert provider.registry is mock_registry
|
||||
497
codex-lens/tests/lsp/test_references.py
Normal file
497
codex-lens/tests/lsp/test_references.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""Tests for reference search functionality.
|
||||
|
||||
This module tests the ReferenceResult dataclass and search_references method
|
||||
in ChainSearchEngine, as well as the updated lsp_references handler.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class TestReferenceResult:
|
||||
"""Test ReferenceResult dataclass."""
|
||||
|
||||
def test_reference_result_fields(self):
|
||||
"""ReferenceResult has all required fields."""
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
ref = ReferenceResult(
|
||||
file_path="/test/file.py",
|
||||
line=10,
|
||||
column=5,
|
||||
context="def foo():",
|
||||
relationship_type="call",
|
||||
)
|
||||
assert ref.file_path == "/test/file.py"
|
||||
assert ref.line == 10
|
||||
assert ref.column == 5
|
||||
assert ref.context == "def foo():"
|
||||
assert ref.relationship_type == "call"
|
||||
|
||||
def test_reference_result_with_empty_context(self):
|
||||
"""ReferenceResult can have empty context."""
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
ref = ReferenceResult(
|
||||
file_path="/test/file.py",
|
||||
line=1,
|
||||
column=0,
|
||||
context="",
|
||||
relationship_type="import",
|
||||
)
|
||||
assert ref.context == ""
|
||||
|
||||
def test_reference_result_different_relationship_types(self):
|
||||
"""ReferenceResult supports different relationship types."""
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
types = ["call", "import", "inheritance", "implementation", "usage"]
|
||||
for rel_type in types:
|
||||
ref = ReferenceResult(
|
||||
file_path="/test/file.py",
|
||||
line=1,
|
||||
column=0,
|
||||
context="test",
|
||||
relationship_type=rel_type,
|
||||
)
|
||||
assert ref.relationship_type == rel_type
|
||||
|
||||
|
||||
class TestExtractContext:
|
||||
"""Test the _extract_context helper method."""
|
||||
|
||||
def test_extract_context_middle_of_file(self):
|
||||
"""Extract context from middle of file."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
|
||||
|
||||
content = "\n".join([
|
||||
"line 1",
|
||||
"line 2",
|
||||
"line 3",
|
||||
"line 4", # target line
|
||||
"line 5",
|
||||
"line 6",
|
||||
"line 7",
|
||||
])
|
||||
|
||||
# Create minimal mock engine to test _extract_context
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context(content, line=4, context_lines=2)
|
||||
|
||||
assert "line 2" in context
|
||||
assert "line 3" in context
|
||||
assert "line 4" in context
|
||||
assert "line 5" in context
|
||||
assert "line 6" in context
|
||||
|
||||
def test_extract_context_start_of_file(self):
|
||||
"""Extract context at start of file."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
content = "\n".join([
|
||||
"line 1", # target
|
||||
"line 2",
|
||||
"line 3",
|
||||
"line 4",
|
||||
])
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context(content, line=1, context_lines=2)
|
||||
|
||||
assert "line 1" in context
|
||||
assert "line 2" in context
|
||||
assert "line 3" in context
|
||||
|
||||
def test_extract_context_end_of_file(self):
|
||||
"""Extract context at end of file."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
content = "\n".join([
|
||||
"line 1",
|
||||
"line 2",
|
||||
"line 3",
|
||||
"line 4", # target
|
||||
])
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context(content, line=4, context_lines=2)
|
||||
|
||||
assert "line 2" in context
|
||||
assert "line 3" in context
|
||||
assert "line 4" in context
|
||||
|
||||
def test_extract_context_empty_content(self):
|
||||
"""Extract context from empty content."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context("", line=1, context_lines=3)
|
||||
|
||||
assert context == ""
|
||||
|
||||
def test_extract_context_invalid_line(self):
|
||||
"""Extract context with invalid line number."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
content = "line 1\nline 2\nline 3"
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
# Line 0 (invalid)
|
||||
assert engine._extract_context(content, line=0, context_lines=1) == ""
|
||||
|
||||
# Line beyond end
|
||||
assert engine._extract_context(content, line=100, context_lines=1) == ""
|
||||
|
||||
|
||||
class TestSearchReferences:
|
||||
"""Test search_references method."""
|
||||
|
||||
def test_returns_empty_for_no_source_path_and_no_registry(self):
|
||||
"""Returns empty list when no source path and registry has no mappings."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_registry.list_mappings.return_value = []
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
results = engine.search_references("test_symbol")
|
||||
|
||||
assert results == []
|
||||
|
||||
def test_returns_empty_for_no_indexes(self):
|
||||
"""Returns empty list when no indexes found."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
mock_mapper.source_to_index_db.return_value = Path("/nonexistent/_index.db")
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=None):
|
||||
results = engine.search_references("test_symbol", Path("/some/path"))
|
||||
|
||||
assert results == []
|
||||
|
||||
def test_deduplicates_results(self):
|
||||
"""Removes duplicate file:line references."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
# Create a temporary database with duplicate relationships
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'def test(): pass');
|
||||
INSERT INTO symbols VALUES (1, 1, 'test_func', 'function', 1, 1);
|
||||
INSERT INTO code_relationships VALUES (1, 1, 'target_func', 'call', 10, NULL);
|
||||
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 10, NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target_func", Path(tmpdir))
|
||||
|
||||
# Should only have 1 result due to deduplication
|
||||
assert len(results) == 1
|
||||
assert results[0].line == 10
|
||||
|
||||
def test_sorts_by_file_and_line(self):
|
||||
"""Results sorted by file path then line number."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/b_file.py', 'python', 'content');
|
||||
INSERT INTO files VALUES (2, '/test/a_file.py', 'python', 'content');
|
||||
INSERT INTO symbols VALUES (1, 1, 'func1', 'function', 1, 1);
|
||||
INSERT INTO symbols VALUES (2, 2, 'func2', 'function', 1, 1);
|
||||
INSERT INTO code_relationships VALUES (1, 1, 'target', 'call', 20, NULL);
|
||||
INSERT INTO code_relationships VALUES (2, 1, 'target', 'call', 10, NULL);
|
||||
INSERT INTO code_relationships VALUES (3, 2, 'target', 'call', 5, NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target", Path(tmpdir))
|
||||
|
||||
# Should be sorted: a_file.py:5, b_file.py:10, b_file.py:20
|
||||
assert len(results) == 3
|
||||
assert results[0].file_path == "/test/a_file.py"
|
||||
assert results[0].line == 5
|
||||
assert results[1].file_path == "/test/b_file.py"
|
||||
assert results[1].line == 10
|
||||
assert results[2].file_path == "/test/b_file.py"
|
||||
assert results[2].line == 20
|
||||
|
||||
def test_respects_limit(self):
|
||||
"""Returns at most limit results."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
|
||||
INSERT INTO symbols VALUES (1, 1, 'func', 'function', 1, 1);
|
||||
""")
|
||||
# Insert many relationships
|
||||
for i in range(50):
|
||||
conn.execute(
|
||||
"INSERT INTO code_relationships VALUES (?, 1, 'target', 'call', ?, NULL)",
|
||||
(i + 1, i + 1)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target", Path(tmpdir), limit=10)
|
||||
|
||||
assert len(results) == 10
|
||||
|
||||
def test_matches_qualified_name(self):
|
||||
"""Matches symbols by qualified name suffix."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
|
||||
INSERT INTO symbols VALUES (1, 1, 'caller', 'function', 1, 1);
|
||||
-- Fully qualified name
|
||||
INSERT INTO code_relationships VALUES (1, 1, 'module.submodule.target_func', 'call', 10, NULL);
|
||||
-- Simple name
|
||||
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 20, NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target_func", Path(tmpdir))
|
||||
|
||||
# Should find both references
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
class TestLspReferencesHandler:
|
||||
"""Test the LSP references handler."""
|
||||
|
||||
def test_handler_uses_search_engine(self):
|
||||
"""Handler uses search_engine.search_references when available."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import _path_to_uri
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
# Create mock references
|
||||
mock_references = [
|
||||
ReferenceResult(
|
||||
file_path="/test/file1.py",
|
||||
line=10,
|
||||
column=5,
|
||||
context="def foo():",
|
||||
relationship_type="call",
|
||||
),
|
||||
ReferenceResult(
|
||||
file_path="/test/file2.py",
|
||||
line=20,
|
||||
column=0,
|
||||
context="import foo",
|
||||
relationship_type="import",
|
||||
),
|
||||
]
|
||||
|
||||
# Verify conversion to LSP Location
|
||||
locations = []
|
||||
for ref in mock_references:
|
||||
locations.append(
|
||||
lsp.Location(
|
||||
uri=_path_to_uri(ref.file_path),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column,
|
||||
),
|
||||
end=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column + len("foo"),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
assert len(locations) == 2
|
||||
# First reference at line 10 (0-indexed = 9)
|
||||
assert locations[0].range.start.line == 9
|
||||
assert locations[0].range.start.character == 5
|
||||
# Second reference at line 20 (0-indexed = 19)
|
||||
assert locations[1].range.start.line == 19
|
||||
assert locations[1].range.start.character == 0
|
||||
|
||||
def test_handler_falls_back_to_global_index(self):
|
||||
"""Handler falls back to global_index when search_engine unavailable."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from codexlens.lsp.handlers import symbol_to_location
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
# Test fallback path converts Symbol to Location
|
||||
symbol = Symbol(
|
||||
name="test_func",
|
||||
kind="function",
|
||||
range=(10, 15),
|
||||
file="/test/file.py",
|
||||
)
|
||||
|
||||
location = symbol_to_location(symbol)
|
||||
assert location is not None
|
||||
# LSP uses 0-based lines
|
||||
assert location.range.start.line == 9
|
||||
assert location.range.end.line == 14
|
||||
210
codex-lens/tests/lsp/test_server.py
Normal file
210
codex-lens/tests/lsp/test_server.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for codex-lens LSP server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
|
||||
class TestCodexLensLanguageServer:
|
||||
"""Tests for CodexLensLanguageServer."""
|
||||
|
||||
def test_server_import(self):
|
||||
"""Test that server module can be imported."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from codexlens.lsp.server import CodexLensLanguageServer, server
|
||||
|
||||
assert CodexLensLanguageServer is not None
|
||||
assert server is not None
|
||||
assert server.name == "codexlens-lsp"
|
||||
|
||||
def test_server_initialization(self):
|
||||
"""Test server instance creation."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.server import CodexLensLanguageServer
|
||||
|
||||
ls = CodexLensLanguageServer()
|
||||
assert ls.registry is None
|
||||
assert ls.mapper is None
|
||||
assert ls.global_index is None
|
||||
assert ls.search_engine is None
|
||||
assert ls.workspace_root is None
|
||||
|
||||
|
||||
class TestDefinitionHandler:
|
||||
"""Tests for definition handler."""
|
||||
|
||||
def test_definition_lookup(self):
|
||||
"""Test definition lookup returns location for known symbol."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import symbol_to_location
|
||||
|
||||
symbol = Symbol(
|
||||
name="test_function",
|
||||
kind="function",
|
||||
range=(10, 15),
|
||||
file="/path/to/file.py",
|
||||
)
|
||||
|
||||
location = symbol_to_location(symbol)
|
||||
|
||||
assert location is not None
|
||||
assert isinstance(location, lsp.Location)
|
||||
# LSP uses 0-based lines
|
||||
assert location.range.start.line == 9
|
||||
assert location.range.end.line == 14
|
||||
|
||||
def test_definition_no_file(self):
|
||||
"""Test definition lookup returns None for symbol without file."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import symbol_to_location
|
||||
|
||||
symbol = Symbol(
|
||||
name="test_function",
|
||||
kind="function",
|
||||
range=(10, 15),
|
||||
file=None,
|
||||
)
|
||||
|
||||
location = symbol_to_location(symbol)
|
||||
assert location is None
|
||||
|
||||
|
||||
class TestCompletionHandler:
|
||||
"""Tests for completion handler."""
|
||||
|
||||
def test_get_prefix_at_position(self):
|
||||
"""Test extracting prefix at cursor position."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _get_prefix_at_position
|
||||
|
||||
document_text = "def hello_world():\n print(hel"
|
||||
|
||||
# Cursor at end of "hel"
|
||||
prefix = _get_prefix_at_position(document_text, 1, 14)
|
||||
assert prefix == "hel"
|
||||
|
||||
# Cursor at beginning of line (after whitespace)
|
||||
prefix = _get_prefix_at_position(document_text, 1, 4)
|
||||
assert prefix == ""
|
||||
|
||||
# Cursor after "he" in "hello_world" - returns text before cursor
|
||||
prefix = _get_prefix_at_position(document_text, 0, 6)
|
||||
assert prefix == "he"
|
||||
|
||||
# Cursor at end of "hello_world"
|
||||
prefix = _get_prefix_at_position(document_text, 0, 15)
|
||||
assert prefix == "hello_world"
|
||||
|
||||
def test_get_word_at_position(self):
|
||||
"""Test extracting word at cursor position."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _get_word_at_position
|
||||
|
||||
document_text = "def hello_world():\n print(msg)"
|
||||
|
||||
# Cursor on "hello_world"
|
||||
word = _get_word_at_position(document_text, 0, 6)
|
||||
assert word == "hello_world"
|
||||
|
||||
# Cursor on "print"
|
||||
word = _get_word_at_position(document_text, 1, 6)
|
||||
assert word == "print"
|
||||
|
||||
# Cursor on "msg"
|
||||
word = _get_word_at_position(document_text, 1, 11)
|
||||
assert word == "msg"
|
||||
|
||||
def test_symbol_kind_mapping(self):
|
||||
"""Test symbol kind to completion kind mapping."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import _symbol_kind_to_completion_kind
|
||||
|
||||
assert _symbol_kind_to_completion_kind("function") == lsp.CompletionItemKind.Function
|
||||
assert _symbol_kind_to_completion_kind("class") == lsp.CompletionItemKind.Class
|
||||
assert _symbol_kind_to_completion_kind("method") == lsp.CompletionItemKind.Method
|
||||
assert _symbol_kind_to_completion_kind("variable") == lsp.CompletionItemKind.Variable
|
||||
|
||||
# Unknown kind should default to Text
|
||||
assert _symbol_kind_to_completion_kind("unknown") == lsp.CompletionItemKind.Text
|
||||
|
||||
|
||||
class TestWorkspaceSymbolHandler:
|
||||
"""Tests for workspace symbol handler."""
|
||||
|
||||
def test_symbol_kind_to_lsp(self):
|
||||
"""Test symbol kind to LSP SymbolKind mapping."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import _symbol_kind_to_lsp
|
||||
|
||||
assert _symbol_kind_to_lsp("function") == lsp.SymbolKind.Function
|
||||
assert _symbol_kind_to_lsp("class") == lsp.SymbolKind.Class
|
||||
assert _symbol_kind_to_lsp("method") == lsp.SymbolKind.Method
|
||||
assert _symbol_kind_to_lsp("interface") == lsp.SymbolKind.Interface
|
||||
|
||||
# Unknown kind should default to Variable
|
||||
assert _symbol_kind_to_lsp("unknown") == lsp.SymbolKind.Variable
|
||||
|
||||
|
||||
class TestUriConversion:
|
||||
"""Tests for URI path conversion."""
|
||||
|
||||
def test_path_to_uri(self):
|
||||
"""Test path to URI conversion."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _path_to_uri
|
||||
|
||||
# Unix path
|
||||
uri = _path_to_uri("/home/user/file.py")
|
||||
assert uri.startswith("file://")
|
||||
assert "file.py" in uri
|
||||
|
||||
def test_uri_to_path(self):
|
||||
"""Test URI to path conversion."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _uri_to_path
|
||||
|
||||
# Basic URI
|
||||
path = _uri_to_path("file:///home/user/file.py")
|
||||
assert path.name == "file.py"
|
||||
|
||||
|
||||
class TestMainEntryPoint:
|
||||
"""Tests for main entry point."""
|
||||
|
||||
def test_main_help(self):
|
||||
"""Test that main shows help without errors."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
# Patch sys.argv to show help
|
||||
with patch.object(sys, 'argv', ['codexlens-lsp', '--help']):
|
||||
from codexlens.lsp.server import main
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main()
|
||||
|
||||
# Help exits with 0
|
||||
assert exc_info.value.code == 0
|
||||
1
codex-lens/tests/mcp/__init__.py
Normal file
1
codex-lens/tests/mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for MCP (Model Context Protocol) module."""
|
||||
208
codex-lens/tests/mcp/test_hooks.py
Normal file
208
codex-lens/tests/mcp/test_hooks.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Tests for MCP hooks module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
|
||||
from codexlens.mcp.schema import MCPContext, SymbolInfo
|
||||
|
||||
|
||||
class TestHookManager:
|
||||
"""Test HookManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(self):
|
||||
"""Create a mock MCP provider."""
|
||||
provider = Mock()
|
||||
provider.build_context.return_value = MCPContext(
|
||||
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
|
||||
context_type="symbol_explanation",
|
||||
)
|
||||
provider.build_context_for_file.return_value = MCPContext(
|
||||
context_type="file_overview",
|
||||
)
|
||||
return provider
|
||||
|
||||
@pytest.fixture
|
||||
def hook_manager(self, mock_provider):
|
||||
"""Create a HookManager with mocked provider."""
|
||||
return HookManager(mock_provider)
|
||||
|
||||
def test_default_hooks_registered(self, hook_manager):
|
||||
"""Default hooks are registered on initialization."""
|
||||
assert "explain" in hook_manager._pre_hooks
|
||||
assert "refactor" in hook_manager._pre_hooks
|
||||
assert "document" in hook_manager._pre_hooks
|
||||
|
||||
def test_execute_pre_hook_returns_context(self, hook_manager, mock_provider):
|
||||
"""execute_pre_hook returns MCPContext for registered hook."""
|
||||
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MCPContext)
|
||||
mock_provider.build_context.assert_called_once()
|
||||
|
||||
def test_execute_pre_hook_returns_none_for_unknown_action(self, hook_manager):
|
||||
"""execute_pre_hook returns None for unregistered action."""
|
||||
result = hook_manager.execute_pre_hook("unknown_action", {"symbol": "test"})
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_execute_pre_hook_handles_exception(self, hook_manager, mock_provider):
|
||||
"""execute_pre_hook handles provider exceptions gracefully."""
|
||||
mock_provider.build_context.side_effect = Exception("Provider failed")
|
||||
|
||||
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_execute_post_hook_no_error_for_unregistered(self, hook_manager):
|
||||
"""execute_post_hook doesn't error for unregistered action."""
|
||||
# Should not raise
|
||||
hook_manager.execute_post_hook("unknown", {"result": "data"})
|
||||
|
||||
def test_pre_explain_hook_calls_build_context(self, hook_manager, mock_provider):
|
||||
"""_pre_explain_hook calls build_context correctly."""
|
||||
hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
|
||||
|
||||
mock_provider.build_context.assert_called_with(
|
||||
symbol_name="my_func",
|
||||
context_type="symbol_explanation",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def test_pre_explain_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
|
||||
"""_pre_explain_hook returns None when symbol param missing."""
|
||||
result = hook_manager.execute_pre_hook("explain", {})
|
||||
|
||||
assert result is None
|
||||
mock_provider.build_context.assert_not_called()
|
||||
|
||||
def test_pre_refactor_hook_calls_build_context(self, hook_manager, mock_provider):
|
||||
"""_pre_refactor_hook calls build_context with refactor settings."""
|
||||
hook_manager.execute_pre_hook("refactor", {"symbol": "my_class"})
|
||||
|
||||
mock_provider.build_context.assert_called_with(
|
||||
symbol_name="my_class",
|
||||
context_type="refactor_context",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
max_references=20,
|
||||
)
|
||||
|
||||
def test_pre_refactor_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
|
||||
"""_pre_refactor_hook returns None when symbol param missing."""
|
||||
result = hook_manager.execute_pre_hook("refactor", {})
|
||||
|
||||
assert result is None
|
||||
mock_provider.build_context.assert_not_called()
|
||||
|
||||
def test_pre_document_hook_with_symbol(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook uses build_context when symbol provided."""
|
||||
hook_manager.execute_pre_hook("document", {"symbol": "my_func"})
|
||||
|
||||
mock_provider.build_context.assert_called_with(
|
||||
symbol_name="my_func",
|
||||
context_type="documentation_context",
|
||||
include_references=False,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def test_pre_document_hook_with_file_path(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook uses build_context_for_file when file_path provided."""
|
||||
hook_manager.execute_pre_hook("document", {"file_path": "/src/module.py"})
|
||||
|
||||
mock_provider.build_context_for_file.assert_called_once()
|
||||
call_args = mock_provider.build_context_for_file.call_args
|
||||
assert call_args[0][0] == Path("/src/module.py")
|
||||
assert call_args[1].get("context_type") == "file_documentation"
|
||||
|
||||
def test_pre_document_hook_prefers_symbol_over_file(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook prefers symbol when both provided."""
|
||||
hook_manager.execute_pre_hook(
|
||||
"document", {"symbol": "my_func", "file_path": "/src/module.py"}
|
||||
)
|
||||
|
||||
mock_provider.build_context.assert_called_once()
|
||||
mock_provider.build_context_for_file.assert_not_called()
|
||||
|
||||
def test_pre_document_hook_returns_none_without_params(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook returns None when neither symbol nor file_path provided."""
|
||||
result = hook_manager.execute_pre_hook("document", {})
|
||||
|
||||
assert result is None
|
||||
mock_provider.build_context.assert_not_called()
|
||||
mock_provider.build_context_for_file.assert_not_called()
|
||||
|
||||
def test_register_pre_hook(self, hook_manager):
|
||||
"""register_pre_hook adds custom hook."""
|
||||
custom_hook = Mock(return_value=MCPContext())
|
||||
|
||||
hook_manager.register_pre_hook("custom_action", custom_hook)
|
||||
|
||||
assert "custom_action" in hook_manager._pre_hooks
|
||||
hook_manager.execute_pre_hook("custom_action", {"data": "value"})
|
||||
custom_hook.assert_called_once_with({"data": "value"})
|
||||
|
||||
def test_register_post_hook(self, hook_manager):
|
||||
"""register_post_hook adds custom hook."""
|
||||
custom_hook = Mock()
|
||||
|
||||
hook_manager.register_post_hook("custom_action", custom_hook)
|
||||
|
||||
assert "custom_action" in hook_manager._post_hooks
|
||||
hook_manager.execute_post_hook("custom_action", {"result": "data"})
|
||||
custom_hook.assert_called_once_with({"result": "data"})
|
||||
|
||||
def test_execute_post_hook_handles_exception(self, hook_manager):
|
||||
"""execute_post_hook handles hook exceptions gracefully."""
|
||||
failing_hook = Mock(side_effect=Exception("Hook failed"))
|
||||
hook_manager.register_post_hook("failing", failing_hook)
|
||||
|
||||
# Should not raise
|
||||
hook_manager.execute_post_hook("failing", {"data": "value"})
|
||||
|
||||
|
||||
class TestCreateContextForPrompt:
|
||||
"""Test create_context_for_prompt function."""
|
||||
|
||||
def test_returns_prompt_injection_string(self):
|
||||
"""create_context_for_prompt returns formatted string."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.build_context.return_value = MCPContext(
|
||||
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
|
||||
definition="def test_func(): pass",
|
||||
)
|
||||
|
||||
result = create_context_for_prompt(
|
||||
mock_provider, "explain", {"symbol": "test_func"}
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "<code_context>" in result
|
||||
assert "test_func" in result
|
||||
assert "</code_context>" in result
|
||||
|
||||
def test_returns_empty_string_when_no_context(self):
|
||||
"""create_context_for_prompt returns empty string when no context built."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.build_context.return_value = None
|
||||
|
||||
result = create_context_for_prompt(
|
||||
mock_provider, "explain", {"symbol": "nonexistent"}
|
||||
)
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_returns_empty_string_for_unknown_action(self):
|
||||
"""create_context_for_prompt returns empty string for unregistered action."""
|
||||
mock_provider = Mock()
|
||||
|
||||
result = create_context_for_prompt(
|
||||
mock_provider, "unknown_action", {"data": "value"}
|
||||
)
|
||||
|
||||
assert result == ""
|
||||
mock_provider.build_context.assert_not_called()
|
||||
383
codex-lens/tests/mcp/test_provider.py
Normal file
383
codex-lens/tests/mcp/test_provider.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Tests for MCP provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
from codexlens.mcp.schema import MCPContext, SymbolInfo, ReferenceInfo
|
||||
|
||||
|
||||
class TestMCPProvider:
|
||||
"""Test MCPProvider class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_global_index(self):
|
||||
"""Create a mock global index."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_engine(self):
|
||||
"""Create a mock search engine."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
"""Create a mock registry."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_global_index, mock_search_engine, mock_registry):
|
||||
"""Create an MCPProvider with mocked dependencies."""
|
||||
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
|
||||
|
||||
def test_build_context_returns_none_for_unknown_symbol(self, provider, mock_global_index):
|
||||
"""build_context returns None when symbol is not found."""
|
||||
mock_global_index.search.return_value = []
|
||||
|
||||
result = provider.build_context("unknown_symbol")
|
||||
|
||||
assert result is None
|
||||
mock_global_index.search.assert_called_once_with(
|
||||
"unknown_symbol", prefix_mode=False, limit=1
|
||||
)
|
||||
|
||||
def test_build_context_returns_mcp_context(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context returns MCPContext for known symbol."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
result = provider.build_context("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MCPContext)
|
||||
assert result.symbol is not None
|
||||
assert result.symbol.name == "my_func"
|
||||
assert result.symbol.kind == "function"
|
||||
assert result.context_type == "symbol_explanation"
|
||||
|
||||
def test_build_context_with_custom_context_type(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context respects custom context_type."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
result = provider.build_context("my_func", context_type="refactor_context")
|
||||
|
||||
assert result is not None
|
||||
assert result.context_type == "refactor_context"
|
||||
|
||||
def test_build_context_includes_references(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context includes references when include_references=True."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_ref = Mock()
|
||||
mock_ref.file_path = "/caller.py"
|
||||
mock_ref.line = 25
|
||||
mock_ref.column = 4
|
||||
mock_ref.context = "result = my_func()"
|
||||
mock_ref.relationship_type = "call"
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = [mock_ref]
|
||||
|
||||
result = provider.build_context("my_func", include_references=True)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.references) == 1
|
||||
assert result.references[0].file_path == "/caller.py"
|
||||
assert result.references[0].line == 25
|
||||
assert result.references[0].relationship_type == "call"
|
||||
|
||||
def test_build_context_excludes_references_when_disabled(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context excludes references when include_references=False."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
# Disable both references and related to avoid any search_references calls
|
||||
result = provider.build_context(
|
||||
"my_func", include_references=False, include_related=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.references) == 0
|
||||
mock_search_engine.search_references.assert_not_called()
|
||||
|
||||
def test_build_context_respects_max_references(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context passes max_references to search engine."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
# Disable include_related to test only the references call
|
||||
provider.build_context("my_func", max_references=5, include_related=False)
|
||||
|
||||
mock_search_engine.search_references.assert_called_once_with(
|
||||
"my_func", limit=5
|
||||
)
|
||||
|
||||
def test_build_context_includes_metadata(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context includes source metadata."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
result = provider.build_context("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert result.metadata.get("source") == "codex-lens"
|
||||
|
||||
def test_extract_definition_with_valid_file(self, provider):
|
||||
"""_extract_definition reads file content correctly."""
|
||||
# Create a temporary file with some content
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write("# Line 1\n")
|
||||
f.write("# Line 2\n")
|
||||
f.write("def my_func():\n") # Line 3
|
||||
f.write(" pass\n") # Line 4
|
||||
f.write("# Line 5\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (3, 4) # 1-based line numbers
|
||||
|
||||
definition = provider._extract_definition(mock_symbol)
|
||||
|
||||
assert definition is not None
|
||||
assert "def my_func():" in definition
|
||||
assert "pass" in definition
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_extract_definition_returns_none_for_missing_file(self, provider):
|
||||
"""_extract_definition returns None for non-existent file."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.file = "/nonexistent/path/file.py"
|
||||
mock_symbol.range = (1, 5)
|
||||
|
||||
definition = provider._extract_definition(mock_symbol)
|
||||
|
||||
assert definition is None
|
||||
|
||||
def test_extract_definition_returns_none_for_none_file(self, provider):
|
||||
"""_extract_definition returns None when symbol.file is None."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.file = None
|
||||
mock_symbol.range = (1, 5)
|
||||
|
||||
definition = provider._extract_definition(mock_symbol)
|
||||
|
||||
assert definition is None
|
||||
|
||||
def test_build_context_for_file_returns_context(
|
||||
self, provider, mock_global_index
|
||||
):
|
||||
"""build_context_for_file returns MCPContext."""
|
||||
mock_global_index.search.return_value = []
|
||||
|
||||
result = provider.build_context_for_file(
|
||||
Path("/test/file.py"),
|
||||
context_type="file_overview",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MCPContext)
|
||||
assert result.context_type == "file_overview"
|
||||
assert result.metadata.get("file_path") == str(Path("/test/file.py"))
|
||||
|
||||
def test_build_context_for_file_includes_symbols(
|
||||
self, provider, mock_global_index
|
||||
):
|
||||
"""build_context_for_file includes symbols from the file."""
|
||||
# Create temp file to get resolved path
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write("def func(): pass\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (1, 1)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
|
||||
result = provider.build_context_for_file(Path(temp_path))
|
||||
|
||||
assert result is not None
|
||||
# Symbols from this file should be in related_symbols
|
||||
assert len(result.related_symbols) >= 0 # May be 0 if filtering doesn't match
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
class TestMCPProviderRelatedSymbols:
|
||||
"""Test related symbols functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
"""Create provider with mocks."""
|
||||
mock_global_index = Mock()
|
||||
mock_search_engine = Mock()
|
||||
mock_registry = Mock()
|
||||
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
|
||||
|
||||
def test_get_related_symbols_from_references(self, provider):
|
||||
"""_get_related_symbols extracts symbols from references."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.file = "/test.py"
|
||||
|
||||
mock_ref1 = Mock()
|
||||
mock_ref1.file_path = "/caller1.py"
|
||||
mock_ref1.relationship_type = "call"
|
||||
|
||||
mock_ref2 = Mock()
|
||||
mock_ref2.file_path = "/caller2.py"
|
||||
mock_ref2.relationship_type = "import"
|
||||
|
||||
provider.search_engine.search_references.return_value = [mock_ref1, mock_ref2]
|
||||
|
||||
related = provider._get_related_symbols(mock_symbol)
|
||||
|
||||
assert len(related) == 2
|
||||
assert related[0].relationship == "call"
|
||||
assert related[1].relationship == "import"
|
||||
|
||||
def test_get_related_symbols_limits_results(self, provider):
|
||||
"""_get_related_symbols limits to 10 unique relationship types."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.file = "/test.py"
|
||||
|
||||
# Create 15 references with unique relationship types
|
||||
refs = []
|
||||
for i in range(15):
|
||||
ref = Mock()
|
||||
ref.file_path = f"/file{i}.py"
|
||||
ref.relationship_type = f"type{i}"
|
||||
refs.append(ref)
|
||||
|
||||
provider.search_engine.search_references.return_value = refs
|
||||
|
||||
related = provider._get_related_symbols(mock_symbol)
|
||||
|
||||
assert len(related) <= 10
|
||||
|
||||
def test_get_related_symbols_handles_exception(self, provider):
|
||||
"""_get_related_symbols handles exceptions gracefully."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.file = "/test.py"
|
||||
|
||||
provider.search_engine.search_references.side_effect = Exception("Search failed")
|
||||
|
||||
related = provider._get_related_symbols(mock_symbol)
|
||||
|
||||
assert related == []
|
||||
|
||||
|
||||
class TestMCPProviderIntegration:
|
||||
"""Integration-style tests for MCPProvider."""
|
||||
|
||||
def test_full_context_workflow(self):
|
||||
"""Test complete context building workflow."""
|
||||
# Create temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write("def my_function(arg1, arg2):\n")
|
||||
f.write(" '''This is my function.'''\n")
|
||||
f.write(" return arg1 + arg2\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Setup mocks
|
||||
mock_global_index = Mock()
|
||||
mock_search_engine = Mock()
|
||||
mock_registry = Mock()
|
||||
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_function"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (1, 3)
|
||||
|
||||
mock_ref = Mock()
|
||||
mock_ref.file_path = "/user.py"
|
||||
mock_ref.line = 10
|
||||
mock_ref.column = 4
|
||||
mock_ref.context = "result = my_function(1, 2)"
|
||||
mock_ref.relationship_type = "call"
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = [mock_ref]
|
||||
|
||||
provider = MCPProvider(mock_global_index, mock_search_engine, mock_registry)
|
||||
context = provider.build_context("my_function")
|
||||
|
||||
assert context is not None
|
||||
assert context.symbol.name == "my_function"
|
||||
assert context.definition is not None
|
||||
assert "def my_function" in context.definition
|
||||
assert len(context.references) == 1
|
||||
assert context.references[0].relationship_type == "call"
|
||||
|
||||
# Test serialization
|
||||
json_str = context.to_json()
|
||||
assert "my_function" in json_str
|
||||
|
||||
# Test prompt injection
|
||||
prompt = context.to_prompt_injection()
|
||||
assert "<code_context>" in prompt
|
||||
assert "my_function" in prompt
|
||||
assert "</code_context>" in prompt
|
||||
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
288
codex-lens/tests/mcp/test_schema.py
Normal file
288
codex-lens/tests/mcp/test_schema.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Tests for MCP schema."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
|
||||
|
||||
class TestSymbolInfo:
|
||||
"""Test SymbolInfo dataclass."""
|
||||
|
||||
def test_to_dict_includes_all_fields(self):
|
||||
"""SymbolInfo.to_dict() includes all non-None fields."""
|
||||
info = SymbolInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line_start=10,
|
||||
line_end=20,
|
||||
signature="def func():",
|
||||
documentation="Test doc",
|
||||
)
|
||||
d = info.to_dict()
|
||||
assert d["name"] == "func"
|
||||
assert d["kind"] == "function"
|
||||
assert d["file_path"] == "/test.py"
|
||||
assert d["line_start"] == 10
|
||||
assert d["line_end"] == 20
|
||||
assert d["signature"] == "def func():"
|
||||
assert d["documentation"] == "Test doc"
|
||||
|
||||
def test_to_dict_excludes_none(self):
|
||||
"""SymbolInfo.to_dict() excludes None fields."""
|
||||
info = SymbolInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line_start=10,
|
||||
line_end=20,
|
||||
)
|
||||
d = info.to_dict()
|
||||
assert "signature" not in d
|
||||
assert "documentation" not in d
|
||||
assert "name" in d
|
||||
assert "kind" in d
|
||||
|
||||
def test_basic_creation(self):
|
||||
"""SymbolInfo can be created with required fields only."""
|
||||
info = SymbolInfo(
|
||||
name="MyClass",
|
||||
kind="class",
|
||||
file_path="/src/module.py",
|
||||
line_start=1,
|
||||
line_end=50,
|
||||
)
|
||||
assert info.name == "MyClass"
|
||||
assert info.kind == "class"
|
||||
assert info.signature is None
|
||||
assert info.documentation is None
|
||||
|
||||
|
||||
class TestReferenceInfo:
|
||||
"""Test ReferenceInfo dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""ReferenceInfo.to_dict() returns all fields."""
|
||||
ref = ReferenceInfo(
|
||||
file_path="/src/main.py",
|
||||
line=25,
|
||||
column=4,
|
||||
context="result = func()",
|
||||
relationship_type="call",
|
||||
)
|
||||
d = ref.to_dict()
|
||||
assert d["file_path"] == "/src/main.py"
|
||||
assert d["line"] == 25
|
||||
assert d["column"] == 4
|
||||
assert d["context"] == "result = func()"
|
||||
assert d["relationship_type"] == "call"
|
||||
|
||||
def test_all_fields_required(self):
|
||||
"""ReferenceInfo requires all fields."""
|
||||
ref = ReferenceInfo(
|
||||
file_path="/test.py",
|
||||
line=10,
|
||||
column=0,
|
||||
context="import module",
|
||||
relationship_type="import",
|
||||
)
|
||||
assert ref.file_path == "/test.py"
|
||||
assert ref.relationship_type == "import"
|
||||
|
||||
|
||||
class TestRelatedSymbol:
|
||||
"""Test RelatedSymbol dataclass."""
|
||||
|
||||
def test_to_dict_includes_all_fields(self):
|
||||
"""RelatedSymbol.to_dict() includes all non-None fields."""
|
||||
sym = RelatedSymbol(
|
||||
name="BaseClass",
|
||||
kind="class",
|
||||
relationship="inherits",
|
||||
file_path="/src/base.py",
|
||||
)
|
||||
d = sym.to_dict()
|
||||
assert d["name"] == "BaseClass"
|
||||
assert d["kind"] == "class"
|
||||
assert d["relationship"] == "inherits"
|
||||
assert d["file_path"] == "/src/base.py"
|
||||
|
||||
def test_to_dict_excludes_none(self):
|
||||
"""RelatedSymbol.to_dict() excludes None file_path."""
|
||||
sym = RelatedSymbol(
|
||||
name="helper",
|
||||
kind="function",
|
||||
relationship="calls",
|
||||
)
|
||||
d = sym.to_dict()
|
||||
assert "file_path" not in d
|
||||
assert d["name"] == "helper"
|
||||
assert d["relationship"] == "calls"
|
||||
|
||||
|
||||
class TestMCPContext:
|
||||
"""Test MCPContext dataclass."""
|
||||
|
||||
def test_to_dict_basic(self):
|
||||
"""MCPContext.to_dict() returns basic structure."""
|
||||
ctx = MCPContext(context_type="test")
|
||||
d = ctx.to_dict()
|
||||
assert d["version"] == "1.0"
|
||||
assert d["context_type"] == "test"
|
||||
assert d["metadata"] == {}
|
||||
|
||||
def test_to_dict_with_symbol(self):
|
||||
"""MCPContext.to_dict() includes symbol when present."""
|
||||
ctx = MCPContext(
|
||||
context_type="test",
|
||||
symbol=SymbolInfo("f", "function", "/t.py", 1, 2),
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert "symbol" in d
|
||||
assert d["symbol"]["name"] == "f"
|
||||
assert d["symbol"]["kind"] == "function"
|
||||
|
||||
def test_to_dict_with_references(self):
|
||||
"""MCPContext.to_dict() includes references when present."""
|
||||
ctx = MCPContext(
|
||||
context_type="test",
|
||||
references=[
|
||||
ReferenceInfo("/a.py", 10, 0, "call()", "call"),
|
||||
ReferenceInfo("/b.py", 20, 5, "import x", "import"),
|
||||
],
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert "references" in d
|
||||
assert len(d["references"]) == 2
|
||||
assert d["references"][0]["line"] == 10
|
||||
|
||||
def test_to_dict_with_related_symbols(self):
|
||||
"""MCPContext.to_dict() includes related_symbols when present."""
|
||||
ctx = MCPContext(
|
||||
context_type="test",
|
||||
related_symbols=[
|
||||
RelatedSymbol("Base", "class", "inherits"),
|
||||
RelatedSymbol("helper", "function", "calls"),
|
||||
],
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert "related_symbols" in d
|
||||
assert len(d["related_symbols"]) == 2
|
||||
|
||||
def test_to_json(self):
|
||||
"""MCPContext.to_json() returns valid JSON."""
|
||||
ctx = MCPContext(context_type="test")
|
||||
j = ctx.to_json()
|
||||
parsed = json.loads(j)
|
||||
assert parsed["version"] == "1.0"
|
||||
assert parsed["context_type"] == "test"
|
||||
|
||||
def test_to_json_with_indent(self):
|
||||
"""MCPContext.to_json() respects indent parameter."""
|
||||
ctx = MCPContext(context_type="test")
|
||||
j = ctx.to_json(indent=4)
|
||||
# Check it's properly indented
|
||||
assert " " in j
|
||||
|
||||
def test_to_prompt_injection_basic(self):
|
||||
"""MCPContext.to_prompt_injection() returns formatted string."""
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("my_func", "function", "/test.py", 10, 20),
|
||||
definition="def my_func(): pass",
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "<code_context>" in prompt
|
||||
assert "my_func" in prompt
|
||||
assert "def my_func()" in prompt
|
||||
assert "</code_context>" in prompt
|
||||
|
||||
def test_to_prompt_injection_with_references(self):
|
||||
"""MCPContext.to_prompt_injection() includes references."""
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
|
||||
references=[
|
||||
ReferenceInfo("/a.py", 10, 0, "func()", "call"),
|
||||
ReferenceInfo("/b.py", 20, 0, "from x import func", "import"),
|
||||
],
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "References (2 found)" in prompt
|
||||
assert "/a.py:10" in prompt
|
||||
assert "call" in prompt
|
||||
|
||||
def test_to_prompt_injection_limits_references(self):
|
||||
"""MCPContext.to_prompt_injection() limits references to 5."""
|
||||
refs = [
|
||||
ReferenceInfo(f"/file{i}.py", i, 0, f"ref{i}", "call")
|
||||
for i in range(10)
|
||||
]
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
|
||||
references=refs,
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
# Should show "10 found" but only include 5
|
||||
assert "References (10 found)" in prompt
|
||||
assert "/file0.py" in prompt
|
||||
assert "/file4.py" in prompt
|
||||
assert "/file5.py" not in prompt
|
||||
|
||||
def test_to_prompt_injection_with_related_symbols(self):
|
||||
"""MCPContext.to_prompt_injection() includes related symbols."""
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("MyClass", "class", "/test.py", 1, 50),
|
||||
related_symbols=[
|
||||
RelatedSymbol("BaseClass", "class", "inherits"),
|
||||
RelatedSymbol("helper", "function", "calls"),
|
||||
],
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "Related Symbols" in prompt
|
||||
assert "BaseClass (inherits)" in prompt
|
||||
assert "helper (calls)" in prompt
|
||||
|
||||
def test_to_prompt_injection_limits_related_symbols(self):
|
||||
"""MCPContext.to_prompt_injection() limits related symbols to 10."""
|
||||
related = [
|
||||
RelatedSymbol(f"sym{i}", "function", "calls")
|
||||
for i in range(15)
|
||||
]
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
|
||||
related_symbols=related,
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "sym0 (calls)" in prompt
|
||||
assert "sym9 (calls)" in prompt
|
||||
assert "sym10 (calls)" not in prompt
|
||||
|
||||
def test_empty_context(self):
|
||||
"""MCPContext works with minimal data."""
|
||||
ctx = MCPContext()
|
||||
d = ctx.to_dict()
|
||||
assert d["version"] == "1.0"
|
||||
assert d["context_type"] == "code_context"
|
||||
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "<code_context>" in prompt
|
||||
assert "</code_context>" in prompt
|
||||
|
||||
def test_metadata_preserved(self):
|
||||
"""MCPContext preserves custom metadata."""
|
||||
ctx = MCPContext(
|
||||
context_type="custom",
|
||||
metadata={
|
||||
"source": "codex-lens",
|
||||
"indexed_at": "2024-01-01",
|
||||
"custom_key": "custom_value",
|
||||
},
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert d["metadata"]["source"] == "codex-lens"
|
||||
assert d["metadata"]["custom_key"] == "custom_value"
|
||||
@@ -79,3 +79,87 @@ def test_symbol_filtering_handles_path_failures(monkeypatch: pytest.MonkeyPatch,
|
||||
if os.name == "nt":
|
||||
assert "CrossDrive" in caplog.text
|
||||
|
||||
|
||||
def test_cascade_search_strategy_routing(temp_paths: Path) -> None:
|
||||
"""Test cascade_search() routes to correct strategy implementation."""
|
||||
from unittest.mock import patch
|
||||
from codexlens.search.chain_search import ChainSearchResult, SearchStats
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
mapper = PathMapper(index_root=temp_paths / "indexes")
|
||||
config = Config(data_dir=temp_paths / "data")
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
source_path = temp_paths / "src"
|
||||
|
||||
# Test strategy='staged' routing
|
||||
with patch.object(engine, "staged_cascade_search") as mock_staged:
|
||||
mock_staged.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="staged")
|
||||
mock_staged.assert_called_once()
|
||||
|
||||
# Test strategy='binary' routing
|
||||
with patch.object(engine, "binary_cascade_search") as mock_binary:
|
||||
mock_binary.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="binary")
|
||||
mock_binary.assert_called_once()
|
||||
|
||||
# Test strategy='hybrid' routing
|
||||
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
|
||||
mock_hybrid.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="hybrid")
|
||||
mock_hybrid.assert_called_once()
|
||||
|
||||
# Test strategy='binary_rerank' routing
|
||||
with patch.object(engine, "binary_rerank_cascade_search") as mock_br:
|
||||
mock_br.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="binary_rerank")
|
||||
mock_br.assert_called_once()
|
||||
|
||||
# Test strategy='dense_rerank' routing
|
||||
with patch.object(engine, "dense_rerank_cascade_search") as mock_dr:
|
||||
mock_dr.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="dense_rerank")
|
||||
mock_dr.assert_called_once()
|
||||
|
||||
# Test default routing (no strategy specified) - defaults to binary
|
||||
with patch.object(engine, "binary_cascade_search") as mock_default:
|
||||
mock_default.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path)
|
||||
mock_default.assert_called_once()
|
||||
|
||||
|
||||
def test_cascade_search_invalid_strategy(temp_paths: Path) -> None:
|
||||
"""Test cascade_search() defaults to 'binary' for invalid strategy."""
|
||||
from unittest.mock import patch
|
||||
from codexlens.search.chain_search import ChainSearchResult, SearchStats
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
mapper = PathMapper(index_root=temp_paths / "indexes")
|
||||
config = Config(data_dir=temp_paths / "data")
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
source_path = temp_paths / "src"
|
||||
|
||||
# Invalid strategy should default to binary
|
||||
with patch.object(engine, "binary_cascade_search") as mock_binary:
|
||||
mock_binary.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="invalid_strategy")
|
||||
mock_binary.assert_called_once()
|
||||
|
||||
|
||||
766
codex-lens/tests/test_clustering_strategies.py
Normal file
766
codex-lens/tests/test_clustering_strategies.py
Normal file
@@ -0,0 +1,766 @@
|
||||
"""Unit tests for clustering strategies in the hybrid search pipeline.
|
||||
|
||||
Tests cover:
|
||||
1. HDBSCANStrategy - Primary HDBSCAN clustering
|
||||
2. DBSCANStrategy - Fallback DBSCAN clustering
|
||||
3. NoOpStrategy - No-op fallback when clustering unavailable
|
||||
4. ClusteringStrategyFactory - Factory with fallback chain
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.clustering import (
|
||||
BaseClusteringStrategy,
|
||||
ClusteringConfig,
|
||||
ClusteringStrategyFactory,
|
||||
NoOpStrategy,
|
||||
check_clustering_strategy_available,
|
||||
get_strategy,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_results() -> List[SearchResult]:
|
||||
"""Create sample search results for testing."""
|
||||
return [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="def foo(): pass"),
|
||||
SearchResult(path="b.py", score=0.8, excerpt="def foo(): pass"),
|
||||
SearchResult(path="c.py", score=0.7, excerpt="def bar(): pass"),
|
||||
SearchResult(path="d.py", score=0.6, excerpt="def bar(): pass"),
|
||||
SearchResult(path="e.py", score=0.5, excerpt="def baz(): pass"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings():
|
||||
"""Create mock embeddings for 5 results.
|
||||
|
||||
Creates embeddings that should form 2 clusters:
|
||||
- Results 0, 1 (similar to each other)
|
||||
- Results 2, 3 (similar to each other)
|
||||
- Result 4 (noise/singleton)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Create embeddings in 3D for simplicity
|
||||
return np.array(
|
||||
[
|
||||
[1.0, 0.0, 0.0], # Result 0 - cluster A
|
||||
[0.9, 0.1, 0.0], # Result 1 - cluster A
|
||||
[0.0, 1.0, 0.0], # Result 2 - cluster B
|
||||
[0.1, 0.9, 0.0], # Result 3 - cluster B
|
||||
[0.0, 0.0, 1.0], # Result 4 - noise/singleton
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_config() -> ClusteringConfig:
|
||||
"""Create default clustering configuration."""
|
||||
return ClusteringConfig(
|
||||
min_cluster_size=2,
|
||||
min_samples=1,
|
||||
metric="euclidean",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test ClusteringConfig
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestClusteringConfig:
|
||||
"""Tests for ClusteringConfig validation."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default configuration values."""
|
||||
config = ClusteringConfig()
|
||||
assert config.min_cluster_size == 3
|
||||
assert config.min_samples == 2
|
||||
assert config.metric == "cosine"
|
||||
assert config.cluster_selection_epsilon == 0.0
|
||||
assert config.allow_single_cluster is True
|
||||
assert config.prediction_data is False
|
||||
|
||||
def test_custom_values(self):
|
||||
"""Test custom configuration values."""
|
||||
config = ClusteringConfig(
|
||||
min_cluster_size=5,
|
||||
min_samples=3,
|
||||
metric="euclidean",
|
||||
cluster_selection_epsilon=0.1,
|
||||
allow_single_cluster=False,
|
||||
prediction_data=True,
|
||||
)
|
||||
assert config.min_cluster_size == 5
|
||||
assert config.min_samples == 3
|
||||
assert config.metric == "euclidean"
|
||||
|
||||
def test_invalid_min_cluster_size(self):
|
||||
"""Test validation rejects min_cluster_size < 2."""
|
||||
with pytest.raises(ValueError, match="min_cluster_size must be >= 2"):
|
||||
ClusteringConfig(min_cluster_size=1)
|
||||
|
||||
def test_invalid_min_samples(self):
|
||||
"""Test validation rejects min_samples < 1."""
|
||||
with pytest.raises(ValueError, match="min_samples must be >= 1"):
|
||||
ClusteringConfig(min_samples=0)
|
||||
|
||||
def test_invalid_metric(self):
|
||||
"""Test validation rejects invalid metric."""
|
||||
with pytest.raises(ValueError, match="metric must be one of"):
|
||||
ClusteringConfig(metric="invalid")
|
||||
|
||||
def test_invalid_epsilon(self):
|
||||
"""Test validation rejects negative epsilon."""
|
||||
with pytest.raises(ValueError, match="cluster_selection_epsilon must be >= 0"):
|
||||
ClusteringConfig(cluster_selection_epsilon=-0.1)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test NoOpStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNoOpStrategy:
|
||||
"""Tests for NoOpStrategy - always available."""
|
||||
|
||||
def test_cluster_returns_singleton_clusters(
|
||||
self, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test cluster() returns each result as singleton cluster."""
|
||||
strategy = NoOpStrategy()
|
||||
clusters = strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
assert len(clusters) == 5
|
||||
for i, cluster in enumerate(clusters):
|
||||
assert cluster == [i]
|
||||
|
||||
def test_cluster_empty_results(self):
|
||||
"""Test cluster() with empty results."""
|
||||
import numpy as np
|
||||
|
||||
strategy = NoOpStrategy()
|
||||
clusters = strategy.cluster(np.array([]), [])
|
||||
|
||||
assert clusters == []
|
||||
|
||||
def test_select_representatives_returns_all_sorted(
|
||||
self, sample_results: List[SearchResult]
|
||||
):
|
||||
"""Test select_representatives() returns all results sorted by score."""
|
||||
strategy = NoOpStrategy()
|
||||
clusters = [[i] for i in range(len(sample_results))]
|
||||
representatives = strategy.select_representatives(clusters, sample_results)
|
||||
|
||||
assert len(representatives) == 5
|
||||
# Check sorted by score descending
|
||||
scores = [r.score for r in representatives]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_select_representatives_empty(self):
|
||||
"""Test select_representatives() with empty input."""
|
||||
strategy = NoOpStrategy()
|
||||
representatives = strategy.select_representatives([], [])
|
||||
assert representatives == []
|
||||
|
||||
def test_fit_predict_convenience_method(
|
||||
self, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test fit_predict() convenience method."""
|
||||
strategy = NoOpStrategy()
|
||||
representatives = strategy.fit_predict(mock_embeddings, sample_results)
|
||||
|
||||
assert len(representatives) == 5
|
||||
# All results returned, sorted by score
|
||||
assert representatives[0].score >= representatives[-1].score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test HDBSCANStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestHDBSCANStrategy:
|
||||
"""Tests for HDBSCANStrategy - requires hdbscan package."""
|
||||
|
||||
@pytest.fixture
|
||||
def hdbscan_strategy(self, default_config):
|
||||
"""Create HDBSCANStrategy if available."""
|
||||
try:
|
||||
from codexlens.search.clustering import HDBSCANStrategy
|
||||
|
||||
return HDBSCANStrategy(default_config)
|
||||
except ImportError:
|
||||
pytest.skip("hdbscan not installed")
|
||||
|
||||
def test_cluster_returns_list_of_lists(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test cluster() returns List[List[int]]."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
assert isinstance(clusters, list)
|
||||
for cluster in clusters:
|
||||
assert isinstance(cluster, list)
|
||||
for idx in cluster:
|
||||
assert isinstance(idx, int)
|
||||
assert 0 <= idx < len(sample_results)
|
||||
|
||||
def test_cluster_covers_all_results(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test all result indices appear in clusters."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
all_indices = set()
|
||||
for cluster in clusters:
|
||||
all_indices.update(cluster)
|
||||
|
||||
assert all_indices == set(range(len(sample_results)))
|
||||
|
||||
def test_cluster_empty_results(self, hdbscan_strategy):
|
||||
"""Test cluster() with empty results."""
|
||||
import numpy as np
|
||||
|
||||
clusters = hdbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
|
||||
assert clusters == []
|
||||
|
||||
def test_cluster_single_result(self, hdbscan_strategy):
|
||||
"""Test cluster() with single result."""
|
||||
import numpy as np
|
||||
|
||||
result = SearchResult(path="a.py", score=0.9, excerpt="test")
|
||||
embeddings = np.array([[1.0, 0.0, 0.0]])
|
||||
clusters = hdbscan_strategy.cluster(embeddings, [result])
|
||||
|
||||
assert len(clusters) == 1
|
||||
assert clusters[0] == [0]
|
||||
|
||||
def test_cluster_fewer_than_min_cluster_size(self, hdbscan_strategy):
|
||||
"""Test cluster() with fewer results than min_cluster_size."""
|
||||
import numpy as np
|
||||
|
||||
# Strategy has min_cluster_size=2, so 1 result returns singleton
|
||||
result = SearchResult(path="a.py", score=0.9, excerpt="test")
|
||||
embeddings = np.array([[1.0, 0.0, 0.0]])
|
||||
clusters = hdbscan_strategy.cluster(embeddings, [result])
|
||||
|
||||
assert len(clusters) == 1
|
||||
assert clusters[0] == [0]
|
||||
|
||||
def test_select_representatives_picks_highest_score(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test select_representatives() picks highest score per cluster."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = hdbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
# Each representative should be the highest-scored in its cluster
|
||||
for rep in representatives:
|
||||
# Find the cluster containing this representative
|
||||
rep_idx = next(
|
||||
i for i, r in enumerate(sample_results) if r.path == rep.path
|
||||
)
|
||||
for cluster in clusters:
|
||||
if rep_idx in cluster:
|
||||
cluster_scores = [sample_results[i].score for i in cluster]
|
||||
assert rep.score == max(cluster_scores)
|
||||
break
|
||||
|
||||
def test_select_representatives_sorted_by_score(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test representatives are sorted by score descending."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = hdbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
scores = [r.score for r in representatives]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_fit_predict_end_to_end(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test fit_predict() end-to-end clustering."""
|
||||
representatives = hdbscan_strategy.fit_predict(mock_embeddings, sample_results)
|
||||
|
||||
# Should have fewer or equal representatives than input
|
||||
assert len(representatives) <= len(sample_results)
|
||||
# All representatives should be from original results
|
||||
rep_paths = {r.path for r in representatives}
|
||||
original_paths = {r.path for r in sample_results}
|
||||
assert rep_paths.issubset(original_paths)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test DBSCANStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDBSCANStrategy:
|
||||
"""Tests for DBSCANStrategy - requires sklearn."""
|
||||
|
||||
@pytest.fixture
|
||||
def dbscan_strategy(self, default_config):
|
||||
"""Create DBSCANStrategy if available."""
|
||||
try:
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
|
||||
return DBSCANStrategy(default_config)
|
||||
except ImportError:
|
||||
pytest.skip("sklearn not installed")
|
||||
|
||||
def test_cluster_returns_list_of_lists(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test cluster() returns List[List[int]]."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
assert isinstance(clusters, list)
|
||||
for cluster in clusters:
|
||||
assert isinstance(cluster, list)
|
||||
for idx in cluster:
|
||||
assert isinstance(idx, int)
|
||||
assert 0 <= idx < len(sample_results)
|
||||
|
||||
def test_cluster_covers_all_results(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test all result indices appear in clusters."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
all_indices = set()
|
||||
for cluster in clusters:
|
||||
all_indices.update(cluster)
|
||||
|
||||
assert all_indices == set(range(len(sample_results)))
|
||||
|
||||
def test_cluster_empty_results(self, dbscan_strategy):
|
||||
"""Test cluster() with empty results."""
|
||||
import numpy as np
|
||||
|
||||
clusters = dbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
|
||||
assert clusters == []
|
||||
|
||||
def test_cluster_single_result(self, dbscan_strategy):
|
||||
"""Test cluster() with single result."""
|
||||
import numpy as np
|
||||
|
||||
result = SearchResult(path="a.py", score=0.9, excerpt="test")
|
||||
embeddings = np.array([[1.0, 0.0, 0.0]])
|
||||
clusters = dbscan_strategy.cluster(embeddings, [result])
|
||||
|
||||
assert len(clusters) == 1
|
||||
assert clusters[0] == [0]
|
||||
|
||||
def test_cluster_with_explicit_eps(self, default_config):
|
||||
"""Test cluster() with explicit eps parameter."""
|
||||
try:
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
except ImportError:
|
||||
pytest.skip("sklearn not installed")
|
||||
|
||||
import numpy as np
|
||||
|
||||
strategy = DBSCANStrategy(default_config, eps=0.5)
|
||||
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(3)]
|
||||
embeddings = np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0]])
|
||||
|
||||
clusters = strategy.cluster(embeddings, results)
|
||||
# With eps=0.5, first two should cluster, third should be separate
|
||||
assert len(clusters) >= 2
|
||||
|
||||
def test_auto_compute_eps(self, dbscan_strategy, mock_embeddings):
|
||||
"""Test eps auto-computation from distance distribution."""
|
||||
# Should not raise - eps is computed automatically
|
||||
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(5)]
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, results)
|
||||
assert len(clusters) > 0
|
||||
|
||||
def test_select_representatives_picks_highest_score(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test select_representatives() picks highest score per cluster."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = dbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
# Each representative should be the highest-scored in its cluster
|
||||
for rep in representatives:
|
||||
rep_idx = next(
|
||||
i for i, r in enumerate(sample_results) if r.path == rep.path
|
||||
)
|
||||
for cluster in clusters:
|
||||
if rep_idx in cluster:
|
||||
cluster_scores = [sample_results[i].score for i in cluster]
|
||||
assert rep.score == max(cluster_scores)
|
||||
break
|
||||
|
||||
def test_select_representatives_sorted_by_score(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test representatives are sorted by score descending."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = dbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
scores = [r.score for r in representatives]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test ClusteringStrategyFactory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestClusteringStrategyFactory:
|
||||
"""Tests for ClusteringStrategyFactory."""
|
||||
|
||||
def test_check_noop_always_available(self):
|
||||
"""Test noop strategy is always available."""
|
||||
ok, err = check_clustering_strategy_available("noop")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
def test_check_invalid_strategy(self):
|
||||
"""Test invalid strategy name returns error."""
|
||||
ok, err = check_clustering_strategy_available("invalid")
|
||||
assert ok is False
|
||||
assert "Invalid clustering strategy" in err
|
||||
|
||||
def test_get_strategy_noop(self, default_config):
|
||||
"""Test get_strategy('noop') returns NoOpStrategy."""
|
||||
strategy = get_strategy("noop", default_config)
|
||||
assert isinstance(strategy, NoOpStrategy)
|
||||
|
||||
def test_get_strategy_auto_returns_something(self, default_config):
|
||||
"""Test get_strategy('auto') returns a strategy."""
|
||||
strategy = get_strategy("auto", default_config)
|
||||
assert isinstance(strategy, BaseClusteringStrategy)
|
||||
|
||||
def test_get_strategy_with_fallback_enabled(self, default_config):
|
||||
"""Test fallback when primary strategy unavailable."""
|
||||
# Mock hdbscan unavailable
|
||||
with patch.dict("sys.modules", {"hdbscan": None}):
|
||||
# Should fall back to dbscan or noop
|
||||
strategy = get_strategy("hdbscan", default_config, fallback=True)
|
||||
assert isinstance(strategy, BaseClusteringStrategy)
|
||||
|
||||
def test_get_strategy_fallback_disabled_raises(self, default_config):
|
||||
"""Test ImportError when fallback disabled and strategy unavailable."""
|
||||
with patch(
|
||||
"codexlens.search.clustering.factory.check_clustering_strategy_available"
|
||||
) as mock_check:
|
||||
mock_check.return_value = (False, "Test error")
|
||||
|
||||
with pytest.raises(ImportError, match="Test error"):
|
||||
get_strategy("hdbscan", default_config, fallback=False)
|
||||
|
||||
def test_get_strategy_invalid_raises(self, default_config):
|
||||
"""Test ValueError for invalid strategy name."""
|
||||
with pytest.raises(ValueError, match="Unknown clustering strategy"):
|
||||
get_strategy("invalid", default_config)
|
||||
|
||||
def test_factory_class_interface(self, default_config):
|
||||
"""Test ClusteringStrategyFactory class interface."""
|
||||
strategy = ClusteringStrategyFactory.get_strategy("noop", default_config)
|
||||
assert isinstance(strategy, NoOpStrategy)
|
||||
|
||||
ok, err = ClusteringStrategyFactory.check_available("noop")
|
||||
assert ok is True
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_clustering_strategy_available("hdbscan")[0],
|
||||
reason="hdbscan not installed",
|
||||
)
|
||||
def test_get_strategy_hdbscan(self, default_config):
|
||||
"""Test get_strategy('hdbscan') returns HDBSCANStrategy."""
|
||||
from codexlens.search.clustering import HDBSCANStrategy
|
||||
|
||||
strategy = get_strategy("hdbscan", default_config)
|
||||
assert isinstance(strategy, HDBSCANStrategy)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_clustering_strategy_available("dbscan")[0],
|
||||
reason="sklearn not installed",
|
||||
)
|
||||
def test_get_strategy_dbscan(self, default_config):
|
||||
"""Test get_strategy('dbscan') returns DBSCANStrategy."""
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
|
||||
strategy = get_strategy("dbscan", default_config)
|
||||
assert isinstance(strategy, DBSCANStrategy)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_clustering_strategy_available("dbscan")[0],
|
||||
reason="sklearn not installed",
|
||||
)
|
||||
def test_get_strategy_dbscan_with_kwargs(self, default_config):
|
||||
"""Test DBSCANStrategy kwargs passed through factory."""
|
||||
strategy = get_strategy("dbscan", default_config, eps=0.3, eps_percentile=20.0)
|
||||
assert strategy.eps == 0.3
|
||||
assert strategy.eps_percentile == 20.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestClusteringIntegration:
|
||||
"""Integration tests for clustering strategies."""
|
||||
|
||||
def test_all_strategies_same_interface(
|
||||
self, sample_results: List[SearchResult], mock_embeddings, default_config
|
||||
):
|
||||
"""Test all strategies have consistent interface."""
|
||||
strategies = [NoOpStrategy(default_config)]
|
||||
|
||||
# Add available strategies
|
||||
try:
|
||||
from codexlens.search.clustering import HDBSCANStrategy
|
||||
|
||||
strategies.append(HDBSCANStrategy(default_config))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
|
||||
strategies.append(DBSCANStrategy(default_config))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
for strategy in strategies:
|
||||
# All should implement cluster()
|
||||
clusters = strategy.cluster(mock_embeddings, sample_results)
|
||||
assert isinstance(clusters, list)
|
||||
|
||||
# All should implement select_representatives()
|
||||
reps = strategy.select_representatives(clusters, sample_results)
|
||||
assert isinstance(reps, list)
|
||||
assert all(isinstance(r, SearchResult) for r in reps)
|
||||
|
||||
# All should implement fit_predict()
|
||||
reps = strategy.fit_predict(mock_embeddings, sample_results)
|
||||
assert isinstance(reps, list)
|
||||
|
||||
def test_clustering_reduces_redundancy(
|
||||
self, default_config
|
||||
):
|
||||
"""Test clustering reduces redundant similar results."""
|
||||
import numpy as np
|
||||
|
||||
# Create results with very similar embeddings
|
||||
results = [
|
||||
SearchResult(path=f"{i}.py", score=0.9 - i * 0.01, excerpt="def foo(): pass")
|
||||
for i in range(10)
|
||||
]
|
||||
# Very similar embeddings - should cluster together
|
||||
embeddings = np.array(
|
||||
[[1.0 + i * 0.01, 0.0, 0.0] for i in range(10)], dtype=np.float32
|
||||
)
|
||||
|
||||
strategy = get_strategy("auto", default_config)
|
||||
representatives = strategy.fit_predict(embeddings, results)
|
||||
|
||||
# Should have fewer representatives than input (clustering reduced redundancy)
|
||||
# NoOp returns all, but HDBSCAN/DBSCAN should reduce
|
||||
assert len(representatives) <= len(results)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test FrequencyStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFrequencyStrategy:
|
||||
"""Tests for FrequencyStrategy - frequency-based clustering."""
|
||||
|
||||
@pytest.fixture
|
||||
def frequency_config(self):
|
||||
"""Create FrequencyConfig for testing."""
|
||||
from codexlens.search.clustering import FrequencyConfig
|
||||
return FrequencyConfig(min_frequency=1, max_representatives_per_group=3)
|
||||
|
||||
@pytest.fixture
|
||||
def frequency_strategy(self, frequency_config):
|
||||
"""Create FrequencyStrategy instance."""
|
||||
from codexlens.search.clustering import FrequencyStrategy
|
||||
return FrequencyStrategy(frequency_config)
|
||||
|
||||
@pytest.fixture
|
||||
def symbol_results(self) -> List[SearchResult]:
|
||||
"""Create sample results with symbol names for frequency testing."""
|
||||
return [
|
||||
SearchResult(path="auth.py", score=0.9, excerpt="authenticate user", symbol_name="authenticate"),
|
||||
SearchResult(path="login.py", score=0.85, excerpt="authenticate login", symbol_name="authenticate"),
|
||||
SearchResult(path="session.py", score=0.8, excerpt="authenticate session", symbol_name="authenticate"),
|
||||
SearchResult(path="utils.py", score=0.7, excerpt="helper function", symbol_name="helper_func"),
|
||||
SearchResult(path="validate.py", score=0.6, excerpt="validate input", symbol_name="validate"),
|
||||
SearchResult(path="check.py", score=0.55, excerpt="validate data", symbol_name="validate"),
|
||||
]
|
||||
|
||||
def test_frequency_strategy_available(self):
|
||||
"""Test FrequencyStrategy is always available (no deps)."""
|
||||
ok, err = check_clustering_strategy_available("frequency")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
def test_get_strategy_frequency(self):
|
||||
"""Test get_strategy('frequency') returns FrequencyStrategy."""
|
||||
from codexlens.search.clustering import FrequencyStrategy
|
||||
strategy = get_strategy("frequency")
|
||||
assert isinstance(strategy, FrequencyStrategy)
|
||||
|
||||
def test_cluster_groups_by_symbol(self, frequency_strategy, symbol_results):
|
||||
"""Test cluster() groups results by symbol name."""
|
||||
import numpy as np
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = frequency_strategy.cluster(embeddings, symbol_results)
|
||||
|
||||
# Should have 3 groups: authenticate(3), validate(2), helper_func(1)
|
||||
assert len(clusters) == 3
|
||||
|
||||
# First cluster should be authenticate (highest frequency)
|
||||
first_cluster_symbols = [symbol_results[i].symbol_name for i in clusters[0]]
|
||||
assert all(s == "authenticate" for s in first_cluster_symbols)
|
||||
assert len(clusters[0]) == 3
|
||||
|
||||
def test_cluster_orders_by_frequency(self, frequency_strategy, symbol_results):
|
||||
"""Test clusters are ordered by frequency (descending)."""
|
||||
import numpy as np
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = frequency_strategy.cluster(embeddings, symbol_results)
|
||||
|
||||
# Verify frequency ordering
|
||||
frequencies = [len(c) for c in clusters]
|
||||
assert frequencies == sorted(frequencies, reverse=True)
|
||||
|
||||
def test_select_representatives_adds_frequency_metadata(self, frequency_strategy, symbol_results):
|
||||
"""Test representatives have frequency metadata."""
|
||||
import numpy as np
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = frequency_strategy.cluster(embeddings, symbol_results)
|
||||
reps = frequency_strategy.select_representatives(clusters, symbol_results, embeddings)
|
||||
|
||||
# Check frequency metadata
|
||||
for rep in reps:
|
||||
assert "frequency" in rep.metadata
|
||||
assert rep.metadata["frequency"] >= 1
|
||||
|
||||
def test_min_frequency_filter_mode(self, symbol_results):
|
||||
"""Test min_frequency with filter mode removes low-frequency results."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(min_frequency=2, keep_mode="filter")
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# helper_func (freq=1) should be filtered out
|
||||
rep_symbols = [r.symbol_name for r in reps]
|
||||
assert "helper_func" not in rep_symbols
|
||||
assert "authenticate" in rep_symbols
|
||||
assert "validate" in rep_symbols
|
||||
|
||||
def test_min_frequency_demote_mode(self, symbol_results):
|
||||
"""Test min_frequency with demote mode keeps but deprioritizes low-frequency."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(min_frequency=2, keep_mode="demote")
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# helper_func should still be present but at the end
|
||||
rep_symbols = [r.symbol_name for r in reps]
|
||||
assert "helper_func" in rep_symbols
|
||||
# Should be demoted to end
|
||||
helper_idx = rep_symbols.index("helper_func")
|
||||
assert helper_idx == len(rep_symbols) - 1
|
||||
|
||||
def test_group_by_file(self, symbol_results):
|
||||
"""Test grouping by file path instead of symbol."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(group_by="file")
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = strategy.cluster(embeddings, symbol_results)
|
||||
|
||||
# Each file should be its own group (all unique paths)
|
||||
assert len(clusters) == 6
|
||||
|
||||
def test_max_representatives_per_group(self, symbol_results):
|
||||
"""Test max_representatives_per_group limits output per symbol."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(max_representatives_per_group=1)
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# Should have at most 1 per group = 3 groups = 3 reps
|
||||
assert len(reps) == 3
|
||||
|
||||
def test_frequency_boost_score(self, symbol_results):
|
||||
"""Test frequency_weight boosts high-frequency results."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(frequency_weight=0.5) # Strong boost
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# High-frequency results should have boosted scores in metadata
|
||||
for rep in reps:
|
||||
if rep.metadata.get("frequency", 1) > 1:
|
||||
assert rep.metadata.get("frequency_boosted_score", 0) > rep.score
|
||||
|
||||
def test_empty_results(self, frequency_strategy):
|
||||
"""Test handling of empty results."""
|
||||
import numpy as np
|
||||
|
||||
clusters = frequency_strategy.cluster(np.array([]).reshape(0, 128), [])
|
||||
assert clusters == []
|
||||
|
||||
reps = frequency_strategy.select_representatives([], [], None)
|
||||
assert reps == []
|
||||
|
||||
def test_factory_with_kwargs(self):
|
||||
"""Test factory passes kwargs to FrequencyConfig."""
|
||||
strategy = get_strategy("frequency", min_frequency=3, group_by="file")
|
||||
assert strategy.config.min_frequency == 3
|
||||
assert strategy.config.group_by == "file"
|
||||
698
codex-lens/tests/test_staged_cascade.py
Normal file
698
codex-lens/tests/test_staged_cascade.py
Normal file
@@ -0,0 +1,698 @@
|
||||
"""Integration tests for staged cascade search pipeline.
|
||||
|
||||
Tests the 4-stage pipeline:
|
||||
1. Stage 1: Binary coarse search
|
||||
2. Stage 2: LSP graph expansion
|
||||
3. Stage 3: Clustering and representative selection
|
||||
4. Stage 4: Optional cross-encoder reranking
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_paths():
|
||||
"""Create temporary directory structure."""
|
||||
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
||||
root = Path(tmpdir.name)
|
||||
yield root
|
||||
try:
|
||||
tmpdir.cleanup()
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(temp_paths: Path):
|
||||
"""Create mock registry store."""
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mapper(temp_paths: Path):
|
||||
"""Create path mapper."""
|
||||
return PathMapper(index_root=temp_paths / "indexes")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create mock config with staged cascade settings."""
|
||||
config = MagicMock(spec=Config)
|
||||
config.cascade_coarse_k = 100
|
||||
config.cascade_fine_k = 10
|
||||
config.enable_staged_rerank = False
|
||||
config.staged_clustering_strategy = "auto"
|
||||
config.staged_clustering_min_size = 3
|
||||
config.graph_expansion_depth = 2
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_binary_results() -> List[SearchResult]:
|
||||
"""Create sample binary search results for testing."""
|
||||
return [
|
||||
SearchResult(
|
||||
path="a.py",
|
||||
score=0.95,
|
||||
excerpt="def authenticate_user(username, password):",
|
||||
symbol_name="authenticate_user",
|
||||
symbol_kind="function",
|
||||
start_line=10,
|
||||
end_line=15,
|
||||
),
|
||||
SearchResult(
|
||||
path="b.py",
|
||||
score=0.85,
|
||||
excerpt="class AuthManager:",
|
||||
symbol_name="AuthManager",
|
||||
symbol_kind="class",
|
||||
start_line=5,
|
||||
end_line=20,
|
||||
),
|
||||
SearchResult(
|
||||
path="c.py",
|
||||
score=0.75,
|
||||
excerpt="def check_credentials(user, pwd):",
|
||||
symbol_name="check_credentials",
|
||||
symbol_kind="function",
|
||||
start_line=30,
|
||||
end_line=35,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_expanded_results() -> List[SearchResult]:
|
||||
"""Create sample expanded results (after LSP expansion)."""
|
||||
return [
|
||||
SearchResult(
|
||||
path="a.py",
|
||||
score=0.95,
|
||||
excerpt="def authenticate_user(username, password):",
|
||||
symbol_name="authenticate_user",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="a.py",
|
||||
score=0.90,
|
||||
excerpt="def verify_password(pwd):",
|
||||
symbol_name="verify_password",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="b.py",
|
||||
score=0.85,
|
||||
excerpt="class AuthManager:",
|
||||
symbol_name="AuthManager",
|
||||
symbol_kind="class",
|
||||
),
|
||||
SearchResult(
|
||||
path="b.py",
|
||||
score=0.80,
|
||||
excerpt="def login(self, user):",
|
||||
symbol_name="login",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="c.py",
|
||||
score=0.75,
|
||||
excerpt="def check_credentials(user, pwd):",
|
||||
symbol_name="check_credentials",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="d.py",
|
||||
score=0.70,
|
||||
excerpt="class UserModel:",
|
||||
symbol_name="UserModel",
|
||||
symbol_kind="class",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Stage Methods
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStage1BinarySearch:
|
||||
"""Tests for Stage 1: Binary coarse search."""
|
||||
|
||||
def test_stage1_returns_results_with_index_root(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage1_binary_search returns results and index_root."""
|
||||
from codexlens.search.chain_search import SearchStats
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Mock the binary embedding backend (import is inside the method)
|
||||
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
|
||||
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
|
||||
mock_index = MagicMock()
|
||||
mock_index.count.return_value = 10
|
||||
mock_index.search.return_value = ([1, 2, 3], [10, 20, 30])
|
||||
mock_binary_idx.return_value = mock_index
|
||||
|
||||
index_paths = [Path("/fake/index1/_index.db")]
|
||||
stats = SearchStats()
|
||||
|
||||
results, index_root = engine._stage1_binary_search(
|
||||
"query", index_paths, coarse_k=10, stats=stats
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(index_root, (Path, type(None)))
|
||||
|
||||
def test_stage1_handles_empty_index_paths(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage1_binary_search handles empty index paths."""
|
||||
from codexlens.search.chain_search import SearchStats
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
index_paths = []
|
||||
stats = SearchStats()
|
||||
|
||||
results, index_root = engine._stage1_binary_search(
|
||||
"query", index_paths, coarse_k=10, stats=stats
|
||||
)
|
||||
|
||||
assert results == []
|
||||
assert index_root is None
|
||||
|
||||
def test_stage1_aggregates_results_from_multiple_indexes(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage1_binary_search aggregates results from multiple indexes."""
|
||||
from codexlens.search.chain_search import SearchStats
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
|
||||
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
|
||||
mock_index = MagicMock()
|
||||
mock_index.count.return_value = 10
|
||||
# Return different results for different calls
|
||||
mock_index.search.side_effect = [
|
||||
([1, 2], [10, 20]),
|
||||
([3, 4], [15, 25]),
|
||||
]
|
||||
mock_binary_idx.return_value = mock_index
|
||||
|
||||
index_paths = [
|
||||
Path("/fake/index1/_index.db"),
|
||||
Path("/fake/index2/_index.db"),
|
||||
]
|
||||
stats = SearchStats()
|
||||
|
||||
results, _ = engine._stage1_binary_search(
|
||||
"query", index_paths, coarse_k=10, stats=stats
|
||||
)
|
||||
|
||||
# Should aggregate candidates from both indexes
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
class TestStage2LSPExpand:
|
||||
"""Tests for Stage 2: LSP graph expansion."""
|
||||
|
||||
def test_stage2_returns_expanded_results(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test _stage2_lsp_expand returns expanded results."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Import is inside the method, so we need to patch it there
|
||||
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
|
||||
mock_expander = MagicMock()
|
||||
mock_expander.expand.return_value = [
|
||||
SearchResult(path="related.py", score=0.7, excerpt="related")
|
||||
]
|
||||
mock_expander_cls.return_value = mock_expander
|
||||
|
||||
expanded = engine._stage2_lsp_expand(
|
||||
sample_binary_results, index_root=Path("/fake/index")
|
||||
)
|
||||
|
||||
assert isinstance(expanded, list)
|
||||
# Should include original results
|
||||
assert len(expanded) >= len(sample_binary_results)
|
||||
|
||||
def test_stage2_handles_no_index_root(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test _stage2_lsp_expand handles missing index_root."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
expanded = engine._stage2_lsp_expand(sample_binary_results, index_root=None)
|
||||
|
||||
# Should return original results unchanged
|
||||
assert expanded == sample_binary_results
|
||||
|
||||
def test_stage2_handles_empty_results(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage2_lsp_expand handles empty input."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
expanded = engine._stage2_lsp_expand([], index_root=Path("/fake"))
|
||||
|
||||
assert expanded == []
|
||||
|
||||
def test_stage2_deduplicates_results(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test _stage2_lsp_expand deduplicates by (path, symbol_name, start_line)."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Mock expander to return duplicate of first result
|
||||
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
|
||||
mock_expander = MagicMock()
|
||||
duplicate = SearchResult(
|
||||
path=sample_binary_results[0].path,
|
||||
score=0.5,
|
||||
excerpt="duplicate",
|
||||
symbol_name=sample_binary_results[0].symbol_name,
|
||||
start_line=sample_binary_results[0].start_line,
|
||||
)
|
||||
mock_expander.expand.return_value = [duplicate]
|
||||
mock_expander_cls.return_value = mock_expander
|
||||
|
||||
expanded = engine._stage2_lsp_expand(
|
||||
sample_binary_results, index_root=Path("/fake")
|
||||
)
|
||||
|
||||
# Should not include duplicate
|
||||
assert len(expanded) == len(sample_binary_results)
|
||||
|
||||
|
||||
class TestStage3ClusterPrune:
|
||||
"""Tests for Stage 3: Clustering and representative selection."""
|
||||
|
||||
def test_stage3_returns_representatives(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
|
||||
):
|
||||
"""Test _stage3_cluster_prune returns representative results."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
import numpy as np
|
||||
|
||||
# Mock embeddings
|
||||
mock_embed.return_value = np.random.rand(
|
||||
len(sample_expanded_results), 128
|
||||
).astype(np.float32)
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
assert isinstance(clustered, list)
|
||||
assert len(clustered) <= len(sample_expanded_results)
|
||||
assert all(isinstance(r, SearchResult) for r in clustered)
|
||||
|
||||
def test_stage3_handles_few_results(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage3_cluster_prune skips clustering for few results."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
few_results = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a"),
|
||||
SearchResult(path="b.py", score=0.8, excerpt="b"),
|
||||
]
|
||||
|
||||
clustered = engine._stage3_cluster_prune(few_results, target_count=5)
|
||||
|
||||
# Should return all results unchanged
|
||||
assert clustered == few_results
|
||||
|
||||
def test_stage3_handles_no_embeddings(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
|
||||
):
|
||||
"""Test _stage3_cluster_prune falls back to score-based selection without embeddings."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
mock_embed.return_value = None
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
# Should return top-scored results
|
||||
assert len(clustered) <= 3
|
||||
# Should be sorted by score descending
|
||||
scores = [r.score for r in clustered]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_stage3_uses_config_clustering_strategy(
|
||||
self, mock_registry, mock_mapper, sample_expanded_results
|
||||
):
|
||||
"""Test _stage3_cluster_prune uses config clustering strategy."""
|
||||
config = MagicMock(spec=Config)
|
||||
config.staged_clustering_strategy = "auto"
|
||||
config.staged_clustering_min_size = 2
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, PathMapper(), config=config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
import numpy as np
|
||||
|
||||
mock_embed.return_value = np.random.rand(
|
||||
len(sample_expanded_results), 128
|
||||
).astype(np.float32)
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
# Should use clustering (auto will pick best available)
|
||||
# Result should be a list of SearchResult objects
|
||||
assert isinstance(clustered, list)
|
||||
assert all(isinstance(r, SearchResult) for r in clustered)
|
||||
|
||||
|
||||
class TestStage4OptionalRerank:
|
||||
"""Tests for Stage 4: Optional cross-encoder reranking."""
|
||||
|
||||
def test_stage4_reranks_with_reranker(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage4_optional_rerank uses _cross_encoder_rerank."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a"),
|
||||
SearchResult(path="b.py", score=0.8, excerpt="b"),
|
||||
SearchResult(path="c.py", score=0.7, excerpt="c"),
|
||||
]
|
||||
|
||||
# Mock the _cross_encoder_rerank method that _stage4 calls
|
||||
with patch.object(engine, "_cross_encoder_rerank") as mock_rerank:
|
||||
mock_rerank.return_value = [
|
||||
SearchResult(path="c.py", score=0.95, excerpt="c"),
|
||||
SearchResult(path="a.py", score=0.85, excerpt="a"),
|
||||
]
|
||||
|
||||
reranked = engine._stage4_optional_rerank("query", results, k=2)
|
||||
|
||||
mock_rerank.assert_called_once_with("query", results, 2)
|
||||
assert len(reranked) <= 2
|
||||
# First result should be reranked winner
|
||||
assert reranked[0].path == "c.py"
|
||||
|
||||
def test_stage4_handles_empty_results(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage4_optional_rerank handles empty input."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
reranked = engine._stage4_optional_rerank("query", [], k=2)
|
||||
|
||||
# Should return empty list
|
||||
assert reranked == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStagedCascadeIntegration:
|
||||
"""Integration tests for staged_cascade_search() end-to-end."""
|
||||
|
||||
def test_staged_cascade_returns_chain_result(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search returns ChainSearchResult."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Mock all stages
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src", k=10, coarse_k=100
|
||||
)
|
||||
|
||||
from codexlens.search.chain_search import ChainSearchResult
|
||||
|
||||
assert isinstance(result, ChainSearchResult)
|
||||
assert result.query == "query"
|
||||
assert len(result.results) <= 10
|
||||
|
||||
def test_staged_cascade_includes_stage_stats(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search includes per-stage timing stats."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src"
|
||||
)
|
||||
|
||||
# Check for stage stats in errors field
|
||||
stage_stats = None
|
||||
for err in result.stats.errors:
|
||||
if err.startswith("STAGE_STATS:"):
|
||||
stage_stats = json.loads(err.replace("STAGE_STATS:", ""))
|
||||
break
|
||||
|
||||
assert stage_stats is not None
|
||||
assert "stage_times" in stage_stats
|
||||
assert "stage_counts" in stage_stats
|
||||
assert "stage1_binary_ms" in stage_stats["stage_times"]
|
||||
assert "stage1_candidates" in stage_stats["stage_counts"]
|
||||
|
||||
def test_staged_cascade_with_rerank_enabled(
|
||||
self, mock_registry, mock_mapper, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search with reranking enabled."""
|
||||
config = MagicMock(spec=Config)
|
||||
config.cascade_coarse_k = 100
|
||||
config.cascade_fine_k = 10
|
||||
config.enable_staged_rerank = True
|
||||
config.staged_clustering_strategy = "auto"
|
||||
config.graph_expansion_depth = 2
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage4_optional_rerank") as mock_stage4:
|
||||
mock_stage4.return_value = [
|
||||
SearchResult(path="a.py", score=0.95, excerpt="a")
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src"
|
||||
)
|
||||
|
||||
# Verify stage 4 was called
|
||||
mock_stage4.assert_called_once()
|
||||
|
||||
def test_staged_cascade_fallback_to_hybrid(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search falls back to hybrid when numpy unavailable."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch("codexlens.search.chain_search.NUMPY_AVAILABLE", False):
|
||||
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
|
||||
mock_hybrid.return_value = MagicMock()
|
||||
|
||||
engine.staged_cascade_search("query", temp_paths / "src")
|
||||
|
||||
# Should fall back to hybrid cascade
|
||||
mock_hybrid.assert_called_once()
|
||||
|
||||
def test_staged_cascade_deduplicates_final_results(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search deduplicates results by path."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
# Return duplicates with different scores
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a"),
|
||||
SearchResult(path="a.py", score=0.8, excerpt="a duplicate"),
|
||||
SearchResult(path="b.py", score=0.7, excerpt="b"),
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src", k=10
|
||||
)
|
||||
|
||||
# Should deduplicate a.py (keep higher score)
|
||||
paths = [r.path for r in result.results]
|
||||
assert len(paths) == len(set(paths))
|
||||
# a.py should have score 0.9
|
||||
a_result = next(r for r in result.results if r.path == "a.py")
|
||||
assert a_result.score == 0.9
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Graceful Degradation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStagedCascadeGracefulDegradation:
|
||||
"""Tests for graceful degradation when dependencies unavailable."""
|
||||
|
||||
def test_falls_back_when_clustering_unavailable(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
|
||||
):
|
||||
"""Test clustering stage falls back gracefully when clustering unavailable."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
mock_embed.return_value = None
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
# Should fall back to score-based selection
|
||||
assert len(clustered) <= 3
|
||||
|
||||
def test_falls_back_when_graph_expander_unavailable(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test LSP expansion falls back when GraphExpander unavailable."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Patch the import inside the method
|
||||
with patch("codexlens.search.graph_expander.GraphExpander", side_effect=ImportError):
|
||||
expanded = engine._stage2_lsp_expand(
|
||||
sample_binary_results, index_root=Path("/fake")
|
||||
)
|
||||
|
||||
# Should return original results
|
||||
assert expanded == sample_binary_results
|
||||
|
||||
def test_handles_stage_failures_gracefully(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged pipeline handles stage failures gracefully."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
# Stage 1 returns no results
|
||||
mock_stage1.return_value = ([], None)
|
||||
|
||||
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
|
||||
mock_hybrid.return_value = MagicMock()
|
||||
|
||||
engine.staged_cascade_search("query", temp_paths / "src")
|
||||
|
||||
# Should fall back to hybrid when stage 1 fails
|
||||
mock_hybrid.assert_called_once()
|
||||
Reference in New Issue
Block a user