mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat(cli): 添加 --rule 选项支持模板自动发现
重构 ccw cli 模板系统: - 新增 template-discovery.ts 模块,支持扁平化模板自动发现 - 添加 --rule <template> 选项,自动加载 protocol 和 template - 模板目录从嵌套结构 (prompts/category/file.txt) 迁移到扁平结构 (prompts/category-function.txt) - 更新所有 agent/command 文件,使用 $PROTO $TMPL 环境变量替代 $(cat ...) 模式 - 支持模糊匹配:--rule 02-review-architecture 可匹配 analysis-review-architecture.txt 其他更新: - Dashboard: 添加 Claude Manager 和 Issue Manager 页面 - Codex-lens: 增强 chain_search 和 clustering 模块 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -52,5 +52,10 @@ Requires-Dist: transformers>=4.36; extra == "splade-gpu"
|
||||
Requires-Dist: optimum[onnxruntime-gpu]>=1.16; extra == "splade-gpu"
|
||||
Provides-Extra: encoding
|
||||
Requires-Dist: chardet>=5.0; extra == "encoding"
|
||||
Provides-Extra: clustering
|
||||
Requires-Dist: hdbscan>=0.8.1; extra == "clustering"
|
||||
Requires-Dist: scikit-learn>=1.3.0; extra == "clustering"
|
||||
Provides-Extra: full
|
||||
Requires-Dist: tiktoken>=0.5.0; extra == "full"
|
||||
Provides-Extra: lsp
|
||||
Requires-Dist: pygls>=1.3.0; extra == "lsp"
|
||||
|
||||
@@ -2,6 +2,7 @@ pyproject.toml
|
||||
src/codex_lens.egg-info/PKG-INFO
|
||||
src/codex_lens.egg-info/SOURCES.txt
|
||||
src/codex_lens.egg-info/dependency_links.txt
|
||||
src/codex_lens.egg-info/entry_points.txt
|
||||
src/codex_lens.egg-info/requires.txt
|
||||
src/codex_lens.egg-info/top_level.txt
|
||||
src/codexlens/__init__.py
|
||||
@@ -18,6 +19,14 @@ src/codexlens/cli/output.py
|
||||
src/codexlens/indexing/__init__.py
|
||||
src/codexlens/indexing/embedding.py
|
||||
src/codexlens/indexing/symbol_extractor.py
|
||||
src/codexlens/lsp/__init__.py
|
||||
src/codexlens/lsp/handlers.py
|
||||
src/codexlens/lsp/providers.py
|
||||
src/codexlens/lsp/server.py
|
||||
src/codexlens/mcp/__init__.py
|
||||
src/codexlens/mcp/hooks.py
|
||||
src/codexlens/mcp/provider.py
|
||||
src/codexlens/mcp/schema.py
|
||||
src/codexlens/parsers/__init__.py
|
||||
src/codexlens/parsers/encoding.py
|
||||
src/codexlens/parsers/factory.py
|
||||
@@ -31,6 +40,13 @@ src/codexlens/search/graph_expander.py
|
||||
src/codexlens/search/hybrid_search.py
|
||||
src/codexlens/search/query_parser.py
|
||||
src/codexlens/search/ranking.py
|
||||
src/codexlens/search/clustering/__init__.py
|
||||
src/codexlens/search/clustering/base.py
|
||||
src/codexlens/search/clustering/dbscan_strategy.py
|
||||
src/codexlens/search/clustering/factory.py
|
||||
src/codexlens/search/clustering/frequency_strategy.py
|
||||
src/codexlens/search/clustering/hdbscan_strategy.py
|
||||
src/codexlens/search/clustering/noop_strategy.py
|
||||
src/codexlens/semantic/__init__.py
|
||||
src/codexlens/semantic/ann_index.py
|
||||
src/codexlens/semantic/base.py
|
||||
@@ -84,6 +100,7 @@ tests/test_api_reranker.py
|
||||
tests/test_chain_search.py
|
||||
tests/test_cli_hybrid_search.py
|
||||
tests/test_cli_output.py
|
||||
tests/test_clustering_strategies.py
|
||||
tests/test_code_extractor.py
|
||||
tests/test_config.py
|
||||
tests/test_dual_fts.py
|
||||
@@ -122,6 +139,7 @@ tests/test_search_performance.py
|
||||
tests/test_semantic.py
|
||||
tests/test_semantic_search.py
|
||||
tests/test_sqlite_store.py
|
||||
tests/test_staged_cascade.py
|
||||
tests/test_storage.py
|
||||
tests/test_storage_concurrency.py
|
||||
tests/test_symbol_extractor.py
|
||||
|
||||
2
codex-lens/src/codex_lens.egg-info/entry_points.txt
Normal file
2
codex-lens/src/codex_lens.egg-info/entry_points.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
[console_scripts]
|
||||
codexlens-lsp = codexlens.lsp:main
|
||||
@@ -8,12 +8,19 @@ tree-sitter-typescript>=0.23
|
||||
pathspec>=0.11
|
||||
watchdog>=3.0
|
||||
|
||||
[clustering]
|
||||
hdbscan>=0.8.1
|
||||
scikit-learn>=1.3.0
|
||||
|
||||
[encoding]
|
||||
chardet>=5.0
|
||||
|
||||
[full]
|
||||
tiktoken>=0.5.0
|
||||
|
||||
[lsp]
|
||||
pygls>=1.3.0
|
||||
|
||||
[reranker]
|
||||
optimum>=1.16
|
||||
onnxruntime>=1.15
|
||||
|
||||
88
codex-lens/src/codexlens/api/__init__.py
Normal file
88
codex-lens/src/codexlens/api/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Codexlens Public API Layer.
|
||||
|
||||
This module exports all public API functions and dataclasses for the
|
||||
codexlens LSP-like functionality.
|
||||
|
||||
Dataclasses (from models.py):
|
||||
- CallInfo: Call relationship information
|
||||
- MethodContext: Method context with call relationships
|
||||
- FileContextResult: File context result with method summaries
|
||||
- DefinitionResult: Definition lookup result
|
||||
- ReferenceResult: Reference lookup result
|
||||
- GroupedReferences: References grouped by definition
|
||||
- SymbolInfo: Symbol information for workspace search
|
||||
- HoverInfo: Hover information for a symbol
|
||||
- SemanticResult: Semantic search result
|
||||
|
||||
Utility functions (from utils.py):
|
||||
- resolve_project: Resolve and validate project root path
|
||||
- normalize_relationship_type: Normalize relationship type to canonical form
|
||||
- rank_by_proximity: Rank results by file path proximity
|
||||
|
||||
Example:
|
||||
>>> from codexlens.api import (
|
||||
... DefinitionResult,
|
||||
... resolve_project,
|
||||
... normalize_relationship_type
|
||||
... )
|
||||
>>> project = resolve_project("/path/to/project")
|
||||
>>> rel_type = normalize_relationship_type("calls")
|
||||
>>> print(rel_type)
|
||||
'call'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Dataclasses
|
||||
from .models import (
|
||||
CallInfo,
|
||||
MethodContext,
|
||||
FileContextResult,
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
SymbolInfo,
|
||||
HoverInfo,
|
||||
SemanticResult,
|
||||
)
|
||||
|
||||
# Utility functions
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
rank_by_proximity,
|
||||
rank_by_score,
|
||||
)
|
||||
|
||||
# API functions
|
||||
from .definition import find_definition
|
||||
from .symbols import workspace_symbols
|
||||
from .hover import get_hover
|
||||
from .file_context import file_context
|
||||
from .references import find_references
|
||||
from .semantic import semantic_search
|
||||
|
||||
__all__ = [
|
||||
# Dataclasses
|
||||
"CallInfo",
|
||||
"MethodContext",
|
||||
"FileContextResult",
|
||||
"DefinitionResult",
|
||||
"ReferenceResult",
|
||||
"GroupedReferences",
|
||||
"SymbolInfo",
|
||||
"HoverInfo",
|
||||
"SemanticResult",
|
||||
# Utility functions
|
||||
"resolve_project",
|
||||
"normalize_relationship_type",
|
||||
"rank_by_proximity",
|
||||
"rank_by_score",
|
||||
# API functions
|
||||
"find_definition",
|
||||
"workspace_symbols",
|
||||
"get_hover",
|
||||
"file_context",
|
||||
"find_references",
|
||||
"semantic_search",
|
||||
]
|
||||
126
codex-lens/src/codexlens/api/definition.py
Normal file
126
codex-lens/src/codexlens/api/definition.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""find_definition API implementation.
|
||||
|
||||
This module provides the find_definition() function for looking up
|
||||
symbol definitions with a 3-stage fallback strategy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import DefinitionResult
|
||||
from .utils import resolve_project, rank_by_proximity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_definition(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
file_context: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[DefinitionResult]:
|
||||
"""Find definition locations for a symbol.
|
||||
|
||||
Uses a 3-stage fallback strategy:
|
||||
1. Exact match with kind filter
|
||||
2. Exact match without kind filter
|
||||
3. Prefix match
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to find
|
||||
symbol_kind: Optional symbol kind filter (class, function, etc.)
|
||||
file_context: Optional file path for proximity ranking
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of DefinitionResult sorted by proximity if file_context provided
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Stage 1: Exact match with kind filter
|
||||
results = _search_with_kind(global_index, symbol_name, symbol_kind, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 1 (exact+kind): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 2: Exact match without kind (if kind was specified)
|
||||
if symbol_kind:
|
||||
results = _search_with_kind(global_index, symbol_name, None, limit)
|
||||
if results:
|
||||
logger.debug(f"Stage 2 (exact): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
# Stage 3: Prefix match
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
if results:
|
||||
logger.debug(f"Stage 3 (prefix): Found {len(results)} results for {symbol_name}")
|
||||
return _rank_and_convert(results, file_context)
|
||||
|
||||
logger.debug(f"No definitions found for {symbol_name}")
|
||||
return []
|
||||
|
||||
|
||||
def _search_with_kind(
|
||||
global_index: GlobalSymbolIndex,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str],
|
||||
limit: int
|
||||
) -> List[Symbol]:
|
||||
"""Search for symbols with optional kind filter."""
|
||||
return global_index.search(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind,
|
||||
limit=limit,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
|
||||
def _rank_and_convert(
|
||||
symbols: List[Symbol],
|
||||
file_context: Optional[str]
|
||||
) -> List[DefinitionResult]:
|
||||
"""Convert symbols to DefinitionResult and rank by proximity."""
|
||||
results = [
|
||||
DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Could extract from file if needed
|
||||
container=None, # Could extract from parent symbol
|
||||
score=1.0
|
||||
)
|
||||
for sym in symbols
|
||||
]
|
||||
return rank_by_proximity(results, file_context)
|
||||
271
codex-lens/src/codexlens/api/file_context.py
Normal file
271
codex-lens/src/codexlens/api/file_context.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""file_context API implementation.
|
||||
|
||||
This module provides the file_context() function for retrieving
|
||||
method call graphs from a source file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.dir_index import DirIndexStore
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import (
|
||||
FileContextResult,
|
||||
MethodContext,
|
||||
CallInfo,
|
||||
)
|
||||
from .utils import resolve_project, normalize_relationship_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def file_context(
|
||||
project_root: str,
|
||||
file_path: str,
|
||||
include_calls: bool = True,
|
||||
include_callers: bool = True,
|
||||
max_depth: int = 1,
|
||||
format: str = "brief"
|
||||
) -> FileContextResult:
|
||||
"""Get method call context for a code file.
|
||||
|
||||
Retrieves all methods/functions in the file along with their
|
||||
outgoing calls and incoming callers.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
file_path: Path to the code file to analyze
|
||||
include_calls: Whether to include outgoing calls
|
||||
include_callers: Whether to include incoming callers
|
||||
max_depth: Call chain depth (V1 only supports 1)
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
FileContextResult with method contexts and summary
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
FileNotFoundError: If file does not exist
|
||||
ValueError: If max_depth > 1 (V1 limitation)
|
||||
"""
|
||||
# V1 limitation: only depth=1 supported
|
||||
if max_depth > 1:
|
||||
raise ValueError(
|
||||
f"max_depth > 1 not supported in V1. "
|
||||
f"Requested: {max_depth}, supported: 1"
|
||||
)
|
||||
|
||||
project_path = resolve_project(project_root)
|
||||
file_path_resolved = Path(file_path).resolve()
|
||||
|
||||
# Validate file exists
|
||||
if not file_path_resolved.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path_resolved}")
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Get all symbols in the file
|
||||
symbols = global_index.get_file_symbols(str(file_path_resolved))
|
||||
|
||||
# Filter to functions, methods, and classes
|
||||
method_symbols = [
|
||||
s for s in symbols
|
||||
if s.kind in ("function", "method", "class")
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(method_symbols)} methods in {file_path}")
|
||||
|
||||
# Try to find dir_index for relationship queries
|
||||
dir_index = _find_dir_index(project_info, file_path_resolved)
|
||||
|
||||
# Build method contexts
|
||||
methods: List[MethodContext] = []
|
||||
outgoing_resolved = True
|
||||
incoming_resolved = True
|
||||
targets_resolved = True
|
||||
|
||||
for symbol in method_symbols:
|
||||
calls: List[CallInfo] = []
|
||||
callers: List[CallInfo] = []
|
||||
|
||||
if include_calls and dir_index:
|
||||
try:
|
||||
outgoing = dir_index.get_outgoing_calls(
|
||||
str(file_path_resolved),
|
||||
symbol.name
|
||||
)
|
||||
for target_name, rel_type, line, target_file in outgoing:
|
||||
calls.append(CallInfo(
|
||||
symbol_name=target_name,
|
||||
file_path=target_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
if target_file is None:
|
||||
targets_resolved = False
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get outgoing calls: {e}")
|
||||
outgoing_resolved = False
|
||||
|
||||
if include_callers and dir_index:
|
||||
try:
|
||||
incoming = dir_index.get_incoming_calls(symbol.name)
|
||||
for source_name, rel_type, line, source_file in incoming:
|
||||
callers.append(CallInfo(
|
||||
symbol_name=source_name,
|
||||
file_path=source_file,
|
||||
line=line,
|
||||
relationship=normalize_relationship_type(rel_type)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get incoming calls: {e}")
|
||||
incoming_resolved = False
|
||||
|
||||
methods.append(MethodContext(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
signature=None, # Could extract from source
|
||||
calls=calls,
|
||||
callers=callers
|
||||
))
|
||||
|
||||
# Detect language from file extension
|
||||
language = _detect_language(file_path_resolved)
|
||||
|
||||
# Generate summary
|
||||
summary = _generate_summary(file_path_resolved, methods, format)
|
||||
|
||||
return FileContextResult(
|
||||
file_path=str(file_path_resolved),
|
||||
language=language,
|
||||
methods=methods,
|
||||
summary=summary,
|
||||
discovery_status={
|
||||
"outgoing_resolved": outgoing_resolved,
|
||||
"incoming_resolved": incoming_resolved,
|
||||
"targets_resolved": targets_resolved
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _find_dir_index(project_info, file_path: Path) -> Optional[DirIndexStore]:
|
||||
"""Find the dir_index that contains the file.
|
||||
|
||||
Args:
|
||||
project_info: Project information from registry
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
DirIndexStore if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Look for _index.db in file's directory or parent directories
|
||||
current = file_path.parent
|
||||
while current != current.parent:
|
||||
index_db = current / "_index.db"
|
||||
if index_db.exists():
|
||||
return DirIndexStore(str(index_db))
|
||||
|
||||
# Also check in project's index_root
|
||||
relative = current.relative_to(project_info.source_root)
|
||||
index_in_cache = project_info.index_root / relative / "_index.db"
|
||||
if index_in_cache.exists():
|
||||
return DirIndexStore(str(index_in_cache))
|
||||
|
||||
current = current.parent
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to find dir_index: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _detect_language(file_path: Path) -> str:
|
||||
"""Detect programming language from file extension.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
Language name
|
||||
"""
|
||||
ext_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".jsx": "javascript",
|
||||
".tsx": "typescript",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".java": "java",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
}
|
||||
return ext_map.get(file_path.suffix.lower(), "unknown")
|
||||
|
||||
|
||||
def _generate_summary(
|
||||
file_path: Path,
|
||||
methods: List[MethodContext],
|
||||
format: str
|
||||
) -> str:
|
||||
"""Generate human-readable summary of file context.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
methods: List of method contexts
|
||||
format: Output format (brief | detailed | tree)
|
||||
|
||||
Returns:
|
||||
Markdown-formatted summary
|
||||
"""
|
||||
lines = [f"## {file_path.name} ({len(methods)} methods)\n"]
|
||||
|
||||
for method in methods:
|
||||
start, end = method.line_range
|
||||
lines.append(f"### {method.name} (line {start}-{end})")
|
||||
|
||||
if method.calls:
|
||||
calls_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path or 'unresolved'}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.calls
|
||||
)
|
||||
lines.append(f"- Calls: {calls_str}")
|
||||
|
||||
if method.callers:
|
||||
callers_str = ", ".join(
|
||||
f"{c.symbol_name} ({c.file_path}:{c.line})"
|
||||
if format == "detailed"
|
||||
else c.symbol_name
|
||||
for c in method.callers
|
||||
)
|
||||
lines.append(f"- Called by: {callers_str}")
|
||||
|
||||
if not method.calls and not method.callers:
|
||||
lines.append("- (no call relationships)")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
148
codex-lens/src/codexlens/api/hover.py
Normal file
148
codex-lens/src/codexlens/api/hover.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""get_hover API implementation.
|
||||
|
||||
This module provides the get_hover() function for retrieving
|
||||
detailed hover information for symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import HoverInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hover(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
file_path: Optional[str] = None
|
||||
) -> Optional[HoverInfo]:
|
||||
"""Get detailed hover information for a symbol.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
symbol_name: Name of the symbol to look up
|
||||
file_path: Optional file path to disambiguate when symbol
|
||||
appears in multiple files
|
||||
|
||||
Returns:
|
||||
HoverInfo if symbol found, None otherwise
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search for the symbol
|
||||
results = global_index.search(
|
||||
name=symbol_name,
|
||||
kind=None,
|
||||
limit=50,
|
||||
prefix_mode=False
|
||||
)
|
||||
|
||||
if not results:
|
||||
logger.debug(f"No hover info found for {symbol_name}")
|
||||
return None
|
||||
|
||||
# If file_path provided, filter to that file
|
||||
if file_path:
|
||||
file_path_resolved = str(Path(file_path).resolve())
|
||||
matching = [s for s in results if s.file == file_path_resolved]
|
||||
if matching:
|
||||
results = matching
|
||||
|
||||
# Take the first result
|
||||
symbol = results[0]
|
||||
|
||||
# Build hover info
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=_extract_signature(symbol),
|
||||
documentation=_extract_documentation(symbol),
|
||||
file_path=symbol.file or "",
|
||||
line_range=symbol.range if symbol.range else (1, 1),
|
||||
type_info=_extract_type_info(symbol)
|
||||
)
|
||||
|
||||
|
||||
def _extract_signature(symbol: Symbol) -> str:
|
||||
"""Extract signature from symbol.
|
||||
|
||||
For now, generates a basic signature based on kind and name.
|
||||
In a full implementation, this would parse the actual source code.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract signature from
|
||||
|
||||
Returns:
|
||||
Signature string
|
||||
"""
|
||||
if symbol.kind == "function":
|
||||
return f"def {symbol.name}(...)"
|
||||
elif symbol.kind == "method":
|
||||
return f"def {symbol.name}(self, ...)"
|
||||
elif symbol.kind == "class":
|
||||
return f"class {symbol.name}"
|
||||
elif symbol.kind == "variable":
|
||||
return symbol.name
|
||||
elif symbol.kind == "constant":
|
||||
return f"{symbol.name} = ..."
|
||||
else:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
|
||||
def _extract_documentation(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract documentation from symbol.
|
||||
|
||||
In a full implementation, this would parse docstrings from source.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract documentation from
|
||||
|
||||
Returns:
|
||||
Documentation string if available, None otherwise
|
||||
"""
|
||||
# Would need to read source file and parse docstring
|
||||
# For V1, return None
|
||||
return None
|
||||
|
||||
|
||||
def _extract_type_info(symbol: Symbol) -> Optional[str]:
|
||||
"""Extract type information from symbol.
|
||||
|
||||
In a full implementation, this would parse type annotations.
|
||||
For now, returns None.
|
||||
|
||||
Args:
|
||||
symbol: The symbol to extract type info from
|
||||
|
||||
Returns:
|
||||
Type info string if available, None otherwise
|
||||
"""
|
||||
# Would need to parse type annotations from source
|
||||
# For V1, return None
|
||||
return None
|
||||
281
codex-lens/src/codexlens/api/models.py
Normal file
281
codex-lens/src/codexlens/api/models.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""API dataclass definitions for codexlens LSP API.
|
||||
|
||||
This module defines all result dataclasses used by the public API layer,
|
||||
following the patterns established in mcp/schema.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.2: file_context dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class CallInfo:
|
||||
"""Call relationship information.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the called/calling symbol
|
||||
file_path: Target file path (may be None if unresolved)
|
||||
line: Line number of the call
|
||||
relationship: Type of relationship (call | import | inheritance)
|
||||
"""
|
||||
symbol_name: str
|
||||
file_path: Optional[str]
|
||||
line: int
|
||||
relationship: str # call | import | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodContext:
|
||||
"""Method context with call relationships.
|
||||
|
||||
Attributes:
|
||||
name: Method/function name
|
||||
kind: Symbol kind (function | method | class)
|
||||
line_range: Start and end line numbers
|
||||
signature: Function signature (if available)
|
||||
calls: List of outgoing calls
|
||||
callers: List of incoming calls
|
||||
"""
|
||||
name: str
|
||||
kind: str # function | method | class
|
||||
line_range: Tuple[int, int]
|
||||
signature: Optional[str]
|
||||
calls: List[CallInfo] = field(default_factory=list)
|
||||
callers: List[CallInfo] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"line_range": list(self.line_range),
|
||||
"calls": [c.to_dict() for c in self.calls],
|
||||
"callers": [c.to_dict() for c in self.callers],
|
||||
}
|
||||
if self.signature is not None:
|
||||
result["signature"] = self.signature
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileContextResult:
|
||||
"""File context result with method summaries.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the analyzed file
|
||||
language: Programming language
|
||||
methods: List of method contexts
|
||||
summary: Human-readable summary
|
||||
discovery_status: Status flags for call resolution
|
||||
"""
|
||||
file_path: str
|
||||
language: str
|
||||
methods: List[MethodContext]
|
||||
summary: str
|
||||
discovery_status: Dict[str, bool] = field(default_factory=lambda: {
|
||||
"outgoing_resolved": False,
|
||||
"incoming_resolved": True,
|
||||
"targets_resolved": False
|
||||
})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"language": self.language,
|
||||
"methods": [m.to_dict() for m in self.methods],
|
||||
"summary": self.summary,
|
||||
"discovery_status": self.discovery_status,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.3: find_definition dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class DefinitionResult:
|
||||
"""Definition lookup result.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind (class, function, method, etc.)
|
||||
file_path: File where symbol is defined
|
||||
line: Start line number
|
||||
end_line: End line number
|
||||
signature: Symbol signature (if available)
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
end_line: int
|
||||
signature: Optional[str] = None
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.4: find_references dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class ReferenceResult:
|
||||
"""Reference lookup result.
|
||||
|
||||
Attributes:
|
||||
file_path: File containing the reference
|
||||
line: Line number
|
||||
column: Column number
|
||||
context_line: The line of code containing the reference
|
||||
relationship: Type of reference (call | import | type_annotation | inheritance)
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context_line: str
|
||||
relationship: str # call | import | type_annotation | inheritance
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupedReferences:
|
||||
"""References grouped by definition.
|
||||
|
||||
Used when a symbol has multiple definitions (e.g., overloads).
|
||||
|
||||
Attributes:
|
||||
definition: The definition this group refers to
|
||||
references: List of references to this definition
|
||||
"""
|
||||
definition: DefinitionResult
|
||||
references: List[ReferenceResult] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"definition": self.definition.to_dict(),
|
||||
"references": [r.to_dict() for r in self.references],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.5: workspace_symbols dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Symbol information for workspace search.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
container: Containing class/module (if any)
|
||||
score: Match score for ranking
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
container: Optional[str] = None
|
||||
score: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.6: get_hover dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name
|
||||
kind: Symbol kind
|
||||
signature: Symbol signature
|
||||
documentation: Documentation string (if available)
|
||||
file_path: File where symbol is defined
|
||||
line_range: Start and end line numbers
|
||||
type_info: Type information (if available)
|
||||
"""
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: Tuple[int, int]
|
||||
type_info: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"signature": self.signature,
|
||||
"file_path": self.file_path,
|
||||
"line_range": list(self.line_range),
|
||||
}
|
||||
if self.documentation is not None:
|
||||
result["documentation"] = self.documentation
|
||||
if self.type_info is not None:
|
||||
result["type_info"] = self.type_info
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Section 4.7: semantic_search dataclasses
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class SemanticResult:
|
||||
"""Semantic search result.
|
||||
|
||||
Attributes:
|
||||
symbol_name: Name of the matched symbol
|
||||
kind: Symbol kind
|
||||
file_path: File where symbol is defined
|
||||
line: Line number
|
||||
vector_score: Vector similarity score (None if not available)
|
||||
structural_score: Structural match score (None if not available)
|
||||
fusion_score: Combined fusion score
|
||||
snippet: Code snippet
|
||||
match_reason: Explanation of why this matched (optional)
|
||||
"""
|
||||
symbol_name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line: int
|
||||
vector_score: Optional[float]
|
||||
structural_score: Optional[float]
|
||||
fusion_score: float
|
||||
snippet: str
|
||||
match_reason: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary, filtering None values."""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
345
codex-lens/src/codexlens/api/references.py
Normal file
345
codex-lens/src/codexlens/api/references.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Find references API for codexlens.
|
||||
|
||||
This module implements the find_references() function that wraps
|
||||
ChainSearchEngine.search_references() with grouped result structure
|
||||
for multi-definition symbols.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from .models import (
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
)
|
||||
from .utils import (
|
||||
resolve_project,
|
||||
normalize_relationship_type,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _read_line_from_file(file_path: str, line: int) -> str:
|
||||
"""Read a specific line from a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
line: Line number (1-based)
|
||||
|
||||
Returns:
|
||||
The line content, stripped of trailing whitespace.
|
||||
Returns empty string if file cannot be read or line doesn't exist.
|
||||
"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return ""
|
||||
|
||||
with path.open("r", encoding="utf-8", errors="replace") as f:
|
||||
for i, content in enumerate(f, 1):
|
||||
if i == line:
|
||||
return content.rstrip()
|
||||
return ""
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to read line %d from %s: %s", line, file_path, exc)
|
||||
return ""
|
||||
|
||||
|
||||
def _transform_to_reference_result(
|
||||
raw_ref: "RawReferenceResult",
|
||||
) -> ReferenceResult:
|
||||
"""Transform raw ChainSearchEngine reference to API ReferenceResult.
|
||||
|
||||
Args:
|
||||
raw_ref: Raw reference result from ChainSearchEngine
|
||||
|
||||
Returns:
|
||||
API ReferenceResult with context_line and normalized relationship
|
||||
"""
|
||||
# Read the actual line from the file
|
||||
context_line = _read_line_from_file(raw_ref.file_path, raw_ref.line)
|
||||
|
||||
# Normalize relationship type
|
||||
relationship = normalize_relationship_type(raw_ref.relationship_type)
|
||||
|
||||
return ReferenceResult(
|
||||
file_path=raw_ref.file_path,
|
||||
line=raw_ref.line,
|
||||
column=raw_ref.column,
|
||||
context_line=context_line,
|
||||
relationship=relationship,
|
||||
)
|
||||
|
||||
|
||||
def find_references(
|
||||
project_root: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str] = None,
|
||||
include_definition: bool = True,
|
||||
group_by_definition: bool = True,
|
||||
limit: int = 100,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Find all reference locations for a symbol.
|
||||
|
||||
Multi-definition case returns grouped results to resolve ambiguity.
|
||||
|
||||
This function wraps ChainSearchEngine.search_references() and groups
|
||||
the results by definition location. Each GroupedReferences contains
|
||||
a definition and all references that point to it.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory path
|
||||
symbol_name: Name of the symbol to find references for
|
||||
symbol_kind: Optional symbol kind filter (e.g., 'function', 'class')
|
||||
include_definition: Whether to include the definition location
|
||||
in the result (default True)
|
||||
group_by_definition: Whether to group references by definition.
|
||||
If False, returns a single group with all references.
|
||||
(default True)
|
||||
limit: Maximum number of references to return (default 100)
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences. Each group contains:
|
||||
- definition: The DefinitionResult for this symbol definition
|
||||
- references: List of ReferenceResult pointing to this definition
|
||||
|
||||
Raises:
|
||||
ValueError: If project_root does not exist or is not a directory
|
||||
|
||||
Examples:
|
||||
>>> refs = find_references("/path/to/project", "authenticate")
|
||||
>>> for group in refs:
|
||||
... print(f"Definition: {group.definition.file_path}:{group.definition.line}")
|
||||
... for ref in group.references:
|
||||
... print(f" Reference: {ref.file_path}:{ref.line} ({ref.relationship})")
|
||||
|
||||
Note:
|
||||
Reference relationship types are normalized:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
"""
|
||||
# Validate and resolve project root
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.search.chain_search import ReferenceResult as RawReferenceResult
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
# Initialize infrastructure
|
||||
config = Config()
|
||||
registry = RegistryStore(config.registry_db_path)
|
||||
mapper = PathMapper(config.index_root)
|
||||
|
||||
# Create chain search engine
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
|
||||
try:
|
||||
# Step 1: Find definitions for the symbol
|
||||
definitions: List[DefinitionResult] = []
|
||||
|
||||
if include_definition or group_by_definition:
|
||||
# Search for symbol definitions
|
||||
symbols = engine.search_symbols(
|
||||
name=symbol_name,
|
||||
source_path=project_path,
|
||||
kind=symbol_kind,
|
||||
)
|
||||
|
||||
# Convert Symbol to DefinitionResult
|
||||
for sym in symbols:
|
||||
# Only include exact name matches for definitions
|
||||
if sym.name != symbol_name:
|
||||
continue
|
||||
|
||||
# Optionally filter by kind
|
||||
if symbol_kind and sym.kind != symbol_kind:
|
||||
continue
|
||||
|
||||
definitions.append(DefinitionResult(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
end_line=sym.range[1] if sym.range else 1,
|
||||
signature=None, # Not available from Symbol
|
||||
container=None, # Not available from Symbol
|
||||
score=1.0,
|
||||
))
|
||||
|
||||
# Step 2: Get all references using ChainSearchEngine
|
||||
raw_references = engine.search_references(
|
||||
symbol_name=symbol_name,
|
||||
source_path=project_path,
|
||||
depth=-1,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Step 3: Transform raw references to API ReferenceResult
|
||||
api_references: List[ReferenceResult] = []
|
||||
for raw_ref in raw_references:
|
||||
api_ref = _transform_to_reference_result(raw_ref)
|
||||
api_references.append(api_ref)
|
||||
|
||||
# Step 4: Group references by definition
|
||||
if group_by_definition and definitions:
|
||||
return _group_references_by_definition(
|
||||
definitions=definitions,
|
||||
references=api_references,
|
||||
include_definition=include_definition,
|
||||
)
|
||||
else:
|
||||
# Return single group with placeholder definition or first definition
|
||||
if definitions:
|
||||
definition = definitions[0]
|
||||
else:
|
||||
# Create placeholder definition when no definition found
|
||||
definition = DefinitionResult(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path="",
|
||||
line=0,
|
||||
end_line=0,
|
||||
signature=None,
|
||||
container=None,
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
return [GroupedReferences(
|
||||
definition=definition,
|
||||
references=api_references,
|
||||
)]
|
||||
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
|
||||
def _group_references_by_definition(
|
||||
definitions: List[DefinitionResult],
|
||||
references: List[ReferenceResult],
|
||||
include_definition: bool = True,
|
||||
) -> List[GroupedReferences]:
|
||||
"""Group references by their likely definition.
|
||||
|
||||
Uses file proximity heuristic to assign references to definitions.
|
||||
References in the same file or directory as a definition are
|
||||
assigned to that definition.
|
||||
|
||||
Args:
|
||||
definitions: List of definition locations
|
||||
references: List of reference locations
|
||||
include_definition: Whether to include definition in results
|
||||
|
||||
Returns:
|
||||
List of GroupedReferences with references assigned to definitions
|
||||
"""
|
||||
import os
|
||||
|
||||
if not definitions:
|
||||
return []
|
||||
|
||||
if len(definitions) == 1:
|
||||
# Single definition - all references belong to it
|
||||
return [GroupedReferences(
|
||||
definition=definitions[0],
|
||||
references=references,
|
||||
)]
|
||||
|
||||
# Multiple definitions - group by proximity
|
||||
groups: Dict[int, List[ReferenceResult]] = {
|
||||
i: [] for i in range(len(definitions))
|
||||
}
|
||||
|
||||
for ref in references:
|
||||
# Find the closest definition by file proximity
|
||||
best_def_idx = 0
|
||||
best_score = -1
|
||||
|
||||
for i, defn in enumerate(definitions):
|
||||
score = _proximity_score(ref.file_path, defn.file_path)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_def_idx = i
|
||||
|
||||
groups[best_def_idx].append(ref)
|
||||
|
||||
# Build result groups
|
||||
result: List[GroupedReferences] = []
|
||||
for i, defn in enumerate(definitions):
|
||||
# Skip definitions with no references if not including definition itself
|
||||
if not include_definition and not groups[i]:
|
||||
continue
|
||||
|
||||
result.append(GroupedReferences(
|
||||
definition=defn,
|
||||
references=groups[i],
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _proximity_score(ref_path: str, def_path: str) -> int:
|
||||
"""Calculate proximity score between two file paths.
|
||||
|
||||
Args:
|
||||
ref_path: Reference file path
|
||||
def_path: Definition file path
|
||||
|
||||
Returns:
|
||||
Proximity score (higher = closer):
|
||||
- Same file: 1000
|
||||
- Same directory: 100
|
||||
- Otherwise: common path prefix length
|
||||
"""
|
||||
import os
|
||||
|
||||
if not ref_path or not def_path:
|
||||
return 0
|
||||
|
||||
# Normalize paths
|
||||
ref_path = os.path.normpath(ref_path)
|
||||
def_path = os.path.normpath(def_path)
|
||||
|
||||
# Same file
|
||||
if ref_path == def_path:
|
||||
return 1000
|
||||
|
||||
ref_dir = os.path.dirname(ref_path)
|
||||
def_dir = os.path.dirname(def_path)
|
||||
|
||||
# Same directory
|
||||
if ref_dir == def_dir:
|
||||
return 100
|
||||
|
||||
# Common path prefix
|
||||
try:
|
||||
common = os.path.commonpath([ref_path, def_path])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
|
||||
# Type alias for the raw reference from ChainSearchEngine
|
||||
class RawReferenceResult:
|
||||
"""Type stub for ChainSearchEngine.ReferenceResult.
|
||||
|
||||
This is only used for type hints and is replaced at runtime
|
||||
by the actual import.
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
471
codex-lens/src/codexlens/api/semantic.py
Normal file
471
codex-lens/src/codexlens/api/semantic.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Semantic search API with RRF fusion.
|
||||
|
||||
This module provides the semantic_search() function for combining
|
||||
vector, structural, and keyword search with configurable fusion strategies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from .models import SemanticResult
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def semantic_search(
|
||||
project_root: str,
|
||||
query: str,
|
||||
mode: str = "fusion",
|
||||
vector_weight: float = 0.5,
|
||||
structural_weight: float = 0.3,
|
||||
keyword_weight: float = 0.2,
|
||||
fusion_strategy: str = "rrf",
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
include_match_reason: bool = False,
|
||||
) -> List[SemanticResult]:
|
||||
"""Semantic search - combining vector and structural search.
|
||||
|
||||
This function provides a high-level API for semantic code search,
|
||||
combining vector similarity, structural (symbol + relationships),
|
||||
and keyword-based search methods with configurable fusion.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory
|
||||
query: Natural language query
|
||||
mode: Search mode
|
||||
- vector: Vector search only
|
||||
- structural: Structural search only (symbol + relationships)
|
||||
- fusion: Fusion search (default)
|
||||
vector_weight: Vector search weight [0, 1] (default 0.5)
|
||||
structural_weight: Structural search weight [0, 1] (default 0.3)
|
||||
keyword_weight: Keyword search weight [0, 1] (default 0.2)
|
||||
fusion_strategy: Fusion strategy (maps to chain_search.py)
|
||||
- rrf: Reciprocal Rank Fusion (recommended, default)
|
||||
- staged: Staged cascade -> staged_cascade_search
|
||||
- binary: Binary rerank cascade -> binary_cascade_search
|
||||
- hybrid: Hybrid cascade -> hybrid_cascade_search
|
||||
kind_filter: Symbol type filter (e.g., ["function", "class"])
|
||||
limit: Max return count (default 20)
|
||||
include_match_reason: Generate match reason (heuristic, not LLM)
|
||||
|
||||
Returns:
|
||||
Results sorted by fusion_score
|
||||
|
||||
Degradation:
|
||||
- No vector index: vector_score=None, uses FTS + structural search
|
||||
- No relationship data: structural_score=None, vector search only
|
||||
|
||||
Examples:
|
||||
>>> results = semantic_search(
|
||||
... "/path/to/project",
|
||||
... "authentication handler",
|
||||
... mode="fusion",
|
||||
... fusion_strategy="rrf"
|
||||
... )
|
||||
>>> for r in results:
|
||||
... print(f"{r.symbol_name}: {r.fusion_score:.3f}")
|
||||
"""
|
||||
# Validate and resolve project path
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Normalize weights to sum to 1.0
|
||||
total_weight = vector_weight + structural_weight + keyword_weight
|
||||
if total_weight > 0:
|
||||
vector_weight = vector_weight / total_weight
|
||||
structural_weight = structural_weight / total_weight
|
||||
keyword_weight = keyword_weight / total_weight
|
||||
else:
|
||||
# Default to equal weights if all zero
|
||||
vector_weight = structural_weight = keyword_weight = 1.0 / 3.0
|
||||
|
||||
# Initialize search infrastructure
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
except ImportError as exc:
|
||||
logger.error("Failed to import search dependencies: %s", exc)
|
||||
return []
|
||||
|
||||
# Load config
|
||||
config = Config.load()
|
||||
|
||||
# Get or create registry and mapper
|
||||
try:
|
||||
registry = RegistryStore.default()
|
||||
mapper = PathMapper(registry)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize search infrastructure: %s", exc)
|
||||
return []
|
||||
|
||||
# Build search options based on mode
|
||||
search_options = _build_search_options(
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Execute search based on fusion_strategy
|
||||
try:
|
||||
with ChainSearchEngine(registry, mapper, config=config) as engine:
|
||||
chain_result = _execute_search(
|
||||
engine=engine,
|
||||
query=query,
|
||||
source_path=project_path,
|
||||
fusion_strategy=fusion_strategy,
|
||||
options=search_options,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Search execution failed: %s", exc)
|
||||
return []
|
||||
|
||||
# Transform results to SemanticResult
|
||||
semantic_results = _transform_results(
|
||||
results=chain_result.results,
|
||||
mode=mode,
|
||||
vector_weight=vector_weight,
|
||||
structural_weight=structural_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
kind_filter=kind_filter,
|
||||
include_match_reason=include_match_reason,
|
||||
query=query,
|
||||
)
|
||||
|
||||
return semantic_results[:limit]
|
||||
|
||||
|
||||
def _build_search_options(
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
limit: int,
|
||||
) -> "SearchOptions":
|
||||
"""Build SearchOptions based on mode and weights.
|
||||
|
||||
Args:
|
||||
mode: Search mode (vector, structural, fusion)
|
||||
vector_weight: Vector search weight
|
||||
structural_weight: Structural search weight
|
||||
keyword_weight: Keyword search weight
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
Configured SearchOptions
|
||||
"""
|
||||
from codexlens.search.chain_search import SearchOptions
|
||||
|
||||
# Default options
|
||||
options = SearchOptions(
|
||||
total_limit=limit * 2, # Fetch extra for filtering
|
||||
limit_per_dir=limit,
|
||||
include_symbols=True, # Always include symbols for structural
|
||||
)
|
||||
|
||||
if mode == "vector":
|
||||
# Pure vector mode
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = True
|
||||
options.pure_vector = True
|
||||
options.enable_fuzzy = False
|
||||
elif mode == "structural":
|
||||
# Structural only - use FTS + symbols
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = False
|
||||
options.enable_fuzzy = True
|
||||
options.include_symbols = True
|
||||
else:
|
||||
# Fusion mode (default)
|
||||
options.hybrid_mode = True
|
||||
options.enable_vector = vector_weight > 0
|
||||
options.enable_fuzzy = keyword_weight > 0
|
||||
options.include_symbols = structural_weight > 0
|
||||
|
||||
# Set custom weights for RRF
|
||||
if options.enable_vector and keyword_weight > 0:
|
||||
options.hybrid_weights = {
|
||||
"vector": vector_weight,
|
||||
"exact": keyword_weight * 0.7,
|
||||
"fuzzy": keyword_weight * 0.3,
|
||||
}
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def _execute_search(
|
||||
engine: "ChainSearchEngine",
|
||||
query: str,
|
||||
source_path: Path,
|
||||
fusion_strategy: str,
|
||||
options: "SearchOptions",
|
||||
limit: int,
|
||||
) -> "ChainSearchResult":
|
||||
"""Execute search using appropriate strategy.
|
||||
|
||||
Maps fusion_strategy to ChainSearchEngine methods:
|
||||
- rrf: Standard hybrid search with RRF fusion
|
||||
- staged: staged_cascade_search
|
||||
- binary: binary_cascade_search
|
||||
- hybrid: hybrid_cascade_search
|
||||
|
||||
Args:
|
||||
engine: ChainSearchEngine instance
|
||||
query: Search query
|
||||
source_path: Project root path
|
||||
fusion_strategy: Strategy name
|
||||
options: Search options
|
||||
limit: Result limit
|
||||
|
||||
Returns:
|
||||
ChainSearchResult from the search
|
||||
"""
|
||||
from codexlens.search.chain_search import ChainSearchResult
|
||||
|
||||
if fusion_strategy == "staged":
|
||||
# Use staged cascade search (4-stage pipeline)
|
||||
return engine.staged_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "binary":
|
||||
# Use binary cascade search (binary coarse + dense fine)
|
||||
return engine.binary_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "hybrid":
|
||||
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
|
||||
return engine.hybrid_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
coarse_k=limit * 5,
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
# Default: rrf - Standard search with RRF fusion
|
||||
return engine.search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
def _transform_results(
|
||||
results: List,
|
||||
mode: str,
|
||||
vector_weight: float,
|
||||
structural_weight: float,
|
||||
keyword_weight: float,
|
||||
kind_filter: Optional[List[str]],
|
||||
include_match_reason: bool,
|
||||
query: str,
|
||||
) -> List[SemanticResult]:
|
||||
"""Transform ChainSearchEngine results to SemanticResult.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
mode: Search mode
|
||||
vector_weight: Vector weight used
|
||||
structural_weight: Structural weight used
|
||||
keyword_weight: Keyword weight used
|
||||
kind_filter: Optional symbol kind filter
|
||||
include_match_reason: Whether to generate match reasons
|
||||
query: Original query (for match reason generation)
|
||||
|
||||
Returns:
|
||||
List of SemanticResult objects
|
||||
"""
|
||||
semantic_results = []
|
||||
|
||||
for result in results:
|
||||
# Extract symbol info
|
||||
symbol_name = getattr(result, "symbol_name", None)
|
||||
symbol_kind = getattr(result, "symbol_kind", None)
|
||||
start_line = getattr(result, "start_line", None)
|
||||
|
||||
# Use symbol object if available
|
||||
if hasattr(result, "symbol") and result.symbol:
|
||||
symbol_name = symbol_name or result.symbol.name
|
||||
symbol_kind = symbol_kind or result.symbol.kind
|
||||
if hasattr(result.symbol, "range") and result.symbol.range:
|
||||
start_line = start_line or result.symbol.range[0]
|
||||
|
||||
# Filter by kind if specified
|
||||
if kind_filter and symbol_kind:
|
||||
if symbol_kind.lower() not in [k.lower() for k in kind_filter]:
|
||||
continue
|
||||
|
||||
# Determine scores based on mode and metadata
|
||||
metadata = getattr(result, "metadata", {}) or {}
|
||||
fusion_score = result.score
|
||||
|
||||
# Try to extract source scores from metadata
|
||||
source_scores = metadata.get("source_scores", {})
|
||||
vector_score: Optional[float] = None
|
||||
structural_score: Optional[float] = None
|
||||
|
||||
if mode == "vector":
|
||||
# In pure vector mode, the main score is the vector score
|
||||
vector_score = result.score
|
||||
structural_score = None
|
||||
elif mode == "structural":
|
||||
# In structural mode, no vector score
|
||||
vector_score = None
|
||||
structural_score = result.score
|
||||
else:
|
||||
# Fusion mode - try to extract individual scores
|
||||
if "vector" in source_scores:
|
||||
vector_score = source_scores["vector"]
|
||||
elif metadata.get("fusion_method") == "simple_weighted":
|
||||
# From weighted fusion
|
||||
vector_score = source_scores.get("vector")
|
||||
|
||||
# Structural score approximation (from exact/fuzzy FTS)
|
||||
fts_scores = []
|
||||
if "exact" in source_scores:
|
||||
fts_scores.append(source_scores["exact"])
|
||||
if "fuzzy" in source_scores:
|
||||
fts_scores.append(source_scores["fuzzy"])
|
||||
if "splade" in source_scores:
|
||||
fts_scores.append(source_scores["splade"])
|
||||
|
||||
if fts_scores:
|
||||
structural_score = max(fts_scores)
|
||||
|
||||
# Build snippet
|
||||
snippet = getattr(result, "excerpt", "") or getattr(result, "content", "")
|
||||
if len(snippet) > 500:
|
||||
snippet = snippet[:500] + "..."
|
||||
|
||||
# Generate match reason if requested
|
||||
match_reason = None
|
||||
if include_match_reason:
|
||||
match_reason = _generate_match_reason(
|
||||
query=query,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
snippet=snippet,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
)
|
||||
|
||||
semantic_result = SemanticResult(
|
||||
symbol_name=symbol_name or Path(result.path).stem,
|
||||
kind=symbol_kind or "unknown",
|
||||
file_path=result.path,
|
||||
line=start_line or 1,
|
||||
vector_score=vector_score,
|
||||
structural_score=structural_score,
|
||||
fusion_score=fusion_score,
|
||||
snippet=snippet,
|
||||
match_reason=match_reason,
|
||||
)
|
||||
|
||||
semantic_results.append(semantic_result)
|
||||
|
||||
# Sort by fusion_score descending
|
||||
semantic_results.sort(key=lambda r: r.fusion_score, reverse=True)
|
||||
|
||||
return semantic_results
|
||||
|
||||
|
||||
def _generate_match_reason(
|
||||
query: str,
|
||||
symbol_name: Optional[str],
|
||||
symbol_kind: Optional[str],
|
||||
snippet: str,
|
||||
vector_score: Optional[float],
|
||||
structural_score: Optional[float],
|
||||
) -> str:
|
||||
"""Generate human-readable match reason heuristically.
|
||||
|
||||
This is a simple heuristic-based approach, not LLM-powered.
|
||||
|
||||
Args:
|
||||
query: Original search query
|
||||
symbol_name: Symbol name if available
|
||||
symbol_kind: Symbol kind if available
|
||||
snippet: Code snippet
|
||||
vector_score: Vector similarity score
|
||||
structural_score: Structural match score
|
||||
|
||||
Returns:
|
||||
Human-readable explanation string
|
||||
"""
|
||||
reasons = []
|
||||
|
||||
# Check for direct name match
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_lower.split())
|
||||
|
||||
if symbol_name:
|
||||
name_lower = symbol_name.lower()
|
||||
# Direct substring match
|
||||
if query_lower in name_lower or name_lower in query_lower:
|
||||
reasons.append(f"Symbol name '{symbol_name}' matches query")
|
||||
# Word overlap
|
||||
name_words = set(_split_camel_case(symbol_name).lower().split())
|
||||
overlap = query_words & name_words
|
||||
if overlap and not reasons:
|
||||
reasons.append(f"Symbol name contains: {', '.join(overlap)}")
|
||||
|
||||
# Check snippet for keyword matches
|
||||
snippet_lower = snippet.lower()
|
||||
matching_words = [w for w in query_words if w in snippet_lower and len(w) > 2]
|
||||
if matching_words and len(reasons) < 2:
|
||||
reasons.append(f"Code contains keywords: {', '.join(matching_words[:3])}")
|
||||
|
||||
# Add score-based reasoning
|
||||
if vector_score is not None and vector_score > 0.7:
|
||||
reasons.append("High semantic similarity")
|
||||
elif vector_score is not None and vector_score > 0.5:
|
||||
reasons.append("Moderate semantic similarity")
|
||||
|
||||
if structural_score is not None and structural_score > 0.8:
|
||||
reasons.append("Strong structural match")
|
||||
|
||||
# Symbol kind context
|
||||
if symbol_kind and len(reasons) < 3:
|
||||
reasons.append(f"Matched {symbol_kind}")
|
||||
|
||||
if not reasons:
|
||||
reasons.append("Partial relevance based on content analysis")
|
||||
|
||||
return "; ".join(reasons[:3])
|
||||
|
||||
|
||||
def _split_camel_case(name: str) -> str:
|
||||
"""Split camelCase and PascalCase to words.
|
||||
|
||||
Args:
|
||||
name: Symbol name in camelCase or PascalCase
|
||||
|
||||
Returns:
|
||||
Space-separated words
|
||||
"""
|
||||
import re
|
||||
|
||||
# Insert space before uppercase letters
|
||||
result = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
||||
# Insert space before uppercase followed by lowercase
|
||||
result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1 \2", result)
|
||||
# Replace underscores with spaces
|
||||
result = result.replace("_", " ")
|
||||
|
||||
return result
|
||||
146
codex-lens/src/codexlens/api/symbols.py
Normal file
146
codex-lens/src/codexlens/api/symbols.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""workspace_symbols API implementation.
|
||||
|
||||
This module provides the workspace_symbols() function for searching
|
||||
symbols across the entire workspace with prefix matching.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from ..entities import Symbol
|
||||
from ..storage.global_index import GlobalSymbolIndex
|
||||
from ..storage.registry import RegistryStore
|
||||
from ..errors import IndexNotFoundError
|
||||
from .models import SymbolInfo
|
||||
from .utils import resolve_project
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def workspace_symbols(
|
||||
project_root: str,
|
||||
query: str,
|
||||
kind_filter: Optional[List[str]] = None,
|
||||
file_pattern: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[SymbolInfo]:
|
||||
"""Search for symbols across the entire workspace.
|
||||
|
||||
Uses prefix matching for efficient searching.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory (for index location)
|
||||
query: Search query (prefix match)
|
||||
kind_filter: Optional list of symbol kinds to include
|
||||
(e.g., ["class", "function"])
|
||||
file_pattern: Optional glob pattern to filter by file path
|
||||
(e.g., "*.py", "src/**/*.ts")
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of SymbolInfo sorted by score
|
||||
|
||||
Raises:
|
||||
IndexNotFoundError: If project is not indexed
|
||||
"""
|
||||
project_path = resolve_project(project_root)
|
||||
|
||||
# Get project info from registry
|
||||
registry = RegistryStore()
|
||||
project_info = registry.get_project_by_source(str(project_path))
|
||||
if project_info is None:
|
||||
raise IndexNotFoundError(f"Project not indexed: {project_path}")
|
||||
|
||||
# Open global symbol index
|
||||
index_db = project_info.index_root / "_global_symbols.db"
|
||||
if not index_db.exists():
|
||||
raise IndexNotFoundError(f"Global symbol index not found: {index_db}")
|
||||
|
||||
global_index = GlobalSymbolIndex(str(index_db), project_info.id)
|
||||
|
||||
# Search with prefix matching
|
||||
# If kind_filter has multiple kinds, we need to search for each
|
||||
all_results: List[Symbol] = []
|
||||
|
||||
if kind_filter and len(kind_filter) > 0:
|
||||
# Search for each kind separately
|
||||
for kind in kind_filter:
|
||||
results = global_index.search(
|
||||
name=query,
|
||||
kind=kind,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
all_results.extend(results)
|
||||
else:
|
||||
# Search without kind filter
|
||||
all_results = global_index.search(
|
||||
name=query,
|
||||
kind=None,
|
||||
limit=limit,
|
||||
prefix_mode=True
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(all_results)} symbols matching '{query}'")
|
||||
|
||||
# Apply file pattern filter if specified
|
||||
if file_pattern:
|
||||
all_results = [
|
||||
sym for sym in all_results
|
||||
if sym.file and fnmatch.fnmatch(sym.file, file_pattern)
|
||||
]
|
||||
logger.debug(f"After file filter '{file_pattern}': {len(all_results)} symbols")
|
||||
|
||||
# Convert to SymbolInfo and sort by relevance
|
||||
symbols = [
|
||||
SymbolInfo(
|
||||
name=sym.name,
|
||||
kind=sym.kind,
|
||||
file_path=sym.file or "",
|
||||
line=sym.range[0] if sym.range else 1,
|
||||
container=None, # Could extract from parent
|
||||
score=_calculate_score(sym.name, query)
|
||||
)
|
||||
for sym in all_results
|
||||
]
|
||||
|
||||
# Sort by score (exact matches first)
|
||||
symbols.sort(key=lambda s: s.score, reverse=True)
|
||||
|
||||
return symbols[:limit]
|
||||
|
||||
|
||||
def _calculate_score(symbol_name: str, query: str) -> float:
|
||||
"""Calculate relevance score for a symbol match.
|
||||
|
||||
Scoring:
|
||||
- Exact match: 1.0
|
||||
- Prefix match: 0.8 + 0.2 * (query_len / symbol_len)
|
||||
- Case-insensitive match: 0.6
|
||||
|
||||
Args:
|
||||
symbol_name: The matched symbol name
|
||||
query: The search query
|
||||
|
||||
Returns:
|
||||
Score between 0.0 and 1.0
|
||||
"""
|
||||
if symbol_name == query:
|
||||
return 1.0
|
||||
|
||||
if symbol_name.lower() == query.lower():
|
||||
return 0.9
|
||||
|
||||
if symbol_name.startswith(query):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.8 + 0.2 * ratio
|
||||
|
||||
if symbol_name.lower().startswith(query.lower()):
|
||||
ratio = len(query) / len(symbol_name)
|
||||
return 0.6 + 0.2 * ratio
|
||||
|
||||
return 0.5
|
||||
153
codex-lens/src/codexlens/api/utils.py
Normal file
153
codex-lens/src/codexlens/api/utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Utility functions for the codexlens API.
|
||||
|
||||
This module provides helper functions for:
|
||||
- Project resolution
|
||||
- Relationship type normalization
|
||||
- Result ranking by proximity
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, TypeVar, Callable
|
||||
|
||||
from .models import DefinitionResult
|
||||
|
||||
|
||||
# Type variable for generic ranking
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def resolve_project(project_root: str) -> Path:
|
||||
"""Resolve and validate project root path.
|
||||
|
||||
Args:
|
||||
project_root: Path to project root (relative or absolute)
|
||||
|
||||
Returns:
|
||||
Resolved absolute Path
|
||||
|
||||
Raises:
|
||||
ValueError: If path does not exist or is not a directory
|
||||
"""
|
||||
path = Path(project_root).resolve()
|
||||
if not path.exists():
|
||||
raise ValueError(f"Project root does not exist: {path}")
|
||||
if not path.is_dir():
|
||||
raise ValueError(f"Project root is not a directory: {path}")
|
||||
return path
|
||||
|
||||
|
||||
# Relationship type normalization mapping
|
||||
_RELATIONSHIP_NORMALIZATION = {
|
||||
# Plural to singular
|
||||
"calls": "call",
|
||||
"imports": "import",
|
||||
"inherits": "inheritance",
|
||||
"uses": "use",
|
||||
# Already normalized (passthrough)
|
||||
"call": "call",
|
||||
"import": "import",
|
||||
"inheritance": "inheritance",
|
||||
"use": "use",
|
||||
"type_annotation": "type_annotation",
|
||||
}
|
||||
|
||||
|
||||
def normalize_relationship_type(relationship: str) -> str:
|
||||
"""Normalize relationship type to canonical form.
|
||||
|
||||
Converts plural forms and variations to standard singular forms:
|
||||
- 'calls' -> 'call'
|
||||
- 'imports' -> 'import'
|
||||
- 'inherits' -> 'inheritance'
|
||||
- 'uses' -> 'use'
|
||||
|
||||
Args:
|
||||
relationship: Raw relationship type string
|
||||
|
||||
Returns:
|
||||
Normalized relationship type
|
||||
|
||||
Examples:
|
||||
>>> normalize_relationship_type('calls')
|
||||
'call'
|
||||
>>> normalize_relationship_type('inherits')
|
||||
'inheritance'
|
||||
>>> normalize_relationship_type('call')
|
||||
'call'
|
||||
"""
|
||||
return _RELATIONSHIP_NORMALIZATION.get(relationship.lower(), relationship)
|
||||
|
||||
|
||||
def rank_by_proximity(
|
||||
results: List[DefinitionResult],
|
||||
file_context: Optional[str] = None
|
||||
) -> List[DefinitionResult]:
|
||||
"""Rank results by file path proximity to context.
|
||||
|
||||
V1 Implementation: Uses path-based proximity scoring.
|
||||
|
||||
Scoring algorithm:
|
||||
1. Same directory: highest score (100)
|
||||
2. Otherwise: length of common path prefix
|
||||
|
||||
Args:
|
||||
results: List of definition results to rank
|
||||
file_context: Reference file path for proximity calculation.
|
||||
If None, returns results unchanged.
|
||||
|
||||
Returns:
|
||||
Results sorted by proximity score (highest first)
|
||||
|
||||
Examples:
|
||||
>>> results = [
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/b/c.py", line=1, end_line=10),
|
||||
... DefinitionResult(name="foo", kind="function",
|
||||
... file_path="/a/x/y.py", line=1, end_line=10),
|
||||
... ]
|
||||
>>> ranked = rank_by_proximity(results, "/a/b/test.py")
|
||||
>>> ranked[0].file_path
|
||||
'/a/b/c.py'
|
||||
"""
|
||||
if not file_context or not results:
|
||||
return results
|
||||
|
||||
def proximity_score(result: DefinitionResult) -> int:
|
||||
"""Calculate proximity score for a result."""
|
||||
result_dir = os.path.dirname(result.file_path)
|
||||
context_dir = os.path.dirname(file_context)
|
||||
|
||||
# Same directory gets highest score
|
||||
if result_dir == context_dir:
|
||||
return 100
|
||||
|
||||
# Otherwise, score by common path prefix length
|
||||
try:
|
||||
common = os.path.commonpath([result.file_path, file_context])
|
||||
return len(common)
|
||||
except ValueError:
|
||||
# No common path (different drives on Windows)
|
||||
return 0
|
||||
|
||||
return sorted(results, key=proximity_score, reverse=True)
|
||||
|
||||
|
||||
def rank_by_score(
|
||||
results: List[T],
|
||||
score_fn: Callable[[T], float],
|
||||
reverse: bool = True
|
||||
) -> List[T]:
|
||||
"""Generic ranking function by custom score.
|
||||
|
||||
Args:
|
||||
results: List of items to rank
|
||||
score_fn: Function to extract score from item
|
||||
reverse: If True, highest scores first (default)
|
||||
|
||||
Returns:
|
||||
Sorted list
|
||||
"""
|
||||
return sorted(results, key=score_fn, reverse=reverse)
|
||||
@@ -154,6 +154,13 @@ class Config:
|
||||
cascade_fine_k: int = 10 # Number of final results after reranking
|
||||
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
|
||||
|
||||
# Staged cascade search configuration (4-stage pipeline)
|
||||
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
|
||||
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
|
||||
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
|
||||
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
|
||||
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4
|
||||
|
||||
# RRF fusion configuration
|
||||
fusion_method: str = "rrf" # "simple" (weighted sum) or "rrf" (reciprocal rank fusion)
|
||||
rrf_k: int = 60 # RRF constant (default 60)
|
||||
|
||||
7
codex-lens/src/codexlens/lsp/__init__.py
Normal file
7
codex-lens/src/codexlens/lsp/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""codex-lens Language Server Protocol implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codexlens.lsp.server import CodexLensLanguageServer, main
|
||||
|
||||
__all__ = ["CodexLensLanguageServer", "main"]
|
||||
551
codex-lens/src/codexlens/lsp/handlers.py
Normal file
551
codex-lens/src/codexlens/lsp/handlers.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""LSP request handlers for codex-lens.
|
||||
|
||||
This module contains handlers for LSP requests:
|
||||
- textDocument/definition
|
||||
- textDocument/completion
|
||||
- workspace/symbol
|
||||
- textDocument/didSave
|
||||
- textDocument/hover
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
from codexlens.lsp.server import server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Symbol kind mapping from codex-lens to LSP
|
||||
SYMBOL_KIND_MAP = {
|
||||
"class": lsp.SymbolKind.Class,
|
||||
"function": lsp.SymbolKind.Function,
|
||||
"method": lsp.SymbolKind.Method,
|
||||
"variable": lsp.SymbolKind.Variable,
|
||||
"constant": lsp.SymbolKind.Constant,
|
||||
"property": lsp.SymbolKind.Property,
|
||||
"field": lsp.SymbolKind.Field,
|
||||
"interface": lsp.SymbolKind.Interface,
|
||||
"module": lsp.SymbolKind.Module,
|
||||
"namespace": lsp.SymbolKind.Namespace,
|
||||
"package": lsp.SymbolKind.Package,
|
||||
"enum": lsp.SymbolKind.Enum,
|
||||
"enum_member": lsp.SymbolKind.EnumMember,
|
||||
"struct": lsp.SymbolKind.Struct,
|
||||
"type": lsp.SymbolKind.TypeParameter,
|
||||
"type_alias": lsp.SymbolKind.TypeParameter,
|
||||
}
|
||||
|
||||
# Completion kind mapping from codex-lens to LSP
|
||||
COMPLETION_KIND_MAP = {
|
||||
"class": lsp.CompletionItemKind.Class,
|
||||
"function": lsp.CompletionItemKind.Function,
|
||||
"method": lsp.CompletionItemKind.Method,
|
||||
"variable": lsp.CompletionItemKind.Variable,
|
||||
"constant": lsp.CompletionItemKind.Constant,
|
||||
"property": lsp.CompletionItemKind.Property,
|
||||
"field": lsp.CompletionItemKind.Field,
|
||||
"interface": lsp.CompletionItemKind.Interface,
|
||||
"module": lsp.CompletionItemKind.Module,
|
||||
"enum": lsp.CompletionItemKind.Enum,
|
||||
"enum_member": lsp.CompletionItemKind.EnumMember,
|
||||
"struct": lsp.CompletionItemKind.Struct,
|
||||
"type": lsp.CompletionItemKind.TypeParameter,
|
||||
"type_alias": lsp.CompletionItemKind.TypeParameter,
|
||||
}
|
||||
|
||||
|
||||
def _path_to_uri(path: Union[str, Path]) -> str:
|
||||
"""Convert a file path to a URI.
|
||||
|
||||
Args:
|
||||
path: File path (string or Path object)
|
||||
|
||||
Returns:
|
||||
File URI string
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
# Handle Windows paths
|
||||
if path_str.startswith("/"):
|
||||
return f"file://{quote(path_str)}"
|
||||
else:
|
||||
return f"file:///{quote(path_str.replace(chr(92), '/'))}"
|
||||
|
||||
|
||||
def _uri_to_path(uri: str) -> Path:
|
||||
"""Convert a URI to a file path.
|
||||
|
||||
Args:
|
||||
uri: File URI string
|
||||
|
||||
Returns:
|
||||
Path object
|
||||
"""
|
||||
path = uri.replace("file:///", "").replace("file://", "")
|
||||
return Path(unquote(path))
|
||||
|
||||
|
||||
def _get_word_at_position(document_text: str, line: int, character: int) -> Optional[str]:
|
||||
"""Extract the word at the given position in the document.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Word at position, or None if no word found
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return None
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
return None
|
||||
|
||||
# Find word boundaries
|
||||
word_pattern = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*")
|
||||
for match in word_pattern.finditer(line_text):
|
||||
if match.start() <= character <= match.end():
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_prefix_at_position(document_text: str, line: int, character: int) -> str:
|
||||
"""Extract the incomplete word prefix at the given position.
|
||||
|
||||
Args:
|
||||
document_text: Full document text
|
||||
line: 0-based line number
|
||||
character: 0-based character position
|
||||
|
||||
Returns:
|
||||
Prefix string (may be empty)
|
||||
"""
|
||||
lines = document_text.splitlines()
|
||||
if line >= len(lines):
|
||||
return ""
|
||||
|
||||
line_text = lines[line]
|
||||
if character > len(line_text):
|
||||
character = len(line_text)
|
||||
|
||||
# Extract text before cursor
|
||||
before_cursor = line_text[:character]
|
||||
|
||||
# Find the start of the current word
|
||||
match = re.search(r"[a-zA-Z_][a-zA-Z0-9_]*$", before_cursor)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def symbol_to_location(symbol: Symbol) -> Optional[lsp.Location]:
|
||||
"""Convert a codex-lens Symbol to an LSP Location.
|
||||
|
||||
Args:
|
||||
symbol: codex-lens Symbol object
|
||||
|
||||
Returns:
|
||||
LSP Location, or None if symbol has no file
|
||||
"""
|
||||
if not symbol.file:
|
||||
return None
|
||||
|
||||
# LSP uses 0-based lines, codex-lens uses 1-based
|
||||
start_line = max(0, symbol.range[0] - 1)
|
||||
end_line = max(0, symbol.range[1] - 1)
|
||||
|
||||
return lsp.Location(
|
||||
uri=_path_to_uri(symbol.file),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(line=start_line, character=0),
|
||||
end=lsp.Position(line=end_line, character=0),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _symbol_kind_to_lsp(kind: str) -> lsp.SymbolKind:
|
||||
"""Map codex-lens symbol kind to LSP SymbolKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP SymbolKind
|
||||
"""
|
||||
return SYMBOL_KIND_MAP.get(kind.lower(), lsp.SymbolKind.Variable)
|
||||
|
||||
|
||||
def _symbol_kind_to_completion_kind(kind: str) -> lsp.CompletionItemKind:
|
||||
"""Map codex-lens symbol kind to LSP CompletionItemKind.
|
||||
|
||||
Args:
|
||||
kind: codex-lens symbol kind string
|
||||
|
||||
Returns:
|
||||
LSP CompletionItemKind
|
||||
"""
|
||||
return COMPLETION_KIND_MAP.get(kind.lower(), lsp.CompletionItemKind.Text)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# LSP Request Handlers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DEFINITION)
|
||||
def lsp_definition(
|
||||
params: lsp.DefinitionParams,
|
||||
) -> Optional[Union[lsp.Location, List[lsp.Location]]]:
|
||||
"""Handle textDocument/definition request.
|
||||
|
||||
Finds the definition of the symbol at the cursor position.
|
||||
"""
|
||||
if not server.global_index:
|
||||
logger.debug("No global index available for definition lookup")
|
||||
return None
|
||||
|
||||
# Get document
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
# Get word at position
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
logger.debug("No word found at position")
|
||||
return None
|
||||
|
||||
logger.debug("Looking up definition for: %s", word)
|
||||
|
||||
# Search for exact symbol match
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=False, # Exact match preferred
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
if not exact_matches:
|
||||
# Fall back to prefix search
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=10,
|
||||
prefix_mode=True,
|
||||
)
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
if not exact_matches:
|
||||
logger.debug("No definition found for: %s", word)
|
||||
return None
|
||||
|
||||
# Convert to LSP locations
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
if len(locations) == 1:
|
||||
return locations[0]
|
||||
elif locations:
|
||||
return locations
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error looking up definition: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_REFERENCES)
|
||||
def lsp_references(params: lsp.ReferenceParams) -> Optional[List[lsp.Location]]:
|
||||
"""Handle textDocument/references request.
|
||||
|
||||
Finds all references to the symbol at the cursor position using
|
||||
the code_relationships table for accurate call-site tracking.
|
||||
Falls back to same-name symbol search if search_engine is unavailable.
|
||||
"""
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Finding references for: %s", word)
|
||||
|
||||
try:
|
||||
# Try using search_engine.search_references() for accurate reference tracking
|
||||
if server.search_engine and server.workspace_root:
|
||||
references = server.search_engine.search_references(
|
||||
symbol_name=word,
|
||||
source_path=server.workspace_root,
|
||||
limit=200,
|
||||
)
|
||||
|
||||
if references:
|
||||
locations = []
|
||||
for ref in references:
|
||||
locations.append(
|
||||
lsp.Location(
|
||||
uri=_path_to_uri(ref.file_path),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column,
|
||||
),
|
||||
end=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column + len(word),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return locations if locations else None
|
||||
|
||||
# Fallback: search for symbols with same name using global_index
|
||||
if server.global_index:
|
||||
symbols = server.global_index.search(
|
||||
name=word,
|
||||
limit=100,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact matches
|
||||
exact_matches = [s for s in symbols if s.name == word]
|
||||
|
||||
locations = []
|
||||
for sym in exact_matches:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
locations.append(loc)
|
||||
|
||||
return locations if locations else None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error finding references: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_COMPLETION)
|
||||
def lsp_completion(params: lsp.CompletionParams) -> Optional[lsp.CompletionList]:
|
||||
"""Handle textDocument/completion request.
|
||||
|
||||
Provides code completion suggestions based on indexed symbols.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
prefix = _get_prefix_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not prefix or len(prefix) < 2:
|
||||
# Require at least 2 characters for completion
|
||||
return None
|
||||
|
||||
logger.debug("Completing prefix: %s", prefix)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=prefix,
|
||||
limit=50,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
# Convert to completion items
|
||||
items = []
|
||||
seen_names = set()
|
||||
|
||||
for sym in symbols:
|
||||
if sym.name in seen_names:
|
||||
continue
|
||||
seen_names.add(sym.name)
|
||||
|
||||
items.append(
|
||||
lsp.CompletionItem(
|
||||
label=sym.name,
|
||||
kind=_symbol_kind_to_completion_kind(sym.kind),
|
||||
detail=f"{sym.kind} - {Path(sym.file).name if sym.file else 'unknown'}",
|
||||
sort_text=sym.name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return lsp.CompletionList(
|
||||
is_incomplete=len(symbols) >= 50,
|
||||
items=items,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting completions: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_HOVER)
|
||||
def lsp_hover(params: lsp.HoverParams) -> Optional[lsp.Hover]:
|
||||
"""Handle textDocument/hover request.
|
||||
|
||||
Provides hover information for the symbol at the cursor position
|
||||
using HoverProvider for rich symbol information including
|
||||
signature, documentation, and location.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
document = server.workspace.get_text_document(params.text_document.uri)
|
||||
if not document:
|
||||
return None
|
||||
|
||||
word = _get_word_at_position(
|
||||
document.source,
|
||||
params.position.line,
|
||||
params.position.character,
|
||||
)
|
||||
|
||||
if not word:
|
||||
return None
|
||||
|
||||
logger.debug("Hover for: %s", word)
|
||||
|
||||
try:
|
||||
# Use HoverProvider for rich symbol information
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
provider = HoverProvider(server.global_index, server.registry)
|
||||
info = provider.get_hover_info(word)
|
||||
|
||||
if not info:
|
||||
return None
|
||||
|
||||
# Format as markdown with signature and location
|
||||
content = provider.format_hover_markdown(info)
|
||||
|
||||
return lsp.Hover(
|
||||
contents=lsp.MarkupContent(
|
||||
kind=lsp.MarkupKind.Markdown,
|
||||
value=content,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error getting hover info: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.WORKSPACE_SYMBOL)
|
||||
def lsp_workspace_symbol(
|
||||
params: lsp.WorkspaceSymbolParams,
|
||||
) -> Optional[List[lsp.SymbolInformation]]:
|
||||
"""Handle workspace/symbol request.
|
||||
|
||||
Searches for symbols across the workspace.
|
||||
"""
|
||||
if not server.global_index:
|
||||
return None
|
||||
|
||||
query = params.query
|
||||
if not query or len(query) < 2:
|
||||
return None
|
||||
|
||||
logger.debug("Workspace symbol search: %s", query)
|
||||
|
||||
try:
|
||||
symbols = server.global_index.search(
|
||||
name=query,
|
||||
limit=100,
|
||||
prefix_mode=True,
|
||||
)
|
||||
|
||||
if not symbols:
|
||||
return None
|
||||
|
||||
result = []
|
||||
for sym in symbols:
|
||||
loc = symbol_to_location(sym)
|
||||
if loc:
|
||||
result.append(
|
||||
lsp.SymbolInformation(
|
||||
name=sym.name,
|
||||
kind=_symbol_kind_to_lsp(sym.kind),
|
||||
location=loc,
|
||||
container_name=Path(sym.file).parent.name if sym.file else None,
|
||||
)
|
||||
)
|
||||
|
||||
return result if result else None
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error searching workspace symbols: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_SAVE)
|
||||
def lsp_did_save(params: lsp.DidSaveTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didSave notification.
|
||||
|
||||
Triggers incremental re-indexing of the saved file.
|
||||
Note: Full incremental indexing requires WatcherManager integration,
|
||||
which is planned for Phase 2.
|
||||
"""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.info("File saved: %s", file_path)
|
||||
|
||||
# Phase 1: Just log the save event
|
||||
# Phase 2 will integrate with WatcherManager for incremental indexing
|
||||
# if server.watcher_manager:
|
||||
# server.watcher_manager.trigger_reindex(file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_OPEN)
|
||||
def lsp_did_open(params: lsp.DidOpenTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didOpen notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File opened: %s", file_path)
|
||||
|
||||
|
||||
@server.feature(lsp.TEXT_DOCUMENT_DID_CLOSE)
|
||||
def lsp_did_close(params: lsp.DidCloseTextDocumentParams) -> None:
|
||||
"""Handle textDocument/didClose notification."""
|
||||
file_path = _uri_to_path(params.text_document.uri)
|
||||
logger.debug("File closed: %s", file_path)
|
||||
177
codex-lens/src/codexlens/lsp/providers.py
Normal file
177
codex-lens/src/codexlens/lsp/providers.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""LSP feature providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HoverInfo:
|
||||
"""Hover information for a symbol."""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
signature: str
|
||||
documentation: Optional[str]
|
||||
file_path: str
|
||||
line_range: tuple # (start_line, end_line)
|
||||
|
||||
|
||||
class HoverProvider:
|
||||
"""Provides hover information for symbols."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
registry: Optional["RegistryStore"] = None,
|
||||
) -> None:
|
||||
"""Initialize hover provider.
|
||||
|
||||
Args:
|
||||
global_index: Global symbol index for lookups
|
||||
registry: Optional registry store for index path resolution
|
||||
"""
|
||||
self.global_index = global_index
|
||||
self.registry = registry
|
||||
|
||||
def get_hover_info(self, symbol_name: str) -> Optional[HoverInfo]:
|
||||
"""Get hover information for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to look up
|
||||
|
||||
Returns:
|
||||
HoverInfo or None if symbol not found
|
||||
"""
|
||||
# Look up symbol in global index using exact match
|
||||
symbols = self.global_index.search(
|
||||
name=symbol_name,
|
||||
limit=1,
|
||||
prefix_mode=False,
|
||||
)
|
||||
|
||||
# Filter for exact name match
|
||||
exact_matches = [s for s in symbols if s.name == symbol_name]
|
||||
|
||||
if not exact_matches:
|
||||
return None
|
||||
|
||||
symbol = exact_matches[0]
|
||||
|
||||
# Extract signature from source file
|
||||
signature = self._extract_signature(symbol)
|
||||
|
||||
# Symbol uses 'file' attribute and 'range' tuple
|
||||
file_path = symbol.file or ""
|
||||
start_line, end_line = symbol.range
|
||||
|
||||
return HoverInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
signature=signature,
|
||||
documentation=None, # Symbol doesn't have docstring field
|
||||
file_path=file_path,
|
||||
line_range=(start_line, end_line),
|
||||
)
|
||||
|
||||
def _extract_signature(self, symbol) -> str:
|
||||
"""Extract function/class signature from source file.
|
||||
|
||||
Args:
|
||||
symbol: Symbol object with file and range information
|
||||
|
||||
Returns:
|
||||
Extracted signature string or fallback kind + name
|
||||
"""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.split("\n")
|
||||
|
||||
# Extract signature lines (first line of definition + continuation)
|
||||
start_line = symbol.range[0] - 1 # Convert 1-based to 0-based
|
||||
if start_line >= len(lines) or start_line < 0:
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
signature_lines = []
|
||||
first_line = lines[start_line]
|
||||
signature_lines.append(first_line)
|
||||
|
||||
# Continue if multiline signature (no closing paren + colon yet)
|
||||
# Look for patterns like "def func(", "class Foo(", etc.
|
||||
i = start_line + 1
|
||||
max_lines = min(start_line + 5, len(lines))
|
||||
while i < max_lines:
|
||||
line = signature_lines[-1]
|
||||
# Stop if we see closing pattern
|
||||
if "):" in line or line.rstrip().endswith(":"):
|
||||
break
|
||||
signature_lines.append(lines[i])
|
||||
i += 1
|
||||
|
||||
return "\n".join(signature_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract signature for {symbol.name}: {e}")
|
||||
return f"{symbol.kind} {symbol.name}"
|
||||
|
||||
def format_hover_markdown(self, info: HoverInfo) -> str:
|
||||
"""Format hover info as Markdown.
|
||||
|
||||
Args:
|
||||
info: HoverInfo object to format
|
||||
|
||||
Returns:
|
||||
Markdown-formatted hover content
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Detect language for code fence based on file extension
|
||||
ext = Path(info.file_path).suffix.lower() if info.file_path else ""
|
||||
lang_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "typescript",
|
||||
".jsx": "javascript",
|
||||
".java": "java",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".c": "c",
|
||||
".cpp": "cpp",
|
||||
".h": "c",
|
||||
".hpp": "cpp",
|
||||
".cs": "csharp",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
}
|
||||
lang = lang_map.get(ext, "")
|
||||
|
||||
# Code block with signature
|
||||
parts.append(f"```{lang}\n{info.signature}\n```")
|
||||
|
||||
# Documentation if available
|
||||
if info.documentation:
|
||||
parts.append(f"\n---\n\n{info.documentation}")
|
||||
|
||||
# Location info
|
||||
file_name = Path(info.file_path).name if info.file_path else "unknown"
|
||||
parts.append(
|
||||
f"\n---\n\n*{info.kind}* defined in "
|
||||
f"`{file_name}` "
|
||||
f"(line {info.line_range[0]})"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
263
codex-lens/src/codexlens/lsp/server.py
Normal file
263
codex-lens/src/codexlens/lsp/server.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""codex-lens LSP Server implementation using pygls.
|
||||
|
||||
This module provides the main Language Server class and entry point.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from lsprotocol import types as lsp
|
||||
from pygls.lsp.server import LanguageServer
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"LSP dependencies not installed. Install with: pip install codex-lens[lsp]"
|
||||
) from exc
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodexLensLanguageServer(LanguageServer):
|
||||
"""Language Server for codex-lens code indexing.
|
||||
|
||||
Provides IDE features using codex-lens symbol index:
|
||||
- Go to Definition
|
||||
- Find References
|
||||
- Code Completion
|
||||
- Hover Information
|
||||
- Workspace Symbol Search
|
||||
|
||||
Attributes:
|
||||
registry: Global project registry for path lookups
|
||||
mapper: Path mapper for source/index conversions
|
||||
global_index: Project-wide symbol index
|
||||
search_engine: Chain search engine for symbol search
|
||||
workspace_root: Current workspace root path
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name="codexlens-lsp", version="0.1.0")
|
||||
|
||||
self.registry: Optional[RegistryStore] = None
|
||||
self.mapper: Optional[PathMapper] = None
|
||||
self.global_index: Optional[GlobalSymbolIndex] = None
|
||||
self.search_engine: Optional[ChainSearchEngine] = None
|
||||
self.workspace_root: Optional[Path] = None
|
||||
self._config: Optional[Config] = None
|
||||
|
||||
def initialize_components(self, workspace_root: Path) -> bool:
|
||||
"""Initialize codex-lens components for the workspace.
|
||||
|
||||
Args:
|
||||
workspace_root: Root path of the workspace
|
||||
|
||||
Returns:
|
||||
True if initialization succeeded, False otherwise
|
||||
"""
|
||||
self.workspace_root = workspace_root.resolve()
|
||||
logger.info("Initializing codex-lens for workspace: %s", self.workspace_root)
|
||||
|
||||
try:
|
||||
# Initialize registry
|
||||
self.registry = RegistryStore()
|
||||
self.registry.initialize()
|
||||
|
||||
# Initialize path mapper
|
||||
self.mapper = PathMapper()
|
||||
|
||||
# Try to find project in registry
|
||||
project_info = self.registry.find_by_source_path(str(self.workspace_root))
|
||||
|
||||
if project_info:
|
||||
project_id = int(project_info["id"])
|
||||
index_root = Path(project_info["index_root"])
|
||||
|
||||
# Initialize global symbol index
|
||||
global_db = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||
self.global_index = GlobalSymbolIndex(global_db, project_id)
|
||||
self.global_index.initialize()
|
||||
|
||||
# Initialize search engine
|
||||
self._config = Config()
|
||||
self.search_engine = ChainSearchEngine(
|
||||
registry=self.registry,
|
||||
mapper=self.mapper,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
logger.info("codex-lens initialized for project: %s", project_info["source_root"])
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Workspace not indexed by codex-lens: %s. "
|
||||
"Run 'codexlens index %s' to index first.",
|
||||
self.workspace_root,
|
||||
self.workspace_root,
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to initialize codex-lens: %s", exc)
|
||||
return False
|
||||
|
||||
def shutdown_components(self) -> None:
|
||||
"""Clean up codex-lens components."""
|
||||
if self.global_index:
|
||||
try:
|
||||
self.global_index.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing global index: %s", exc)
|
||||
self.global_index = None
|
||||
|
||||
if self.search_engine:
|
||||
try:
|
||||
self.search_engine.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing search engine: %s", exc)
|
||||
self.search_engine = None
|
||||
|
||||
if self.registry:
|
||||
try:
|
||||
self.registry.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing registry: %s", exc)
|
||||
self.registry = None
|
||||
|
||||
|
||||
# Create server instance
|
||||
server = CodexLensLanguageServer()
|
||||
|
||||
|
||||
@server.feature(lsp.INITIALIZE)
|
||||
def lsp_initialize(params: lsp.InitializeParams) -> lsp.InitializeResult:
|
||||
"""Handle LSP initialize request."""
|
||||
logger.info("LSP initialize request received")
|
||||
|
||||
# Get workspace root
|
||||
workspace_root: Optional[Path] = None
|
||||
if params.root_uri:
|
||||
workspace_root = Path(params.root_uri.replace("file://", "").replace("file:", ""))
|
||||
elif params.root_path:
|
||||
workspace_root = Path(params.root_path)
|
||||
|
||||
if workspace_root:
|
||||
server.initialize_components(workspace_root)
|
||||
|
||||
# Declare server capabilities
|
||||
return lsp.InitializeResult(
|
||||
capabilities=lsp.ServerCapabilities(
|
||||
text_document_sync=lsp.TextDocumentSyncOptions(
|
||||
open_close=True,
|
||||
change=lsp.TextDocumentSyncKind.Incremental,
|
||||
save=lsp.SaveOptions(include_text=False),
|
||||
),
|
||||
definition_provider=True,
|
||||
references_provider=True,
|
||||
completion_provider=lsp.CompletionOptions(
|
||||
trigger_characters=[".", ":"],
|
||||
resolve_provider=False,
|
||||
),
|
||||
hover_provider=True,
|
||||
workspace_symbol_provider=True,
|
||||
),
|
||||
server_info=lsp.ServerInfo(
|
||||
name="codexlens-lsp",
|
||||
version="0.1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@server.feature(lsp.SHUTDOWN)
|
||||
def lsp_shutdown(params: None) -> None:
|
||||
"""Handle LSP shutdown request."""
|
||||
logger.info("LSP shutdown request received")
|
||||
server.shutdown_components()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Entry point for codexlens-lsp command.
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success)
|
||||
"""
|
||||
# Import handlers to register them with the server
|
||||
# This must be done before starting the server
|
||||
import codexlens.lsp.handlers # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="codex-lens Language Server",
|
||||
prog="codexlens-lsp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stdio",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use stdio for communication (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tcp",
|
||||
action="store_true",
|
||||
help="Use TCP for communication",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default="127.0.0.1",
|
||||
help="TCP host (default: 127.0.0.1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=2087,
|
||||
help="TCP port (default: 2087)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Log level (default: INFO)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
help="Log file path (optional)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
log_handlers = []
|
||||
if args.log_file:
|
||||
log_handlers.append(logging.FileHandler(args.log_file))
|
||||
else:
|
||||
log_handlers.append(logging.StreamHandler(sys.stderr))
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
logger.info("Starting codexlens-lsp server")
|
||||
|
||||
if args.tcp:
|
||||
logger.info("Starting TCP server on %s:%d", args.host, args.port)
|
||||
server.start_tcp(args.host, args.port)
|
||||
else:
|
||||
logger.info("Starting stdio server")
|
||||
server.start_io()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
20
codex-lens/src/codexlens/mcp/__init__.py
Normal file
20
codex-lens/src/codexlens/mcp/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Model Context Protocol implementation for Claude Code integration."""
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"MCPContext",
|
||||
"SymbolInfo",
|
||||
"ReferenceInfo",
|
||||
"RelatedSymbol",
|
||||
"MCPProvider",
|
||||
"HookManager",
|
||||
"create_context_for_prompt",
|
||||
]
|
||||
170
codex-lens/src/codexlens/mcp/hooks.py
Normal file
170
codex-lens/src/codexlens/mcp/hooks.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Hook interfaces for Claude Code integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import MCPContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Manages hook registration and execution."""
|
||||
|
||||
def __init__(self, mcp_provider: "MCPProvider") -> None:
|
||||
self.mcp_provider = mcp_provider
|
||||
self._pre_hooks: Dict[str, Callable] = {}
|
||||
self._post_hooks: Dict[str, Callable] = {}
|
||||
|
||||
# Register default hooks
|
||||
self._register_default_hooks()
|
||||
|
||||
def _register_default_hooks(self) -> None:
|
||||
"""Register built-in hooks."""
|
||||
self._pre_hooks["explain"] = self._pre_explain_hook
|
||||
self._pre_hooks["refactor"] = self._pre_refactor_hook
|
||||
self._pre_hooks["document"] = self._pre_document_hook
|
||||
|
||||
def execute_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> Optional[MCPContext]:
|
||||
"""Execute pre-tool hook to gather context.
|
||||
|
||||
Args:
|
||||
action: The action being performed (e.g., "explain", "refactor")
|
||||
params: Parameters for the action
|
||||
|
||||
Returns:
|
||||
MCPContext to inject into prompt, or None
|
||||
"""
|
||||
hook = self._pre_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
logger.debug(f"No pre-hook for action: {action}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return hook(params)
|
||||
except Exception as e:
|
||||
logger.error(f"Pre-hook failed for {action}: {e}")
|
||||
return None
|
||||
|
||||
def execute_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
result: Any,
|
||||
) -> None:
|
||||
"""Execute post-tool hook for proactive caching.
|
||||
|
||||
Args:
|
||||
action: The action that was performed
|
||||
result: Result of the action
|
||||
"""
|
||||
hook = self._post_hooks.get(action)
|
||||
|
||||
if not hook:
|
||||
return
|
||||
|
||||
try:
|
||||
hook(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Post-hook failed for {action}: {e}")
|
||||
|
||||
def _pre_explain_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'explain' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="symbol_explanation",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def _pre_refactor_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'refactor' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
|
||||
if not symbol_name:
|
||||
return None
|
||||
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="refactor_context",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
max_references=20,
|
||||
)
|
||||
|
||||
def _pre_document_hook(self, params: Dict[str, Any]) -> Optional[MCPContext]:
|
||||
"""Pre-hook for 'document' action."""
|
||||
symbol_name = params.get("symbol")
|
||||
file_path = params.get("file_path")
|
||||
|
||||
if symbol_name:
|
||||
return self.mcp_provider.build_context(
|
||||
symbol_name=symbol_name,
|
||||
context_type="documentation_context",
|
||||
include_references=False,
|
||||
include_related=True,
|
||||
)
|
||||
elif file_path:
|
||||
return self.mcp_provider.build_context_for_file(
|
||||
Path(file_path),
|
||||
context_type="file_documentation",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def register_pre_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Dict[str, Any]], Optional[MCPContext]],
|
||||
) -> None:
|
||||
"""Register a custom pre-tool hook."""
|
||||
self._pre_hooks[action] = hook
|
||||
|
||||
def register_post_hook(
|
||||
self,
|
||||
action: str,
|
||||
hook: Callable[[Any], None],
|
||||
) -> None:
|
||||
"""Register a custom post-tool hook."""
|
||||
self._post_hooks[action] = hook
|
||||
|
||||
|
||||
def create_context_for_prompt(
|
||||
mcp_provider: "MCPProvider",
|
||||
action: str,
|
||||
params: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Create context string for prompt injection.
|
||||
|
||||
This is the main entry point for Claude Code hook integration.
|
||||
|
||||
Args:
|
||||
mcp_provider: The MCP provider instance
|
||||
action: Action being performed
|
||||
params: Action parameters
|
||||
|
||||
Returns:
|
||||
Formatted context string for prompt injection
|
||||
"""
|
||||
manager = HookManager(mcp_provider)
|
||||
context = manager.execute_pre_hook(action, params)
|
||||
|
||||
if context:
|
||||
return context.to_prompt_injection()
|
||||
|
||||
return ""
|
||||
202
codex-lens/src/codexlens/mcp/provider.py
Normal file
202
codex-lens/src/codexlens/mcp/provider.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""MCP context provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPProvider:
|
||||
"""Builds MCP context objects from codex-lens data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
global_index: "GlobalSymbolIndex",
|
||||
search_engine: "ChainSearchEngine",
|
||||
registry: "RegistryStore",
|
||||
) -> None:
|
||||
self.global_index = global_index
|
||||
self.search_engine = search_engine
|
||||
self.registry = registry
|
||||
|
||||
def build_context(
|
||||
self,
|
||||
symbol_name: str,
|
||||
context_type: str = "symbol_explanation",
|
||||
include_references: bool = True,
|
||||
include_related: bool = True,
|
||||
max_references: int = 10,
|
||||
) -> Optional[MCPContext]:
|
||||
"""Build comprehensive context for a symbol.
|
||||
|
||||
Args:
|
||||
symbol_name: Name of the symbol to contextualize
|
||||
context_type: Type of context being requested
|
||||
include_references: Whether to include reference locations
|
||||
include_related: Whether to include related symbols
|
||||
max_references: Maximum number of references to include
|
||||
|
||||
Returns:
|
||||
MCPContext object or None if symbol not found
|
||||
"""
|
||||
# Look up symbol
|
||||
symbols = self.global_index.search(symbol_name, prefix_mode=False, limit=1)
|
||||
|
||||
if not symbols:
|
||||
logger.debug(f"Symbol not found for MCP context: {symbol_name}")
|
||||
return None
|
||||
|
||||
symbol = symbols[0]
|
||||
|
||||
# Build SymbolInfo
|
||||
symbol_info = SymbolInfo(
|
||||
name=symbol.name,
|
||||
kind=symbol.kind,
|
||||
file_path=symbol.file or "",
|
||||
line_start=symbol.range[0],
|
||||
line_end=symbol.range[1],
|
||||
signature=None, # Symbol entity doesn't have signature
|
||||
documentation=None, # Symbol entity doesn't have docstring
|
||||
)
|
||||
|
||||
# Extract definition source code
|
||||
definition = self._extract_definition(symbol)
|
||||
|
||||
# Get references
|
||||
references = []
|
||||
if include_references:
|
||||
refs = self.search_engine.search_references(
|
||||
symbol_name,
|
||||
limit=max_references,
|
||||
)
|
||||
references = [
|
||||
ReferenceInfo(
|
||||
file_path=r.file_path,
|
||||
line=r.line,
|
||||
column=r.column,
|
||||
context=r.context,
|
||||
relationship_type=r.relationship_type,
|
||||
)
|
||||
for r in refs
|
||||
]
|
||||
|
||||
# Get related symbols
|
||||
related_symbols = []
|
||||
if include_related:
|
||||
related_symbols = self._get_related_symbols(symbol)
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
symbol=symbol_info,
|
||||
definition=definition,
|
||||
references=references,
|
||||
related_symbols=related_symbols,
|
||||
metadata={
|
||||
"source": "codex-lens",
|
||||
},
|
||||
)
|
||||
|
||||
def _extract_definition(self, symbol) -> Optional[str]:
|
||||
"""Extract source code for symbol definition."""
|
||||
try:
|
||||
file_path = Path(symbol.file) if symbol.file else None
|
||||
if not file_path or not file_path.exists():
|
||||
return None
|
||||
|
||||
content = file_path.read_text(encoding='utf-8', errors='ignore')
|
||||
lines = content.split("\n")
|
||||
|
||||
start = symbol.range[0] - 1
|
||||
end = symbol.range[1]
|
||||
|
||||
if start >= len(lines):
|
||||
return None
|
||||
|
||||
return "\n".join(lines[start:end])
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract definition: {e}")
|
||||
return None
|
||||
|
||||
def _get_related_symbols(self, symbol) -> List[RelatedSymbol]:
|
||||
"""Get symbols related to the given symbol."""
|
||||
related = []
|
||||
|
||||
try:
|
||||
# Search for symbols that might be related by name patterns
|
||||
# This is a simplified implementation - could be enhanced with relationship data
|
||||
|
||||
# Look for imports/callers via reference search
|
||||
refs = self.search_engine.search_references(symbol.name, limit=20)
|
||||
|
||||
seen_names = set()
|
||||
for ref in refs:
|
||||
# Extract potential symbol name from context
|
||||
if ref.relationship_type and ref.relationship_type not in seen_names:
|
||||
related.append(RelatedSymbol(
|
||||
name=f"{Path(ref.file_path).stem}",
|
||||
kind="module",
|
||||
relationship=ref.relationship_type,
|
||||
file_path=ref.file_path,
|
||||
))
|
||||
seen_names.add(ref.relationship_type)
|
||||
if len(related) >= 10:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get related symbols: {e}")
|
||||
|
||||
return related
|
||||
|
||||
def build_context_for_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
context_type: str = "file_overview",
|
||||
) -> MCPContext:
|
||||
"""Build context for an entire file."""
|
||||
# Try to get symbols by searching with file path
|
||||
# Note: GlobalSymbolIndex doesn't have search_by_file, so we use a different approach
|
||||
symbols = []
|
||||
|
||||
# Search for common symbols that might be in this file
|
||||
# This is a simplified approach - a full implementation would query by file path
|
||||
try:
|
||||
# Use the global index to search for symbols from this file
|
||||
file_str = str(file_path.resolve())
|
||||
# Get all symbols and filter by file path (not efficient but works)
|
||||
all_symbols = self.global_index.search("", prefix_mode=True, limit=1000)
|
||||
symbols = [s for s in all_symbols if s.file and str(Path(s.file).resolve()) == file_str]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get file symbols: {e}")
|
||||
|
||||
related = [
|
||||
RelatedSymbol(
|
||||
name=s.name,
|
||||
kind=s.kind,
|
||||
relationship="defines",
|
||||
)
|
||||
for s in symbols
|
||||
]
|
||||
|
||||
return MCPContext(
|
||||
context_type=context_type,
|
||||
related_symbols=related,
|
||||
metadata={
|
||||
"file_path": str(file_path),
|
||||
"symbol_count": len(symbols),
|
||||
},
|
||||
)
|
||||
113
codex-lens/src/codexlens/mcp/schema.py
Normal file
113
codex-lens/src/codexlens/mcp/schema.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""MCP data models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class SymbolInfo:
|
||||
"""Information about a code symbol."""
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
line_start: int
|
||||
line_end: int
|
||||
signature: Optional[str] = None
|
||||
documentation: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceInfo:
|
||||
"""Information about a symbol reference."""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelatedSymbol:
|
||||
"""Related symbol (import, call target, etc.)."""
|
||||
name: str
|
||||
kind: str
|
||||
relationship: str # "imports", "calls", "inherits", "uses"
|
||||
file_path: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPContext:
|
||||
"""Model Context Protocol context object.
|
||||
|
||||
This is the structured context that gets injected into
|
||||
LLM prompts to provide code understanding.
|
||||
"""
|
||||
version: str = "1.0"
|
||||
context_type: str = "code_context"
|
||||
symbol: Optional[SymbolInfo] = None
|
||||
definition: Optional[str] = None
|
||||
references: List[ReferenceInfo] = field(default_factory=list)
|
||||
related_symbols: List[RelatedSymbol] = field(default_factory=list)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result = {
|
||||
"version": self.version,
|
||||
"context_type": self.context_type,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
if self.symbol:
|
||||
result["symbol"] = self.symbol.to_dict()
|
||||
if self.definition:
|
||||
result["definition"] = self.definition
|
||||
if self.references:
|
||||
result["references"] = [r.to_dict() for r in self.references]
|
||||
if self.related_symbols:
|
||||
result["related_symbols"] = [s.to_dict() for s in self.related_symbols]
|
||||
|
||||
return result
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""Serialize to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
def to_prompt_injection(self) -> str:
|
||||
"""Format for injection into LLM prompt."""
|
||||
parts = ["<code_context>"]
|
||||
|
||||
if self.symbol:
|
||||
parts.append(f"## Symbol: {self.symbol.name}")
|
||||
parts.append(f"Type: {self.symbol.kind}")
|
||||
parts.append(f"Location: {self.symbol.file_path}:{self.symbol.line_start}")
|
||||
|
||||
if self.definition:
|
||||
parts.append("\n## Definition")
|
||||
parts.append(f"```\n{self.definition}\n```")
|
||||
|
||||
if self.references:
|
||||
parts.append(f"\n## References ({len(self.references)} found)")
|
||||
for ref in self.references[:5]: # Limit to 5
|
||||
parts.append(f"- {ref.file_path}:{ref.line} ({ref.relationship_type})")
|
||||
parts.append(f" ```\n {ref.context}\n ```")
|
||||
|
||||
if self.related_symbols:
|
||||
parts.append("\n## Related Symbols")
|
||||
for sym in self.related_symbols[:10]: # Limit to 10
|
||||
parts.append(f"- {sym.name} ({sym.relationship})")
|
||||
|
||||
parts.append("</code_context>")
|
||||
return "\n".join(parts)
|
||||
@@ -6,10 +6,48 @@ from .chain_search import (
|
||||
quick_search,
|
||||
)
|
||||
|
||||
# Clustering availability flag (lazy import pattern)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
_clustering_import_error: str | None = None
|
||||
|
||||
try:
|
||||
from .clustering import CLUSTERING_AVAILABLE as _clustering_flag
|
||||
from .clustering import check_clustering_available
|
||||
CLUSTERING_AVAILABLE = _clustering_flag
|
||||
except ImportError as e:
|
||||
_clustering_import_error = str(e)
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Fallback when clustering module not loadable."""
|
||||
return False, _clustering_import_error
|
||||
|
||||
|
||||
# Clustering module exports (conditional)
|
||||
try:
|
||||
from .clustering import (
|
||||
BaseClusteringStrategy,
|
||||
ClusteringConfig,
|
||||
ClusteringStrategyFactory,
|
||||
get_strategy,
|
||||
)
|
||||
_clustering_exports = [
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
]
|
||||
except ImportError:
|
||||
_clustering_exports = []
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChainSearchEngine",
|
||||
"SearchOptions",
|
||||
"SearchStats",
|
||||
"ChainSearchResult",
|
||||
"quick_search",
|
||||
# Clustering
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
*_clustering_exports,
|
||||
]
|
||||
|
||||
@@ -116,6 +116,24 @@ class ChainSearchResult:
|
||||
related_results: List[SearchResult] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReferenceResult:
|
||||
"""Result from reference search in code_relationships table.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the file containing the reference
|
||||
line: Line number where the reference occurs (1-based)
|
||||
column: Column number where the reference occurs (0-based)
|
||||
context: Surrounding code snippet for context
|
||||
relationship_type: Type of relationship (call, import, inheritance, etc.)
|
||||
"""
|
||||
file_path: str
|
||||
line: int
|
||||
column: int
|
||||
context: str
|
||||
relationship_type: str
|
||||
|
||||
|
||||
class ChainSearchEngine:
|
||||
"""Parallel chain search engine for hierarchical directory indexes.
|
||||
|
||||
@@ -810,7 +828,7 @@ class ChainSearchEngine:
|
||||
k: int = 10,
|
||||
coarse_k: int = 100,
|
||||
options: Optional[SearchOptions] = None,
|
||||
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank"]] = None,
|
||||
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank", "staged"]] = None,
|
||||
) -> ChainSearchResult:
|
||||
"""Unified cascade search entry point with strategy selection.
|
||||
|
||||
@@ -819,6 +837,7 @@ class ChainSearchEngine:
|
||||
- "hybrid": Uses FTS+SPLADE+Vector coarse ranking + cross-encoder reranking (original)
|
||||
- "binary_rerank": Uses binary vector coarse ranking + cross-encoder reranking (best balance)
|
||||
- "dense_rerank": Uses dense vector coarse ranking + cross-encoder reranking
|
||||
- "staged": 4-stage pipeline: binary -> LSP expand -> clustering -> optional rerank
|
||||
|
||||
The strategy is determined with the following priority:
|
||||
1. The `strategy` parameter (e.g., from CLI --cascade-strategy option)
|
||||
@@ -831,7 +850,7 @@ class ChainSearchEngine:
|
||||
k: Number of final results to return (default 10)
|
||||
coarse_k: Number of coarse candidates from first stage (default 100)
|
||||
options: Search configuration (uses defaults if None)
|
||||
strategy: Cascade strategy - "binary", "hybrid", or "binary_rerank".
|
||||
strategy: Cascade strategy - "binary", "hybrid", "binary_rerank", "dense_rerank", or "staged".
|
||||
|
||||
Returns:
|
||||
ChainSearchResult with reranked results and statistics
|
||||
@@ -844,10 +863,12 @@ class ChainSearchEngine:
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="hybrid")
|
||||
>>> # Use binary + cross-encoder (best balance of speed and quality)
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="binary_rerank")
|
||||
>>> # Use 4-stage pipeline (binary + LSP expand + clustering + optional rerank)
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="staged")
|
||||
"""
|
||||
# Strategy priority: parameter > config > default
|
||||
effective_strategy = strategy
|
||||
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank")
|
||||
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank", "staged")
|
||||
if effective_strategy is None:
|
||||
# Not passed via parameter, check config
|
||||
if self._config is not None:
|
||||
@@ -865,9 +886,635 @@ class ChainSearchEngine:
|
||||
return self.binary_rerank_cascade_search(query, source_path, k, coarse_k, options)
|
||||
elif effective_strategy == "dense_rerank":
|
||||
return self.dense_rerank_cascade_search(query, source_path, k, coarse_k, options)
|
||||
elif effective_strategy == "staged":
|
||||
return self.staged_cascade_search(query, source_path, k, coarse_k, options)
|
||||
else:
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
|
||||
def staged_cascade_search(
|
||||
self,
|
||||
query: str,
|
||||
source_path: Path,
|
||||
k: int = 10,
|
||||
coarse_k: int = 100,
|
||||
options: Optional[SearchOptions] = None,
|
||||
) -> ChainSearchResult:
|
||||
"""Execute 4-stage cascade search pipeline with binary, LSP expansion, clustering, and optional reranking.
|
||||
|
||||
Staged cascade search process:
|
||||
1. Stage 1 (Binary Coarse): Fast binary vector search using Hamming distance
|
||||
to quickly filter to coarse_k candidates (256-bit binary vectors)
|
||||
2. Stage 2 (LSP Expansion): Expand coarse candidates using GraphExpander to
|
||||
include related symbols (definitions, references, callers/callees)
|
||||
3. Stage 3 (Clustering): Use configurable clustering strategy to group similar
|
||||
results and select representative results from each cluster
|
||||
4. Stage 4 (Optional Rerank): If config.enable_staged_rerank is True, apply
|
||||
cross-encoder reranking for final precision
|
||||
|
||||
This approach combines the speed of binary search with graph-based context
|
||||
expansion and diversity-preserving clustering for high-quality results.
|
||||
|
||||
Performance characteristics:
|
||||
- Stage 1: O(N) binary search with SIMD acceleration (~8ms)
|
||||
- Stage 2: O(k * d) graph traversal where d is expansion depth
|
||||
- Stage 3: O(n^2) clustering on expanded candidates
|
||||
- Stage 4: Optional cross-encoder reranking (API call)
|
||||
|
||||
Args:
|
||||
query: Natural language or keyword query string
|
||||
source_path: Starting directory path
|
||||
k: Number of final results to return (default 10)
|
||||
coarse_k: Number of coarse candidates from first stage (default 100)
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
ChainSearchResult with per-stage statistics
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
>>> result = engine.staged_cascade_search(
|
||||
... "authentication handler",
|
||||
... Path("D:/project/src"),
|
||||
... k=10,
|
||||
... coarse_k=100
|
||||
... )
|
||||
>>> for r in result.results:
|
||||
... print(f"{r.path}: {r.score:.3f}")
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
self.logger.warning(
|
||||
"NumPy not available, falling back to hybrid cascade search"
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
|
||||
options = options or SearchOptions()
|
||||
start_time = time.time()
|
||||
stats = SearchStats()
|
||||
|
||||
# Per-stage timing stats
|
||||
stage_times: Dict[str, float] = {}
|
||||
stage_counts: Dict[str, int] = {}
|
||||
|
||||
# Use config defaults if available
|
||||
if self._config is not None:
|
||||
if hasattr(self._config, "cascade_coarse_k"):
|
||||
coarse_k = coarse_k or self._config.cascade_coarse_k
|
||||
if hasattr(self._config, "cascade_fine_k"):
|
||||
k = k or self._config.cascade_fine_k
|
||||
|
||||
# Step 1: Find starting index
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=[],
|
||||
symbols=[],
|
||||
stats=stats
|
||||
)
|
||||
|
||||
# Step 2: Collect all index paths
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
stats.dirs_searched = len(index_paths)
|
||||
|
||||
if not index_paths:
|
||||
self.logger.warning(f"No indexes collected from {start_index}")
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=[],
|
||||
symbols=[],
|
||||
stats=stats
|
||||
)
|
||||
|
||||
# ========== Stage 1: Binary Coarse Search ==========
|
||||
stage1_start = time.time()
|
||||
coarse_results, index_root = self._stage1_binary_search(
|
||||
query, index_paths, coarse_k, stats
|
||||
)
|
||||
stage_times["stage1_binary_ms"] = (time.time() - stage1_start) * 1000
|
||||
stage_counts["stage1_candidates"] = len(coarse_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 1: Binary search found %d candidates in %.2fms",
|
||||
len(coarse_results), stage_times["stage1_binary_ms"]
|
||||
)
|
||||
|
||||
if not coarse_results:
|
||||
self.logger.debug("No binary candidates found, falling back to hybrid cascade")
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
|
||||
# ========== Stage 2: LSP Graph Expansion ==========
|
||||
stage2_start = time.time()
|
||||
expanded_results = self._stage2_lsp_expand(coarse_results, index_root)
|
||||
stage_times["stage2_expand_ms"] = (time.time() - stage2_start) * 1000
|
||||
stage_counts["stage2_expanded"] = len(expanded_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 2: LSP expansion %d -> %d results in %.2fms",
|
||||
len(coarse_results), len(expanded_results), stage_times["stage2_expand_ms"]
|
||||
)
|
||||
|
||||
# ========== Stage 3: Clustering and Representative Selection ==========
|
||||
stage3_start = time.time()
|
||||
clustered_results = self._stage3_cluster_prune(expanded_results, k * 2)
|
||||
stage_times["stage3_cluster_ms"] = (time.time() - stage3_start) * 1000
|
||||
stage_counts["stage3_clustered"] = len(clustered_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 3: Clustering %d -> %d representatives in %.2fms",
|
||||
len(expanded_results), len(clustered_results), stage_times["stage3_cluster_ms"]
|
||||
)
|
||||
|
||||
# ========== Stage 4: Optional Cross-Encoder Reranking ==========
|
||||
enable_rerank = False
|
||||
if self._config is not None:
|
||||
enable_rerank = getattr(self._config, "enable_staged_rerank", False)
|
||||
|
||||
if enable_rerank:
|
||||
stage4_start = time.time()
|
||||
final_results = self._stage4_optional_rerank(query, clustered_results, k)
|
||||
stage_times["stage4_rerank_ms"] = (time.time() - stage4_start) * 1000
|
||||
stage_counts["stage4_reranked"] = len(final_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Staged Stage 4: Reranking %d -> %d results in %.2fms",
|
||||
len(clustered_results), len(final_results), stage_times["stage4_rerank_ms"]
|
||||
)
|
||||
else:
|
||||
# Skip reranking, just take top-k by score
|
||||
final_results = sorted(
|
||||
clustered_results, key=lambda r: r.score, reverse=True
|
||||
)[:k]
|
||||
stage_counts["stage4_reranked"] = len(final_results)
|
||||
|
||||
# Deduplicate by path (keep highest score)
|
||||
path_to_result: Dict[str, SearchResult] = {}
|
||||
for result in final_results:
|
||||
if result.path not in path_to_result or result.score > path_to_result[result.path].score:
|
||||
path_to_result[result.path] = result
|
||||
|
||||
final_results = list(path_to_result.values())[:k]
|
||||
|
||||
# Optional: grouping of similar results
|
||||
if options.group_results:
|
||||
from codexlens.search.ranking import group_similar_results
|
||||
final_results = group_similar_results(
|
||||
final_results, score_threshold_abs=options.grouping_threshold
|
||||
)
|
||||
|
||||
stats.files_matched = len(final_results)
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Add per-stage stats to errors field (as JSON for now, will be proper field later)
|
||||
stage_stats_json = json.dumps({
|
||||
"stage_times": stage_times,
|
||||
"stage_counts": stage_counts,
|
||||
})
|
||||
stats.errors.append(f"STAGE_STATS:{stage_stats_json}")
|
||||
|
||||
self.logger.debug(
|
||||
"Staged cascade search complete: %d results in %.2fms "
|
||||
"(stage1=%.1fms, stage2=%.1fms, stage3=%.1fms)",
|
||||
len(final_results),
|
||||
stats.time_ms,
|
||||
stage_times.get("stage1_binary_ms", 0),
|
||||
stage_times.get("stage2_expand_ms", 0),
|
||||
stage_times.get("stage3_cluster_ms", 0),
|
||||
)
|
||||
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=final_results,
|
||||
symbols=[],
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def _stage1_binary_search(
|
||||
self,
|
||||
query: str,
|
||||
index_paths: List[Path],
|
||||
coarse_k: int,
|
||||
stats: SearchStats,
|
||||
) -> Tuple[List[SearchResult], Optional[Path]]:
|
||||
"""Stage 1: Binary vector coarse search using Hamming distance.
|
||||
|
||||
Reuses the binary coarse search logic from binary_cascade_search.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
index_paths: List of index database paths to search
|
||||
coarse_k: Number of coarse candidates to retrieve
|
||||
stats: SearchStats to update with errors
|
||||
|
||||
Returns:
|
||||
Tuple of (list of SearchResult objects, index_root path or None)
|
||||
"""
|
||||
# Initialize binary embedding backend
|
||||
try:
|
||||
from codexlens.indexing.embedding import BinaryEmbeddingBackend
|
||||
except ImportError as exc:
|
||||
self.logger.warning(
|
||||
"BinaryEmbeddingBackend not available: %s", exc
|
||||
)
|
||||
return [], None
|
||||
|
||||
# Try centralized BinarySearcher first (preferred for mmap indexes)
|
||||
index_root = index_paths[0].parent if index_paths else None
|
||||
coarse_candidates: List[Tuple[int, int, Path]] = [] # (chunk_id, distance, index_path)
|
||||
used_centralized = False
|
||||
|
||||
if index_root:
|
||||
binary_searcher = self._get_centralized_binary_searcher(index_root)
|
||||
if binary_searcher is not None:
|
||||
try:
|
||||
from codexlens.semantic.embedder import Embedder
|
||||
embedder = Embedder()
|
||||
query_dense = embedder.embed_to_numpy([query])[0]
|
||||
|
||||
results = binary_searcher.search(query_dense, top_k=coarse_k)
|
||||
for chunk_id, distance in results:
|
||||
coarse_candidates.append((chunk_id, distance, index_root))
|
||||
if coarse_candidates:
|
||||
used_centralized = True
|
||||
self.logger.debug(
|
||||
"Stage 1 centralized binary search: %d candidates", len(results)
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Centralized binary search failed: {exc}")
|
||||
|
||||
if not used_centralized:
|
||||
# Fallback to per-directory binary indexes
|
||||
use_gpu = True
|
||||
if self._config is not None:
|
||||
use_gpu = getattr(self._config, "embedding_use_gpu", True)
|
||||
|
||||
try:
|
||||
binary_backend = BinaryEmbeddingBackend(use_gpu=use_gpu)
|
||||
query_binary = binary_backend.embed_packed([query])[0]
|
||||
except Exception as exc:
|
||||
self.logger.warning(f"Failed to generate binary query embedding: {exc}")
|
||||
return [], index_root
|
||||
|
||||
for index_path in index_paths:
|
||||
try:
|
||||
binary_index = self._get_or_create_binary_index(index_path)
|
||||
if binary_index is None or binary_index.count() == 0:
|
||||
continue
|
||||
ids, distances = binary_index.search(query_binary, coarse_k)
|
||||
for chunk_id, dist in zip(ids, distances):
|
||||
coarse_candidates.append((chunk_id, dist, index_path))
|
||||
except Exception as exc:
|
||||
self.logger.debug(
|
||||
"Binary search failed for %s: %s", index_path, exc
|
||||
)
|
||||
|
||||
if not coarse_candidates:
|
||||
return [], index_root
|
||||
|
||||
# Sort by Hamming distance and take top coarse_k
|
||||
coarse_candidates.sort(key=lambda x: x[1])
|
||||
coarse_candidates = coarse_candidates[:coarse_k]
|
||||
|
||||
# Build SearchResult objects from candidates
|
||||
coarse_results: List[SearchResult] = []
|
||||
|
||||
# Group candidates by index path for efficient retrieval
|
||||
candidates_by_index: Dict[Path, List[int]] = {}
|
||||
for chunk_id, _, idx_path in coarse_candidates:
|
||||
if idx_path not in candidates_by_index:
|
||||
candidates_by_index[idx_path] = []
|
||||
candidates_by_index[idx_path].append(chunk_id)
|
||||
|
||||
# Retrieve chunk content
|
||||
import sqlite3
|
||||
central_meta_path = index_root / VECTORS_META_DB_NAME if index_root else None
|
||||
central_meta_store = None
|
||||
if central_meta_path and central_meta_path.exists():
|
||||
central_meta_store = VectorMetadataStore(central_meta_path)
|
||||
|
||||
for idx_path, chunk_ids in candidates_by_index.items():
|
||||
try:
|
||||
chunks_data = []
|
||||
if central_meta_store:
|
||||
chunks_data = central_meta_store.get_chunks_by_ids(chunk_ids)
|
||||
|
||||
if not chunks_data and used_centralized:
|
||||
meta_db_path = idx_path / VECTORS_META_DB_NAME
|
||||
if meta_db_path.exists():
|
||||
meta_store = VectorMetadataStore(meta_db_path)
|
||||
chunks_data = meta_store.get_chunks_by_ids(chunk_ids)
|
||||
|
||||
if not chunks_data:
|
||||
try:
|
||||
conn = sqlite3.connect(str(idx_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
cursor = conn.execute(
|
||||
f"""
|
||||
SELECT id, file_path, content, metadata, category
|
||||
FROM semantic_chunks
|
||||
WHERE id IN ({placeholders})
|
||||
""",
|
||||
chunk_ids
|
||||
)
|
||||
chunks_data = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"file_path": row["file_path"],
|
||||
"content": row["content"],
|
||||
"metadata": row["metadata"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for chunk in chunks_data:
|
||||
chunk_id = chunk.get("id") or chunk.get("chunk_id")
|
||||
distance = next(
|
||||
(d for cid, d, _ in coarse_candidates if cid == chunk_id),
|
||||
256
|
||||
)
|
||||
score = 1.0 - (distance / 256.0)
|
||||
|
||||
content = chunk.get("content", "")
|
||||
|
||||
# Extract symbol info from metadata if available
|
||||
metadata = chunk.get("metadata")
|
||||
symbol_name = None
|
||||
symbol_kind = None
|
||||
start_line = None
|
||||
end_line = None
|
||||
if metadata:
|
||||
try:
|
||||
meta_dict = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
symbol_name = meta_dict.get("symbol_name")
|
||||
symbol_kind = meta_dict.get("symbol_kind")
|
||||
start_line = meta_dict.get("start_line")
|
||||
end_line = meta_dict.get("end_line")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result = SearchResult(
|
||||
path=chunk.get("file_path", ""),
|
||||
score=float(score),
|
||||
excerpt=content[:500] if content else "",
|
||||
content=content,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
)
|
||||
coarse_results.append(result)
|
||||
except Exception as exc:
|
||||
self.logger.debug(
|
||||
"Failed to retrieve chunks from %s: %s", idx_path, exc
|
||||
)
|
||||
stats.errors.append(f"Stage 1 chunk retrieval failed for {idx_path}: {exc}")
|
||||
|
||||
return coarse_results, index_root
|
||||
|
||||
def _stage2_lsp_expand(
|
||||
self,
|
||||
coarse_results: List[SearchResult],
|
||||
index_root: Optional[Path],
|
||||
) -> List[SearchResult]:
|
||||
"""Stage 2: LSP-based graph expansion using GraphExpander.
|
||||
|
||||
Expands coarse results with related symbols (definitions, references,
|
||||
callers, callees) using precomputed graph neighbors.
|
||||
|
||||
Args:
|
||||
coarse_results: Results from Stage 1 binary search
|
||||
index_root: Root path of the index (for graph database access)
|
||||
|
||||
Returns:
|
||||
Combined list of original results plus expanded related results
|
||||
"""
|
||||
if not coarse_results or index_root is None:
|
||||
return coarse_results
|
||||
|
||||
try:
|
||||
from codexlens.search.graph_expander import GraphExpander
|
||||
|
||||
# Get expansion depth from config
|
||||
depth = 2
|
||||
if self._config is not None:
|
||||
depth = getattr(self._config, "graph_expansion_depth", 2)
|
||||
|
||||
expander = GraphExpander(self.mapper, config=self._config)
|
||||
|
||||
# Expand top results (limit expansion to avoid explosion)
|
||||
max_expand = min(10, len(coarse_results))
|
||||
max_related = 50
|
||||
|
||||
related_results = expander.expand(
|
||||
coarse_results,
|
||||
depth=depth,
|
||||
max_expand=max_expand,
|
||||
max_related=max_related,
|
||||
)
|
||||
|
||||
if related_results:
|
||||
self.logger.debug(
|
||||
"Stage 2 expanded %d base results to %d related symbols",
|
||||
len(coarse_results), len(related_results)
|
||||
)
|
||||
|
||||
# Combine: original results + related results
|
||||
# Keep original results first (higher relevance)
|
||||
combined = list(coarse_results)
|
||||
seen_keys = {(r.path, r.symbol_name, r.start_line) for r in coarse_results}
|
||||
|
||||
for related in related_results:
|
||||
key = (related.path, related.symbol_name, related.start_line)
|
||||
if key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
combined.append(related)
|
||||
|
||||
return combined
|
||||
|
||||
except ImportError as exc:
|
||||
self.logger.debug("GraphExpander not available: %s", exc)
|
||||
return coarse_results
|
||||
except Exception as exc:
|
||||
self.logger.debug("Stage 2 LSP expansion failed: %s", exc)
|
||||
return coarse_results
|
||||
|
||||
def _stage3_cluster_prune(
|
||||
self,
|
||||
expanded_results: List[SearchResult],
|
||||
target_count: int,
|
||||
) -> List[SearchResult]:
|
||||
"""Stage 3: Cluster expanded results and select representatives.
|
||||
|
||||
Uses the extensible clustering infrastructure from codexlens.search.clustering
|
||||
to group similar results and select the best representative from each cluster.
|
||||
|
||||
Args:
|
||||
expanded_results: Results from Stage 2 expansion
|
||||
target_count: Target number of representative results
|
||||
|
||||
Returns:
|
||||
List of representative results (one per cluster)
|
||||
"""
|
||||
if not expanded_results:
|
||||
return []
|
||||
|
||||
# If few results, skip clustering
|
||||
if len(expanded_results) <= target_count:
|
||||
return expanded_results
|
||||
|
||||
try:
|
||||
from codexlens.search.clustering import (
|
||||
ClusteringConfig,
|
||||
get_strategy,
|
||||
)
|
||||
|
||||
# Get clustering config from config
|
||||
strategy_name = "auto"
|
||||
min_cluster_size = 3
|
||||
|
||||
if self._config is not None:
|
||||
strategy_name = getattr(self._config, "staged_clustering_strategy", "auto")
|
||||
min_cluster_size = getattr(self._config, "staged_clustering_min_size", 3)
|
||||
|
||||
# Get embeddings for clustering
|
||||
# Try to get dense embeddings from results' content
|
||||
embeddings = self._get_embeddings_for_clustering(expanded_results)
|
||||
|
||||
if embeddings is None or len(embeddings) == 0:
|
||||
# No embeddings available, fall back to score-based selection
|
||||
self.logger.debug("No embeddings for clustering, using score-based selection")
|
||||
return sorted(
|
||||
expanded_results, key=lambda r: r.score, reverse=True
|
||||
)[:target_count]
|
||||
|
||||
# Create clustering config
|
||||
config = ClusteringConfig(
|
||||
min_cluster_size=min(min_cluster_size, max(2, len(expanded_results) // 5)),
|
||||
min_samples=2,
|
||||
metric="cosine",
|
||||
)
|
||||
|
||||
# Get strategy with fallback
|
||||
strategy = get_strategy(strategy_name, config, fallback=True)
|
||||
|
||||
# Cluster and select representatives
|
||||
representatives = strategy.fit_predict(embeddings, expanded_results)
|
||||
|
||||
self.logger.debug(
|
||||
"Stage 3 clustered %d results into %d representatives using %s",
|
||||
len(expanded_results), len(representatives), type(strategy).__name__
|
||||
)
|
||||
|
||||
# If clustering returned too few, supplement with top-scored unclustered
|
||||
if len(representatives) < target_count:
|
||||
rep_paths = {r.path for r in representatives}
|
||||
remaining = [r for r in expanded_results if r.path not in rep_paths]
|
||||
remaining_sorted = sorted(remaining, key=lambda r: r.score, reverse=True)
|
||||
representatives.extend(remaining_sorted[:target_count - len(representatives)])
|
||||
|
||||
return representatives[:target_count]
|
||||
|
||||
except ImportError as exc:
|
||||
self.logger.debug("Clustering not available: %s", exc)
|
||||
return sorted(
|
||||
expanded_results, key=lambda r: r.score, reverse=True
|
||||
)[:target_count]
|
||||
except Exception as exc:
|
||||
self.logger.debug("Stage 3 clustering failed: %s", exc)
|
||||
return sorted(
|
||||
expanded_results, key=lambda r: r.score, reverse=True
|
||||
)[:target_count]
|
||||
|
||||
def _stage4_optional_rerank(
|
||||
self,
|
||||
query: str,
|
||||
clustered_results: List[SearchResult],
|
||||
k: int,
|
||||
) -> List[SearchResult]:
|
||||
"""Stage 4: Optional cross-encoder reranking.
|
||||
|
||||
Applies cross-encoder reranking if enabled in config.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
clustered_results: Results from Stage 3 clustering
|
||||
k: Number of final results to return
|
||||
|
||||
Returns:
|
||||
Reranked results sorted by cross-encoder score
|
||||
"""
|
||||
if not clustered_results:
|
||||
return []
|
||||
|
||||
# Use existing _cross_encoder_rerank method
|
||||
return self._cross_encoder_rerank(query, clustered_results, k)
|
||||
|
||||
def _get_embeddings_for_clustering(
|
||||
self,
|
||||
results: List[SearchResult],
|
||||
) -> Optional["np.ndarray"]:
|
||||
"""Get dense embeddings for clustering results.
|
||||
|
||||
Tries to generate embeddings from result content for clustering.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
|
||||
Returns:
|
||||
NumPy array of embeddings or None if not available
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
return None
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
try:
|
||||
from codexlens.semantic.factory import get_embedder
|
||||
|
||||
# Get embedding settings from config
|
||||
embedding_backend = "fastembed"
|
||||
embedding_model = "code"
|
||||
use_gpu = True
|
||||
|
||||
if self._config is not None:
|
||||
embedding_backend = getattr(self._config, "embedding_backend", "fastembed")
|
||||
embedding_model = getattr(self._config, "embedding_model", "code")
|
||||
use_gpu = getattr(self._config, "embedding_use_gpu", True)
|
||||
|
||||
# Create embedder
|
||||
if embedding_backend == "litellm":
|
||||
embedder = get_embedder(backend="litellm", model=embedding_model)
|
||||
else:
|
||||
embedder = get_embedder(backend="fastembed", profile=embedding_model, use_gpu=use_gpu)
|
||||
|
||||
# Extract text content from results
|
||||
texts = []
|
||||
for result in results:
|
||||
# Use content if available, otherwise use excerpt
|
||||
text = result.content or result.excerpt or ""
|
||||
if not text and result.path:
|
||||
text = result.path
|
||||
texts.append(text[:2000]) # Limit text length
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = embedder.embed_to_numpy(texts)
|
||||
return embeddings
|
||||
|
||||
except ImportError as exc:
|
||||
self.logger.debug("Embedder not available for clustering: %s", exc)
|
||||
return None
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to generate embeddings for clustering: %s", exc)
|
||||
return None
|
||||
|
||||
def binary_rerank_cascade_search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -1990,6 +2637,220 @@ class ChainSearchEngine:
|
||||
index_paths, name, kind, options.total_limit
|
||||
)
|
||||
|
||||
def search_references(
|
||||
self,
|
||||
symbol_name: str,
|
||||
source_path: Optional[Path] = None,
|
||||
depth: int = -1,
|
||||
limit: int = 100,
|
||||
) -> List[ReferenceResult]:
|
||||
"""Find all references to a symbol across the project.
|
||||
|
||||
Searches the code_relationships table in all index databases to find
|
||||
where the given symbol is referenced (called, imported, inherited, etc.).
|
||||
|
||||
Args:
|
||||
symbol_name: Fully qualified or simple name of the symbol to find references to
|
||||
source_path: Starting path for search (default: workspace root from registry)
|
||||
depth: Search depth (-1 = unlimited, 0 = current dir only)
|
||||
limit: Maximum results to return (default 100)
|
||||
|
||||
Returns:
|
||||
List of ReferenceResult objects sorted by file path and line number
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> refs = engine.search_references("authenticate", Path("D:/project/src"))
|
||||
>>> for ref in refs[:10]:
|
||||
... print(f"{ref.file_path}:{ref.line} ({ref.relationship_type})")
|
||||
"""
|
||||
import sqlite3
|
||||
from concurrent.futures import as_completed
|
||||
|
||||
# Determine starting path
|
||||
if source_path is None:
|
||||
# Try to get workspace root from registry
|
||||
mappings = self.registry.list_mappings()
|
||||
if mappings:
|
||||
source_path = Path(mappings[0].source_path)
|
||||
else:
|
||||
self.logger.warning("No source path provided and no mappings in registry")
|
||||
return []
|
||||
|
||||
# Find starting index
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
# Collect all index paths
|
||||
index_paths = self._collect_index_paths(start_index, depth)
|
||||
if not index_paths:
|
||||
self.logger.debug(f"No indexes collected from {start_index}")
|
||||
return []
|
||||
|
||||
self.logger.debug(
|
||||
"Searching %d indexes for references to '%s'",
|
||||
len(index_paths), symbol_name
|
||||
)
|
||||
|
||||
# Search in parallel
|
||||
all_results: List[ReferenceResult] = []
|
||||
executor = self._get_executor()
|
||||
|
||||
def search_single_index(index_path: Path) -> List[ReferenceResult]:
|
||||
"""Search a single index for references."""
|
||||
results: List[ReferenceResult] = []
|
||||
try:
|
||||
conn = sqlite3.connect(str(index_path), check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Query code_relationships for references to this symbol
|
||||
# Match either target_qualified_name containing the symbol name
|
||||
# or an exact match on the last component
|
||||
# Try full_path first (new schema), fallback to path (old schema)
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT
|
||||
f.full_path as source_file,
|
||||
cr.source_line,
|
||||
cr.relationship_type,
|
||||
f.content
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name = ?
|
||||
ORDER BY f.full_path, cr.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(
|
||||
f"%{symbol_name}", # Ends with symbol name
|
||||
f"%.{symbol_name}", # Qualified name ending with .symbol_name
|
||||
symbol_name, # Exact match
|
||||
limit,
|
||||
)
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# Fallback for old schema with 'path' column
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT
|
||||
f.path as source_file,
|
||||
cr.source_line,
|
||||
cr.relationship_type,
|
||||
f.content
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name = ?
|
||||
ORDER BY f.path, cr.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(
|
||||
f"%{symbol_name}", # Ends with symbol name
|
||||
f"%.{symbol_name}", # Qualified name ending with .symbol_name
|
||||
symbol_name, # Exact match
|
||||
limit,
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row in rows:
|
||||
file_path = row["source_file"]
|
||||
line = row["source_line"] or 1
|
||||
rel_type = row["relationship_type"]
|
||||
content = row["content"] or ""
|
||||
|
||||
# Extract context (3 lines around reference)
|
||||
context = self._extract_context(content, line, context_lines=3)
|
||||
|
||||
results.append(ReferenceResult(
|
||||
file_path=file_path,
|
||||
line=line,
|
||||
column=0, # Column info not stored in code_relationships
|
||||
context=context,
|
||||
relationship_type=rel_type,
|
||||
))
|
||||
|
||||
conn.close()
|
||||
except sqlite3.DatabaseError as exc:
|
||||
self.logger.debug(
|
||||
"Failed to search references in %s: %s", index_path, exc
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.debug(
|
||||
"Unexpected error searching references in %s: %s", index_path, exc
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
# Submit parallel searches
|
||||
futures = {
|
||||
executor.submit(search_single_index, idx_path): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
results = future.result()
|
||||
all_results.extend(results)
|
||||
except Exception as exc:
|
||||
idx_path = futures[future]
|
||||
self.logger.debug(
|
||||
"Reference search failed for %s: %s", idx_path, exc
|
||||
)
|
||||
|
||||
# Deduplicate by (file_path, line)
|
||||
seen: set = set()
|
||||
unique_results: List[ReferenceResult] = []
|
||||
for ref in all_results:
|
||||
key = (ref.file_path, ref.line)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_results.append(ref)
|
||||
|
||||
# Sort by file path and line
|
||||
unique_results.sort(key=lambda r: (r.file_path, r.line))
|
||||
|
||||
# Apply limit
|
||||
return unique_results[:limit]
|
||||
|
||||
def _extract_context(
|
||||
self,
|
||||
content: str,
|
||||
line: int,
|
||||
context_lines: int = 3
|
||||
) -> str:
|
||||
"""Extract lines around a given line number from file content.
|
||||
|
||||
Args:
|
||||
content: Full file content
|
||||
line: Target line number (1-based)
|
||||
context_lines: Number of lines to include before and after
|
||||
|
||||
Returns:
|
||||
Context snippet as a string
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
lines = content.splitlines()
|
||||
total_lines = len(lines)
|
||||
|
||||
if line < 1 or line > total_lines:
|
||||
return ""
|
||||
|
||||
# Calculate range (0-indexed internally)
|
||||
start = max(0, line - 1 - context_lines)
|
||||
end = min(total_lines, line + context_lines)
|
||||
|
||||
context = lines[start:end]
|
||||
return "\n".join(context)
|
||||
|
||||
# === Internal Methods ===
|
||||
|
||||
def _find_start_index(self, source_path: Path) -> Optional[Path]:
|
||||
|
||||
124
codex-lens/src/codexlens/search/clustering/__init__.py
Normal file
124
codex-lens/src/codexlens/search/clustering/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Clustering strategies for the staged hybrid search pipeline.
|
||||
|
||||
This module provides extensible clustering infrastructure for grouping
|
||||
similar search results and selecting representative results.
|
||||
|
||||
Install with: pip install codexlens[clustering]
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import (
|
||||
... CLUSTERING_AVAILABLE,
|
||||
... ClusteringConfig,
|
||||
... get_strategy,
|
||||
... )
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy with fallback
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
>>>
|
||||
>>> # Or explicitly use a specific strategy
|
||||
>>> if CLUSTERING_AVAILABLE:
|
||||
... from codexlens.search.clustering import HDBSCANStrategy
|
||||
... strategy = HDBSCANStrategy(config)
|
||||
... representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Always export base classes and factory (no heavy dependencies)
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .factory import (
|
||||
ClusteringStrategyFactory,
|
||||
check_clustering_strategy_available,
|
||||
get_strategy,
|
||||
)
|
||||
from .noop_strategy import NoOpStrategy
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
|
||||
# Feature flag for clustering availability (hdbscan + sklearn)
|
||||
CLUSTERING_AVAILABLE = False
|
||||
HDBSCAN_AVAILABLE = False
|
||||
DBSCAN_AVAILABLE = False
|
||||
_import_error: str | None = None
|
||||
|
||||
|
||||
def _detect_clustering_available() -> tuple[bool, bool, bool, str | None]:
|
||||
"""Detect if clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_available, hdbscan_available, dbscan_available, error_message).
|
||||
"""
|
||||
hdbscan_ok = False
|
||||
dbscan_ok = False
|
||||
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
hdbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
dbscan_ok = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
all_ok = hdbscan_ok and dbscan_ok
|
||||
error = None
|
||||
if not all_ok:
|
||||
missing = []
|
||||
if not hdbscan_ok:
|
||||
missing.append("hdbscan")
|
||||
if not dbscan_ok:
|
||||
missing.append("scikit-learn")
|
||||
error = f"{', '.join(missing)} not available. Install with: pip install codexlens[clustering]"
|
||||
|
||||
return all_ok, hdbscan_ok, dbscan_ok, error
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
CLUSTERING_AVAILABLE, HDBSCAN_AVAILABLE, DBSCAN_AVAILABLE, _import_error = (
|
||||
_detect_clustering_available()
|
||||
)
|
||||
|
||||
|
||||
def check_clustering_available() -> tuple[bool, str | None]:
|
||||
"""Check if all clustering dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
return CLUSTERING_AVAILABLE, _import_error
|
||||
|
||||
|
||||
# Conditionally export strategy implementations
|
||||
__all__ = [
|
||||
# Feature flags
|
||||
"CLUSTERING_AVAILABLE",
|
||||
"HDBSCAN_AVAILABLE",
|
||||
"DBSCAN_AVAILABLE",
|
||||
"check_clustering_available",
|
||||
# Base classes
|
||||
"BaseClusteringStrategy",
|
||||
"ClusteringConfig",
|
||||
# Factory
|
||||
"ClusteringStrategyFactory",
|
||||
"get_strategy",
|
||||
"check_clustering_strategy_available",
|
||||
# Always-available strategies
|
||||
"NoOpStrategy",
|
||||
"FrequencyStrategy",
|
||||
"FrequencyConfig",
|
||||
]
|
||||
|
||||
# Conditionally add strategy classes to __all__ and module namespace
|
||||
if HDBSCAN_AVAILABLE:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
|
||||
__all__.append("HDBSCANStrategy")
|
||||
|
||||
if DBSCAN_AVAILABLE:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
|
||||
__all__.append("DBSCANStrategy")
|
||||
153
codex-lens/src/codexlens/search/clustering/base.py
Normal file
153
codex-lens/src/codexlens/search/clustering/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Base classes for clustering strategies in the hybrid search pipeline.
|
||||
|
||||
This module defines the abstract base class for clustering strategies used
|
||||
in the staged hybrid search pipeline. Strategies cluster search results
|
||||
based on their embeddings and select representative results from each cluster.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusteringConfig:
|
||||
"""Configuration parameters for clustering strategies.
|
||||
|
||||
Attributes:
|
||||
min_cluster_size: Minimum number of results to form a cluster.
|
||||
HDBSCAN default is 5, but for search results 2-3 is often better.
|
||||
min_samples: Number of samples in a neighborhood for a point to be
|
||||
considered a core point. Lower values allow more clusters.
|
||||
metric: Distance metric for clustering. Common options:
|
||||
- 'euclidean': Standard L2 distance
|
||||
- 'cosine': Cosine distance (1 - cosine_similarity)
|
||||
- 'manhattan': L1 distance
|
||||
cluster_selection_epsilon: Distance threshold for cluster selection.
|
||||
Results within this distance may be merged into the same cluster.
|
||||
allow_single_cluster: If True, allow all results to form one cluster.
|
||||
Useful when results are very similar.
|
||||
prediction_data: If True, generate prediction data for new points.
|
||||
"""
|
||||
|
||||
min_cluster_size: int = 3
|
||||
min_samples: int = 2
|
||||
metric: str = "cosine"
|
||||
cluster_selection_epsilon: float = 0.0
|
||||
allow_single_cluster: bool = True
|
||||
prediction_data: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
if self.min_cluster_size < 2:
|
||||
raise ValueError("min_cluster_size must be >= 2")
|
||||
if self.min_samples < 1:
|
||||
raise ValueError("min_samples must be >= 1")
|
||||
if self.metric not in ("euclidean", "cosine", "manhattan"):
|
||||
raise ValueError(f"metric must be one of: euclidean, cosine, manhattan; got {self.metric}")
|
||||
if self.cluster_selection_epsilon < 0:
|
||||
raise ValueError("cluster_selection_epsilon must be >= 0")
|
||||
|
||||
|
||||
class BaseClusteringStrategy(ABC):
|
||||
"""Abstract base class for clustering strategies.
|
||||
|
||||
Clustering strategies are used in the staged hybrid search pipeline to
|
||||
group similar search results and select representative results from each
|
||||
cluster, reducing redundancy while maintaining diversity.
|
||||
|
||||
Subclasses must implement:
|
||||
- cluster(): Group results into clusters based on embeddings
|
||||
- select_representatives(): Choose best result(s) from each cluster
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize the clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config = config or ClusteringConfig()
|
||||
|
||||
@abstractmethod
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results based on their embeddings.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
Used for additional metadata during clustering.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Results not assigned to any cluster
|
||||
(noise points) should be returned as single-element clusters.
|
||||
|
||||
Example:
|
||||
>>> strategy = HDBSCANStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # clusters = [[0, 2, 5], [1, 3], [4], [6, 7, 8]]
|
||||
>>> # Result indices 0, 2, 5 are in cluster 0
|
||||
>>> # Result indices 1, 3 are in cluster 1
|
||||
>>> # Result index 4 is a noise point (singleton cluster)
|
||||
>>> # Result indices 6, 7, 8 are in cluster 2
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
This method chooses the best result(s) from each cluster to include
|
||||
in the final search results. The selection can be based on:
|
||||
- Highest score within cluster
|
||||
- Closest to cluster centroid
|
||||
- Custom selection logic
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings array for centroid-based selection.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one or more per cluster,
|
||||
ordered by relevance (highest score first).
|
||||
|
||||
Example:
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns best result from each cluster
|
||||
"""
|
||||
...
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
197
codex-lens/src/codexlens/search/clustering/dbscan_strategy.py
Normal file
197
codex-lens/src/codexlens/search/clustering/dbscan_strategy.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""DBSCAN-based clustering strategy for search results.
|
||||
|
||||
DBSCAN (Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the fallback clustering strategy when HDBSCAN is not available.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class DBSCANStrategy(BaseClusteringStrategy):
|
||||
"""DBSCAN-based clustering strategy.
|
||||
|
||||
Uses sklearn's DBSCAN algorithm as a fallback when HDBSCAN is not available.
|
||||
DBSCAN requires an explicit eps parameter, which is auto-computed from the
|
||||
distance distribution if not provided.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import DBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = DBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
# Default eps percentile for auto-computation
|
||||
DEFAULT_EPS_PERCENTILE: float = 15.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
eps: Optional[float] = None,
|
||||
eps_percentile: float = DEFAULT_EPS_PERCENTILE,
|
||||
) -> None:
|
||||
"""Initialize DBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
eps: Explicit eps parameter for DBSCAN. If None, auto-computed
|
||||
from the distance distribution.
|
||||
eps_percentile: Percentile of pairwise distances to use for
|
||||
auto-computing eps. Default is 15th percentile.
|
||||
|
||||
Raises:
|
||||
ImportError: If sklearn is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.eps = eps
|
||||
self.eps_percentile = eps_percentile
|
||||
|
||||
# Validate sklearn is available
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"scikit-learn package is required for DBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def _compute_eps(self, embeddings: "np.ndarray") -> float:
|
||||
"""Auto-compute eps from pairwise distance distribution.
|
||||
|
||||
Uses the specified percentile of pairwise distances as eps,
|
||||
which typically captures local density well.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
|
||||
Returns:
|
||||
Computed eps value.
|
||||
"""
|
||||
import numpy as np
|
||||
from sklearn.metrics import pairwise_distances
|
||||
|
||||
# Compute pairwise distances
|
||||
distances = pairwise_distances(embeddings, metric=self.config.metric)
|
||||
|
||||
# Get upper triangle (excluding diagonal)
|
||||
upper_tri = distances[np.triu_indices_from(distances, k=1)]
|
||||
|
||||
if len(upper_tri) == 0:
|
||||
# Only one point, return a default small eps
|
||||
return 0.1
|
||||
|
||||
# Use percentile of distances as eps
|
||||
eps = float(np.percentile(upper_tri, self.eps_percentile))
|
||||
|
||||
# Ensure eps is positive
|
||||
return max(eps, 1e-6)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using DBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
from sklearn.cluster import DBSCAN
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: single result
|
||||
if n_results == 1:
|
||||
return [[0]]
|
||||
|
||||
# Determine eps value
|
||||
eps = self.eps if self.eps is not None else self._compute_eps(embeddings)
|
||||
|
||||
# Configure DBSCAN clusterer
|
||||
# Note: DBSCAN min_samples corresponds to min_cluster_size concept
|
||||
clusterer = DBSCAN(
|
||||
eps=eps,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
202
codex-lens/src/codexlens/search/clustering/factory.py
Normal file
202
codex-lens/src/codexlens/search/clustering/factory.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Factory for creating clustering strategies.
|
||||
|
||||
Provides a unified interface for instantiating different clustering backends
|
||||
with automatic fallback chain: hdbscan -> dbscan -> noop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
from .noop_strategy import NoOpStrategy
|
||||
|
||||
|
||||
def check_clustering_strategy_available(strategy: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific clustering strategy can be used.
|
||||
|
||||
Args:
|
||||
strategy: Strategy name to check. Options:
|
||||
- "hdbscan": HDBSCAN clustering (requires hdbscan package)
|
||||
- "dbscan": DBSCAN clustering (requires sklearn)
|
||||
- "frequency": Frequency-based clustering (always available)
|
||||
- "noop": No-op strategy (always available)
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, error_message).
|
||||
error_message is None if available, otherwise contains install instructions.
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
if strategy == "hdbscan":
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"hdbscan package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "dbscan":
|
||||
try:
|
||||
from sklearn.cluster import DBSCAN # noqa: F401
|
||||
except ImportError:
|
||||
return False, (
|
||||
"scikit-learn package not available. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
)
|
||||
return True, None
|
||||
|
||||
if strategy == "frequency":
|
||||
# Frequency strategy is always available (no external deps)
|
||||
return True, None
|
||||
|
||||
if strategy == "noop":
|
||||
return True, None
|
||||
|
||||
return False, (
|
||||
f"Invalid clustering strategy: {strategy}. "
|
||||
"Must be 'hdbscan', 'dbscan', 'frequency', or 'noop'."
|
||||
)
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy: str = "hdbscan",
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
*,
|
||||
fallback: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Factory function to create clustering strategy with fallback chain.
|
||||
|
||||
The fallback chain is: hdbscan -> dbscan -> frequency -> noop
|
||||
|
||||
Args:
|
||||
strategy: Clustering strategy to use. Options:
|
||||
- "hdbscan": HDBSCAN clustering (default, recommended)
|
||||
- "dbscan": DBSCAN clustering (fallback)
|
||||
- "frequency": Frequency-based clustering (groups by symbol occurrence)
|
||||
- "noop": No-op strategy (returns all results ungrouped)
|
||||
- "auto": Try hdbscan, then dbscan, then noop
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
For frequency strategy, pass FrequencyConfig for full control.
|
||||
fallback: If True (default), automatically fall back to next strategy
|
||||
in the chain when primary is unavailable. If False, raise ImportError
|
||||
when requested strategy is unavailable.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
For DBSCANStrategy: eps, eps_percentile
|
||||
For FrequencyStrategy: group_by, min_frequency, etc.
|
||||
|
||||
Returns:
|
||||
BaseClusteringStrategy: Configured clustering strategy instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If strategy is not recognized.
|
||||
ImportError: If required dependencies are not installed and fallback=False.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import get_strategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3)
|
||||
>>> # Auto-select best available strategy
|
||||
>>> strategy = get_strategy("auto", config)
|
||||
>>> # Explicitly use HDBSCAN (will fall back if unavailable)
|
||||
>>> strategy = get_strategy("hdbscan", config)
|
||||
>>> # Use frequency-based strategy
|
||||
>>> from codexlens.search.clustering import FrequencyConfig
|
||||
>>> freq_config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = get_strategy("frequency", freq_config)
|
||||
"""
|
||||
strategy = (strategy or "").strip().lower()
|
||||
|
||||
# Handle "auto" - try strategies in order
|
||||
if strategy == "auto":
|
||||
return _get_best_available_strategy(config, **kwargs)
|
||||
|
||||
if strategy == "hdbscan":
|
||||
ok, err = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
if fallback:
|
||||
# Try dbscan fallback
|
||||
ok_dbscan, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok_dbscan:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
# Final fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "dbscan":
|
||||
ok, err = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
if fallback:
|
||||
# Fallback to noop
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ImportError(err)
|
||||
|
||||
if strategy == "frequency":
|
||||
from .frequency_strategy import FrequencyStrategy, FrequencyConfig
|
||||
# If config is ClusteringConfig but not FrequencyConfig, create default FrequencyConfig
|
||||
if config is None or not isinstance(config, FrequencyConfig):
|
||||
freq_config = FrequencyConfig(**kwargs) if kwargs else FrequencyConfig()
|
||||
else:
|
||||
freq_config = config
|
||||
return FrequencyStrategy(freq_config)
|
||||
|
||||
if strategy == "noop":
|
||||
return NoOpStrategy(config)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown clustering strategy: {strategy}. "
|
||||
"Supported strategies: 'hdbscan', 'dbscan', 'frequency', 'noop', 'auto'"
|
||||
)
|
||||
|
||||
|
||||
def _get_best_available_strategy(
|
||||
config: Optional[ClusteringConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseClusteringStrategy:
|
||||
"""Get the best available clustering strategy.
|
||||
|
||||
Tries strategies in order: hdbscan -> dbscan -> noop
|
||||
|
||||
Args:
|
||||
config: Clustering configuration.
|
||||
**kwargs: Additional strategy-specific arguments.
|
||||
|
||||
Returns:
|
||||
Best available clustering strategy instance.
|
||||
"""
|
||||
# Try HDBSCAN first
|
||||
ok, _ = check_clustering_strategy_available("hdbscan")
|
||||
if ok:
|
||||
from .hdbscan_strategy import HDBSCANStrategy
|
||||
return HDBSCANStrategy(config)
|
||||
|
||||
# Try DBSCAN second
|
||||
ok, _ = check_clustering_strategy_available("dbscan")
|
||||
if ok:
|
||||
from .dbscan_strategy import DBSCANStrategy
|
||||
return DBSCANStrategy(config, **kwargs)
|
||||
|
||||
# Fallback to NoOp
|
||||
return NoOpStrategy(config)
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
ClusteringStrategyFactory = type(
|
||||
"ClusteringStrategyFactory",
|
||||
(),
|
||||
{
|
||||
"get_strategy": staticmethod(get_strategy),
|
||||
"check_available": staticmethod(check_clustering_strategy_available),
|
||||
},
|
||||
)
|
||||
263
codex-lens/src/codexlens/search/clustering/frequency_strategy.py
Normal file
263
codex-lens/src/codexlens/search/clustering/frequency_strategy.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol/method name and prunes based on
|
||||
occurrence frequency. High-frequency symbols (frequently referenced methods)
|
||||
are considered more important and retained, while low-frequency results
|
||||
(potentially noise) can be filtered out.
|
||||
|
||||
Use cases:
|
||||
- Prioritize commonly called methods/functions
|
||||
- Filter out one-off results that may be less relevant
|
||||
- Deduplicate results pointing to the same symbol from different locations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Literal
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrequencyConfig(ClusteringConfig):
|
||||
"""Configuration for frequency-based clustering strategy.
|
||||
|
||||
Attributes:
|
||||
group_by: Field to group results by for frequency counting.
|
||||
- 'symbol': Group by symbol_name (default, for method/function dedup)
|
||||
- 'file': Group by file path
|
||||
- 'symbol_kind': Group by symbol type (function, class, etc.)
|
||||
min_frequency: Minimum occurrence count to keep a result.
|
||||
Results appearing less than this are considered noise and pruned.
|
||||
max_representatives_per_group: Maximum results to keep per symbol group.
|
||||
frequency_weight: How much to boost score based on frequency.
|
||||
Final score = original_score * (1 + frequency_weight * log(frequency))
|
||||
keep_mode: How to handle low-frequency results.
|
||||
- 'filter': Remove results below min_frequency
|
||||
- 'demote': Keep but lower their score ranking
|
||||
"""
|
||||
|
||||
group_by: Literal["symbol", "file", "symbol_kind"] = "symbol"
|
||||
min_frequency: int = 1 # 1 means keep all, 2+ filters singletons
|
||||
max_representatives_per_group: int = 3
|
||||
frequency_weight: float = 0.1 # Boost factor for frequency
|
||||
keep_mode: Literal["filter", "demote"] = "demote"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration parameters."""
|
||||
# Skip parent validation since we don't use HDBSCAN params
|
||||
if self.min_frequency < 1:
|
||||
raise ValueError("min_frequency must be >= 1")
|
||||
if self.max_representatives_per_group < 1:
|
||||
raise ValueError("max_representatives_per_group must be >= 1")
|
||||
if self.frequency_weight < 0:
|
||||
raise ValueError("frequency_weight must be >= 0")
|
||||
if self.group_by not in ("symbol", "file", "symbol_kind"):
|
||||
raise ValueError(f"group_by must be one of: symbol, file, symbol_kind; got {self.group_by}")
|
||||
if self.keep_mode not in ("filter", "demote"):
|
||||
raise ValueError(f"keep_mode must be one of: filter, demote; got {self.keep_mode}")
|
||||
|
||||
|
||||
class FrequencyStrategy(BaseClusteringStrategy):
|
||||
"""Frequency-based clustering strategy for search result deduplication.
|
||||
|
||||
This strategy groups search results by symbol name (or file/kind) and:
|
||||
1. Counts how many times each symbol appears in results
|
||||
2. Higher frequency = more important (frequently referenced method)
|
||||
3. Filters or demotes low-frequency results
|
||||
4. Selects top representatives from each frequency group
|
||||
|
||||
Unlike embedding-based strategies (HDBSCAN, DBSCAN), this strategy:
|
||||
- Does NOT require embeddings (works with metadata only)
|
||||
- Is very fast (O(n) complexity)
|
||||
- Is deterministic (no random initialization)
|
||||
- Works well for symbol-level deduplication
|
||||
|
||||
Example:
|
||||
>>> config = FrequencyConfig(min_frequency=2, group_by="symbol")
|
||||
>>> strategy = FrequencyStrategy(config)
|
||||
>>> # Results with symbol "authenticate" appearing 5 times
|
||||
>>> # will be prioritized over "helper_func" appearing once
|
||||
>>> representatives = strategy.fit_predict(embeddings, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[FrequencyConfig] = None) -> None:
|
||||
"""Initialize the frequency strategy.
|
||||
|
||||
Args:
|
||||
config: Frequency configuration. Uses defaults if not provided.
|
||||
"""
|
||||
self.config: FrequencyConfig = config or FrequencyConfig()
|
||||
|
||||
def _get_group_key(self, result: "SearchResult") -> str:
|
||||
"""Extract grouping key from a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult to extract key from.
|
||||
|
||||
Returns:
|
||||
String key for grouping (symbol name, file path, or kind).
|
||||
"""
|
||||
if self.config.group_by == "symbol":
|
||||
# Use symbol_name if available, otherwise fall back to file:line
|
||||
symbol = getattr(result, "symbol_name", None)
|
||||
if symbol:
|
||||
return str(symbol)
|
||||
# Fallback: use file path + start_line as pseudo-symbol
|
||||
start_line = getattr(result, "start_line", 0) or 0
|
||||
return f"{result.path}:{start_line}"
|
||||
|
||||
elif self.config.group_by == "file":
|
||||
return str(result.path)
|
||||
|
||||
elif self.config.group_by == "symbol_kind":
|
||||
kind = getattr(result, "symbol_kind", None)
|
||||
return str(kind) if kind else "unknown"
|
||||
|
||||
return str(result.path) # Default fallback
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Group search results by frequency of occurrence.
|
||||
|
||||
Note: This method ignores embeddings and groups by metadata only.
|
||||
The embeddings parameter is kept for interface compatibility.
|
||||
|
||||
Args:
|
||||
embeddings: Ignored (kept for interface compatibility).
|
||||
results: List of SearchResult objects to cluster.
|
||||
|
||||
Returns:
|
||||
List of clusters (groups), where each cluster contains indices
|
||||
of results with the same grouping key. Clusters are ordered by
|
||||
frequency (highest frequency first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Group results by key
|
||||
groups: Dict[str, List[int]] = defaultdict(list)
|
||||
for idx, result in enumerate(results):
|
||||
key = self._get_group_key(result)
|
||||
groups[key].append(idx)
|
||||
|
||||
# Sort groups by frequency (descending) then by key (for stability)
|
||||
sorted_groups = sorted(
|
||||
groups.items(),
|
||||
key=lambda x: (-len(x[1]), x[0]) # -frequency, then alphabetical
|
||||
)
|
||||
|
||||
# Convert to list of clusters
|
||||
clusters = [indices for _, indices in sorted_groups]
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results based on frequency and score.
|
||||
|
||||
For each frequency group:
|
||||
1. If frequency < min_frequency: filter or demote based on keep_mode
|
||||
2. Sort by score within group
|
||||
3. Apply frequency boost to scores
|
||||
4. Select top N representatives
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (used for tie-breaking if provided).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, ordered by
|
||||
frequency-adjusted score (highest first).
|
||||
"""
|
||||
import math
|
||||
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
demoted: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
frequency = len(cluster_indices)
|
||||
|
||||
# Get results in this cluster, sorted by score
|
||||
cluster_results = [results[i] for i in cluster_indices]
|
||||
cluster_results.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
|
||||
# Check frequency threshold
|
||||
if frequency < self.config.min_frequency:
|
||||
if self.config.keep_mode == "filter":
|
||||
# Skip low-frequency results entirely
|
||||
continue
|
||||
else: # demote mode
|
||||
# Keep but add to demoted list (lower priority)
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
demoted.append(result)
|
||||
continue
|
||||
|
||||
# Apply frequency boost and select top representatives
|
||||
for result in cluster_results[: self.config.max_representatives_per_group]:
|
||||
# Calculate frequency-boosted score
|
||||
original_score = getattr(result, "score", 0.0)
|
||||
# log(frequency + 1) to handle frequency=1 case smoothly
|
||||
frequency_boost = 1.0 + self.config.frequency_weight * math.log(frequency + 1)
|
||||
boosted_score = original_score * frequency_boost
|
||||
|
||||
# Create new result with boosted score and frequency metadata
|
||||
# Note: SearchResult might be immutable, so we preserve original
|
||||
# and track boosted score in metadata
|
||||
if hasattr(result, "metadata") and isinstance(result.metadata, dict):
|
||||
result.metadata["frequency"] = frequency
|
||||
result.metadata["frequency_boosted_score"] = boosted_score
|
||||
|
||||
representatives.append(result)
|
||||
|
||||
# Sort representatives by boosted score (or original score as fallback)
|
||||
def get_sort_score(r: "SearchResult") -> float:
|
||||
if hasattr(r, "metadata") and isinstance(r.metadata, dict):
|
||||
return r.metadata.get("frequency_boosted_score", getattr(r, "score", 0.0))
|
||||
return getattr(r, "score", 0.0)
|
||||
|
||||
representatives.sort(key=get_sort_score, reverse=True)
|
||||
|
||||
# Add demoted results at the end
|
||||
if demoted:
|
||||
demoted.sort(key=lambda r: getattr(r, "score", 0.0), reverse=True)
|
||||
representatives.extend(demoted)
|
||||
|
||||
return representatives
|
||||
|
||||
def fit_predict(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List["SearchResult"]:
|
||||
"""Convenience method to cluster and select representatives in one call.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array (may be ignored for frequency-based clustering).
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects.
|
||||
"""
|
||||
clusters = self.cluster(embeddings, results)
|
||||
return self.select_representatives(clusters, results, embeddings)
|
||||
153
codex-lens/src/codexlens/search/clustering/hdbscan_strategy.py
Normal file
153
codex-lens/src/codexlens/search/clustering/hdbscan_strategy.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""HDBSCAN-based clustering strategy for search results.
|
||||
|
||||
HDBSCAN (Hierarchical Density-Based Spatial Clustering of Applications with Noise)
|
||||
is the primary clustering strategy for grouping similar search results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class HDBSCANStrategy(BaseClusteringStrategy):
|
||||
"""HDBSCAN-based clustering strategy.
|
||||
|
||||
Uses HDBSCAN algorithm to cluster search results based on embedding similarity.
|
||||
HDBSCAN is preferred over DBSCAN because it:
|
||||
- Automatically determines the number of clusters
|
||||
- Handles varying density clusters well
|
||||
- Identifies noise points (outliers) effectively
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import HDBSCANStrategy, ClusteringConfig
|
||||
>>> config = ClusteringConfig(min_cluster_size=3, metric='cosine')
|
||||
>>> strategy = HDBSCANStrategy(config)
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize HDBSCAN clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Uses defaults if not provided.
|
||||
|
||||
Raises:
|
||||
ImportError: If hdbscan package is not installed.
|
||||
"""
|
||||
super().__init__(config)
|
||||
# Validate hdbscan is available
|
||||
try:
|
||||
import hdbscan # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"hdbscan package is required for HDBSCANStrategy. "
|
||||
"Install with: pip install codexlens[clustering]"
|
||||
) from exc
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Cluster search results using HDBSCAN algorithm.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim)
|
||||
containing the embedding vectors for each result.
|
||||
results: List of SearchResult objects corresponding to embeddings.
|
||||
|
||||
Returns:
|
||||
List of clusters, where each cluster is a list of indices
|
||||
into the results list. Noise points are returned as singleton clusters.
|
||||
"""
|
||||
import hdbscan
|
||||
import numpy as np
|
||||
|
||||
n_results = len(results)
|
||||
if n_results == 0:
|
||||
return []
|
||||
|
||||
# Handle edge case: fewer results than min_cluster_size
|
||||
if n_results < self.config.min_cluster_size:
|
||||
# Return each result as its own singleton cluster
|
||||
return [[i] for i in range(n_results)]
|
||||
|
||||
# Configure HDBSCAN clusterer
|
||||
clusterer = hdbscan.HDBSCAN(
|
||||
min_cluster_size=self.config.min_cluster_size,
|
||||
min_samples=self.config.min_samples,
|
||||
metric=self.config.metric,
|
||||
cluster_selection_epsilon=self.config.cluster_selection_epsilon,
|
||||
allow_single_cluster=self.config.allow_single_cluster,
|
||||
prediction_data=self.config.prediction_data,
|
||||
)
|
||||
|
||||
# Fit and get cluster labels
|
||||
# Labels: -1 = noise, 0+ = cluster index
|
||||
labels = clusterer.fit_predict(embeddings)
|
||||
|
||||
# Group indices by cluster label
|
||||
cluster_map: dict[int, list[int]] = {}
|
||||
for idx, label in enumerate(labels):
|
||||
if label not in cluster_map:
|
||||
cluster_map[label] = []
|
||||
cluster_map[label].append(idx)
|
||||
|
||||
# Build result: non-noise clusters first, then noise as singletons
|
||||
clusters: List[List[int]] = []
|
||||
|
||||
# Add proper clusters (label >= 0)
|
||||
for label in sorted(cluster_map.keys()):
|
||||
if label >= 0:
|
||||
clusters.append(cluster_map[label])
|
||||
|
||||
# Add noise points as singleton clusters (label == -1)
|
||||
if -1 in cluster_map:
|
||||
for idx in cluster_map[-1]:
|
||||
clusters.append([idx])
|
||||
|
||||
return clusters
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Select representative results from each cluster.
|
||||
|
||||
Selects the result with the highest score from each cluster.
|
||||
|
||||
Args:
|
||||
clusters: List of clusters from cluster() method.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used in score-based selection).
|
||||
|
||||
Returns:
|
||||
List of representative SearchResult objects, one per cluster,
|
||||
ordered by score (highest first).
|
||||
"""
|
||||
if not clusters or not results:
|
||||
return []
|
||||
|
||||
representatives: List["SearchResult"] = []
|
||||
|
||||
for cluster_indices in clusters:
|
||||
if not cluster_indices:
|
||||
continue
|
||||
|
||||
# Find the result with the highest score in this cluster
|
||||
best_idx = max(cluster_indices, key=lambda i: results[i].score)
|
||||
representatives.append(results[best_idx])
|
||||
|
||||
# Sort by score descending
|
||||
representatives.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return representatives
|
||||
83
codex-lens/src/codexlens/search/clustering/noop_strategy.py
Normal file
83
codex-lens/src/codexlens/search/clustering/noop_strategy.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""No-op clustering strategy for search results.
|
||||
|
||||
NoOpStrategy returns all results ungrouped when clustering dependencies
|
||||
are not available or clustering is disabled.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from .base import BaseClusteringStrategy, ClusteringConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from codexlens.entities import SearchResult
|
||||
|
||||
|
||||
class NoOpStrategy(BaseClusteringStrategy):
|
||||
"""No-op clustering strategy that returns all results ungrouped.
|
||||
|
||||
This strategy is used as a final fallback when no clustering dependencies
|
||||
are available, or when clustering is explicitly disabled. Each result
|
||||
is treated as its own singleton cluster.
|
||||
|
||||
Example:
|
||||
>>> from codexlens.search.clustering import NoOpStrategy
|
||||
>>> strategy = NoOpStrategy()
|
||||
>>> clusters = strategy.cluster(embeddings, results)
|
||||
>>> # Returns [[0], [1], [2], ...] - each result in its own cluster
|
||||
>>> representatives = strategy.select_representatives(clusters, results)
|
||||
>>> # Returns all results sorted by score
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ClusteringConfig] = None) -> None:
|
||||
"""Initialize NoOp clustering strategy.
|
||||
|
||||
Args:
|
||||
config: Clustering configuration. Ignored for NoOpStrategy
|
||||
but accepted for interface compatibility.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
def cluster(
|
||||
self,
|
||||
embeddings: "np.ndarray",
|
||||
results: List["SearchResult"],
|
||||
) -> List[List[int]]:
|
||||
"""Return each result as its own singleton cluster.
|
||||
|
||||
Args:
|
||||
embeddings: NumPy array of shape (n_results, embedding_dim).
|
||||
Not used but accepted for interface compatibility.
|
||||
results: List of SearchResult objects.
|
||||
|
||||
Returns:
|
||||
List of singleton clusters, one per result.
|
||||
"""
|
||||
return [[i] for i in range(len(results))]
|
||||
|
||||
def select_representatives(
|
||||
self,
|
||||
clusters: List[List[int]],
|
||||
results: List["SearchResult"],
|
||||
embeddings: Optional["np.ndarray"] = None,
|
||||
) -> List["SearchResult"]:
|
||||
"""Return all results sorted by score.
|
||||
|
||||
Since each cluster is a singleton, this effectively returns all
|
||||
results sorted by score descending.
|
||||
|
||||
Args:
|
||||
clusters: List of singleton clusters.
|
||||
results: Original list of SearchResult objects.
|
||||
embeddings: Optional embeddings (not used).
|
||||
|
||||
Returns:
|
||||
All SearchResult objects sorted by score (highest first).
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
# Return all results sorted by score
|
||||
return sorted(results, key=lambda r: r.score, reverse=True)
|
||||
@@ -1807,6 +1807,178 @@ class DirIndexStore:
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
|
||||
"""Get all symbols in a specific file, sorted by start_line.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the file
|
||||
|
||||
Returns:
|
||||
List of Symbol objects sorted by start_line
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
# First get the file_id
|
||||
file_row = conn.execute(
|
||||
"SELECT id FROM files WHERE full_path=?",
|
||||
(file_path_str,),
|
||||
).fetchone()
|
||||
|
||||
if not file_row:
|
||||
return []
|
||||
|
||||
file_id = int(file_row["id"])
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.name, s.kind, s.start_line, s.end_line
|
||||
FROM symbols s
|
||||
WHERE s.file_id=?
|
||||
ORDER BY s.start_line
|
||||
""",
|
||||
(file_id,),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
Symbol(
|
||||
name=row["name"],
|
||||
kind=row["kind"],
|
||||
range=(row["start_line"], row["end_line"]),
|
||||
file=file_path_str,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_outgoing_calls(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
symbol_name: Optional[str] = None,
|
||||
) -> List[Tuple[str, str, int, Optional[str]]]:
|
||||
"""Get outgoing calls from symbols in a file.
|
||||
|
||||
Queries code_relationships table for calls originating from symbols
|
||||
in the specified file.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the source file
|
||||
symbol_name: Optional symbol name to filter by. If None, returns
|
||||
calls from all symbols in the file.
|
||||
|
||||
Returns:
|
||||
List of tuples: (target_name, relationship_type, source_line, target_file)
|
||||
- target_name: Qualified name of the call target
|
||||
- relationship_type: Type of relationship (e.g., "calls", "imports")
|
||||
- source_line: Line number where the call occurs
|
||||
- target_file: Target file path (may be None if unknown)
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
# First get the file_id
|
||||
file_row = conn.execute(
|
||||
"SELECT id FROM files WHERE full_path=?",
|
||||
(file_path_str,),
|
||||
).fetchone()
|
||||
|
||||
if not file_row:
|
||||
return []
|
||||
|
||||
file_id = int(file_row["id"])
|
||||
|
||||
if symbol_name:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT cr.target_qualified_name, cr.relationship_type,
|
||||
cr.source_line, cr.target_file
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
WHERE s.file_id=? AND s.name=?
|
||||
ORDER BY cr.source_line
|
||||
""",
|
||||
(file_id, symbol_name),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT cr.target_qualified_name, cr.relationship_type,
|
||||
cr.source_line, cr.target_file
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
WHERE s.file_id=?
|
||||
ORDER BY cr.source_line
|
||||
""",
|
||||
(file_id,),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
(
|
||||
row["target_qualified_name"],
|
||||
row["relationship_type"],
|
||||
int(row["source_line"]),
|
||||
row["target_file"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_incoming_calls(
|
||||
self,
|
||||
target_name: str,
|
||||
limit: int = 100,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
"""Get incoming calls/references to a target symbol.
|
||||
|
||||
Queries code_relationships table for references to the specified
|
||||
target symbol name.
|
||||
|
||||
Args:
|
||||
target_name: Name of the target symbol to find references for.
|
||||
Matches against target_qualified_name (exact match,
|
||||
suffix match, or contains match).
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of tuples: (source_symbol_name, relationship_type, source_line, source_file)
|
||||
- source_symbol_name: Name of the calling symbol
|
||||
- relationship_type: Type of relationship (e.g., "calls", "imports")
|
||||
- source_line: Line number where the call occurs
|
||||
- source_file: Full path to the source file
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.name AS source_name, cr.relationship_type,
|
||||
cr.source_line, f.full_path AS source_file
|
||||
FROM code_relationships cr
|
||||
JOIN symbols s ON s.id = cr.source_symbol_id
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE cr.target_qualified_name = ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
OR cr.target_qualified_name LIKE ?
|
||||
ORDER BY f.full_path, cr.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(
|
||||
target_name,
|
||||
f"%.{target_name}",
|
||||
f"%{target_name}",
|
||||
limit,
|
||||
),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
(
|
||||
row["source_name"],
|
||||
row["relationship_type"],
|
||||
int(row["source_line"]),
|
||||
row["source_file"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
# === Statistics ===
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -270,6 +270,39 @@ class GlobalSymbolIndex:
|
||||
symbols = self.search(name=name, kind=kind, limit=limit, prefix_mode=prefix_mode)
|
||||
return [(s.file or "", s.range) for s in symbols]
|
||||
|
||||
def get_file_symbols(self, file_path: str | Path) -> List[Symbol]:
|
||||
"""Get all symbols in a specific file, sorted by start_line.
|
||||
|
||||
Args:
|
||||
file_path: Full path to the file
|
||||
|
||||
Returns:
|
||||
List of Symbol objects sorted by start_line
|
||||
"""
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT symbol_name, symbol_kind, file_path, start_line, end_line
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND file_path=?
|
||||
ORDER BY start_line
|
||||
""",
|
||||
(self.project_id, file_path_str),
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
Symbol(
|
||||
name=row["symbol_name"],
|
||||
kind=row["symbol_kind"],
|
||||
range=(row["start_line"], row["end_line"]),
|
||||
file=row["file_path"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _get_existing_index_path(self, file_path_str: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
|
||||
Reference in New Issue
Block a user