mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
Add graph expansion and cross-encoder reranking features
- Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors. - Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring. - Created migrations to establish necessary database tables for relationships and graph neighbors. - Developed tests for graph expansion functionality, ensuring related results are populated correctly. - Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead. - Updated schema cleanup tests to reflect changes in versioning and deprecated fields. - Added new test cases for Treesitter parser to validate relationship extraction with alias resolution.
This commit is contained in:
@@ -49,6 +49,12 @@ semantic-directml = [
|
||||
"onnxruntime-directml>=1.15.0", # DirectML support
|
||||
]
|
||||
|
||||
# Cross-encoder reranking (second-stage, optional)
|
||||
# Install with: pip install codexlens[reranker]
|
||||
reranker = [
|
||||
"sentence-transformers>=2.2",
|
||||
]
|
||||
|
||||
# Encoding detection for non-UTF8 files
|
||||
encoding = [
|
||||
"chardet>=5.0",
|
||||
|
||||
@@ -407,7 +407,7 @@ def search(
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
|
||||
# Auto-detect mode if set to "auto"
|
||||
actual_mode = mode
|
||||
@@ -550,7 +550,7 @@ def symbol(
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
options = SearchOptions(depth=depth, total_limit=limit)
|
||||
|
||||
syms = engine.search_symbols(name, search_path, kind=kind, options=options)
|
||||
|
||||
@@ -105,12 +105,22 @@ class Config:
|
||||
|
||||
# Indexing/search optimizations
|
||||
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
|
||||
|
||||
# Graph expansion (search-time, uses precomputed neighbors)
|
||||
enable_graph_expansion: bool = False
|
||||
graph_expansion_depth: int = 2
|
||||
|
||||
# Optional search reranking (disabled by default)
|
||||
enable_reranking: bool = False
|
||||
reranking_top_k: int = 50
|
||||
symbol_boost_factor: float = 1.5
|
||||
|
||||
# Optional cross-encoder reranking (second stage, requires codexlens[reranker])
|
||||
enable_cross_encoder_rerank: bool = False
|
||||
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
reranker_top_k: int = 50
|
||||
|
||||
# Multi-endpoint configuration for litellm backend
|
||||
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
|
||||
|
||||
@@ -58,6 +58,7 @@ class IndexedFile(BaseModel):
|
||||
language: str = Field(..., min_length=1)
|
||||
symbols: List[Symbol] = Field(default_factory=list)
|
||||
chunks: List[SemanticChunk] = Field(default_factory=list)
|
||||
relationships: List["CodeRelationship"] = Field(default_factory=list)
|
||||
|
||||
@field_validator("path", "language")
|
||||
@classmethod
|
||||
@@ -70,7 +71,7 @@ class IndexedFile(BaseModel):
|
||||
|
||||
class RelationshipType(str, Enum):
|
||||
"""Types of code relationships."""
|
||||
CALL = "call"
|
||||
CALL = "calls"
|
||||
INHERITS = "inherits"
|
||||
IMPORTS = "imports"
|
||||
|
||||
|
||||
@@ -4,6 +4,11 @@ import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
except Exception: # pragma: no cover - optional dependency / platform variance
|
||||
TreeSitterSymbolParser = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class SymbolExtractor:
|
||||
"""Extract symbols and relationships from source code using regex patterns."""
|
||||
@@ -118,7 +123,7 @@ class SymbolExtractor:
|
||||
|
||||
patterns = self.PATTERNS[lang]
|
||||
symbols = []
|
||||
relationships = []
|
||||
relationships: List[Dict] = []
|
||||
lines = content.split('\n')
|
||||
|
||||
current_scope = None
|
||||
@@ -141,33 +146,62 @@ class SymbolExtractor:
|
||||
})
|
||||
current_scope = name
|
||||
|
||||
# Extract imports
|
||||
if 'import' in patterns:
|
||||
match = re.search(patterns['import'], line)
|
||||
if match:
|
||||
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
if import_target and current_scope:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': import_target.strip(),
|
||||
'type': 'imports',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
if TreeSitterSymbolParser is not None:
|
||||
try:
|
||||
ts_parser = TreeSitterSymbolParser(lang, file_path)
|
||||
if ts_parser.is_available():
|
||||
indexed = ts_parser.parse(content, file_path)
|
||||
if indexed is not None and indexed.relationships:
|
||||
relationships = [
|
||||
{
|
||||
"source_scope": r.source_symbol,
|
||||
"target": r.target_symbol,
|
||||
"type": r.relationship_type.value,
|
||||
"file_path": str(file_path),
|
||||
"line": r.source_line,
|
||||
}
|
||||
for r in indexed.relationships
|
||||
]
|
||||
except Exception:
|
||||
relationships = []
|
||||
|
||||
# Extract function calls (simplified)
|
||||
if 'call' in patterns and current_scope:
|
||||
for match in re.finditer(patterns['call'], line):
|
||||
call_name = match.group(1)
|
||||
# Skip common keywords and the current function
|
||||
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': call_name,
|
||||
'type': 'calls',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
# Regex fallback for relationships (when tree-sitter is unavailable)
|
||||
if not relationships:
|
||||
current_scope = None
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
for kind in ['function', 'class']:
|
||||
if kind in patterns:
|
||||
match = re.search(patterns[kind], line)
|
||||
if match:
|
||||
current_scope = match.group(1)
|
||||
|
||||
# Extract imports
|
||||
if 'import' in patterns:
|
||||
match = re.search(patterns['import'], line)
|
||||
if match:
|
||||
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
|
||||
if import_target and current_scope:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': import_target.strip(),
|
||||
'type': 'imports',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
|
||||
# Extract function calls (simplified)
|
||||
if 'call' in patterns and current_scope:
|
||||
for match in re.finditer(patterns['call'], line):
|
||||
call_name = match.group(1)
|
||||
# Skip common keywords and the current function
|
||||
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
|
||||
relationships.append({
|
||||
'source_scope': current_scope,
|
||||
'target': call_name,
|
||||
'type': 'calls',
|
||||
'file_path': str(file_path),
|
||||
'line': line_num,
|
||||
})
|
||||
|
||||
return symbols, relationships
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import IndexedFile, Symbol
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
@@ -30,36 +30,39 @@ class SimpleRegexParser:
|
||||
if self.language_id in {"python", "javascript", "typescript"}:
|
||||
ts_parser = TreeSitterSymbolParser(self.language_id, path)
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
)
|
||||
indexed = ts_parser.parse(text, path)
|
||||
if indexed is not None:
|
||||
return indexed
|
||||
|
||||
# Fallback to regex parsing
|
||||
if self.language_id == "python":
|
||||
symbols = _parse_python_symbols_regex(text)
|
||||
relationships = _parse_python_relationships_regex(text, path)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
symbols = _parse_js_ts_symbols_regex(text)
|
||||
relationships = _parse_js_ts_relationships_regex(text, path)
|
||||
elif self.language_id == "java":
|
||||
symbols = _parse_java_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "go":
|
||||
symbols = _parse_go_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "markdown":
|
||||
symbols = _parse_markdown_symbols(text)
|
||||
relationships = []
|
||||
elif self.language_id == "text":
|
||||
symbols = _parse_text_symbols(text)
|
||||
relationships = []
|
||||
else:
|
||||
symbols = _parse_generic_symbols(text)
|
||||
relationships = []
|
||||
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
relationships=relationships,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,6 +81,9 @@ class ParserFactory:
|
||||
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
|
||||
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
|
||||
|
||||
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
|
||||
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -127,12 +133,81 @@ def _parse_python_symbols_regex(text: str) -> List[Symbol]:
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||
relationships: List[CodeRelationship] = []
|
||||
current_scope: str | None = None
|
||||
source_file = str(path.resolve())
|
||||
|
||||
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _PY_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_scope = class_match.group(1)
|
||||
continue
|
||||
|
||||
def_match = _PY_DEF_RE.match(line)
|
||||
if def_match:
|
||||
current_scope = def_match.group(1)
|
||||
continue
|
||||
|
||||
if current_scope is None:
|
||||
continue
|
||||
|
||||
import_match = _PY_IMPORT_RE.search(line)
|
||||
if import_match:
|
||||
import_target = import_match.group(1) or import_match.group(2)
|
||||
if import_target:
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=import_target.strip(),
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
for call_match in _PY_CALL_RE.finditer(line):
|
||||
call_name = call_match.group(1)
|
||||
if call_name in {
|
||||
"if",
|
||||
"for",
|
||||
"while",
|
||||
"return",
|
||||
"print",
|
||||
"len",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"list",
|
||||
"dict",
|
||||
"set",
|
||||
"tuple",
|
||||
current_scope,
|
||||
}:
|
||||
continue
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=call_name,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
|
||||
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
|
||||
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
|
||||
_JS_ARROW_RE = re.compile(
|
||||
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
|
||||
)
|
||||
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
|
||||
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
|
||||
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
|
||||
|
||||
|
||||
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
|
||||
@@ -174,6 +249,61 @@ def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
|
||||
relationships: List[CodeRelationship] = []
|
||||
current_scope: str | None = None
|
||||
source_file = str(path.resolve())
|
||||
|
||||
for line_num, line in enumerate(text.splitlines(), start=1):
|
||||
class_match = _JS_CLASS_RE.match(line)
|
||||
if class_match:
|
||||
current_scope = class_match.group(1)
|
||||
continue
|
||||
|
||||
func_match = _JS_FUNC_RE.match(line)
|
||||
if func_match:
|
||||
current_scope = func_match.group(1)
|
||||
continue
|
||||
|
||||
arrow_match = _JS_ARROW_RE.match(line)
|
||||
if arrow_match:
|
||||
current_scope = arrow_match.group(1)
|
||||
continue
|
||||
|
||||
if current_scope is None:
|
||||
continue
|
||||
|
||||
import_match = _JS_IMPORT_RE.search(line)
|
||||
if import_match:
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=import_match.group(1),
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
for call_match in _JS_CALL_RE.finditer(line):
|
||||
call_name = call_match.group(1)
|
||||
if call_name in {current_scope}:
|
||||
continue
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=current_scope,
|
||||
target_symbol=call_name,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=line_num,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
|
||||
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
|
||||
_JAVA_METHOD_RE = re.compile(
|
||||
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
|
||||
@@ -253,4 +383,3 @@ def _parse_text_symbols(text: str) -> List[Symbol]:
|
||||
# Text files don't have structured symbols, return empty list
|
||||
# The file content will still be indexed for FTS search
|
||||
return []
|
||||
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
"""Tree-sitter based parser for CodexLens.
|
||||
|
||||
Provides precise AST-level parsing with fallback to regex-based parsing.
|
||||
Provides precise AST-level parsing via tree-sitter.
|
||||
|
||||
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
|
||||
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
|
||||
return `None`; callers should use a regex-based fallback such as
|
||||
`codexlens.parsers.factory.SimpleRegexParser`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Language as TreeSitterLanguage
|
||||
@@ -19,7 +24,7 @@ except ImportError:
|
||||
TreeSitterParser = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import IndexedFile, Symbol
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
@@ -85,6 +90,16 @@ class TreeSitterSymbolParser:
|
||||
"""
|
||||
return self._parser is not None and self._language is not None
|
||||
|
||||
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
|
||||
if not self.is_available() or self._parser is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
return source_bytes, tree.root_node
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
|
||||
"""Parse source code and extract symbols without creating IndexedFile.
|
||||
@@ -95,17 +110,15 @@ class TreeSitterSymbolParser:
|
||||
Returns:
|
||||
List of symbols if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
if not self.is_available() or self._parser is None:
|
||||
parsed = self._parse_tree(text)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
source_bytes, root = parsed
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
return self._extract_symbols(source_bytes, root)
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
# Gracefully handle extraction errors
|
||||
return None
|
||||
|
||||
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
|
||||
@@ -118,19 +131,21 @@ class TreeSitterSymbolParser:
|
||||
Returns:
|
||||
IndexedFile if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
if not self.is_available() or self._parser is None:
|
||||
parsed = self._parse_tree(text)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
source_bytes, root = parsed
|
||||
try:
|
||||
symbols = self.parse_symbols(text)
|
||||
if symbols is None:
|
||||
return None
|
||||
symbols = self._extract_symbols(source_bytes, root)
|
||||
relationships = self._extract_relationships(source_bytes, root, path)
|
||||
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
relationships=relationships,
|
||||
)
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
@@ -153,6 +168,465 @@ class TreeSitterSymbolParser:
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_relationships(source_bytes, root, path)
|
||||
if self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_relationships(source_bytes, root, path)
|
||||
return []
|
||||
|
||||
def _extract_python_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
source_file = str(path.resolve())
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
scope_stack: List[str] = []
|
||||
alias_stack: List[Dict[str, str]] = [{}]
|
||||
|
||||
def record_import(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_call(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
base = target_symbol.split(".", 1)[0]
|
||||
if base in {"self", "cls"}:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.INHERITS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def visit(node: TreeSitterNode) -> None:
|
||||
pushed_scope = False
|
||||
pushed_aliases = False
|
||||
|
||||
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type == "class_definition" and pushed_scope:
|
||||
superclasses = node.child_by_field_name("superclasses")
|
||||
if superclasses is not None:
|
||||
for child in superclasses.children:
|
||||
dotted = self._python_expression_to_dotted(source_bytes, child)
|
||||
if not dotted:
|
||||
continue
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_inherits(resolved, self._node_start_line(node))
|
||||
|
||||
if node.type in {"import_statement", "import_from_statement"}:
|
||||
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
|
||||
if updates:
|
||||
alias_stack[-1].update(updates)
|
||||
for target_symbol in imported_targets:
|
||||
record_import(target_symbol, self._node_start_line(node))
|
||||
|
||||
if node.type == "call":
|
||||
fn_node = node.child_by_field_name("function")
|
||||
if fn_node is not None:
|
||||
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_call(resolved, self._node_start_line(node))
|
||||
|
||||
for child in node.children:
|
||||
visit(child)
|
||||
|
||||
if pushed_aliases:
|
||||
alias_stack.pop()
|
||||
if pushed_scope:
|
||||
scope_stack.pop()
|
||||
|
||||
visit(root)
|
||||
return relationships
|
||||
|
||||
def _extract_js_ts_relationships(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
root: TreeSitterNode,
|
||||
path: Path,
|
||||
) -> List[CodeRelationship]:
|
||||
source_file = str(path.resolve())
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
scope_stack: List[str] = []
|
||||
alias_stack: List[Dict[str, str]] = [{}]
|
||||
|
||||
def record_import(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.IMPORTS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_call(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
base = target_symbol.split(".", 1)[0]
|
||||
if base in {"this", "super"}:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def record_inherits(target_symbol: str, source_line: int) -> None:
|
||||
if not target_symbol.strip() or not scope_stack:
|
||||
return
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=scope_stack[-1],
|
||||
target_symbol=target_symbol,
|
||||
relationship_type=RelationshipType.INHERITS,
|
||||
source_file=source_file,
|
||||
target_file=None,
|
||||
source_line=source_line,
|
||||
)
|
||||
)
|
||||
|
||||
def visit(node: TreeSitterNode) -> None:
|
||||
pushed_scope = False
|
||||
pushed_aliases = False
|
||||
|
||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if pushed_scope:
|
||||
superclass = node.child_by_field_name("superclass")
|
||||
if superclass is not None:
|
||||
dotted = self._js_expression_to_dotted(source_bytes, superclass)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_inherits(resolved, self._node_start_line(node))
|
||||
|
||||
if node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is not None
|
||||
and value_node is not None
|
||||
and name_node.type in {"identifier", "property_identifier"}
|
||||
and value_node.type == "arrow_function"
|
||||
):
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name:
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type == "method_definition" and self._has_class_ancestor(node):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is not None:
|
||||
scope_name = self._node_text(source_bytes, name_node).strip()
|
||||
if scope_name and scope_name != "constructor":
|
||||
scope_stack.append(scope_name)
|
||||
pushed_scope = True
|
||||
alias_stack.append(dict(alias_stack[-1]))
|
||||
pushed_aliases = True
|
||||
|
||||
if node.type in {"import_declaration", "import_statement"}:
|
||||
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
|
||||
if updates:
|
||||
alias_stack[-1].update(updates)
|
||||
for target_symbol in imported_targets:
|
||||
record_import(target_symbol, self._node_start_line(node))
|
||||
|
||||
# Best-effort support for CommonJS require() imports:
|
||||
# const fs = require("fs")
|
||||
if node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is not None
|
||||
and value_node is not None
|
||||
and name_node.type == "identifier"
|
||||
and value_node.type == "call_expression"
|
||||
):
|
||||
callee = value_node.child_by_field_name("function")
|
||||
args = value_node.child_by_field_name("arguments")
|
||||
if (
|
||||
callee is not None
|
||||
and self._node_text(source_bytes, callee).strip() == "require"
|
||||
and args is not None
|
||||
):
|
||||
module_name = self._js_first_string_argument(source_bytes, args)
|
||||
if module_name:
|
||||
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
|
||||
record_import(module_name, self._node_start_line(node))
|
||||
|
||||
if node.type == "call_expression":
|
||||
fn_node = node.child_by_field_name("function")
|
||||
if fn_node is not None:
|
||||
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
|
||||
if dotted:
|
||||
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
|
||||
record_call(resolved, self._node_start_line(node))
|
||||
|
||||
for child in node.children:
|
||||
visit(child)
|
||||
|
||||
if pushed_aliases:
|
||||
alias_stack.pop()
|
||||
if pushed_scope:
|
||||
scope_stack.pop()
|
||||
|
||||
visit(root)
|
||||
return relationships
|
||||
|
||||
def _node_start_line(self, node: TreeSitterNode) -> int:
|
||||
return node.start_point[0] + 1
|
||||
|
||||
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
|
||||
dotted = (dotted or "").strip()
|
||||
if not dotted:
|
||||
return ""
|
||||
|
||||
base, sep, rest = dotted.partition(".")
|
||||
resolved_base = aliases.get(base, base)
|
||||
if not rest:
|
||||
return resolved_base
|
||||
if resolved_base and rest:
|
||||
return f"{resolved_base}.{rest}"
|
||||
return resolved_base
|
||||
|
||||
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
if node.type in {"identifier", "dotted_name"}:
|
||||
return self._node_text(source_bytes, node).strip()
|
||||
if node.type == "attribute":
|
||||
obj = node.child_by_field_name("object")
|
||||
attr = node.child_by_field_name("attribute")
|
||||
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
|
||||
if obj_text and attr_text:
|
||||
return f"{obj_text}.{attr_text}"
|
||||
return obj_text or attr_text
|
||||
return ""
|
||||
|
||||
def _python_import_aliases_and_targets(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
node: TreeSitterNode,
|
||||
) -> tuple[Dict[str, str], List[str]]:
|
||||
aliases: Dict[str, str] = {}
|
||||
targets: List[str] = []
|
||||
|
||||
if node.type == "import_statement":
|
||||
for child in node.children:
|
||||
if child.type == "aliased_import":
|
||||
name_node = child.child_by_field_name("name")
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
module_name = self._node_text(source_bytes, name_node).strip()
|
||||
if not module_name:
|
||||
continue
|
||||
bound_name = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else module_name.split(".", 1)[0]
|
||||
)
|
||||
if bound_name:
|
||||
aliases[bound_name] = module_name
|
||||
targets.append(module_name)
|
||||
elif child.type == "dotted_name":
|
||||
module_name = self._node_text(source_bytes, child).strip()
|
||||
if not module_name:
|
||||
continue
|
||||
bound_name = module_name.split(".", 1)[0]
|
||||
if bound_name:
|
||||
aliases[bound_name] = bound_name
|
||||
targets.append(module_name)
|
||||
|
||||
if node.type == "import_from_statement":
|
||||
module_name = ""
|
||||
module_node = node.child_by_field_name("module_name")
|
||||
if module_node is None:
|
||||
for child in node.children:
|
||||
if child.type == "dotted_name":
|
||||
module_node = child
|
||||
break
|
||||
if module_node is not None:
|
||||
module_name = self._node_text(source_bytes, module_node).strip()
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "aliased_import":
|
||||
name_node = child.child_by_field_name("name")
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
imported_name = self._node_text(source_bytes, name_node).strip()
|
||||
if not imported_name or imported_name == "*":
|
||||
continue
|
||||
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||
bound_name = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else imported_name
|
||||
)
|
||||
if bound_name:
|
||||
aliases[bound_name] = target
|
||||
targets.append(target)
|
||||
elif child.type == "identifier":
|
||||
imported_name = self._node_text(source_bytes, child).strip()
|
||||
if not imported_name or imported_name in {"from", "import", "*"}:
|
||||
continue
|
||||
target = f"{module_name}.{imported_name}" if module_name else imported_name
|
||||
aliases[imported_name] = target
|
||||
targets.append(target)
|
||||
|
||||
return aliases, targets
|
||||
|
||||
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
if node.type in {"this", "super"}:
|
||||
return node.type
|
||||
if node.type in {"identifier", "property_identifier"}:
|
||||
return self._node_text(source_bytes, node).strip()
|
||||
if node.type == "member_expression":
|
||||
obj = node.child_by_field_name("object")
|
||||
prop = node.child_by_field_name("property")
|
||||
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
|
||||
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
|
||||
if obj_text and prop_text:
|
||||
return f"{obj_text}.{prop_text}"
|
||||
return obj_text or prop_text
|
||||
return ""
|
||||
|
||||
def _js_import_aliases_and_targets(
|
||||
self,
|
||||
source_bytes: bytes,
|
||||
node: TreeSitterNode,
|
||||
) -> tuple[Dict[str, str], List[str]]:
|
||||
aliases: Dict[str, str] = {}
|
||||
targets: List[str] = []
|
||||
|
||||
module_name = ""
|
||||
source_node = node.child_by_field_name("source")
|
||||
if source_node is not None:
|
||||
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
|
||||
if module_name:
|
||||
targets.append(module_name)
|
||||
|
||||
for child in node.children:
|
||||
if child.type == "import_clause":
|
||||
for clause_child in child.children:
|
||||
if clause_child.type == "identifier":
|
||||
# Default import: import React from "react"
|
||||
local = self._node_text(source_bytes, clause_child).strip()
|
||||
if local and module_name:
|
||||
aliases[local] = module_name
|
||||
if clause_child.type == "namespace_import":
|
||||
# Namespace import: import * as fs from "fs"
|
||||
name_node = clause_child.child_by_field_name("name")
|
||||
if name_node is not None and module_name:
|
||||
local = self._node_text(source_bytes, name_node).strip()
|
||||
if local:
|
||||
aliases[local] = module_name
|
||||
if clause_child.type == "named_imports":
|
||||
for spec in clause_child.children:
|
||||
if spec.type != "import_specifier":
|
||||
continue
|
||||
name_node = spec.child_by_field_name("name")
|
||||
alias_node = spec.child_by_field_name("alias")
|
||||
if name_node is None:
|
||||
continue
|
||||
imported = self._node_text(source_bytes, name_node).strip()
|
||||
if not imported:
|
||||
continue
|
||||
local = (
|
||||
self._node_text(source_bytes, alias_node).strip()
|
||||
if alias_node is not None
|
||||
else imported
|
||||
)
|
||||
if local and module_name:
|
||||
aliases[local] = f"{module_name}.{imported}"
|
||||
targets.append(f"{module_name}.{imported}")
|
||||
|
||||
return aliases, targets
|
||||
|
||||
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
|
||||
for child in args_node.children:
|
||||
if child.type == "string":
|
||||
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
|
||||
return ""
|
||||
|
||||
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract Python symbols from AST.
|
||||
|
||||
|
||||
@@ -83,6 +83,7 @@ class ChainSearchResult:
|
||||
Attributes:
|
||||
query: Original search query
|
||||
results: List of SearchResult objects
|
||||
related_results: Expanded results from graph neighbors (optional)
|
||||
symbols: List of Symbol objects (if include_symbols=True)
|
||||
stats: SearchStats with execution metrics
|
||||
"""
|
||||
@@ -90,6 +91,7 @@ class ChainSearchResult:
|
||||
results: List[SearchResult]
|
||||
symbols: List[Symbol]
|
||||
stats: SearchStats
|
||||
related_results: List[SearchResult] = field(default_factory=list)
|
||||
|
||||
|
||||
class ChainSearchEngine:
|
||||
@@ -236,13 +238,26 @@ class ChainSearchEngine:
|
||||
index_paths, query, None, options.total_limit
|
||||
)
|
||||
|
||||
# Optional: graph expansion using precomputed neighbors
|
||||
related_results: List[SearchResult] = []
|
||||
if self._config is not None and getattr(self._config, "enable_graph_expansion", False):
|
||||
try:
|
||||
from codexlens.search.enrichment import SearchEnrichmentPipeline
|
||||
|
||||
pipeline = SearchEnrichmentPipeline(self.mapper, config=self._config)
|
||||
related_results = pipeline.expand_related_results(final_results)
|
||||
except Exception as exc:
|
||||
self.logger.debug("Graph expansion failed: %s", exc)
|
||||
related_results = []
|
||||
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=final_results,
|
||||
symbols=symbols,
|
||||
stats=stats
|
||||
stats=stats,
|
||||
related_results=related_results,
|
||||
)
|
||||
|
||||
def search_files_only(self, query: str,
|
||||
|
||||
@@ -4,6 +4,11 @@ import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.graph_expander import GraphExpander
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
|
||||
class RelationshipEnricher:
|
||||
"""Enriches search results with code graph relationships."""
|
||||
@@ -148,3 +153,19 @@ class RelationshipEnricher:
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class SearchEnrichmentPipeline:
|
||||
"""Search post-processing pipeline (optional enrichments)."""
|
||||
|
||||
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||
self._config = config
|
||||
self._graph_expander = GraphExpander(mapper, config=config)
|
||||
|
||||
def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]:
|
||||
"""Expand base results with related symbols when enabled in config."""
|
||||
if self._config is None or not getattr(self._config, "enable_graph_expansion", False):
|
||||
return []
|
||||
|
||||
depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2)
|
||||
return self._graph_expander.expand(results, depth=depth)
|
||||
|
||||
264
codex-lens/src/codexlens/search/graph_expander.py
Normal file
264
codex-lens/src/codexlens/search/graph_expander.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Graph expansion for search results using precomputed neighbors.
|
||||
|
||||
Expands top search results with related symbol definitions by traversing
|
||||
precomputed N-hop neighbors stored in the per-directory index databases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]:
|
||||
return (result.path, result.symbol_name, result.start_line, result.end_line)
|
||||
|
||||
|
||||
def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]:
|
||||
if content is None:
|
||||
return None
|
||||
if start_line is None or end_line is None:
|
||||
return None
|
||||
if start_line < 1 or end_line < start_line:
|
||||
return None
|
||||
|
||||
lines = content.splitlines()
|
||||
start_idx = max(0, start_line - 1)
|
||||
end_idx = min(len(lines), end_line)
|
||||
if start_idx >= len(lines):
|
||||
return None
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
|
||||
class GraphExpander:
|
||||
"""Expands SearchResult lists with related symbols from the code graph."""
|
||||
|
||||
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
|
||||
self._mapper = mapper
|
||||
self._config = config
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
def expand(
|
||||
self,
|
||||
results: Sequence[SearchResult],
|
||||
*,
|
||||
depth: Optional[int] = None,
|
||||
max_expand: int = 10,
|
||||
max_related: int = 50,
|
||||
) -> List[SearchResult]:
|
||||
"""Expand top results with related symbols.
|
||||
|
||||
Args:
|
||||
results: Base ranked results.
|
||||
depth: Maximum relationship depth to include (defaults to Config or 2).
|
||||
max_expand: Only expand the top-N base results to bound cost.
|
||||
max_related: Maximum related results to return.
|
||||
|
||||
Returns:
|
||||
A list of related SearchResult objects with relationship_depth metadata.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2
|
||||
max_depth = int(depth if depth is not None else configured_depth)
|
||||
if max_depth <= 0:
|
||||
return []
|
||||
max_depth = min(max_depth, 2)
|
||||
|
||||
expand_count = max(0, int(max_expand))
|
||||
related_limit = max(0, int(max_related))
|
||||
if expand_count == 0 or related_limit == 0:
|
||||
return []
|
||||
|
||||
seen = {_result_key(r) for r in results}
|
||||
related_results: List[SearchResult] = []
|
||||
conn_cache: Dict[Path, sqlite3.Connection] = {}
|
||||
|
||||
try:
|
||||
for base in list(results)[:expand_count]:
|
||||
if len(related_results) >= related_limit:
|
||||
break
|
||||
|
||||
if not base.symbol_name or not base.path:
|
||||
continue
|
||||
|
||||
index_path = self._mapper.source_to_index_db(Path(base.path).parent)
|
||||
conn = conn_cache.get(index_path)
|
||||
if conn is None:
|
||||
conn = self._connect_readonly(index_path)
|
||||
if conn is None:
|
||||
continue
|
||||
conn_cache[index_path] = conn
|
||||
|
||||
source_ids = self._resolve_source_symbol_ids(
|
||||
conn,
|
||||
file_path=base.path,
|
||||
symbol_name=base.symbol_name,
|
||||
symbol_kind=base.symbol_kind,
|
||||
)
|
||||
if not source_ids:
|
||||
continue
|
||||
|
||||
for source_id in source_ids:
|
||||
neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit)
|
||||
for neighbor_id, rel_depth in neighbors:
|
||||
if len(related_results) >= related_limit:
|
||||
break
|
||||
row = self._get_symbol_details(conn, neighbor_id)
|
||||
if row is None:
|
||||
continue
|
||||
|
||||
path = str(row["full_path"])
|
||||
symbol_name = str(row["name"])
|
||||
symbol_kind = str(row["kind"])
|
||||
start_line = int(row["start_line"]) if row["start_line"] is not None else None
|
||||
end_line = int(row["end_line"]) if row["end_line"] is not None else None
|
||||
content_block = _slice_content_block(
|
||||
str(row["content"]) if row["content"] is not None else "",
|
||||
start_line,
|
||||
end_line,
|
||||
)
|
||||
|
||||
score = float(base.score) * (0.5 ** int(rel_depth))
|
||||
candidate = SearchResult(
|
||||
path=path,
|
||||
score=max(0.0, score),
|
||||
excerpt=None,
|
||||
content=content_block,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
metadata={"relationship_depth": int(rel_depth)},
|
||||
)
|
||||
|
||||
key = _result_key(candidate)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
related_results.append(candidate)
|
||||
|
||||
finally:
|
||||
for conn in conn_cache.values():
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return related_results
|
||||
|
||||
def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]:
|
||||
try:
|
||||
if not index_path.exists() or index_path.stat().st_size == 0:
|
||||
return None
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
except Exception as exc:
|
||||
self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc)
|
||||
return None
|
||||
|
||||
def _resolve_source_symbol_ids(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
*,
|
||||
file_path: str,
|
||||
symbol_name: str,
|
||||
symbol_kind: Optional[str],
|
||||
) -> List[int]:
|
||||
try:
|
||||
if symbol_kind:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.id
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE f.full_path = ? AND s.name = ? AND s.kind = ?
|
||||
""",
|
||||
(file_path, symbol_name, symbol_kind),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s.id
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE f.full_path = ? AND s.name = ?
|
||||
""",
|
||||
(file_path, symbol_name),
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
ids: List[int] = []
|
||||
for row in rows:
|
||||
try:
|
||||
ids.append(int(row["id"]))
|
||||
except Exception:
|
||||
continue
|
||||
return ids
|
||||
|
||||
def _get_neighbors(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
source_symbol_id: int,
|
||||
*,
|
||||
max_depth: int,
|
||||
limit: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT neighbor_symbol_id, relationship_depth
|
||||
FROM graph_neighbors
|
||||
WHERE source_symbol_id = ? AND relationship_depth <= ?
|
||||
ORDER BY relationship_depth ASC, neighbor_symbol_id ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(int(source_symbol_id), int(max_depth), int(limit)),
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
neighbors: List[Tuple[int, int]] = []
|
||||
for row in rows:
|
||||
try:
|
||||
neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"])))
|
||||
except Exception:
|
||||
continue
|
||||
return neighbors
|
||||
|
||||
def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]:
|
||||
try:
|
||||
return conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.id,
|
||||
s.name,
|
||||
s.kind,
|
||||
s.start_line,
|
||||
s.end_line,
|
||||
f.full_path,
|
||||
f.content
|
||||
FROM symbols s
|
||||
JOIN files f ON f.id = s.file_id
|
||||
WHERE s.id = ?
|
||||
""",
|
||||
(int(symbol_id),),
|
||||
).fetchone()
|
||||
except sqlite3.Error:
|
||||
return None
|
||||
|
||||
@@ -34,6 +34,7 @@ from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import (
|
||||
apply_symbol_boost,
|
||||
cross_encoder_rerank,
|
||||
get_rrf_weights,
|
||||
reciprocal_rank_fusion,
|
||||
rerank_results,
|
||||
@@ -77,6 +78,7 @@ class HybridSearchEngine:
|
||||
self.weights = weights or self.DEFAULT_WEIGHTS.copy()
|
||||
self._config = config
|
||||
self.embedder = embedder
|
||||
self.reranker: Any = None
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -112,6 +114,14 @@ class HybridSearchEngine:
|
||||
>>> for r in results[:5]:
|
||||
... print(f"{r.path}: {r.score:.3f}")
|
||||
"""
|
||||
# Defensive: avoid creating/locking an index database when callers pass
|
||||
# an empty placeholder file (common in tests and misconfigured callers).
|
||||
try:
|
||||
if index_path.exists() and index_path.stat().st_size == 0:
|
||||
return []
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
# Determine which backends to use
|
||||
backends = {}
|
||||
|
||||
@@ -180,9 +190,30 @@ class HybridSearchEngine:
|
||||
query,
|
||||
fused_results[:100],
|
||||
self.embedder,
|
||||
top_k=self._config.reranking_top_k,
|
||||
top_k=(
|
||||
100
|
||||
if self._config.enable_cross_encoder_rerank
|
||||
else self._config.reranking_top_k
|
||||
),
|
||||
)
|
||||
|
||||
# Optional: cross-encoder reranking as a second stage
|
||||
if (
|
||||
self._config is not None
|
||||
and self._config.enable_reranking
|
||||
and self._config.enable_cross_encoder_rerank
|
||||
):
|
||||
with timer("cross_encoder_rerank", self.logger):
|
||||
if self.reranker is None:
|
||||
self.reranker = self._get_cross_encoder_reranker()
|
||||
if self.reranker is not None:
|
||||
fused_results = cross_encoder_rerank(
|
||||
query,
|
||||
fused_results,
|
||||
self.reranker,
|
||||
top_k=self._config.reranker_top_k,
|
||||
)
|
||||
|
||||
# Apply final limit
|
||||
return fused_results[:limit]
|
||||
|
||||
@@ -222,6 +253,27 @@ class HybridSearchEngine:
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_cross_encoder_reranker(self) -> Any:
|
||||
if self._config is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from codexlens.semantic.reranker import CrossEncoderReranker, check_cross_encoder_available
|
||||
except Exception as exc:
|
||||
self.logger.debug("Cross-encoder reranker unavailable: %s", exc)
|
||||
return None
|
||||
|
||||
ok, err = check_cross_encoder_available()
|
||||
if not ok:
|
||||
self.logger.debug("Cross-encoder reranker unavailable: %s", err)
|
||||
return None
|
||||
|
||||
try:
|
||||
return CrossEncoderReranker(model_name=self._config.reranker_model)
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to initialize cross-encoder reranker: %s", exc)
|
||||
return None
|
||||
|
||||
def _search_parallel(
|
||||
self,
|
||||
index_path: Path,
|
||||
|
||||
@@ -379,6 +379,117 @@ def rerank_results(
|
||||
return reranked_results
|
||||
|
||||
|
||||
def cross_encoder_rerank(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
reranker: Any,
|
||||
top_k: int = 50,
|
||||
batch_size: int = 32,
|
||||
) -> List[SearchResult]:
|
||||
"""Second-stage reranking using a cross-encoder model.
|
||||
|
||||
This function is dependency-agnostic: callers can pass any object that exposes
|
||||
a compatible `score_pairs(pairs, batch_size=...)` method.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if reranker is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def text_for_pair(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
return r.symbol_name or r.path
|
||||
|
||||
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
|
||||
|
||||
try:
|
||||
if hasattr(reranker, "score_pairs"):
|
||||
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
|
||||
elif hasattr(reranker, "predict"):
|
||||
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
|
||||
else:
|
||||
return results
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
if not raw_scores or len(raw_scores) != rerank_count:
|
||||
return results
|
||||
|
||||
scores = [float(s) for s in raw_scores]
|
||||
min_s = min(scores)
|
||||
max_s = max(scores)
|
||||
|
||||
def sigmoid(x: float) -> float:
|
||||
# Clamp to keep exp() stable.
|
||||
x = max(-50.0, min(50.0, x))
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
if 0.0 <= min_s and max_s <= 1.0:
|
||||
probs = scores
|
||||
else:
|
||||
probs = [sigmoid(s) for s in scores]
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
prev_score = float(result.score)
|
||||
ce_score = scores[idx]
|
||||
ce_prob = probs[idx]
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"pre_cross_encoder_score": prev_score,
|
||||
"cross_encoder_score": ce_score,
|
||||
"cross_encoder_prob": ce_prob,
|
||||
"cross_encoder_reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def normalize_bm25_score(score: float) -> float:
|
||||
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
||||
|
||||
|
||||
86
codex-lens/src/codexlens/semantic/reranker.py
Normal file
86
codex-lens/src/codexlens/semantic/reranker.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Optional cross-encoder reranker for second-stage search ranking.
|
||||
|
||||
Install with: pip install codexlens[reranker]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder as _CrossEncoder
|
||||
|
||||
CROSS_ENCODER_AVAILABLE = True
|
||||
_import_error: str | None = None
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
_CrossEncoder = None # type: ignore[assignment]
|
||||
CROSS_ENCODER_AVAILABLE = False
|
||||
_import_error = str(exc)
|
||||
|
||||
|
||||
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
||||
if CROSS_ENCODER_AVAILABLE:
|
||||
return True, None
|
||||
return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]"
|
||||
|
||||
|
||||
class CrossEncoderReranker:
|
||||
"""Cross-encoder reranker with lazy model loading."""
|
||||
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
self.model_name = (model_name or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.device = (device or "").strip() or None
|
||||
self._model = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
ok, err = check_cross_encoder_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.device:
|
||||
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
|
||||
else:
|
||||
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
|
||||
raise
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[Tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> List[float]:
|
||||
"""Score (query, doc) pairs using the cross-encoder.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair) in the model's native scale (usually logits).
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
||||
return [float(s) for s in scores]
|
||||
|
||||
@@ -10,15 +10,17 @@ Each directory maintains its own _index.db with:
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
from codexlens.entities import CodeRelationship, SearchResult, Symbol
|
||||
from codexlens.errors import StorageError
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
|
||||
@@ -60,7 +62,7 @@ class DirIndexStore:
|
||||
|
||||
# Schema version for migration tracking
|
||||
# Increment this when schema changes require migration
|
||||
SCHEMA_VERSION = 5
|
||||
SCHEMA_VERSION = 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -150,6 +152,21 @@ class DirIndexStore:
|
||||
from codexlens.storage.migrations.migration_005_cleanup_unused_fields import upgrade
|
||||
upgrade(conn)
|
||||
|
||||
# Migration v5 -> v6: Ensure relationship tables/indexes exist
|
||||
if from_version < 6:
|
||||
from codexlens.storage.migrations.migration_006_enhance_relationships import upgrade
|
||||
upgrade(conn)
|
||||
|
||||
# Migration v6 -> v7: Add graph neighbor cache for search expansion
|
||||
if from_version < 7:
|
||||
from codexlens.storage.migrations.migration_007_add_graph_neighbors import upgrade
|
||||
upgrade(conn)
|
||||
|
||||
# Migration v7 -> v8: Add Merkle hashes for incremental change detection
|
||||
if from_version < 8:
|
||||
from codexlens.storage.migrations.migration_008_add_merkle_hashes import upgrade
|
||||
upgrade(conn)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
with self._lock:
|
||||
@@ -179,6 +196,7 @@ class DirIndexStore:
|
||||
content: str,
|
||||
language: str,
|
||||
symbols: Optional[List[Symbol]] = None,
|
||||
relationships: Optional[List[CodeRelationship]] = None,
|
||||
) -> int:
|
||||
"""Add or update a file in the current directory index.
|
||||
|
||||
@@ -188,6 +206,7 @@ class DirIndexStore:
|
||||
content: File content for indexing
|
||||
language: Programming language identifier
|
||||
symbols: List of Symbol objects from the file
|
||||
relationships: Optional list of CodeRelationship edges from this file
|
||||
|
||||
Returns:
|
||||
Database file_id
|
||||
@@ -240,6 +259,8 @@ class DirIndexStore:
|
||||
symbol_rows,
|
||||
)
|
||||
|
||||
self._save_merkle_hash(conn, file_id=file_id, content=content)
|
||||
self._save_relationships(conn, file_id=file_id, relationships=relationships)
|
||||
conn.commit()
|
||||
self._maybe_update_global_symbols(full_path_str, symbols or [])
|
||||
return file_id
|
||||
@@ -248,6 +269,96 @@ class DirIndexStore:
|
||||
conn.rollback()
|
||||
raise StorageError(f"Failed to add file {name}: {exc}") from exc
|
||||
|
||||
def save_relationships(self, file_id: int, relationships: List[CodeRelationship]) -> None:
|
||||
"""Save relationships for an already-indexed file.
|
||||
|
||||
Args:
|
||||
file_id: Database file id
|
||||
relationships: Relationship edges to persist
|
||||
"""
|
||||
if not relationships:
|
||||
return
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
self._save_relationships(conn, file_id=file_id, relationships=relationships)
|
||||
conn.commit()
|
||||
|
||||
def _save_relationships(
|
||||
self,
|
||||
conn: sqlite3.Connection,
|
||||
file_id: int,
|
||||
relationships: Optional[List[CodeRelationship]],
|
||||
) -> None:
|
||||
if not relationships:
|
||||
return
|
||||
|
||||
rows = conn.execute(
|
||||
"SELECT id, name FROM symbols WHERE file_id=? ORDER BY start_line, id",
|
||||
(file_id,),
|
||||
).fetchall()
|
||||
|
||||
name_to_id: Dict[str, int] = {}
|
||||
for row in rows:
|
||||
name = row["name"]
|
||||
if name not in name_to_id:
|
||||
name_to_id[name] = int(row["id"])
|
||||
|
||||
if not name_to_id:
|
||||
return
|
||||
|
||||
rel_rows: List[Tuple[int, str, str, int, Optional[str]]] = []
|
||||
seen: set[tuple[int, str, str, int, Optional[str]]] = set()
|
||||
|
||||
for rel in relationships:
|
||||
source_id = name_to_id.get(rel.source_symbol)
|
||||
if source_id is None:
|
||||
continue
|
||||
|
||||
target = (rel.target_symbol or "").strip()
|
||||
if not target:
|
||||
continue
|
||||
|
||||
rel_type = rel.relationship_type.value
|
||||
source_line = int(rel.source_line)
|
||||
key = (source_id, target, rel_type, source_line, rel.target_file)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
|
||||
rel_rows.append((source_id, target, rel_type, source_line, rel.target_file))
|
||||
|
||||
if not rel_rows:
|
||||
return
|
||||
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO code_relationships(
|
||||
source_symbol_id, target_qualified_name,
|
||||
relationship_type, source_line, target_file
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
""",
|
||||
rel_rows,
|
||||
)
|
||||
|
||||
def _save_merkle_hash(self, conn: sqlite3.Connection, file_id: int, content: str) -> None:
|
||||
"""Upsert a SHA-256 content hash for the given file_id (best-effort)."""
|
||||
try:
|
||||
digest = hashlib.sha256(content.encode("utf-8", errors="ignore")).hexdigest()
|
||||
now = time.time()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO merkle_hashes(file_id, sha256, updated_at)
|
||||
VALUES(?, ?, ?)
|
||||
ON CONFLICT(file_id) DO UPDATE SET
|
||||
sha256=excluded.sha256,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
(file_id, digest, now),
|
||||
)
|
||||
except sqlite3.Error:
|
||||
return
|
||||
|
||||
def add_files_batch(
|
||||
self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]]
|
||||
) -> int:
|
||||
@@ -312,6 +423,8 @@ class DirIndexStore:
|
||||
symbol_rows,
|
||||
)
|
||||
|
||||
self._save_merkle_hash(conn, file_id=file_id, content=content)
|
||||
|
||||
conn.commit()
|
||||
return count
|
||||
|
||||
@@ -395,9 +508,13 @@ class DirIndexStore:
|
||||
return float(row["mtime"]) if row and row["mtime"] else None
|
||||
|
||||
def needs_reindex(self, full_path: str | Path) -> bool:
|
||||
"""Check if a file needs reindexing based on mtime comparison.
|
||||
"""Check if a file needs reindexing.
|
||||
|
||||
Uses 1ms tolerance to handle filesystem timestamp precision variations.
|
||||
Default behavior uses mtime comparison (with 1ms tolerance).
|
||||
|
||||
When `Config.enable_merkle_detection` is enabled and Merkle metadata is
|
||||
available, uses SHA-256 content hash comparison (with mtime as a fast
|
||||
path to avoid hashing unchanged files).
|
||||
|
||||
Args:
|
||||
full_path: Complete source file path
|
||||
@@ -415,16 +532,154 @@ class DirIndexStore:
|
||||
except OSError:
|
||||
return False # Can't read file stats, skip
|
||||
|
||||
# Get stored mtime from database
|
||||
stored_mtime = self.get_file_mtime(full_path_obj)
|
||||
MTIME_TOLERANCE = 0.001
|
||||
|
||||
# File not in index, needs indexing
|
||||
if stored_mtime is None:
|
||||
# Fast path: mtime-only mode (default / backward-compatible)
|
||||
if self._config is None or not getattr(self._config, "enable_merkle_detection", False):
|
||||
stored_mtime = self.get_file_mtime(full_path_obj)
|
||||
if stored_mtime is None:
|
||||
return True
|
||||
return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE
|
||||
|
||||
full_path_str = str(full_path_obj)
|
||||
|
||||
# Hash-based change detection (best-effort, falls back to mtime when metadata missing)
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT f.id AS file_id, f.mtime AS mtime, mh.sha256 AS sha256
|
||||
FROM files f
|
||||
LEFT JOIN merkle_hashes mh ON mh.file_id = f.id
|
||||
WHERE f.full_path=?
|
||||
""",
|
||||
(full_path_str,),
|
||||
).fetchone()
|
||||
except sqlite3.Error:
|
||||
row = None
|
||||
|
||||
if row is None:
|
||||
return True
|
||||
|
||||
# Compare with 1ms tolerance for floating point precision
|
||||
MTIME_TOLERANCE = 0.001
|
||||
return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE
|
||||
stored_mtime = float(row["mtime"]) if row["mtime"] else None
|
||||
stored_hash = row["sha256"] if row["sha256"] else None
|
||||
file_id = int(row["file_id"])
|
||||
|
||||
# Missing Merkle data: fall back to mtime
|
||||
if stored_hash is None:
|
||||
if stored_mtime is None:
|
||||
return True
|
||||
return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE
|
||||
|
||||
# If mtime is unchanged within tolerance, assume unchanged without hashing.
|
||||
if stored_mtime is not None and abs(current_mtime - stored_mtime) <= MTIME_TOLERANCE:
|
||||
return False
|
||||
|
||||
try:
|
||||
current_text = full_path_obj.read_text(encoding="utf-8", errors="ignore")
|
||||
current_hash = hashlib.sha256(current_text.encode("utf-8", errors="ignore")).hexdigest()
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
if current_hash == stored_hash:
|
||||
# Content unchanged, but mtime drifted: update stored mtime to avoid repeated hashing.
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
conn.execute("UPDATE files SET mtime=? WHERE id=?", (current_mtime, file_id))
|
||||
conn.commit()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_merkle_root_hash(self) -> Optional[str]:
|
||||
"""Return the stored Merkle root hash for this directory index (if present)."""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT root_hash FROM merkle_state WHERE id=1"
|
||||
).fetchone()
|
||||
except sqlite3.Error:
|
||||
return None
|
||||
|
||||
return row["root_hash"] if row and row["root_hash"] else None
|
||||
|
||||
def update_merkle_root(self) -> Optional[str]:
|
||||
"""Compute and persist the Merkle root hash for this directory index.
|
||||
|
||||
The root hash includes:
|
||||
- Direct file hashes from `merkle_hashes`
|
||||
- Direct subdirectory root hashes (read from child `_index.db` files)
|
||||
"""
|
||||
if self._config is None or not getattr(self._config, "enable_merkle_detection", False):
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
file_rows = conn.execute(
|
||||
"""
|
||||
SELECT f.name AS name, mh.sha256 AS sha256
|
||||
FROM files f
|
||||
LEFT JOIN merkle_hashes mh ON mh.file_id = f.id
|
||||
ORDER BY f.name
|
||||
"""
|
||||
).fetchall()
|
||||
|
||||
subdir_rows = conn.execute(
|
||||
"SELECT name, index_path FROM subdirs ORDER BY name"
|
||||
).fetchall()
|
||||
except sqlite3.Error as exc:
|
||||
self.logger.debug("Failed to compute merkle root: %s", exc)
|
||||
return None
|
||||
|
||||
items: List[str] = []
|
||||
|
||||
for row in file_rows:
|
||||
name = row["name"]
|
||||
sha = (row["sha256"] or "").strip()
|
||||
items.append(f"f:{name}:{sha}")
|
||||
|
||||
def read_child_root(index_path: str) -> str:
|
||||
try:
|
||||
with sqlite3.connect(index_path) as child_conn:
|
||||
child_conn.row_factory = sqlite3.Row
|
||||
child_row = child_conn.execute(
|
||||
"SELECT root_hash FROM merkle_state WHERE id=1"
|
||||
).fetchone()
|
||||
return child_row["root_hash"] if child_row and child_row["root_hash"] else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
for row in subdir_rows:
|
||||
name = row["name"]
|
||||
index_path = row["index_path"]
|
||||
child_hash = read_child_root(index_path) if index_path else ""
|
||||
items.append(f"d:{name}:{child_hash}")
|
||||
|
||||
root_hash = hashlib.sha256("\n".join(items).encode("utf-8", errors="ignore")).hexdigest()
|
||||
now = time.time()
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO merkle_state(id, root_hash, updated_at)
|
||||
VALUES(1, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
root_hash=excluded.root_hash,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
(root_hash, now),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as exc:
|
||||
self.logger.debug("Failed to persist merkle root: %s", exc)
|
||||
return None
|
||||
|
||||
return root_hash
|
||||
|
||||
def add_file_incremental(
|
||||
self,
|
||||
@@ -433,6 +688,7 @@ class DirIndexStore:
|
||||
content: str,
|
||||
language: str,
|
||||
symbols: Optional[List[Symbol]] = None,
|
||||
relationships: Optional[List[CodeRelationship]] = None,
|
||||
) -> Optional[int]:
|
||||
"""Add or update a file only if it has changed (incremental indexing).
|
||||
|
||||
@@ -444,6 +700,7 @@ class DirIndexStore:
|
||||
content: File content for indexing
|
||||
language: Programming language identifier
|
||||
symbols: List of Symbol objects from the file
|
||||
relationships: Optional list of CodeRelationship edges from this file
|
||||
|
||||
Returns:
|
||||
Database file_id if indexed, None if skipped (unchanged)
|
||||
@@ -456,7 +713,7 @@ class DirIndexStore:
|
||||
return None # Skip unchanged file
|
||||
|
||||
# File changed or new, perform full indexing
|
||||
return self.add_file(name, full_path, content, language, symbols)
|
||||
return self.add_file(name, full_path, content, language, symbols, relationships)
|
||||
|
||||
def cleanup_deleted_files(self, source_dir: Path) -> int:
|
||||
"""Remove indexed files that no longer exist in the source directory.
|
||||
@@ -1767,6 +2024,39 @@ class DirIndexStore:
|
||||
"""
|
||||
)
|
||||
|
||||
# Precomputed graph neighbors cache for search expansion (v7)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS graph_neighbors (
|
||||
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||
neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||
relationship_depth INTEGER NOT NULL,
|
||||
PRIMARY KEY (source_symbol_id, neighbor_symbol_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Merkle hashes for incremental change detection (v8)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS merkle_hashes (
|
||||
file_id INTEGER PRIMARY KEY REFERENCES files(id) ON DELETE CASCADE,
|
||||
sha256 TEXT NOT NULL,
|
||||
updated_at REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS merkle_state (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
root_hash TEXT,
|
||||
updated_at REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Indexes (v5: removed idx_symbols_type)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_files_name ON files(name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_files_path ON files(full_path)")
|
||||
@@ -1780,6 +2070,14 @@ class DirIndexStore:
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)")
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth "
|
||||
"ON graph_neighbors(source_symbol_id, relationship_depth)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor "
|
||||
"ON graph_neighbors(neighbor_symbol_id)"
|
||||
)
|
||||
|
||||
except sqlite3.DatabaseError as exc:
|
||||
raise StorageError(f"Failed to create schema: {exc}") from exc
|
||||
|
||||
@@ -8,11 +8,13 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.parsers.factory import ParserFactory
|
||||
@@ -247,6 +249,9 @@ class IndexTreeBuilder:
|
||||
try:
|
||||
with DirIndexStore(result.index_path, config=self.config, global_index=global_index) as store:
|
||||
deleted_count = store.cleanup_deleted_files(result.source_path)
|
||||
if deleted_count > 0:
|
||||
_compute_graph_neighbors(store, logger=self.logger)
|
||||
store.update_merkle_root()
|
||||
total_deleted += deleted_count
|
||||
if deleted_count > 0:
|
||||
self.logger.debug("Removed %d deleted files from %s", deleted_count, result.source_path)
|
||||
@@ -575,6 +580,7 @@ class IndexTreeBuilder:
|
||||
content=text,
|
||||
language=language_id,
|
||||
symbols=indexed_file.symbols,
|
||||
relationships=indexed_file.relationships,
|
||||
)
|
||||
|
||||
files_count += 1
|
||||
@@ -584,6 +590,9 @@ class IndexTreeBuilder:
|
||||
self.logger.debug("Failed to index %s: %s", file_path, exc)
|
||||
continue
|
||||
|
||||
if files_count > 0:
|
||||
_compute_graph_neighbors(store, logger=self.logger)
|
||||
|
||||
# Get list of subdirectories
|
||||
subdirs = [
|
||||
d.name
|
||||
@@ -593,6 +602,7 @@ class IndexTreeBuilder:
|
||||
and not d.name.startswith(".")
|
||||
]
|
||||
|
||||
store.update_merkle_root()
|
||||
store.close()
|
||||
if global_index is not None:
|
||||
global_index.close()
|
||||
@@ -654,31 +664,29 @@ class IndexTreeBuilder:
|
||||
parent_index_db = self.mapper.source_to_index_db(parent_path)
|
||||
|
||||
try:
|
||||
store = DirIndexStore(parent_index_db)
|
||||
store.initialize()
|
||||
with DirIndexStore(parent_index_db, config=self.config) as store:
|
||||
for result in all_results:
|
||||
# Only register direct children (parent is one level up)
|
||||
if result.source_path.parent != parent_path:
|
||||
continue
|
||||
|
||||
for result in all_results:
|
||||
# Only register direct children (parent is one level up)
|
||||
if result.source_path.parent != parent_path:
|
||||
continue
|
||||
if result.error:
|
||||
continue
|
||||
|
||||
if result.error:
|
||||
continue
|
||||
# Register subdirectory link
|
||||
store.register_subdir(
|
||||
name=result.source_path.name,
|
||||
index_path=result.index_path,
|
||||
files_count=result.files_count,
|
||||
direct_files=result.files_count,
|
||||
)
|
||||
self.logger.debug(
|
||||
"Linked %s to parent %s",
|
||||
result.source_path.name,
|
||||
parent_path,
|
||||
)
|
||||
|
||||
# Register subdirectory link
|
||||
store.register_subdir(
|
||||
name=result.source_path.name,
|
||||
index_path=result.index_path,
|
||||
files_count=result.files_count,
|
||||
direct_files=result.files_count,
|
||||
)
|
||||
self.logger.debug(
|
||||
"Linked %s to parent %s",
|
||||
result.source_path.name,
|
||||
parent_path,
|
||||
)
|
||||
|
||||
store.close()
|
||||
store.update_merkle_root()
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.error(
|
||||
@@ -726,6 +734,164 @@ class IndexTreeBuilder:
|
||||
return files
|
||||
|
||||
|
||||
def _normalize_relationship_target(target: str) -> str:
|
||||
"""Best-effort normalization of a relationship target into a local symbol name."""
|
||||
target = (target or "").strip()
|
||||
if not target:
|
||||
return ""
|
||||
|
||||
# Drop trailing call parentheses when present (e.g., "foo()" -> "foo").
|
||||
if target.endswith("()"):
|
||||
target = target[:-2]
|
||||
|
||||
# Keep the leaf identifier for common qualified formats.
|
||||
for sep in ("::", ".", "#"):
|
||||
if sep in target:
|
||||
target = target.split(sep)[-1]
|
||||
|
||||
# Strip non-identifier suffix/prefix noise.
|
||||
target = re.sub(r"^[^A-Za-z0-9_]+", "", target)
|
||||
target = re.sub(r"[^A-Za-z0-9_]+$", "", target)
|
||||
return target
|
||||
|
||||
|
||||
def _compute_graph_neighbors(
|
||||
store: DirIndexStore,
|
||||
*,
|
||||
max_depth: int = 2,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
"""Compute and persist N-hop neighbors for all symbols in a directory index."""
|
||||
if max_depth <= 0:
|
||||
return
|
||||
|
||||
log = logger or logging.getLogger(__name__)
|
||||
|
||||
with store._lock:
|
||||
conn = store._get_connection()
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Ensure schema exists even for older databases pinned to the same user_version.
|
||||
try:
|
||||
from codexlens.storage.migrations.migration_007_add_graph_neighbors import upgrade
|
||||
|
||||
upgrade(conn)
|
||||
except Exception as exc:
|
||||
log.debug("Graph neighbor schema ensure failed: %s", exc)
|
||||
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("DELETE FROM graph_neighbors")
|
||||
except sqlite3.Error:
|
||||
# Table missing or schema mismatch; skip gracefully.
|
||||
return
|
||||
|
||||
try:
|
||||
symbol_rows = cursor.execute(
|
||||
"SELECT id, file_id, name FROM symbols"
|
||||
).fetchall()
|
||||
rel_rows = cursor.execute(
|
||||
"SELECT source_symbol_id, target_qualified_name FROM code_relationships"
|
||||
).fetchall()
|
||||
except sqlite3.Error:
|
||||
return
|
||||
|
||||
if not symbol_rows or not rel_rows:
|
||||
try:
|
||||
conn.commit()
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
return
|
||||
|
||||
symbol_file_by_id: Dict[int, int] = {}
|
||||
symbols_by_file_and_name: Dict[Tuple[int, str], List[int]] = {}
|
||||
symbols_by_name: Dict[str, List[int]] = {}
|
||||
|
||||
for row in symbol_rows:
|
||||
symbol_id = int(row["id"])
|
||||
file_id = int(row["file_id"])
|
||||
name = str(row["name"])
|
||||
symbol_file_by_id[symbol_id] = file_id
|
||||
symbols_by_file_and_name.setdefault((file_id, name), []).append(symbol_id)
|
||||
symbols_by_name.setdefault(name, []).append(symbol_id)
|
||||
|
||||
adjacency: Dict[int, Set[int]] = {}
|
||||
|
||||
for row in rel_rows:
|
||||
source_id = int(row["source_symbol_id"])
|
||||
target_raw = str(row["target_qualified_name"] or "")
|
||||
target_name = _normalize_relationship_target(target_raw)
|
||||
if not target_name:
|
||||
continue
|
||||
|
||||
source_file_id = symbol_file_by_id.get(source_id)
|
||||
if source_file_id is None:
|
||||
continue
|
||||
|
||||
candidate_ids = symbols_by_file_and_name.get((source_file_id, target_name))
|
||||
if not candidate_ids:
|
||||
global_candidates = symbols_by_name.get(target_name, [])
|
||||
# Only resolve cross-file by name when unambiguous.
|
||||
candidate_ids = global_candidates if len(global_candidates) == 1 else []
|
||||
|
||||
for target_id in candidate_ids:
|
||||
if target_id == source_id:
|
||||
continue
|
||||
adjacency.setdefault(source_id, set()).add(target_id)
|
||||
adjacency.setdefault(target_id, set()).add(source_id)
|
||||
|
||||
if not adjacency:
|
||||
try:
|
||||
conn.commit()
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
return
|
||||
|
||||
insert_rows: List[Tuple[int, int, int]] = []
|
||||
max_depth = min(int(max_depth), 2)
|
||||
|
||||
for source_id, first_hop in adjacency.items():
|
||||
if not first_hop:
|
||||
continue
|
||||
for neighbor_id in first_hop:
|
||||
insert_rows.append((source_id, neighbor_id, 1))
|
||||
|
||||
if max_depth < 2:
|
||||
continue
|
||||
|
||||
second_hop: Set[int] = set()
|
||||
for neighbor_id in first_hop:
|
||||
second_hop.update(adjacency.get(neighbor_id, set()))
|
||||
|
||||
second_hop.discard(source_id)
|
||||
second_hop.difference_update(first_hop)
|
||||
|
||||
for neighbor_id in second_hop:
|
||||
insert_rows.append((source_id, neighbor_id, 2))
|
||||
|
||||
if not insert_rows:
|
||||
try:
|
||||
conn.commit()
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
return
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
"""
|
||||
INSERT INTO graph_neighbors(
|
||||
source_symbol_id, neighbor_symbol_id, relationship_depth
|
||||
)
|
||||
VALUES(?, ?, ?)
|
||||
""",
|
||||
insert_rows,
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error:
|
||||
return
|
||||
|
||||
|
||||
# === Worker Function for ProcessPoolExecutor ===
|
||||
|
||||
|
||||
@@ -795,6 +961,7 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
|
||||
content=text,
|
||||
language=language_id,
|
||||
symbols=indexed_file.symbols,
|
||||
relationships=indexed_file.relationships,
|
||||
)
|
||||
|
||||
files_count += 1
|
||||
@@ -803,6 +970,9 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if files_count > 0:
|
||||
_compute_graph_neighbors(store)
|
||||
|
||||
# Get subdirectories
|
||||
ignore_dirs = {
|
||||
".git",
|
||||
@@ -821,6 +991,7 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
|
||||
if d.is_dir() and d.name not in ignore_dirs and not d.name.startswith(".")
|
||||
]
|
||||
|
||||
store.update_merkle_root()
|
||||
store.close()
|
||||
if global_index is not None:
|
||||
global_index.close()
|
||||
|
||||
136
codex-lens/src/codexlens/storage/merkle_tree.py
Normal file
136
codex-lens/src/codexlens/storage/merkle_tree.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Merkle tree utilities for change detection.
|
||||
|
||||
This module provides a generic, file-system based Merkle tree implementation
|
||||
that can be used to efficiently diff directory states.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
|
||||
def sha256_bytes(data: bytes) -> str:
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def sha256_text(text: str) -> str:
|
||||
return sha256_bytes(text.encode("utf-8", errors="ignore"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MerkleNode:
|
||||
"""A Merkle node representing either a file (leaf) or directory (internal)."""
|
||||
|
||||
name: str
|
||||
rel_path: str
|
||||
hash: str
|
||||
is_dir: bool
|
||||
children: Dict[str, "MerkleNode"] = field(default_factory=dict)
|
||||
|
||||
def iter_files(self) -> Iterable["MerkleNode"]:
|
||||
if not self.is_dir:
|
||||
yield self
|
||||
return
|
||||
for child in self.children.values():
|
||||
yield from child.iter_files()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MerkleTree:
|
||||
"""Merkle tree for a directory snapshot."""
|
||||
|
||||
root: MerkleNode
|
||||
|
||||
@classmethod
|
||||
def build_from_directory(cls, root_dir: Path) -> "MerkleTree":
|
||||
root_dir = Path(root_dir).resolve()
|
||||
node = cls._build_node(root_dir, base=root_dir)
|
||||
return cls(root=node)
|
||||
|
||||
@classmethod
|
||||
def _build_node(cls, path: Path, *, base: Path) -> MerkleNode:
|
||||
if path.is_file():
|
||||
rel = str(path.relative_to(base)).replace("\\", "/")
|
||||
return MerkleNode(
|
||||
name=path.name,
|
||||
rel_path=rel,
|
||||
hash=sha256_bytes(path.read_bytes()),
|
||||
is_dir=False,
|
||||
)
|
||||
|
||||
if not path.is_dir():
|
||||
rel = str(path.relative_to(base)).replace("\\", "/")
|
||||
return MerkleNode(name=path.name, rel_path=rel, hash="", is_dir=False)
|
||||
|
||||
children: Dict[str, MerkleNode] = {}
|
||||
for child in sorted(path.iterdir(), key=lambda p: p.name):
|
||||
child_node = cls._build_node(child, base=base)
|
||||
children[child_node.name] = child_node
|
||||
|
||||
items = [
|
||||
f"{'d' if n.is_dir else 'f'}:{name}:{n.hash}"
|
||||
for name, n in sorted(children.items(), key=lambda kv: kv[0])
|
||||
]
|
||||
dir_hash = sha256_text("\n".join(items))
|
||||
|
||||
rel_path = "." if path == base else str(path.relative_to(base)).replace("\\", "/")
|
||||
return MerkleNode(
|
||||
name="." if path == base else path.name,
|
||||
rel_path=rel_path,
|
||||
hash=dir_hash,
|
||||
is_dir=True,
|
||||
children=children,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def find_changed_files(old: Optional["MerkleTree"], new: Optional["MerkleTree"]) -> List[str]:
|
||||
"""Find changed/added/removed files between two trees.
|
||||
|
||||
Returns:
|
||||
List of relative file paths (POSIX-style separators).
|
||||
"""
|
||||
if old is None and new is None:
|
||||
return []
|
||||
if old is None:
|
||||
return sorted({n.rel_path for n in new.root.iter_files()}) # type: ignore[union-attr]
|
||||
if new is None:
|
||||
return sorted({n.rel_path for n in old.root.iter_files()})
|
||||
|
||||
changed: set[str] = set()
|
||||
|
||||
def walk(old_node: Optional[MerkleNode], new_node: Optional[MerkleNode]) -> None:
|
||||
if old_node is None and new_node is None:
|
||||
return
|
||||
|
||||
if old_node is None and new_node is not None:
|
||||
changed.update(n.rel_path for n in new_node.iter_files())
|
||||
return
|
||||
|
||||
if new_node is None and old_node is not None:
|
||||
changed.update(n.rel_path for n in old_node.iter_files())
|
||||
return
|
||||
|
||||
assert old_node is not None and new_node is not None
|
||||
|
||||
if old_node.hash == new_node.hash:
|
||||
return
|
||||
|
||||
if not old_node.is_dir and not new_node.is_dir:
|
||||
changed.add(new_node.rel_path)
|
||||
return
|
||||
|
||||
if old_node.is_dir != new_node.is_dir:
|
||||
changed.update(n.rel_path for n in old_node.iter_files())
|
||||
changed.update(n.rel_path for n in new_node.iter_files())
|
||||
return
|
||||
|
||||
names = set(old_node.children.keys()) | set(new_node.children.keys())
|
||||
for name in names:
|
||||
walk(old_node.children.get(name), new_node.children.get(name))
|
||||
|
||||
walk(old.root, new.root)
|
||||
return sorted(changed)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Migration 006: Ensure relationship tables and indexes exist.
|
||||
|
||||
This migration is intentionally idempotent. It creates the `code_relationships`
|
||||
table (used for graph visualization) and its indexes if missing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Ensuring code_relationships table exists...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL REFERENCES symbols (id) ON DELETE CASCADE,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Ensuring relationship indexes exist...")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)")
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Migration 007: Add precomputed graph neighbor table for search expansion.
|
||||
|
||||
Adds:
|
||||
- graph_neighbors: cached N-hop neighbors between symbols (keyed by symbol ids)
|
||||
|
||||
This table is derived data (a cache) and is safe to rebuild at any time.
|
||||
The migration is intentionally idempotent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating graph_neighbors table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS graph_neighbors (
|
||||
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||
neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||
relationship_depth INTEGER NOT NULL,
|
||||
PRIMARY KEY (source_symbol_id, neighbor_symbol_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for graph_neighbors...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth
|
||||
ON graph_neighbors(source_symbol_id, relationship_depth)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor
|
||||
ON graph_neighbors(neighbor_symbol_id)
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Migration 008: Add Merkle hash tables for content-based incremental indexing.
|
||||
|
||||
Adds:
|
||||
- merkle_hashes: per-file SHA-256 hashes (keyed by file_id)
|
||||
- merkle_state: directory-level root hash (single row, id=1)
|
||||
|
||||
Backfills merkle_hashes using the existing `files.content` column when available.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating merkle_hashes table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS merkle_hashes (
|
||||
file_id INTEGER PRIMARY KEY REFERENCES files(id) ON DELETE CASCADE,
|
||||
sha256 TEXT NOT NULL,
|
||||
updated_at REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating merkle_state table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS merkle_state (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
root_hash TEXT,
|
||||
updated_at REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Backfill file hashes from stored content (best-effort).
|
||||
try:
|
||||
rows = cursor.execute("SELECT id, content FROM files").fetchall()
|
||||
except Exception as exc:
|
||||
log.warning("Unable to backfill merkle hashes (files table missing?): %s", exc)
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
inserts: list[tuple[int, str, float]] = []
|
||||
|
||||
for row in rows:
|
||||
file_id = int(row[0])
|
||||
content = row[1]
|
||||
if content is None:
|
||||
continue
|
||||
try:
|
||||
digest = hashlib.sha256(str(content).encode("utf-8", errors="ignore")).hexdigest()
|
||||
inserts.append((file_id, digest, now))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not inserts:
|
||||
return
|
||||
|
||||
log.info("Backfilling %d file hashes...", len(inserts))
|
||||
cursor.executemany(
|
||||
"""
|
||||
INSERT INTO merkle_hashes(file_id, sha256, updated_at)
|
||||
VALUES(?, ?, ?)
|
||||
ON CONFLICT(file_id) DO UPDATE SET
|
||||
sha256=excluded.sha256,
|
||||
updated_at=excluded.updated_at
|
||||
""",
|
||||
inserts,
|
||||
)
|
||||
|
||||
188
codex-lens/tests/test_graph_expansion.py
Normal file
188
codex-lens/tests/test_graph_expansion.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import CodeRelationship, RelationshipType, SearchResult, Symbol
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
from codexlens.search.graph_expander import GraphExpander
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
from codexlens.storage.index_tree import _compute_graph_neighbors
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def temp_paths() -> Path:
|
||||
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
||||
root = Path(tmpdir.name)
|
||||
yield root
|
||||
try:
|
||||
tmpdir.cleanup()
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _create_index_with_neighbors(root: Path) -> tuple[PathMapper, Path, Path]:
|
||||
project_root = root / "project"
|
||||
project_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index_root = root / "indexes"
|
||||
mapper = PathMapper(index_root=index_root)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
index_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
content = "\n".join(
|
||||
[
|
||||
"def a():",
|
||||
" b()",
|
||||
"",
|
||||
"def b():",
|
||||
" c()",
|
||||
"",
|
||||
"def c():",
|
||||
" return 1",
|
||||
"",
|
||||
]
|
||||
)
|
||||
file_path = project_root / "module.py"
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
|
||||
symbols = [
|
||||
Symbol(name="a", kind="function", range=(1, 2), file=str(file_path)),
|
||||
Symbol(name="b", kind="function", range=(4, 5), file=str(file_path)),
|
||||
Symbol(name="c", kind="function", range=(7, 8), file=str(file_path)),
|
||||
]
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="a",
|
||||
target_symbol="b",
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=str(file_path),
|
||||
target_file=None,
|
||||
source_line=2,
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="b",
|
||||
target_symbol="c",
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=str(file_path),
|
||||
target_file=None,
|
||||
source_line=5,
|
||||
),
|
||||
]
|
||||
|
||||
config = Config(data_dir=root / "data")
|
||||
store = DirIndexStore(index_db_path, config=config)
|
||||
store.initialize()
|
||||
store.add_file(
|
||||
name=file_path.name,
|
||||
full_path=file_path,
|
||||
content=content,
|
||||
language="python",
|
||||
symbols=symbols,
|
||||
relationships=relationships,
|
||||
)
|
||||
_compute_graph_neighbors(store)
|
||||
store.close()
|
||||
|
||||
return mapper, project_root, file_path
|
||||
|
||||
|
||||
def test_graph_neighbors_precomputed_two_hop(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
|
||||
conn = sqlite3.connect(str(index_db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s1.name AS source_name, s2.name AS neighbor_name, gn.relationship_depth
|
||||
FROM graph_neighbors gn
|
||||
JOIN symbols s1 ON s1.id = gn.source_symbol_id
|
||||
JOIN symbols s2 ON s2.id = gn.neighbor_symbol_id
|
||||
ORDER BY source_name, neighbor_name, relationship_depth
|
||||
"""
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
triples = {(r["source_name"], r["neighbor_name"], int(r["relationship_depth"])) for r in rows}
|
||||
assert ("a", "b", 1) in triples
|
||||
assert ("a", "c", 2) in triples
|
||||
assert ("b", "c", 1) in triples
|
||||
assert ("c", "b", 1) in triples
|
||||
assert file_path.exists()
|
||||
|
||||
|
||||
def test_graph_expander_returns_related_results_with_depth_metadata(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
_ = project_root
|
||||
|
||||
expander = GraphExpander(mapper, config=Config(data_dir=temp_paths / "data", graph_expansion_depth=2))
|
||||
base = SearchResult(
|
||||
path=str(file_path.resolve()),
|
||||
score=1.0,
|
||||
excerpt="",
|
||||
content=None,
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
symbol_name="a",
|
||||
symbol_kind="function",
|
||||
)
|
||||
related = expander.expand([base], depth=2, max_expand=1, max_related=10)
|
||||
|
||||
depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in related}
|
||||
assert depth_by_symbol.get("b") == 1
|
||||
assert depth_by_symbol.get("c") == 2
|
||||
|
||||
|
||||
def test_chain_search_populates_related_results_when_enabled(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
_ = file_path
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
|
||||
config = Config(
|
||||
data_dir=temp_paths / "data",
|
||||
enable_graph_expansion=True,
|
||||
graph_expansion_depth=2,
|
||||
)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
try:
|
||||
options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False)
|
||||
result = engine.search("b", project_root, options)
|
||||
|
||||
assert result.results
|
||||
assert result.results[0].symbol_name == "a"
|
||||
|
||||
depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in result.related_results}
|
||||
assert depth_by_symbol.get("b") == 1
|
||||
assert depth_by_symbol.get("c") == 2
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
|
||||
def test_chain_search_related_results_empty_when_disabled(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
_ = file_path
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
|
||||
config = Config(
|
||||
data_dir=temp_paths / "data",
|
||||
enable_graph_expansion=False,
|
||||
)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
try:
|
||||
options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False)
|
||||
result = engine.search("b", project_root, options)
|
||||
assert result.related_results == []
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
@@ -869,3 +869,47 @@ class TestHybridSearchAdaptiveWeights:
|
||||
) as rerank_mock:
|
||||
engine_on.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
assert rerank_mock.call_count == 1
|
||||
|
||||
def test_cross_encoder_reranking_enabled(self, tmp_path):
|
||||
"""Cross-encoder stage runs only when explicitly enabled via config."""
|
||||
from unittest.mock import patch
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
class DummyEmbedder:
|
||||
def embed(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
return [[1.0, 0.0] for _ in texts]
|
||||
|
||||
class DummyReranker:
|
||||
def score_pairs(self, pairs, batch_size=32):
|
||||
return [0.0 for _ in pairs]
|
||||
|
||||
config = Config(
|
||||
data_dir=tmp_path / "ce",
|
||||
enable_reranking=True,
|
||||
enable_cross_encoder_rerank=True,
|
||||
reranker_top_k=10,
|
||||
)
|
||||
engine = HybridSearchEngine(config=config, embedder=DummyEmbedder())
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.rerank_results",
|
||||
side_effect=lambda q, r, e, top_k=50: r,
|
||||
) as rerank_mock, patch.object(
|
||||
HybridSearchEngine,
|
||||
"_get_cross_encoder_reranker",
|
||||
return_value=DummyReranker(),
|
||||
) as get_ce_mock, patch(
|
||||
"codexlens.search.hybrid_search.cross_encoder_rerank",
|
||||
side_effect=lambda q, r, ce, top_k=50: r,
|
||||
) as ce_mock:
|
||||
engine.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
assert rerank_mock.call_count == 1
|
||||
assert get_ce_mock.call_count == 1
|
||||
assert ce_mock.call_count == 1
|
||||
|
||||
100
codex-lens/tests/test_merkle_detection.py
Normal file
100
codex-lens/tests/test_merkle_detection.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
|
||||
|
||||
def _make_merkle_config(tmp_path: Path) -> Config:
|
||||
data_dir = tmp_path / "data"
|
||||
return Config(
|
||||
data_dir=data_dir,
|
||||
venv_path=data_dir / "venv",
|
||||
enable_merkle_detection=True,
|
||||
)
|
||||
|
||||
|
||||
class TestMerkleDetection:
|
||||
def test_needs_reindex_touch_updates_mtime(self, tmp_path: Path) -> None:
|
||||
config = _make_merkle_config(tmp_path)
|
||||
source_dir = tmp_path / "src"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = source_dir / "a.py"
|
||||
file_path.write_text("print('hi')\n", encoding="utf-8")
|
||||
original_content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
index_db = tmp_path / "_index.db"
|
||||
with DirIndexStore(index_db, config=config) as store:
|
||||
store.add_file(
|
||||
name=file_path.name,
|
||||
full_path=file_path,
|
||||
content=original_content,
|
||||
language="python",
|
||||
symbols=[],
|
||||
)
|
||||
|
||||
stored_mtime_before = store.get_file_mtime(file_path)
|
||||
assert stored_mtime_before is not None
|
||||
|
||||
# Touch file without changing content
|
||||
time.sleep(0.02)
|
||||
file_path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
assert store.needs_reindex(file_path) is False
|
||||
|
||||
stored_mtime_after = store.get_file_mtime(file_path)
|
||||
assert stored_mtime_after is not None
|
||||
assert stored_mtime_after != stored_mtime_before
|
||||
|
||||
current_mtime = file_path.stat().st_mtime
|
||||
assert abs(stored_mtime_after - current_mtime) <= 0.001
|
||||
|
||||
def test_parent_root_changes_when_child_changes(self, tmp_path: Path) -> None:
|
||||
config = _make_merkle_config(tmp_path)
|
||||
|
||||
source_root = tmp_path / "project"
|
||||
child_dir = source_root / "child"
|
||||
child_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
child_file = child_dir / "child.py"
|
||||
child_file.write_text("x = 1\n", encoding="utf-8")
|
||||
|
||||
child_db = tmp_path / "child_index.db"
|
||||
parent_db = tmp_path / "parent_index.db"
|
||||
|
||||
with DirIndexStore(child_db, config=config) as child_store:
|
||||
child_store.add_file(
|
||||
name=child_file.name,
|
||||
full_path=child_file,
|
||||
content=child_file.read_text(encoding="utf-8"),
|
||||
language="python",
|
||||
symbols=[],
|
||||
)
|
||||
child_root_1 = child_store.update_merkle_root()
|
||||
assert child_root_1
|
||||
|
||||
with DirIndexStore(parent_db, config=config) as parent_store:
|
||||
parent_store.register_subdir(name="child", index_path=child_db, files_count=1)
|
||||
parent_root_1 = parent_store.update_merkle_root()
|
||||
assert parent_root_1
|
||||
|
||||
time.sleep(0.02)
|
||||
child_file.write_text("x = 2\n", encoding="utf-8")
|
||||
|
||||
with DirIndexStore(child_db, config=config) as child_store:
|
||||
child_store.add_file(
|
||||
name=child_file.name,
|
||||
full_path=child_file,
|
||||
content=child_file.read_text(encoding="utf-8"),
|
||||
language="python",
|
||||
symbols=[],
|
||||
)
|
||||
child_root_2 = child_store.update_merkle_root()
|
||||
assert child_root_2
|
||||
assert child_root_2 != child_root_1
|
||||
|
||||
with DirIndexStore(parent_db, config=config) as parent_store:
|
||||
parent_root_2 = parent_store.update_merkle_root()
|
||||
assert parent_root_2
|
||||
assert parent_root_2 != parent_root_1
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for performance optimizations in CodexLens storage.
|
||||
"""Tests for performance optimizations in CodexLens.
|
||||
|
||||
This module tests the following optimizations:
|
||||
1. Normalized keywords search (migration_001)
|
||||
2. Optimized path lookup in registry
|
||||
3. Prefix-mode symbol search
|
||||
4. Graph expansion neighbor precompute overhead (<20%)
|
||||
5. Cross-encoder reranking latency (<200ms)
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -479,3 +481,113 @@ class TestPerformanceComparison:
|
||||
print(f" Substring: {substring_time*1000:.3f}ms ({len(substring_results)} results)")
|
||||
print(f" Ratio: {prefix_time/substring_time:.2f}x")
|
||||
print(f" Note: Performance benefits appear with 1000+ symbols")
|
||||
|
||||
|
||||
class TestPerformanceBenchmarks:
|
||||
"""Benchmark-style assertions for key performance requirements."""
|
||||
|
||||
def test_graph_expansion_indexing_overhead_under_20_percent(self, temp_index_db, tmp_path):
|
||||
"""Graph neighbor precompute adds <20% overhead versus indexing baseline."""
|
||||
from codexlens.entities import CodeRelationship, RelationshipType, Symbol
|
||||
from codexlens.storage.index_tree import _compute_graph_neighbors
|
||||
|
||||
store = temp_index_db
|
||||
|
||||
file_count = 60
|
||||
symbols_per_file = 8
|
||||
|
||||
start = time.perf_counter()
|
||||
for file_idx in range(file_count):
|
||||
file_path = tmp_path / f"graph_{file_idx}.py"
|
||||
lines = []
|
||||
for sym_idx in range(symbols_per_file):
|
||||
lines.append(f"def func_{file_idx}_{sym_idx}():")
|
||||
lines.append(f" return {sym_idx}")
|
||||
lines.append("")
|
||||
content = "\n".join(lines)
|
||||
|
||||
symbols = [
|
||||
Symbol(
|
||||
name=f"func_{file_idx}_{sym_idx}",
|
||||
kind="function",
|
||||
range=(sym_idx * 3 + 1, sym_idx * 3 + 2),
|
||||
file=str(file_path),
|
||||
)
|
||||
for sym_idx in range(symbols_per_file)
|
||||
]
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol=f"func_{file_idx}_{sym_idx}",
|
||||
target_symbol=f"func_{file_idx}_{sym_idx + 1}",
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=str(file_path),
|
||||
target_file=None,
|
||||
source_line=sym_idx * 3 + 2,
|
||||
)
|
||||
for sym_idx in range(symbols_per_file - 1)
|
||||
]
|
||||
|
||||
store.add_file(
|
||||
name=file_path.name,
|
||||
full_path=file_path,
|
||||
content=content,
|
||||
language="python",
|
||||
symbols=symbols,
|
||||
relationships=relationships,
|
||||
)
|
||||
baseline_time = time.perf_counter() - start
|
||||
|
||||
durations = []
|
||||
for _ in range(3):
|
||||
start = time.perf_counter()
|
||||
_compute_graph_neighbors(store)
|
||||
durations.append(time.perf_counter() - start)
|
||||
graph_time = min(durations)
|
||||
|
||||
# Sanity-check that the benchmark exercised graph neighbor generation.
|
||||
conn = store._get_connection()
|
||||
neighbor_count = conn.execute(
|
||||
"SELECT COUNT(*) as c FROM graph_neighbors"
|
||||
).fetchone()["c"]
|
||||
assert neighbor_count > 0
|
||||
|
||||
assert baseline_time > 0.0
|
||||
overhead_ratio = graph_time / baseline_time
|
||||
assert overhead_ratio < 0.2, (
|
||||
f"Graph neighbor precompute overhead too high: {overhead_ratio:.2%} "
|
||||
f"(baseline={baseline_time:.3f}s, graph={graph_time:.3f}s)"
|
||||
)
|
||||
|
||||
def test_cross_encoder_reranking_latency_under_200ms(self):
|
||||
"""Cross-encoder rerank step completes under 200ms (excluding model load)."""
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import cross_encoder_rerank
|
||||
|
||||
query = "find function"
|
||||
results = [
|
||||
SearchResult(
|
||||
path=f"file_{idx}.py",
|
||||
score=1.0 / (idx + 1),
|
||||
excerpt=f"def func_{idx}():\n return {idx}",
|
||||
symbol_name=f"func_{idx}",
|
||||
symbol_kind="function",
|
||||
)
|
||||
for idx in range(50)
|
||||
]
|
||||
|
||||
class DummyReranker:
|
||||
def score_pairs(self, pairs, batch_size=32):
|
||||
_ = batch_size
|
||||
# Return deterministic pseudo-logits to exercise sigmoid normalization.
|
||||
return [float(i) for i in range(len(pairs))]
|
||||
|
||||
reranker = DummyReranker()
|
||||
|
||||
start = time.perf_counter()
|
||||
reranked = cross_encoder_rerank(query, results, reranker, top_k=50, batch_size=32)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
|
||||
assert len(reranked) == len(results)
|
||||
assert any(r.metadata.get("cross_encoder_reranked") for r in reranked[:50])
|
||||
assert elapsed_ms < 200.0, f"Cross-encoder rerank too slow: {elapsed_ms:.1f}ms"
|
||||
|
||||
@@ -19,7 +19,7 @@ from codexlens.entities import Symbol
|
||||
|
||||
|
||||
class TestSchemaCleanupMigration:
|
||||
"""Test schema cleanup migration (v4 -> v5)."""
|
||||
"""Test schema cleanup migration (v4 -> latest)."""
|
||||
|
||||
def test_migration_from_v4_to_v5(self):
|
||||
"""Test that migration successfully removes deprecated fields."""
|
||||
@@ -129,10 +129,12 @@ class TestSchemaCleanupMigration:
|
||||
# Now initialize store - this should trigger migration
|
||||
store.initialize()
|
||||
|
||||
# Verify schema version is now 5
|
||||
# Verify schema version is now the latest
|
||||
conn = store._get_connection()
|
||||
version_row = conn.execute("PRAGMA user_version").fetchone()
|
||||
assert version_row[0] == 5, f"Expected schema version 5, got {version_row[0]}"
|
||||
assert version_row[0] == DirIndexStore.SCHEMA_VERSION, (
|
||||
f"Expected schema version {DirIndexStore.SCHEMA_VERSION}, got {version_row[0]}"
|
||||
)
|
||||
|
||||
# Check that deprecated columns are removed
|
||||
# 1. Check semantic_metadata doesn't have keywords column
|
||||
@@ -166,7 +168,7 @@ class TestSchemaCleanupMigration:
|
||||
store.close()
|
||||
|
||||
def test_new_database_has_clean_schema(self):
|
||||
"""Test that new databases are created with clean schema (v5)."""
|
||||
"""Test that new databases are created with clean schema (latest)."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
store = DirIndexStore(db_path)
|
||||
@@ -174,9 +176,9 @@ class TestSchemaCleanupMigration:
|
||||
|
||||
conn = store._get_connection()
|
||||
|
||||
# Verify schema version is 5
|
||||
# Verify schema version is the latest
|
||||
version_row = conn.execute("PRAGMA user_version").fetchone()
|
||||
assert version_row[0] == 5
|
||||
assert version_row[0] == DirIndexStore.SCHEMA_VERSION
|
||||
|
||||
# Check that new schema doesn't have deprecated columns
|
||||
cursor = conn.execute("PRAGMA table_info(semantic_metadata)")
|
||||
|
||||
@@ -582,6 +582,7 @@ class TestChainSearchResult:
|
||||
)
|
||||
assert result.query == "test"
|
||||
assert result.results == []
|
||||
assert result.related_results == []
|
||||
assert result.symbols == []
|
||||
assert result.stats.dirs_searched == 0
|
||||
|
||||
|
||||
@@ -1173,6 +1173,7 @@ class TestChainSearchResultExtended:
|
||||
assert result.query == "test query"
|
||||
assert len(result.results) == 1
|
||||
assert len(result.symbols) == 1
|
||||
assert result.related_results == []
|
||||
assert result.stats.dirs_searched == 5
|
||||
|
||||
def test_result_with_empty_collections(self):
|
||||
@@ -1186,5 +1187,6 @@ class TestChainSearchResultExtended:
|
||||
|
||||
assert result.query == "no matches"
|
||||
assert result.results == []
|
||||
assert result.related_results == []
|
||||
assert result.symbols == []
|
||||
assert result.stats.dirs_searched == 0
|
||||
|
||||
@@ -110,6 +110,37 @@ class DataProcessor:
|
||||
assert result is not None
|
||||
assert len(result.symbols) == 0
|
||||
|
||||
def test_extracts_relationships_with_alias_resolution(self):
|
||||
parser = TreeSitterSymbolParser("python")
|
||||
code = """
|
||||
import os.path as osp
|
||||
from math import sqrt as sq
|
||||
|
||||
class Base:
|
||||
pass
|
||||
|
||||
class Child(Base):
|
||||
pass
|
||||
|
||||
def main():
|
||||
osp.join("a", "b")
|
||||
sq(4)
|
||||
"""
|
||||
result = parser.parse(code, Path("test.py"))
|
||||
|
||||
assert result is not None
|
||||
|
||||
rels = [r for r in result.relationships if r.source_symbol == "main"]
|
||||
targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"}
|
||||
assert "os.path.join" in targets
|
||||
assert "math.sqrt" in targets
|
||||
|
||||
inherits = [
|
||||
r for r in result.relationships
|
||||
if r.source_symbol == "Child" and r.relationship_type.value == "inherits"
|
||||
]
|
||||
assert any(r.target_symbol == "Base" for r in inherits)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
|
||||
class TestTreeSitterJavaScriptParser:
|
||||
@@ -175,6 +206,22 @@ export const arrowFunc = () => {}
|
||||
assert "exported" in names
|
||||
assert "arrowFunc" in names
|
||||
|
||||
def test_extracts_relationships_with_import_alias(self):
|
||||
parser = TreeSitterSymbolParser("javascript")
|
||||
code = """
|
||||
import { readFile as rf } from "fs";
|
||||
|
||||
function main() {
|
||||
rf("a");
|
||||
}
|
||||
"""
|
||||
result = parser.parse(code, Path("test.js"))
|
||||
|
||||
assert result is not None
|
||||
rels = [r for r in result.relationships if r.source_symbol == "main"]
|
||||
targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"}
|
||||
assert "fs.readFile" in targets
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
|
||||
class TestTreeSitterTypeScriptParser:
|
||||
|
||||
Reference in New Issue
Block a user