From 31a45f1f30c609f5d012c9470e80e749587d959c Mon Sep 17 00:00:00 2001 From: catlog22 Date: Wed, 31 Dec 2025 16:58:59 +0800 Subject: [PATCH] Add graph expansion and cross-encoder reranking features - Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors. - Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring. - Created migrations to establish necessary database tables for relationships and graph neighbors. - Developed tests for graph expansion functionality, ensuring related results are populated correctly. - Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead. - Updated schema cleanup tests to reflect changes in versioning and deprecated fields. - Added new test cases for Treesitter parser to validate relationship extraction with alias resolution. --- codex-lens/pyproject.toml | 6 + codex-lens/src/codexlens/cli/commands.py | 4 +- codex-lens/src/codexlens/config.py | 10 + codex-lens/src/codexlens/entities.py | 3 +- .../codexlens/indexing/symbol_extractor.py | 88 ++- codex-lens/src/codexlens/parsers/factory.py | 149 +++++- .../codexlens/parsers/treesitter_parser.py | 500 +++++++++++++++++- .../src/codexlens/search/chain_search.py | 17 +- codex-lens/src/codexlens/search/enrichment.py | 21 + .../src/codexlens/search/graph_expander.py | 264 +++++++++ .../src/codexlens/search/hybrid_search.py | 54 +- codex-lens/src/codexlens/search/ranking.py | 111 ++++ codex-lens/src/codexlens/semantic/reranker.py | 86 +++ codex-lens/src/codexlens/storage/dir_index.py | 322 ++++++++++- .../src/codexlens/storage/index_tree.py | 217 +++++++- .../src/codexlens/storage/merkle_tree.py | 136 +++++ .../migration_006_enhance_relationships.py | 37 ++ .../migration_007_add_graph_neighbors.py | 47 ++ .../migration_008_add_merkle_hashes.py | 81 +++ codex-lens/tests/test_graph_expansion.py | 188 +++++++ codex-lens/tests/test_hybrid_search_e2e.py | 44 ++ codex-lens/tests/test_merkle_detection.py | 100 ++++ .../tests/test_performance_optimizations.py | 114 +++- .../tests/test_schema_cleanup_migration.py | 14 +- codex-lens/tests/test_search_comprehensive.py | 1 + codex-lens/tests/test_search_full_coverage.py | 2 + codex-lens/tests/test_treesitter_parser.py | 47 ++ 27 files changed, 2566 insertions(+), 97 deletions(-) create mode 100644 codex-lens/src/codexlens/search/graph_expander.py create mode 100644 codex-lens/src/codexlens/semantic/reranker.py create mode 100644 codex-lens/src/codexlens/storage/merkle_tree.py create mode 100644 codex-lens/src/codexlens/storage/migrations/migration_006_enhance_relationships.py create mode 100644 codex-lens/src/codexlens/storage/migrations/migration_007_add_graph_neighbors.py create mode 100644 codex-lens/src/codexlens/storage/migrations/migration_008_add_merkle_hashes.py create mode 100644 codex-lens/tests/test_graph_expansion.py create mode 100644 codex-lens/tests/test_merkle_detection.py diff --git a/codex-lens/pyproject.toml b/codex-lens/pyproject.toml index 3198bea6..fbb61294 100644 --- a/codex-lens/pyproject.toml +++ b/codex-lens/pyproject.toml @@ -49,6 +49,12 @@ semantic-directml = [ "onnxruntime-directml>=1.15.0", # DirectML support ] +# Cross-encoder reranking (second-stage, optional) +# Install with: pip install codexlens[reranker] +reranker = [ + "sentence-transformers>=2.2", +] + # Encoding detection for non-UTF8 files encoding = [ "chardet>=5.0", diff --git a/codex-lens/src/codexlens/cli/commands.py b/codex-lens/src/codexlens/cli/commands.py index a42983ff..10f66e9b 100644 --- a/codex-lens/src/codexlens/cli/commands.py +++ b/codex-lens/src/codexlens/cli/commands.py @@ -407,7 +407,7 @@ def search( registry.initialize() mapper = PathMapper() - engine = ChainSearchEngine(registry, mapper) + engine = ChainSearchEngine(registry, mapper, config=config) # Auto-detect mode if set to "auto" actual_mode = mode @@ -550,7 +550,7 @@ def symbol( registry.initialize() mapper = PathMapper() - engine = ChainSearchEngine(registry, mapper) + engine = ChainSearchEngine(registry, mapper, config=config) options = SearchOptions(depth=depth, total_limit=limit) syms = engine.search_symbols(name, search_path, kind=kind, options=options) diff --git a/codex-lens/src/codexlens/config.py b/codex-lens/src/codexlens/config.py index 7436b18d..11c550bf 100644 --- a/codex-lens/src/codexlens/config.py +++ b/codex-lens/src/codexlens/config.py @@ -105,12 +105,22 @@ class Config: # Indexing/search optimizations global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path + enable_merkle_detection: bool = True # Enable content-hash based incremental indexing + + # Graph expansion (search-time, uses precomputed neighbors) + enable_graph_expansion: bool = False + graph_expansion_depth: int = 2 # Optional search reranking (disabled by default) enable_reranking: bool = False reranking_top_k: int = 50 symbol_boost_factor: float = 1.5 + # Optional cross-encoder reranking (second stage, requires codexlens[reranker]) + enable_cross_encoder_rerank: bool = False + reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" + reranker_top_k: int = 50 + # Multi-endpoint configuration for litellm backend embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list) # List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}] diff --git a/codex-lens/src/codexlens/entities.py b/codex-lens/src/codexlens/entities.py index 2e1477d0..9de58c07 100644 --- a/codex-lens/src/codexlens/entities.py +++ b/codex-lens/src/codexlens/entities.py @@ -58,6 +58,7 @@ class IndexedFile(BaseModel): language: str = Field(..., min_length=1) symbols: List[Symbol] = Field(default_factory=list) chunks: List[SemanticChunk] = Field(default_factory=list) + relationships: List["CodeRelationship"] = Field(default_factory=list) @field_validator("path", "language") @classmethod @@ -70,7 +71,7 @@ class IndexedFile(BaseModel): class RelationshipType(str, Enum): """Types of code relationships.""" - CALL = "call" + CALL = "calls" INHERITS = "inherits" IMPORTS = "imports" diff --git a/codex-lens/src/codexlens/indexing/symbol_extractor.py b/codex-lens/src/codexlens/indexing/symbol_extractor.py index 092541a2..45439e7b 100644 --- a/codex-lens/src/codexlens/indexing/symbol_extractor.py +++ b/codex-lens/src/codexlens/indexing/symbol_extractor.py @@ -4,6 +4,11 @@ import sqlite3 from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +try: + from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser +except Exception: # pragma: no cover - optional dependency / platform variance + TreeSitterSymbolParser = None # type: ignore[assignment] + class SymbolExtractor: """Extract symbols and relationships from source code using regex patterns.""" @@ -118,7 +123,7 @@ class SymbolExtractor: patterns = self.PATTERNS[lang] symbols = [] - relationships = [] + relationships: List[Dict] = [] lines = content.split('\n') current_scope = None @@ -141,33 +146,62 @@ class SymbolExtractor: }) current_scope = name - # Extract imports - if 'import' in patterns: - match = re.search(patterns['import'], line) - if match: - import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1) - if import_target and current_scope: - relationships.append({ - 'source_scope': current_scope, - 'target': import_target.strip(), - 'type': 'imports', - 'file_path': str(file_path), - 'line': line_num, - }) + if TreeSitterSymbolParser is not None: + try: + ts_parser = TreeSitterSymbolParser(lang, file_path) + if ts_parser.is_available(): + indexed = ts_parser.parse(content, file_path) + if indexed is not None and indexed.relationships: + relationships = [ + { + "source_scope": r.source_symbol, + "target": r.target_symbol, + "type": r.relationship_type.value, + "file_path": str(file_path), + "line": r.source_line, + } + for r in indexed.relationships + ] + except Exception: + relationships = [] - # Extract function calls (simplified) - if 'call' in patterns and current_scope: - for match in re.finditer(patterns['call'], line): - call_name = match.group(1) - # Skip common keywords and the current function - if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]: - relationships.append({ - 'source_scope': current_scope, - 'target': call_name, - 'type': 'calls', - 'file_path': str(file_path), - 'line': line_num, - }) + # Regex fallback for relationships (when tree-sitter is unavailable) + if not relationships: + current_scope = None + for line_num, line in enumerate(lines, 1): + for kind in ['function', 'class']: + if kind in patterns: + match = re.search(patterns[kind], line) + if match: + current_scope = match.group(1) + + # Extract imports + if 'import' in patterns: + match = re.search(patterns['import'], line) + if match: + import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1) + if import_target and current_scope: + relationships.append({ + 'source_scope': current_scope, + 'target': import_target.strip(), + 'type': 'imports', + 'file_path': str(file_path), + 'line': line_num, + }) + + # Extract function calls (simplified) + if 'call' in patterns and current_scope: + for match in re.finditer(patterns['call'], line): + call_name = match.group(1) + # Skip common keywords and the current function + if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]: + relationships.append({ + 'source_scope': current_scope, + 'target': call_name, + 'type': 'calls', + 'file_path': str(file_path), + 'line': line_num, + }) return symbols, relationships diff --git a/codex-lens/src/codexlens/parsers/factory.py b/codex-lens/src/codexlens/parsers/factory.py index cd868ba6..0f8f4f14 100644 --- a/codex-lens/src/codexlens/parsers/factory.py +++ b/codex-lens/src/codexlens/parsers/factory.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Dict, List, Optional, Protocol from codexlens.config import Config -from codexlens.entities import IndexedFile, Symbol +from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser @@ -30,36 +30,39 @@ class SimpleRegexParser: if self.language_id in {"python", "javascript", "typescript"}: ts_parser = TreeSitterSymbolParser(self.language_id, path) if ts_parser.is_available(): - symbols = ts_parser.parse_symbols(text) - if symbols is not None: - return IndexedFile( - path=str(path.resolve()), - language=self.language_id, - symbols=symbols, - chunks=[], - ) + indexed = ts_parser.parse(text, path) + if indexed is not None: + return indexed # Fallback to regex parsing if self.language_id == "python": symbols = _parse_python_symbols_regex(text) + relationships = _parse_python_relationships_regex(text, path) elif self.language_id in {"javascript", "typescript"}: symbols = _parse_js_ts_symbols_regex(text) + relationships = _parse_js_ts_relationships_regex(text, path) elif self.language_id == "java": symbols = _parse_java_symbols(text) + relationships = [] elif self.language_id == "go": symbols = _parse_go_symbols(text) + relationships = [] elif self.language_id == "markdown": symbols = _parse_markdown_symbols(text) + relationships = [] elif self.language_id == "text": symbols = _parse_text_symbols(text) + relationships = [] else: symbols = _parse_generic_symbols(text) + relationships = [] return IndexedFile( path=str(path.resolve()), language=self.language_id, symbols=symbols, chunks=[], + relationships=relationships, ) @@ -78,6 +81,9 @@ class ParserFactory: _PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b") _PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(") +_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)") +_PY_CALL_RE = re.compile(r"(? List[Symbol]: return symbols +def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]: + relationships: List[CodeRelationship] = [] + current_scope: str | None = None + source_file = str(path.resolve()) + + for line_num, line in enumerate(text.splitlines(), start=1): + class_match = _PY_CLASS_RE.match(line) + if class_match: + current_scope = class_match.group(1) + continue + + def_match = _PY_DEF_RE.match(line) + if def_match: + current_scope = def_match.group(1) + continue + + if current_scope is None: + continue + + import_match = _PY_IMPORT_RE.search(line) + if import_match: + import_target = import_match.group(1) or import_match.group(2) + if import_target: + relationships.append( + CodeRelationship( + source_symbol=current_scope, + target_symbol=import_target.strip(), + relationship_type=RelationshipType.IMPORTS, + source_file=source_file, + target_file=None, + source_line=line_num, + ) + ) + + for call_match in _PY_CALL_RE.finditer(line): + call_name = call_match.group(1) + if call_name in { + "if", + "for", + "while", + "return", + "print", + "len", + "str", + "int", + "float", + "list", + "dict", + "set", + "tuple", + current_scope, + }: + continue + relationships.append( + CodeRelationship( + source_symbol=current_scope, + target_symbol=call_name, + relationship_type=RelationshipType.CALL, + source_file=source_file, + target_file=None, + source_line=line_num, + ) + ) + + return relationships + + _JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(") _JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b") _JS_ARROW_RE = re.compile( r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>" ) _JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{") +_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]") +_JS_CALL_RE = re.compile(r"(? List[Symbol]: @@ -174,6 +249,61 @@ def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]: return symbols +def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]: + relationships: List[CodeRelationship] = [] + current_scope: str | None = None + source_file = str(path.resolve()) + + for line_num, line in enumerate(text.splitlines(), start=1): + class_match = _JS_CLASS_RE.match(line) + if class_match: + current_scope = class_match.group(1) + continue + + func_match = _JS_FUNC_RE.match(line) + if func_match: + current_scope = func_match.group(1) + continue + + arrow_match = _JS_ARROW_RE.match(line) + if arrow_match: + current_scope = arrow_match.group(1) + continue + + if current_scope is None: + continue + + import_match = _JS_IMPORT_RE.search(line) + if import_match: + relationships.append( + CodeRelationship( + source_symbol=current_scope, + target_symbol=import_match.group(1), + relationship_type=RelationshipType.IMPORTS, + source_file=source_file, + target_file=None, + source_line=line_num, + ) + ) + + for call_match in _JS_CALL_RE.finditer(line): + call_name = call_match.group(1) + if call_name in {current_scope}: + continue + relationships.append( + CodeRelationship( + source_symbol=current_scope, + target_symbol=call_name, + relationship_type=RelationshipType.CALL, + source_file=source_file, + target_file=None, + source_line=line_num, + ) + ) + + return relationships + + _JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b") _JAVA_METHOD_RE = re.compile( r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\(" @@ -253,4 +383,3 @@ def _parse_text_symbols(text: str) -> List[Symbol]: # Text files don't have structured symbols, return empty list # The file content will still be indexed for FTS search return [] - diff --git a/codex-lens/src/codexlens/parsers/treesitter_parser.py b/codex-lens/src/codexlens/parsers/treesitter_parser.py index b104a30a..4ae44cae 100644 --- a/codex-lens/src/codexlens/parsers/treesitter_parser.py +++ b/codex-lens/src/codexlens/parsers/treesitter_parser.py @@ -1,12 +1,17 @@ """Tree-sitter based parser for CodexLens. -Provides precise AST-level parsing with fallback to regex-based parsing. +Provides precise AST-level parsing via tree-sitter. + +Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`. +If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()` +return `None`; callers should use a regex-based fallback such as +`codexlens.parsers.factory.SimpleRegexParser`. """ from __future__ import annotations from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional try: from tree_sitter import Language as TreeSitterLanguage @@ -19,7 +24,7 @@ except ImportError: TreeSitterParser = None # type: ignore[assignment] TREE_SITTER_AVAILABLE = False -from codexlens.entities import IndexedFile, Symbol +from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol from codexlens.parsers.tokenizer import get_default_tokenizer @@ -85,6 +90,16 @@ class TreeSitterSymbolParser: """ return self._parser is not None and self._language is not None + def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]: + if not self.is_available() or self._parser is None: + return None + + try: + source_bytes = text.encode("utf8") + tree = self._parser.parse(source_bytes) # type: ignore[attr-defined] + return source_bytes, tree.root_node + except Exception: + return None def parse_symbols(self, text: str) -> Optional[List[Symbol]]: """Parse source code and extract symbols without creating IndexedFile. @@ -95,17 +110,15 @@ class TreeSitterSymbolParser: Returns: List of symbols if parsing succeeds, None if tree-sitter unavailable """ - if not self.is_available() or self._parser is None: + parsed = self._parse_tree(text) + if parsed is None: return None + source_bytes, root = parsed try: - source_bytes = text.encode("utf8") - tree = self._parser.parse(source_bytes) # type: ignore[attr-defined] - root = tree.root_node - return self._extract_symbols(source_bytes, root) except Exception: - # Gracefully handle parsing errors + # Gracefully handle extraction errors return None def parse(self, text: str, path: Path) -> Optional[IndexedFile]: @@ -118,19 +131,21 @@ class TreeSitterSymbolParser: Returns: IndexedFile if parsing succeeds, None if tree-sitter unavailable """ - if not self.is_available() or self._parser is None: + parsed = self._parse_tree(text) + if parsed is None: return None + source_bytes, root = parsed try: - symbols = self.parse_symbols(text) - if symbols is None: - return None + symbols = self._extract_symbols(source_bytes, root) + relationships = self._extract_relationships(source_bytes, root, path) return IndexedFile( path=str(path.resolve()), language=self.language_id, symbols=symbols, chunks=[], + relationships=relationships, ) except Exception: # Gracefully handle parsing errors @@ -153,6 +168,465 @@ class TreeSitterSymbolParser: else: return [] + def _extract_relationships( + self, + source_bytes: bytes, + root: TreeSitterNode, + path: Path, + ) -> List[CodeRelationship]: + if self.language_id == "python": + return self._extract_python_relationships(source_bytes, root, path) + if self.language_id in {"javascript", "typescript"}: + return self._extract_js_ts_relationships(source_bytes, root, path) + return [] + + def _extract_python_relationships( + self, + source_bytes: bytes, + root: TreeSitterNode, + path: Path, + ) -> List[CodeRelationship]: + source_file = str(path.resolve()) + relationships: List[CodeRelationship] = [] + + scope_stack: List[str] = [] + alias_stack: List[Dict[str, str]] = [{}] + + def record_import(target_symbol: str, source_line: int) -> None: + if not target_symbol.strip() or not scope_stack: + return + relationships.append( + CodeRelationship( + source_symbol=scope_stack[-1], + target_symbol=target_symbol, + relationship_type=RelationshipType.IMPORTS, + source_file=source_file, + target_file=None, + source_line=source_line, + ) + ) + + def record_call(target_symbol: str, source_line: int) -> None: + if not target_symbol.strip() or not scope_stack: + return + base = target_symbol.split(".", 1)[0] + if base in {"self", "cls"}: + return + relationships.append( + CodeRelationship( + source_symbol=scope_stack[-1], + target_symbol=target_symbol, + relationship_type=RelationshipType.CALL, + source_file=source_file, + target_file=None, + source_line=source_line, + ) + ) + + def record_inherits(target_symbol: str, source_line: int) -> None: + if not target_symbol.strip() or not scope_stack: + return + relationships.append( + CodeRelationship( + source_symbol=scope_stack[-1], + target_symbol=target_symbol, + relationship_type=RelationshipType.INHERITS, + source_file=source_file, + target_file=None, + source_line=source_line, + ) + ) + + def visit(node: TreeSitterNode) -> None: + pushed_scope = False + pushed_aliases = False + + if node.type in {"class_definition", "function_definition", "async_function_definition"}: + name_node = node.child_by_field_name("name") + if name_node is not None: + scope_name = self._node_text(source_bytes, name_node).strip() + if scope_name: + scope_stack.append(scope_name) + pushed_scope = True + alias_stack.append(dict(alias_stack[-1])) + pushed_aliases = True + + if node.type == "class_definition" and pushed_scope: + superclasses = node.child_by_field_name("superclasses") + if superclasses is not None: + for child in superclasses.children: + dotted = self._python_expression_to_dotted(source_bytes, child) + if not dotted: + continue + resolved = self._resolve_alias_dotted(dotted, alias_stack[-1]) + record_inherits(resolved, self._node_start_line(node)) + + if node.type in {"import_statement", "import_from_statement"}: + updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node) + if updates: + alias_stack[-1].update(updates) + for target_symbol in imported_targets: + record_import(target_symbol, self._node_start_line(node)) + + if node.type == "call": + fn_node = node.child_by_field_name("function") + if fn_node is not None: + dotted = self._python_expression_to_dotted(source_bytes, fn_node) + if dotted: + resolved = self._resolve_alias_dotted(dotted, alias_stack[-1]) + record_call(resolved, self._node_start_line(node)) + + for child in node.children: + visit(child) + + if pushed_aliases: + alias_stack.pop() + if pushed_scope: + scope_stack.pop() + + visit(root) + return relationships + + def _extract_js_ts_relationships( + self, + source_bytes: bytes, + root: TreeSitterNode, + path: Path, + ) -> List[CodeRelationship]: + source_file = str(path.resolve()) + relationships: List[CodeRelationship] = [] + + scope_stack: List[str] = [] + alias_stack: List[Dict[str, str]] = [{}] + + def record_import(target_symbol: str, source_line: int) -> None: + if not target_symbol.strip() or not scope_stack: + return + relationships.append( + CodeRelationship( + source_symbol=scope_stack[-1], + target_symbol=target_symbol, + relationship_type=RelationshipType.IMPORTS, + source_file=source_file, + target_file=None, + source_line=source_line, + ) + ) + + def record_call(target_symbol: str, source_line: int) -> None: + if not target_symbol.strip() or not scope_stack: + return + base = target_symbol.split(".", 1)[0] + if base in {"this", "super"}: + return + relationships.append( + CodeRelationship( + source_symbol=scope_stack[-1], + target_symbol=target_symbol, + relationship_type=RelationshipType.CALL, + source_file=source_file, + target_file=None, + source_line=source_line, + ) + ) + + def record_inherits(target_symbol: str, source_line: int) -> None: + if not target_symbol.strip() or not scope_stack: + return + relationships.append( + CodeRelationship( + source_symbol=scope_stack[-1], + target_symbol=target_symbol, + relationship_type=RelationshipType.INHERITS, + source_file=source_file, + target_file=None, + source_line=source_line, + ) + ) + + def visit(node: TreeSitterNode) -> None: + pushed_scope = False + pushed_aliases = False + + if node.type in {"function_declaration", "generator_function_declaration"}: + name_node = node.child_by_field_name("name") + if name_node is not None: + scope_name = self._node_text(source_bytes, name_node).strip() + if scope_name: + scope_stack.append(scope_name) + pushed_scope = True + alias_stack.append(dict(alias_stack[-1])) + pushed_aliases = True + + if node.type in {"class_declaration", "class"}: + name_node = node.child_by_field_name("name") + if name_node is not None: + scope_name = self._node_text(source_bytes, name_node).strip() + if scope_name: + scope_stack.append(scope_name) + pushed_scope = True + alias_stack.append(dict(alias_stack[-1])) + pushed_aliases = True + + if pushed_scope: + superclass = node.child_by_field_name("superclass") + if superclass is not None: + dotted = self._js_expression_to_dotted(source_bytes, superclass) + if dotted: + resolved = self._resolve_alias_dotted(dotted, alias_stack[-1]) + record_inherits(resolved, self._node_start_line(node)) + + if node.type == "variable_declarator": + name_node = node.child_by_field_name("name") + value_node = node.child_by_field_name("value") + if ( + name_node is not None + and value_node is not None + and name_node.type in {"identifier", "property_identifier"} + and value_node.type == "arrow_function" + ): + scope_name = self._node_text(source_bytes, name_node).strip() + if scope_name: + scope_stack.append(scope_name) + pushed_scope = True + alias_stack.append(dict(alias_stack[-1])) + pushed_aliases = True + + if node.type == "method_definition" and self._has_class_ancestor(node): + name_node = node.child_by_field_name("name") + if name_node is not None: + scope_name = self._node_text(source_bytes, name_node).strip() + if scope_name and scope_name != "constructor": + scope_stack.append(scope_name) + pushed_scope = True + alias_stack.append(dict(alias_stack[-1])) + pushed_aliases = True + + if node.type in {"import_declaration", "import_statement"}: + updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node) + if updates: + alias_stack[-1].update(updates) + for target_symbol in imported_targets: + record_import(target_symbol, self._node_start_line(node)) + + # Best-effort support for CommonJS require() imports: + # const fs = require("fs") + if node.type == "variable_declarator": + name_node = node.child_by_field_name("name") + value_node = node.child_by_field_name("value") + if ( + name_node is not None + and value_node is not None + and name_node.type == "identifier" + and value_node.type == "call_expression" + ): + callee = value_node.child_by_field_name("function") + args = value_node.child_by_field_name("arguments") + if ( + callee is not None + and self._node_text(source_bytes, callee).strip() == "require" + and args is not None + ): + module_name = self._js_first_string_argument(source_bytes, args) + if module_name: + alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name + record_import(module_name, self._node_start_line(node)) + + if node.type == "call_expression": + fn_node = node.child_by_field_name("function") + if fn_node is not None: + dotted = self._js_expression_to_dotted(source_bytes, fn_node) + if dotted: + resolved = self._resolve_alias_dotted(dotted, alias_stack[-1]) + record_call(resolved, self._node_start_line(node)) + + for child in node.children: + visit(child) + + if pushed_aliases: + alias_stack.pop() + if pushed_scope: + scope_stack.pop() + + visit(root) + return relationships + + def _node_start_line(self, node: TreeSitterNode) -> int: + return node.start_point[0] + 1 + + def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str: + dotted = (dotted or "").strip() + if not dotted: + return "" + + base, sep, rest = dotted.partition(".") + resolved_base = aliases.get(base, base) + if not rest: + return resolved_base + if resolved_base and rest: + return f"{resolved_base}.{rest}" + return resolved_base + + def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str: + if node.type in {"identifier", "dotted_name"}: + return self._node_text(source_bytes, node).strip() + if node.type == "attribute": + obj = node.child_by_field_name("object") + attr = node.child_by_field_name("attribute") + obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else "" + attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else "" + if obj_text and attr_text: + return f"{obj_text}.{attr_text}" + return obj_text or attr_text + return "" + + def _python_import_aliases_and_targets( + self, + source_bytes: bytes, + node: TreeSitterNode, + ) -> tuple[Dict[str, str], List[str]]: + aliases: Dict[str, str] = {} + targets: List[str] = [] + + if node.type == "import_statement": + for child in node.children: + if child.type == "aliased_import": + name_node = child.child_by_field_name("name") + alias_node = child.child_by_field_name("alias") + if name_node is None: + continue + module_name = self._node_text(source_bytes, name_node).strip() + if not module_name: + continue + bound_name = ( + self._node_text(source_bytes, alias_node).strip() + if alias_node is not None + else module_name.split(".", 1)[0] + ) + if bound_name: + aliases[bound_name] = module_name + targets.append(module_name) + elif child.type == "dotted_name": + module_name = self._node_text(source_bytes, child).strip() + if not module_name: + continue + bound_name = module_name.split(".", 1)[0] + if bound_name: + aliases[bound_name] = bound_name + targets.append(module_name) + + if node.type == "import_from_statement": + module_name = "" + module_node = node.child_by_field_name("module_name") + if module_node is None: + for child in node.children: + if child.type == "dotted_name": + module_node = child + break + if module_node is not None: + module_name = self._node_text(source_bytes, module_node).strip() + + for child in node.children: + if child.type == "aliased_import": + name_node = child.child_by_field_name("name") + alias_node = child.child_by_field_name("alias") + if name_node is None: + continue + imported_name = self._node_text(source_bytes, name_node).strip() + if not imported_name or imported_name == "*": + continue + target = f"{module_name}.{imported_name}" if module_name else imported_name + bound_name = ( + self._node_text(source_bytes, alias_node).strip() + if alias_node is not None + else imported_name + ) + if bound_name: + aliases[bound_name] = target + targets.append(target) + elif child.type == "identifier": + imported_name = self._node_text(source_bytes, child).strip() + if not imported_name or imported_name in {"from", "import", "*"}: + continue + target = f"{module_name}.{imported_name}" if module_name else imported_name + aliases[imported_name] = target + targets.append(target) + + return aliases, targets + + def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str: + if node.type in {"this", "super"}: + return node.type + if node.type in {"identifier", "property_identifier"}: + return self._node_text(source_bytes, node).strip() + if node.type == "member_expression": + obj = node.child_by_field_name("object") + prop = node.child_by_field_name("property") + obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else "" + prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else "" + if obj_text and prop_text: + return f"{obj_text}.{prop_text}" + return obj_text or prop_text + return "" + + def _js_import_aliases_and_targets( + self, + source_bytes: bytes, + node: TreeSitterNode, + ) -> tuple[Dict[str, str], List[str]]: + aliases: Dict[str, str] = {} + targets: List[str] = [] + + module_name = "" + source_node = node.child_by_field_name("source") + if source_node is not None: + module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip() + if module_name: + targets.append(module_name) + + for child in node.children: + if child.type == "import_clause": + for clause_child in child.children: + if clause_child.type == "identifier": + # Default import: import React from "react" + local = self._node_text(source_bytes, clause_child).strip() + if local and module_name: + aliases[local] = module_name + if clause_child.type == "namespace_import": + # Namespace import: import * as fs from "fs" + name_node = clause_child.child_by_field_name("name") + if name_node is not None and module_name: + local = self._node_text(source_bytes, name_node).strip() + if local: + aliases[local] = module_name + if clause_child.type == "named_imports": + for spec in clause_child.children: + if spec.type != "import_specifier": + continue + name_node = spec.child_by_field_name("name") + alias_node = spec.child_by_field_name("alias") + if name_node is None: + continue + imported = self._node_text(source_bytes, name_node).strip() + if not imported: + continue + local = ( + self._node_text(source_bytes, alias_node).strip() + if alias_node is not None + else imported + ) + if local and module_name: + aliases[local] = f"{module_name}.{imported}" + targets.append(f"{module_name}.{imported}") + + return aliases, targets + + def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str: + for child in args_node.children: + if child.type == "string": + return self._node_text(source_bytes, child).strip().strip("\"'").strip() + return "" + def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]: """Extract Python symbols from AST. diff --git a/codex-lens/src/codexlens/search/chain_search.py b/codex-lens/src/codexlens/search/chain_search.py index 23a20f19..958c0285 100644 --- a/codex-lens/src/codexlens/search/chain_search.py +++ b/codex-lens/src/codexlens/search/chain_search.py @@ -83,6 +83,7 @@ class ChainSearchResult: Attributes: query: Original search query results: List of SearchResult objects + related_results: Expanded results from graph neighbors (optional) symbols: List of Symbol objects (if include_symbols=True) stats: SearchStats with execution metrics """ @@ -90,6 +91,7 @@ class ChainSearchResult: results: List[SearchResult] symbols: List[Symbol] stats: SearchStats + related_results: List[SearchResult] = field(default_factory=list) class ChainSearchEngine: @@ -236,13 +238,26 @@ class ChainSearchEngine: index_paths, query, None, options.total_limit ) + # Optional: graph expansion using precomputed neighbors + related_results: List[SearchResult] = [] + if self._config is not None and getattr(self._config, "enable_graph_expansion", False): + try: + from codexlens.search.enrichment import SearchEnrichmentPipeline + + pipeline = SearchEnrichmentPipeline(self.mapper, config=self._config) + related_results = pipeline.expand_related_results(final_results) + except Exception as exc: + self.logger.debug("Graph expansion failed: %s", exc) + related_results = [] + stats.time_ms = (time.time() - start_time) * 1000 return ChainSearchResult( query=query, results=final_results, symbols=symbols, - stats=stats + stats=stats, + related_results=related_results, ) def search_files_only(self, query: str, diff --git a/codex-lens/src/codexlens/search/enrichment.py b/codex-lens/src/codexlens/search/enrichment.py index c231551c..110f56b7 100644 --- a/codex-lens/src/codexlens/search/enrichment.py +++ b/codex-lens/src/codexlens/search/enrichment.py @@ -4,6 +4,11 @@ import sqlite3 from pathlib import Path from typing import List, Dict, Any, Optional +from codexlens.config import Config +from codexlens.entities import SearchResult +from codexlens.search.graph_expander import GraphExpander +from codexlens.storage.path_mapper import PathMapper + class RelationshipEnricher: """Enriches search results with code graph relationships.""" @@ -148,3 +153,19 @@ class RelationshipEnricher: def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() + + +class SearchEnrichmentPipeline: + """Search post-processing pipeline (optional enrichments).""" + + def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None: + self._config = config + self._graph_expander = GraphExpander(mapper, config=config) + + def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]: + """Expand base results with related symbols when enabled in config.""" + if self._config is None or not getattr(self._config, "enable_graph_expansion", False): + return [] + + depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2) + return self._graph_expander.expand(results, depth=depth) diff --git a/codex-lens/src/codexlens/search/graph_expander.py b/codex-lens/src/codexlens/search/graph_expander.py new file mode 100644 index 00000000..73261d53 --- /dev/null +++ b/codex-lens/src/codexlens/search/graph_expander.py @@ -0,0 +1,264 @@ +"""Graph expansion for search results using precomputed neighbors. + +Expands top search results with related symbol definitions by traversing +precomputed N-hop neighbors stored in the per-directory index databases. +""" + +from __future__ import annotations + +import logging +import sqlite3 +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +from codexlens.config import Config +from codexlens.entities import SearchResult +from codexlens.storage.path_mapper import PathMapper + +logger = logging.getLogger(__name__) + + +def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]: + return (result.path, result.symbol_name, result.start_line, result.end_line) + + +def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]: + if content is None: + return None + if start_line is None or end_line is None: + return None + if start_line < 1 or end_line < start_line: + return None + + lines = content.splitlines() + start_idx = max(0, start_line - 1) + end_idx = min(len(lines), end_line) + if start_idx >= len(lines): + return None + return "\n".join(lines[start_idx:end_idx]) + + +class GraphExpander: + """Expands SearchResult lists with related symbols from the code graph.""" + + def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None: + self._mapper = mapper + self._config = config + self._logger = logging.getLogger(__name__) + + def expand( + self, + results: Sequence[SearchResult], + *, + depth: Optional[int] = None, + max_expand: int = 10, + max_related: int = 50, + ) -> List[SearchResult]: + """Expand top results with related symbols. + + Args: + results: Base ranked results. + depth: Maximum relationship depth to include (defaults to Config or 2). + max_expand: Only expand the top-N base results to bound cost. + max_related: Maximum related results to return. + + Returns: + A list of related SearchResult objects with relationship_depth metadata. + """ + if not results: + return [] + + configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2 + max_depth = int(depth if depth is not None else configured_depth) + if max_depth <= 0: + return [] + max_depth = min(max_depth, 2) + + expand_count = max(0, int(max_expand)) + related_limit = max(0, int(max_related)) + if expand_count == 0 or related_limit == 0: + return [] + + seen = {_result_key(r) for r in results} + related_results: List[SearchResult] = [] + conn_cache: Dict[Path, sqlite3.Connection] = {} + + try: + for base in list(results)[:expand_count]: + if len(related_results) >= related_limit: + break + + if not base.symbol_name or not base.path: + continue + + index_path = self._mapper.source_to_index_db(Path(base.path).parent) + conn = conn_cache.get(index_path) + if conn is None: + conn = self._connect_readonly(index_path) + if conn is None: + continue + conn_cache[index_path] = conn + + source_ids = self._resolve_source_symbol_ids( + conn, + file_path=base.path, + symbol_name=base.symbol_name, + symbol_kind=base.symbol_kind, + ) + if not source_ids: + continue + + for source_id in source_ids: + neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit) + for neighbor_id, rel_depth in neighbors: + if len(related_results) >= related_limit: + break + row = self._get_symbol_details(conn, neighbor_id) + if row is None: + continue + + path = str(row["full_path"]) + symbol_name = str(row["name"]) + symbol_kind = str(row["kind"]) + start_line = int(row["start_line"]) if row["start_line"] is not None else None + end_line = int(row["end_line"]) if row["end_line"] is not None else None + content_block = _slice_content_block( + str(row["content"]) if row["content"] is not None else "", + start_line, + end_line, + ) + + score = float(base.score) * (0.5 ** int(rel_depth)) + candidate = SearchResult( + path=path, + score=max(0.0, score), + excerpt=None, + content=content_block, + start_line=start_line, + end_line=end_line, + symbol_name=symbol_name, + symbol_kind=symbol_kind, + metadata={"relationship_depth": int(rel_depth)}, + ) + + key = _result_key(candidate) + if key in seen: + continue + seen.add(key) + related_results.append(candidate) + + finally: + for conn in conn_cache.values(): + try: + conn.close() + except Exception: + pass + + return related_results + + def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]: + try: + if not index_path.exists() or index_path.stat().st_size == 0: + return None + except OSError: + return None + + try: + conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn + except Exception as exc: + self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc) + return None + + def _resolve_source_symbol_ids( + self, + conn: sqlite3.Connection, + *, + file_path: str, + symbol_name: str, + symbol_kind: Optional[str], + ) -> List[int]: + try: + if symbol_kind: + rows = conn.execute( + """ + SELECT s.id + FROM symbols s + JOIN files f ON f.id = s.file_id + WHERE f.full_path = ? AND s.name = ? AND s.kind = ? + """, + (file_path, symbol_name, symbol_kind), + ).fetchall() + else: + rows = conn.execute( + """ + SELECT s.id + FROM symbols s + JOIN files f ON f.id = s.file_id + WHERE f.full_path = ? AND s.name = ? + """, + (file_path, symbol_name), + ).fetchall() + except sqlite3.Error: + return [] + + ids: List[int] = [] + for row in rows: + try: + ids.append(int(row["id"])) + except Exception: + continue + return ids + + def _get_neighbors( + self, + conn: sqlite3.Connection, + source_symbol_id: int, + *, + max_depth: int, + limit: int, + ) -> List[Tuple[int, int]]: + try: + rows = conn.execute( + """ + SELECT neighbor_symbol_id, relationship_depth + FROM graph_neighbors + WHERE source_symbol_id = ? AND relationship_depth <= ? + ORDER BY relationship_depth ASC, neighbor_symbol_id ASC + LIMIT ? + """, + (int(source_symbol_id), int(max_depth), int(limit)), + ).fetchall() + except sqlite3.Error: + return [] + + neighbors: List[Tuple[int, int]] = [] + for row in rows: + try: + neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"]))) + except Exception: + continue + return neighbors + + def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]: + try: + return conn.execute( + """ + SELECT + s.id, + s.name, + s.kind, + s.start_line, + s.end_line, + f.full_path, + f.content + FROM symbols s + JOIN files f ON f.id = s.file_id + WHERE s.id = ? + """, + (int(symbol_id),), + ).fetchone() + except sqlite3.Error: + return None + diff --git a/codex-lens/src/codexlens/search/hybrid_search.py b/codex-lens/src/codexlens/search/hybrid_search.py index 0e5ccbe2..b461e02a 100644 --- a/codex-lens/src/codexlens/search/hybrid_search.py +++ b/codex-lens/src/codexlens/search/hybrid_search.py @@ -34,6 +34,7 @@ from codexlens.config import Config from codexlens.entities import SearchResult from codexlens.search.ranking import ( apply_symbol_boost, + cross_encoder_rerank, get_rrf_weights, reciprocal_rank_fusion, rerank_results, @@ -77,6 +78,7 @@ class HybridSearchEngine: self.weights = weights or self.DEFAULT_WEIGHTS.copy() self._config = config self.embedder = embedder + self.reranker: Any = None def search( self, @@ -112,6 +114,14 @@ class HybridSearchEngine: >>> for r in results[:5]: ... print(f"{r.path}: {r.score:.3f}") """ + # Defensive: avoid creating/locking an index database when callers pass + # an empty placeholder file (common in tests and misconfigured callers). + try: + if index_path.exists() and index_path.stat().st_size == 0: + return [] + except OSError: + return [] + # Determine which backends to use backends = {} @@ -180,9 +190,30 @@ class HybridSearchEngine: query, fused_results[:100], self.embedder, - top_k=self._config.reranking_top_k, + top_k=( + 100 + if self._config.enable_cross_encoder_rerank + else self._config.reranking_top_k + ), ) + # Optional: cross-encoder reranking as a second stage + if ( + self._config is not None + and self._config.enable_reranking + and self._config.enable_cross_encoder_rerank + ): + with timer("cross_encoder_rerank", self.logger): + if self.reranker is None: + self.reranker = self._get_cross_encoder_reranker() + if self.reranker is not None: + fused_results = cross_encoder_rerank( + query, + fused_results, + self.reranker, + top_k=self._config.reranker_top_k, + ) + # Apply final limit return fused_results[:limit] @@ -222,6 +253,27 @@ class HybridSearchEngine: ) return None + def _get_cross_encoder_reranker(self) -> Any: + if self._config is None: + return None + + try: + from codexlens.semantic.reranker import CrossEncoderReranker, check_cross_encoder_available + except Exception as exc: + self.logger.debug("Cross-encoder reranker unavailable: %s", exc) + return None + + ok, err = check_cross_encoder_available() + if not ok: + self.logger.debug("Cross-encoder reranker unavailable: %s", err) + return None + + try: + return CrossEncoderReranker(model_name=self._config.reranker_model) + except Exception as exc: + self.logger.debug("Failed to initialize cross-encoder reranker: %s", exc) + return None + def _search_parallel( self, index_path: Path, diff --git a/codex-lens/src/codexlens/search/ranking.py b/codex-lens/src/codexlens/search/ranking.py index 6f59eaf2..34bd7719 100644 --- a/codex-lens/src/codexlens/search/ranking.py +++ b/codex-lens/src/codexlens/search/ranking.py @@ -379,6 +379,117 @@ def rerank_results( return reranked_results +def cross_encoder_rerank( + query: str, + results: List[SearchResult], + reranker: Any, + top_k: int = 50, + batch_size: int = 32, +) -> List[SearchResult]: + """Second-stage reranking using a cross-encoder model. + + This function is dependency-agnostic: callers can pass any object that exposes + a compatible `score_pairs(pairs, batch_size=...)` method. + """ + if not results: + return [] + + if reranker is None or top_k <= 0: + return results + + rerank_count = min(int(top_k), len(results)) + + def text_for_pair(r: SearchResult) -> str: + if r.excerpt and r.excerpt.strip(): + return r.excerpt + if r.content and r.content.strip(): + return r.content + if r.chunk and r.chunk.content and r.chunk.content.strip(): + return r.chunk.content + return r.symbol_name or r.path + + pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]] + + try: + if hasattr(reranker, "score_pairs"): + raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size)) + elif hasattr(reranker, "predict"): + raw_scores = reranker.predict(pairs, batch_size=int(batch_size)) + else: + return results + except Exception: + return results + + if not raw_scores or len(raw_scores) != rerank_count: + return results + + scores = [float(s) for s in raw_scores] + min_s = min(scores) + max_s = max(scores) + + def sigmoid(x: float) -> float: + # Clamp to keep exp() stable. + x = max(-50.0, min(50.0, x)) + return 1.0 / (1.0 + math.exp(-x)) + + if 0.0 <= min_s and max_s <= 1.0: + probs = scores + else: + probs = [sigmoid(s) for s in scores] + + reranked_results: List[SearchResult] = [] + + for idx, result in enumerate(results): + if idx < rerank_count: + prev_score = float(result.score) + ce_score = scores[idx] + ce_prob = probs[idx] + combined_score = 0.5 * prev_score + 0.5 * ce_prob + + reranked_results.append( + SearchResult( + path=result.path, + score=combined_score, + excerpt=result.excerpt, + content=result.content, + symbol=result.symbol, + chunk=result.chunk, + metadata={ + **result.metadata, + "pre_cross_encoder_score": prev_score, + "cross_encoder_score": ce_score, + "cross_encoder_prob": ce_prob, + "cross_encoder_reranked": True, + }, + start_line=result.start_line, + end_line=result.end_line, + symbol_name=result.symbol_name, + symbol_kind=result.symbol_kind, + additional_locations=list(result.additional_locations), + ) + ) + else: + reranked_results.append( + SearchResult( + path=result.path, + score=result.score, + excerpt=result.excerpt, + content=result.content, + symbol=result.symbol, + chunk=result.chunk, + metadata={**result.metadata}, + start_line=result.start_line, + end_line=result.end_line, + symbol_name=result.symbol_name, + symbol_kind=result.symbol_kind, + additional_locations=list(result.additional_locations), + ) + ) + + reranked_results.sort(key=lambda r: r.score, reverse=True) + return reranked_results + + def normalize_bm25_score(score: float) -> float: """Normalize BM25 scores from SQLite FTS5 to 0-1 range. diff --git a/codex-lens/src/codexlens/semantic/reranker.py b/codex-lens/src/codexlens/semantic/reranker.py new file mode 100644 index 00000000..99a720fe --- /dev/null +++ b/codex-lens/src/codexlens/semantic/reranker.py @@ -0,0 +1,86 @@ +"""Optional cross-encoder reranker for second-stage search ranking. + +Install with: pip install codexlens[reranker] +""" + +from __future__ import annotations + +import logging +import threading +from typing import List, Sequence, Tuple + +logger = logging.getLogger(__name__) + +try: + from sentence_transformers import CrossEncoder as _CrossEncoder + + CROSS_ENCODER_AVAILABLE = True + _import_error: str | None = None +except ImportError as exc: # pragma: no cover - optional dependency + _CrossEncoder = None # type: ignore[assignment] + CROSS_ENCODER_AVAILABLE = False + _import_error = str(exc) + + +def check_cross_encoder_available() -> tuple[bool, str | None]: + if CROSS_ENCODER_AVAILABLE: + return True, None + return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]" + + +class CrossEncoderReranker: + """Cross-encoder reranker with lazy model loading.""" + + def __init__(self, model_name: str, *, device: str | None = None) -> None: + self.model_name = (model_name or "").strip() + if not self.model_name: + raise ValueError("model_name cannot be blank") + + self.device = (device or "").strip() or None + self._model = None + self._lock = threading.RLock() + + def _load_model(self) -> None: + if self._model is not None: + return + + ok, err = check_cross_encoder_available() + if not ok: + raise ImportError(err) + + with self._lock: + if self._model is not None: + return + + try: + if self.device: + self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc] + else: + self._model = _CrossEncoder(self.model_name) # type: ignore[misc] + except Exception as exc: + logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc) + raise + + def score_pairs( + self, + pairs: Sequence[Tuple[str, str]], + *, + batch_size: int = 32, + ) -> List[float]: + """Score (query, doc) pairs using the cross-encoder. + + Returns: + List of scores (one per pair) in the model's native scale (usually logits). + """ + if not pairs: + return [] + + self._load_model() + + if self._model is None: # pragma: no cover - defensive + return [] + + bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32 + scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr] + return [float(s) for s in scores] + diff --git a/codex-lens/src/codexlens/storage/dir_index.py b/codex-lens/src/codexlens/storage/dir_index.py index e9647ec1..9a3a91a9 100644 --- a/codex-lens/src/codexlens/storage/dir_index.py +++ b/codex-lens/src/codexlens/storage/dir_index.py @@ -10,15 +10,17 @@ Each directory maintains its own _index.db with: from __future__ import annotations import logging +import hashlib import re import sqlite3 import threading +import time from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from codexlens.config import Config -from codexlens.entities import SearchResult, Symbol +from codexlens.entities import CodeRelationship, SearchResult, Symbol from codexlens.errors import StorageError from codexlens.storage.global_index import GlobalSymbolIndex @@ -60,7 +62,7 @@ class DirIndexStore: # Schema version for migration tracking # Increment this when schema changes require migration - SCHEMA_VERSION = 5 + SCHEMA_VERSION = 8 def __init__( self, @@ -150,6 +152,21 @@ class DirIndexStore: from codexlens.storage.migrations.migration_005_cleanup_unused_fields import upgrade upgrade(conn) + # Migration v5 -> v6: Ensure relationship tables/indexes exist + if from_version < 6: + from codexlens.storage.migrations.migration_006_enhance_relationships import upgrade + upgrade(conn) + + # Migration v6 -> v7: Add graph neighbor cache for search expansion + if from_version < 7: + from codexlens.storage.migrations.migration_007_add_graph_neighbors import upgrade + upgrade(conn) + + # Migration v7 -> v8: Add Merkle hashes for incremental change detection + if from_version < 8: + from codexlens.storage.migrations.migration_008_add_merkle_hashes import upgrade + upgrade(conn) + def close(self) -> None: """Close database connection.""" with self._lock: @@ -179,6 +196,7 @@ class DirIndexStore: content: str, language: str, symbols: Optional[List[Symbol]] = None, + relationships: Optional[List[CodeRelationship]] = None, ) -> int: """Add or update a file in the current directory index. @@ -188,6 +206,7 @@ class DirIndexStore: content: File content for indexing language: Programming language identifier symbols: List of Symbol objects from the file + relationships: Optional list of CodeRelationship edges from this file Returns: Database file_id @@ -240,6 +259,8 @@ class DirIndexStore: symbol_rows, ) + self._save_merkle_hash(conn, file_id=file_id, content=content) + self._save_relationships(conn, file_id=file_id, relationships=relationships) conn.commit() self._maybe_update_global_symbols(full_path_str, symbols or []) return file_id @@ -248,6 +269,96 @@ class DirIndexStore: conn.rollback() raise StorageError(f"Failed to add file {name}: {exc}") from exc + def save_relationships(self, file_id: int, relationships: List[CodeRelationship]) -> None: + """Save relationships for an already-indexed file. + + Args: + file_id: Database file id + relationships: Relationship edges to persist + """ + if not relationships: + return + with self._lock: + conn = self._get_connection() + self._save_relationships(conn, file_id=file_id, relationships=relationships) + conn.commit() + + def _save_relationships( + self, + conn: sqlite3.Connection, + file_id: int, + relationships: Optional[List[CodeRelationship]], + ) -> None: + if not relationships: + return + + rows = conn.execute( + "SELECT id, name FROM symbols WHERE file_id=? ORDER BY start_line, id", + (file_id,), + ).fetchall() + + name_to_id: Dict[str, int] = {} + for row in rows: + name = row["name"] + if name not in name_to_id: + name_to_id[name] = int(row["id"]) + + if not name_to_id: + return + + rel_rows: List[Tuple[int, str, str, int, Optional[str]]] = [] + seen: set[tuple[int, str, str, int, Optional[str]]] = set() + + for rel in relationships: + source_id = name_to_id.get(rel.source_symbol) + if source_id is None: + continue + + target = (rel.target_symbol or "").strip() + if not target: + continue + + rel_type = rel.relationship_type.value + source_line = int(rel.source_line) + key = (source_id, target, rel_type, source_line, rel.target_file) + if key in seen: + continue + seen.add(key) + + rel_rows.append((source_id, target, rel_type, source_line, rel.target_file)) + + if not rel_rows: + return + + conn.executemany( + """ + INSERT INTO code_relationships( + source_symbol_id, target_qualified_name, + relationship_type, source_line, target_file + ) + VALUES(?, ?, ?, ?, ?) + """, + rel_rows, + ) + + def _save_merkle_hash(self, conn: sqlite3.Connection, file_id: int, content: str) -> None: + """Upsert a SHA-256 content hash for the given file_id (best-effort).""" + try: + digest = hashlib.sha256(content.encode("utf-8", errors="ignore")).hexdigest() + now = time.time() + conn.execute( + """ + INSERT INTO merkle_hashes(file_id, sha256, updated_at) + VALUES(?, ?, ?) + ON CONFLICT(file_id) DO UPDATE SET + sha256=excluded.sha256, + updated_at=excluded.updated_at + """, + (file_id, digest, now), + ) + except sqlite3.Error: + return + def add_files_batch( self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]] ) -> int: @@ -312,6 +423,8 @@ class DirIndexStore: symbol_rows, ) + self._save_merkle_hash(conn, file_id=file_id, content=content) + conn.commit() return count @@ -395,9 +508,13 @@ class DirIndexStore: return float(row["mtime"]) if row and row["mtime"] else None def needs_reindex(self, full_path: str | Path) -> bool: - """Check if a file needs reindexing based on mtime comparison. + """Check if a file needs reindexing. - Uses 1ms tolerance to handle filesystem timestamp precision variations. + Default behavior uses mtime comparison (with 1ms tolerance). + + When `Config.enable_merkle_detection` is enabled and Merkle metadata is + available, uses SHA-256 content hash comparison (with mtime as a fast + path to avoid hashing unchanged files). Args: full_path: Complete source file path @@ -415,16 +532,154 @@ class DirIndexStore: except OSError: return False # Can't read file stats, skip - # Get stored mtime from database - stored_mtime = self.get_file_mtime(full_path_obj) + MTIME_TOLERANCE = 0.001 - # File not in index, needs indexing - if stored_mtime is None: + # Fast path: mtime-only mode (default / backward-compatible) + if self._config is None or not getattr(self._config, "enable_merkle_detection", False): + stored_mtime = self.get_file_mtime(full_path_obj) + if stored_mtime is None: + return True + return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE + + full_path_str = str(full_path_obj) + + # Hash-based change detection (best-effort, falls back to mtime when metadata missing) + with self._lock: + conn = self._get_connection() + try: + row = conn.execute( + """ + SELECT f.id AS file_id, f.mtime AS mtime, mh.sha256 AS sha256 + FROM files f + LEFT JOIN merkle_hashes mh ON mh.file_id = f.id + WHERE f.full_path=? + """, + (full_path_str,), + ).fetchone() + except sqlite3.Error: + row = None + + if row is None: return True - # Compare with 1ms tolerance for floating point precision - MTIME_TOLERANCE = 0.001 - return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE + stored_mtime = float(row["mtime"]) if row["mtime"] else None + stored_hash = row["sha256"] if row["sha256"] else None + file_id = int(row["file_id"]) + + # Missing Merkle data: fall back to mtime + if stored_hash is None: + if stored_mtime is None: + return True + return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE + + # If mtime is unchanged within tolerance, assume unchanged without hashing. + if stored_mtime is not None and abs(current_mtime - stored_mtime) <= MTIME_TOLERANCE: + return False + + try: + current_text = full_path_obj.read_text(encoding="utf-8", errors="ignore") + current_hash = hashlib.sha256(current_text.encode("utf-8", errors="ignore")).hexdigest() + except OSError: + return False + + if current_hash == stored_hash: + # Content unchanged, but mtime drifted: update stored mtime to avoid repeated hashing. + with self._lock: + conn = self._get_connection() + conn.execute("UPDATE files SET mtime=? WHERE id=?", (current_mtime, file_id)) + conn.commit() + return False + + return True + + def get_merkle_root_hash(self) -> Optional[str]: + """Return the stored Merkle root hash for this directory index (if present).""" + with self._lock: + conn = self._get_connection() + try: + row = conn.execute( + "SELECT root_hash FROM merkle_state WHERE id=1" + ).fetchone() + except sqlite3.Error: + return None + + return row["root_hash"] if row and row["root_hash"] else None + + def update_merkle_root(self) -> Optional[str]: + """Compute and persist the Merkle root hash for this directory index. + + The root hash includes: + - Direct file hashes from `merkle_hashes` + - Direct subdirectory root hashes (read from child `_index.db` files) + """ + if self._config is None or not getattr(self._config, "enable_merkle_detection", False): + return None + + with self._lock: + conn = self._get_connection() + try: + file_rows = conn.execute( + """ + SELECT f.name AS name, mh.sha256 AS sha256 + FROM files f + LEFT JOIN merkle_hashes mh ON mh.file_id = f.id + ORDER BY f.name + """ + ).fetchall() + + subdir_rows = conn.execute( + "SELECT name, index_path FROM subdirs ORDER BY name" + ).fetchall() + except sqlite3.Error as exc: + self.logger.debug("Failed to compute merkle root: %s", exc) + return None + + items: List[str] = [] + + for row in file_rows: + name = row["name"] + sha = (row["sha256"] or "").strip() + items.append(f"f:{name}:{sha}") + + def read_child_root(index_path: str) -> str: + try: + with sqlite3.connect(index_path) as child_conn: + child_conn.row_factory = sqlite3.Row + child_row = child_conn.execute( + "SELECT root_hash FROM merkle_state WHERE id=1" + ).fetchone() + return child_row["root_hash"] if child_row and child_row["root_hash"] else "" + except Exception: + return "" + + for row in subdir_rows: + name = row["name"] + index_path = row["index_path"] + child_hash = read_child_root(index_path) if index_path else "" + items.append(f"d:{name}:{child_hash}") + + root_hash = hashlib.sha256("\n".join(items).encode("utf-8", errors="ignore")).hexdigest() + now = time.time() + + with self._lock: + conn = self._get_connection() + try: + conn.execute( + """ + INSERT INTO merkle_state(id, root_hash, updated_at) + VALUES(1, ?, ?) + ON CONFLICT(id) DO UPDATE SET + root_hash=excluded.root_hash, + updated_at=excluded.updated_at + """, + (root_hash, now), + ) + conn.commit() + except sqlite3.Error as exc: + self.logger.debug("Failed to persist merkle root: %s", exc) + return None + + return root_hash def add_file_incremental( self, @@ -433,6 +688,7 @@ class DirIndexStore: content: str, language: str, symbols: Optional[List[Symbol]] = None, + relationships: Optional[List[CodeRelationship]] = None, ) -> Optional[int]: """Add or update a file only if it has changed (incremental indexing). @@ -444,6 +700,7 @@ class DirIndexStore: content: File content for indexing language: Programming language identifier symbols: List of Symbol objects from the file + relationships: Optional list of CodeRelationship edges from this file Returns: Database file_id if indexed, None if skipped (unchanged) @@ -456,7 +713,7 @@ class DirIndexStore: return None # Skip unchanged file # File changed or new, perform full indexing - return self.add_file(name, full_path, content, language, symbols) + return self.add_file(name, full_path, content, language, symbols, relationships) def cleanup_deleted_files(self, source_dir: Path) -> int: """Remove indexed files that no longer exist in the source directory. @@ -1767,6 +2024,39 @@ class DirIndexStore: """ ) + # Precomputed graph neighbors cache for search expansion (v7) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS graph_neighbors ( + source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE, + neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE, + relationship_depth INTEGER NOT NULL, + PRIMARY KEY (source_symbol_id, neighbor_symbol_id) + ) + """ + ) + + # Merkle hashes for incremental change detection (v8) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS merkle_hashes ( + file_id INTEGER PRIMARY KEY REFERENCES files(id) ON DELETE CASCADE, + sha256 TEXT NOT NULL, + updated_at REAL + ) + """ + ) + + conn.execute( + """ + CREATE TABLE IF NOT EXISTS merkle_state ( + id INTEGER PRIMARY KEY CHECK (id = 1), + root_hash TEXT, + updated_at REAL + ) + """ + ) + # Indexes (v5: removed idx_symbols_type) conn.execute("CREATE INDEX IF NOT EXISTS idx_files_name ON files(name)") conn.execute("CREATE INDEX IF NOT EXISTS idx_files_path ON files(full_path)") @@ -1780,6 +2070,14 @@ class DirIndexStore: conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)") conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)") conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth " + "ON graph_neighbors(source_symbol_id, relationship_depth)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor " + "ON graph_neighbors(neighbor_symbol_id)" + ) except sqlite3.DatabaseError as exc: raise StorageError(f"Failed to create schema: {exc}") from exc diff --git a/codex-lens/src/codexlens/storage/index_tree.py b/codex-lens/src/codexlens/storage/index_tree.py index fe91330f..3cbe601e 100644 --- a/codex-lens/src/codexlens/storage/index_tree.py +++ b/codex-lens/src/codexlens/storage/index_tree.py @@ -8,11 +8,13 @@ from __future__ import annotations import logging import os +import re +import sqlite3 import time from concurrent.futures import ProcessPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from codexlens.config import Config from codexlens.parsers.factory import ParserFactory @@ -247,6 +249,9 @@ class IndexTreeBuilder: try: with DirIndexStore(result.index_path, config=self.config, global_index=global_index) as store: deleted_count = store.cleanup_deleted_files(result.source_path) + if deleted_count > 0: + _compute_graph_neighbors(store, logger=self.logger) + store.update_merkle_root() total_deleted += deleted_count if deleted_count > 0: self.logger.debug("Removed %d deleted files from %s", deleted_count, result.source_path) @@ -575,6 +580,7 @@ class IndexTreeBuilder: content=text, language=language_id, symbols=indexed_file.symbols, + relationships=indexed_file.relationships, ) files_count += 1 @@ -584,6 +590,9 @@ class IndexTreeBuilder: self.logger.debug("Failed to index %s: %s", file_path, exc) continue + if files_count > 0: + _compute_graph_neighbors(store, logger=self.logger) + # Get list of subdirectories subdirs = [ d.name @@ -593,6 +602,7 @@ class IndexTreeBuilder: and not d.name.startswith(".") ] + store.update_merkle_root() store.close() if global_index is not None: global_index.close() @@ -654,31 +664,29 @@ class IndexTreeBuilder: parent_index_db = self.mapper.source_to_index_db(parent_path) try: - store = DirIndexStore(parent_index_db) - store.initialize() + with DirIndexStore(parent_index_db, config=self.config) as store: + for result in all_results: + # Only register direct children (parent is one level up) + if result.source_path.parent != parent_path: + continue - for result in all_results: - # Only register direct children (parent is one level up) - if result.source_path.parent != parent_path: - continue + if result.error: + continue - if result.error: - continue + # Register subdirectory link + store.register_subdir( + name=result.source_path.name, + index_path=result.index_path, + files_count=result.files_count, + direct_files=result.files_count, + ) + self.logger.debug( + "Linked %s to parent %s", + result.source_path.name, + parent_path, + ) - # Register subdirectory link - store.register_subdir( - name=result.source_path.name, - index_path=result.index_path, - files_count=result.files_count, - direct_files=result.files_count, - ) - self.logger.debug( - "Linked %s to parent %s", - result.source_path.name, - parent_path, - ) - - store.close() + store.update_merkle_root() except Exception as exc: self.logger.error( @@ -726,6 +734,164 @@ class IndexTreeBuilder: return files +def _normalize_relationship_target(target: str) -> str: + """Best-effort normalization of a relationship target into a local symbol name.""" + target = (target or "").strip() + if not target: + return "" + + # Drop trailing call parentheses when present (e.g., "foo()" -> "foo"). + if target.endswith("()"): + target = target[:-2] + + # Keep the leaf identifier for common qualified formats. + for sep in ("::", ".", "#"): + if sep in target: + target = target.split(sep)[-1] + + # Strip non-identifier suffix/prefix noise. + target = re.sub(r"^[^A-Za-z0-9_]+", "", target) + target = re.sub(r"[^A-Za-z0-9_]+$", "", target) + return target + + +def _compute_graph_neighbors( + store: DirIndexStore, + *, + max_depth: int = 2, + logger: Optional[logging.Logger] = None, +) -> None: + """Compute and persist N-hop neighbors for all symbols in a directory index.""" + if max_depth <= 0: + return + + log = logger or logging.getLogger(__name__) + + with store._lock: + conn = store._get_connection() + conn.row_factory = sqlite3.Row + + # Ensure schema exists even for older databases pinned to the same user_version. + try: + from codexlens.storage.migrations.migration_007_add_graph_neighbors import upgrade + + upgrade(conn) + except Exception as exc: + log.debug("Graph neighbor schema ensure failed: %s", exc) + + cursor = conn.cursor() + + try: + cursor.execute("DELETE FROM graph_neighbors") + except sqlite3.Error: + # Table missing or schema mismatch; skip gracefully. + return + + try: + symbol_rows = cursor.execute( + "SELECT id, file_id, name FROM symbols" + ).fetchall() + rel_rows = cursor.execute( + "SELECT source_symbol_id, target_qualified_name FROM code_relationships" + ).fetchall() + except sqlite3.Error: + return + + if not symbol_rows or not rel_rows: + try: + conn.commit() + except sqlite3.Error: + pass + return + + symbol_file_by_id: Dict[int, int] = {} + symbols_by_file_and_name: Dict[Tuple[int, str], List[int]] = {} + symbols_by_name: Dict[str, List[int]] = {} + + for row in symbol_rows: + symbol_id = int(row["id"]) + file_id = int(row["file_id"]) + name = str(row["name"]) + symbol_file_by_id[symbol_id] = file_id + symbols_by_file_and_name.setdefault((file_id, name), []).append(symbol_id) + symbols_by_name.setdefault(name, []).append(symbol_id) + + adjacency: Dict[int, Set[int]] = {} + + for row in rel_rows: + source_id = int(row["source_symbol_id"]) + target_raw = str(row["target_qualified_name"] or "") + target_name = _normalize_relationship_target(target_raw) + if not target_name: + continue + + source_file_id = symbol_file_by_id.get(source_id) + if source_file_id is None: + continue + + candidate_ids = symbols_by_file_and_name.get((source_file_id, target_name)) + if not candidate_ids: + global_candidates = symbols_by_name.get(target_name, []) + # Only resolve cross-file by name when unambiguous. + candidate_ids = global_candidates if len(global_candidates) == 1 else [] + + for target_id in candidate_ids: + if target_id == source_id: + continue + adjacency.setdefault(source_id, set()).add(target_id) + adjacency.setdefault(target_id, set()).add(source_id) + + if not adjacency: + try: + conn.commit() + except sqlite3.Error: + pass + return + + insert_rows: List[Tuple[int, int, int]] = [] + max_depth = min(int(max_depth), 2) + + for source_id, first_hop in adjacency.items(): + if not first_hop: + continue + for neighbor_id in first_hop: + insert_rows.append((source_id, neighbor_id, 1)) + + if max_depth < 2: + continue + + second_hop: Set[int] = set() + for neighbor_id in first_hop: + second_hop.update(adjacency.get(neighbor_id, set())) + + second_hop.discard(source_id) + second_hop.difference_update(first_hop) + + for neighbor_id in second_hop: + insert_rows.append((source_id, neighbor_id, 2)) + + if not insert_rows: + try: + conn.commit() + except sqlite3.Error: + pass + return + + try: + cursor.executemany( + """ + INSERT INTO graph_neighbors( + source_symbol_id, neighbor_symbol_id, relationship_depth + ) + VALUES(?, ?, ?) + """, + insert_rows, + ) + conn.commit() + except sqlite3.Error: + return + + # === Worker Function for ProcessPoolExecutor === @@ -795,6 +961,7 @@ def _build_dir_worker(args: tuple) -> DirBuildResult: content=text, language=language_id, symbols=indexed_file.symbols, + relationships=indexed_file.relationships, ) files_count += 1 @@ -803,6 +970,9 @@ def _build_dir_worker(args: tuple) -> DirBuildResult: except Exception: continue + if files_count > 0: + _compute_graph_neighbors(store) + # Get subdirectories ignore_dirs = { ".git", @@ -821,6 +991,7 @@ def _build_dir_worker(args: tuple) -> DirBuildResult: if d.is_dir() and d.name not in ignore_dirs and not d.name.startswith(".") ] + store.update_merkle_root() store.close() if global_index is not None: global_index.close() diff --git a/codex-lens/src/codexlens/storage/merkle_tree.py b/codex-lens/src/codexlens/storage/merkle_tree.py new file mode 100644 index 00000000..c8c76988 --- /dev/null +++ b/codex-lens/src/codexlens/storage/merkle_tree.py @@ -0,0 +1,136 @@ +"""Merkle tree utilities for change detection. + +This module provides a generic, file-system based Merkle tree implementation +that can be used to efficiently diff directory states. +""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, Iterable, List, Optional + + +def sha256_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def sha256_text(text: str) -> str: + return sha256_bytes(text.encode("utf-8", errors="ignore")) + + +@dataclass +class MerkleNode: + """A Merkle node representing either a file (leaf) or directory (internal).""" + + name: str + rel_path: str + hash: str + is_dir: bool + children: Dict[str, "MerkleNode"] = field(default_factory=dict) + + def iter_files(self) -> Iterable["MerkleNode"]: + if not self.is_dir: + yield self + return + for child in self.children.values(): + yield from child.iter_files() + + +@dataclass +class MerkleTree: + """Merkle tree for a directory snapshot.""" + + root: MerkleNode + + @classmethod + def build_from_directory(cls, root_dir: Path) -> "MerkleTree": + root_dir = Path(root_dir).resolve() + node = cls._build_node(root_dir, base=root_dir) + return cls(root=node) + + @classmethod + def _build_node(cls, path: Path, *, base: Path) -> MerkleNode: + if path.is_file(): + rel = str(path.relative_to(base)).replace("\\", "/") + return MerkleNode( + name=path.name, + rel_path=rel, + hash=sha256_bytes(path.read_bytes()), + is_dir=False, + ) + + if not path.is_dir(): + rel = str(path.relative_to(base)).replace("\\", "/") + return MerkleNode(name=path.name, rel_path=rel, hash="", is_dir=False) + + children: Dict[str, MerkleNode] = {} + for child in sorted(path.iterdir(), key=lambda p: p.name): + child_node = cls._build_node(child, base=base) + children[child_node.name] = child_node + + items = [ + f"{'d' if n.is_dir else 'f'}:{name}:{n.hash}" + for name, n in sorted(children.items(), key=lambda kv: kv[0]) + ] + dir_hash = sha256_text("\n".join(items)) + + rel_path = "." if path == base else str(path.relative_to(base)).replace("\\", "/") + return MerkleNode( + name="." if path == base else path.name, + rel_path=rel_path, + hash=dir_hash, + is_dir=True, + children=children, + ) + + @staticmethod + def find_changed_files(old: Optional["MerkleTree"], new: Optional["MerkleTree"]) -> List[str]: + """Find changed/added/removed files between two trees. + + Returns: + List of relative file paths (POSIX-style separators). + """ + if old is None and new is None: + return [] + if old is None: + return sorted({n.rel_path for n in new.root.iter_files()}) # type: ignore[union-attr] + if new is None: + return sorted({n.rel_path for n in old.root.iter_files()}) + + changed: set[str] = set() + + def walk(old_node: Optional[MerkleNode], new_node: Optional[MerkleNode]) -> None: + if old_node is None and new_node is None: + return + + if old_node is None and new_node is not None: + changed.update(n.rel_path for n in new_node.iter_files()) + return + + if new_node is None and old_node is not None: + changed.update(n.rel_path for n in old_node.iter_files()) + return + + assert old_node is not None and new_node is not None + + if old_node.hash == new_node.hash: + return + + if not old_node.is_dir and not new_node.is_dir: + changed.add(new_node.rel_path) + return + + if old_node.is_dir != new_node.is_dir: + changed.update(n.rel_path for n in old_node.iter_files()) + changed.update(n.rel_path for n in new_node.iter_files()) + return + + names = set(old_node.children.keys()) | set(new_node.children.keys()) + for name in names: + walk(old_node.children.get(name), new_node.children.get(name)) + + walk(old.root, new.root) + return sorted(changed) + diff --git a/codex-lens/src/codexlens/storage/migrations/migration_006_enhance_relationships.py b/codex-lens/src/codexlens/storage/migrations/migration_006_enhance_relationships.py new file mode 100644 index 00000000..2c7c6cd8 --- /dev/null +++ b/codex-lens/src/codexlens/storage/migrations/migration_006_enhance_relationships.py @@ -0,0 +1,37 @@ +""" +Migration 006: Ensure relationship tables and indexes exist. + +This migration is intentionally idempotent. It creates the `code_relationships` +table (used for graph visualization) and its indexes if missing. +""" + +from __future__ import annotations + +import logging +from sqlite3 import Connection + +log = logging.getLogger(__name__) + + +def upgrade(db_conn: Connection) -> None: + cursor = db_conn.cursor() + + log.info("Ensuring code_relationships table exists...") + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS code_relationships ( + id INTEGER PRIMARY KEY, + source_symbol_id INTEGER NOT NULL REFERENCES symbols (id) ON DELETE CASCADE, + target_qualified_name TEXT NOT NULL, + relationship_type TEXT NOT NULL, + source_line INTEGER NOT NULL, + target_file TEXT + ) + """ + ) + + log.info("Ensuring relationship indexes exist...") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)") + diff --git a/codex-lens/src/codexlens/storage/migrations/migration_007_add_graph_neighbors.py b/codex-lens/src/codexlens/storage/migrations/migration_007_add_graph_neighbors.py new file mode 100644 index 00000000..83306886 --- /dev/null +++ b/codex-lens/src/codexlens/storage/migrations/migration_007_add_graph_neighbors.py @@ -0,0 +1,47 @@ +""" +Migration 007: Add precomputed graph neighbor table for search expansion. + +Adds: +- graph_neighbors: cached N-hop neighbors between symbols (keyed by symbol ids) + +This table is derived data (a cache) and is safe to rebuild at any time. +The migration is intentionally idempotent. +""" + +from __future__ import annotations + +import logging +from sqlite3 import Connection + +log = logging.getLogger(__name__) + + +def upgrade(db_conn: Connection) -> None: + cursor = db_conn.cursor() + + log.info("Creating graph_neighbors table...") + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS graph_neighbors ( + source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE, + neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE, + relationship_depth INTEGER NOT NULL, + PRIMARY KEY (source_symbol_id, neighbor_symbol_id) + ) + """ + ) + + log.info("Creating indexes for graph_neighbors...") + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth + ON graph_neighbors(source_symbol_id, relationship_depth) + """ + ) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor + ON graph_neighbors(neighbor_symbol_id) + """ + ) + diff --git a/codex-lens/src/codexlens/storage/migrations/migration_008_add_merkle_hashes.py b/codex-lens/src/codexlens/storage/migrations/migration_008_add_merkle_hashes.py new file mode 100644 index 00000000..092fc20a --- /dev/null +++ b/codex-lens/src/codexlens/storage/migrations/migration_008_add_merkle_hashes.py @@ -0,0 +1,81 @@ +""" +Migration 008: Add Merkle hash tables for content-based incremental indexing. + +Adds: +- merkle_hashes: per-file SHA-256 hashes (keyed by file_id) +- merkle_state: directory-level root hash (single row, id=1) + +Backfills merkle_hashes using the existing `files.content` column when available. +""" + +from __future__ import annotations + +import hashlib +import logging +import time +from sqlite3 import Connection + +log = logging.getLogger(__name__) + + +def upgrade(db_conn: Connection) -> None: + cursor = db_conn.cursor() + + log.info("Creating merkle_hashes table...") + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS merkle_hashes ( + file_id INTEGER PRIMARY KEY REFERENCES files(id) ON DELETE CASCADE, + sha256 TEXT NOT NULL, + updated_at REAL + ) + """ + ) + + log.info("Creating merkle_state table...") + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS merkle_state ( + id INTEGER PRIMARY KEY CHECK (id = 1), + root_hash TEXT, + updated_at REAL + ) + """ + ) + + # Backfill file hashes from stored content (best-effort). + try: + rows = cursor.execute("SELECT id, content FROM files").fetchall() + except Exception as exc: + log.warning("Unable to backfill merkle hashes (files table missing?): %s", exc) + return + + now = time.time() + inserts: list[tuple[int, str, float]] = [] + + for row in rows: + file_id = int(row[0]) + content = row[1] + if content is None: + continue + try: + digest = hashlib.sha256(str(content).encode("utf-8", errors="ignore")).hexdigest() + inserts.append((file_id, digest, now)) + except Exception: + continue + + if not inserts: + return + + log.info("Backfilling %d file hashes...", len(inserts)) + cursor.executemany( + """ + INSERT INTO merkle_hashes(file_id, sha256, updated_at) + VALUES(?, ?, ?) + ON CONFLICT(file_id) DO UPDATE SET + sha256=excluded.sha256, + updated_at=excluded.updated_at + """, + inserts, + ) + diff --git a/codex-lens/tests/test_graph_expansion.py b/codex-lens/tests/test_graph_expansion.py new file mode 100644 index 00000000..6588a5e4 --- /dev/null +++ b/codex-lens/tests/test_graph_expansion.py @@ -0,0 +1,188 @@ +import sqlite3 +import tempfile +from pathlib import Path + +import pytest + +from codexlens.config import Config +from codexlens.entities import CodeRelationship, RelationshipType, SearchResult, Symbol +from codexlens.search.chain_search import ChainSearchEngine, SearchOptions +from codexlens.search.graph_expander import GraphExpander +from codexlens.storage.dir_index import DirIndexStore +from codexlens.storage.index_tree import _compute_graph_neighbors +from codexlens.storage.path_mapper import PathMapper +from codexlens.storage.registry import RegistryStore + + +@pytest.fixture() +def temp_paths() -> Path: + tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + root = Path(tmpdir.name) + yield root + try: + tmpdir.cleanup() + except (PermissionError, OSError): + pass + + +def _create_index_with_neighbors(root: Path) -> tuple[PathMapper, Path, Path]: + project_root = root / "project" + project_root.mkdir(parents=True, exist_ok=True) + + index_root = root / "indexes" + mapper = PathMapper(index_root=index_root) + index_db_path = mapper.source_to_index_db(project_root) + index_db_path.parent.mkdir(parents=True, exist_ok=True) + + content = "\n".join( + [ + "def a():", + " b()", + "", + "def b():", + " c()", + "", + "def c():", + " return 1", + "", + ] + ) + file_path = project_root / "module.py" + file_path.write_text(content, encoding="utf-8") + + symbols = [ + Symbol(name="a", kind="function", range=(1, 2), file=str(file_path)), + Symbol(name="b", kind="function", range=(4, 5), file=str(file_path)), + Symbol(name="c", kind="function", range=(7, 8), file=str(file_path)), + ] + relationships = [ + CodeRelationship( + source_symbol="a", + target_symbol="b", + relationship_type=RelationshipType.CALL, + source_file=str(file_path), + target_file=None, + source_line=2, + ), + CodeRelationship( + source_symbol="b", + target_symbol="c", + relationship_type=RelationshipType.CALL, + source_file=str(file_path), + target_file=None, + source_line=5, + ), + ] + + config = Config(data_dir=root / "data") + store = DirIndexStore(index_db_path, config=config) + store.initialize() + store.add_file( + name=file_path.name, + full_path=file_path, + content=content, + language="python", + symbols=symbols, + relationships=relationships, + ) + _compute_graph_neighbors(store) + store.close() + + return mapper, project_root, file_path + + +def test_graph_neighbors_precomputed_two_hop(temp_paths: Path) -> None: + mapper, project_root, file_path = _create_index_with_neighbors(temp_paths) + index_db_path = mapper.source_to_index_db(project_root) + + conn = sqlite3.connect(str(index_db_path)) + conn.row_factory = sqlite3.Row + try: + rows = conn.execute( + """ + SELECT s1.name AS source_name, s2.name AS neighbor_name, gn.relationship_depth + FROM graph_neighbors gn + JOIN symbols s1 ON s1.id = gn.source_symbol_id + JOIN symbols s2 ON s2.id = gn.neighbor_symbol_id + ORDER BY source_name, neighbor_name, relationship_depth + """ + ).fetchall() + finally: + conn.close() + + triples = {(r["source_name"], r["neighbor_name"], int(r["relationship_depth"])) for r in rows} + assert ("a", "b", 1) in triples + assert ("a", "c", 2) in triples + assert ("b", "c", 1) in triples + assert ("c", "b", 1) in triples + assert file_path.exists() + + +def test_graph_expander_returns_related_results_with_depth_metadata(temp_paths: Path) -> None: + mapper, project_root, file_path = _create_index_with_neighbors(temp_paths) + _ = project_root + + expander = GraphExpander(mapper, config=Config(data_dir=temp_paths / "data", graph_expansion_depth=2)) + base = SearchResult( + path=str(file_path.resolve()), + score=1.0, + excerpt="", + content=None, + start_line=1, + end_line=2, + symbol_name="a", + symbol_kind="function", + ) + related = expander.expand([base], depth=2, max_expand=1, max_related=10) + + depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in related} + assert depth_by_symbol.get("b") == 1 + assert depth_by_symbol.get("c") == 2 + + +def test_chain_search_populates_related_results_when_enabled(temp_paths: Path) -> None: + mapper, project_root, file_path = _create_index_with_neighbors(temp_paths) + _ = file_path + + registry = RegistryStore(db_path=temp_paths / "registry.db") + registry.initialize() + + config = Config( + data_dir=temp_paths / "data", + enable_graph_expansion=True, + graph_expansion_depth=2, + ) + engine = ChainSearchEngine(registry, mapper, config=config) + try: + options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False) + result = engine.search("b", project_root, options) + + assert result.results + assert result.results[0].symbol_name == "a" + + depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in result.related_results} + assert depth_by_symbol.get("b") == 1 + assert depth_by_symbol.get("c") == 2 + finally: + engine.close() + + +def test_chain_search_related_results_empty_when_disabled(temp_paths: Path) -> None: + mapper, project_root, file_path = _create_index_with_neighbors(temp_paths) + _ = file_path + + registry = RegistryStore(db_path=temp_paths / "registry.db") + registry.initialize() + + config = Config( + data_dir=temp_paths / "data", + enable_graph_expansion=False, + ) + engine = ChainSearchEngine(registry, mapper, config=config) + try: + options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False) + result = engine.search("b", project_root, options) + assert result.related_results == [] + finally: + engine.close() + diff --git a/codex-lens/tests/test_hybrid_search_e2e.py b/codex-lens/tests/test_hybrid_search_e2e.py index 3b952d58..66f513ea 100644 --- a/codex-lens/tests/test_hybrid_search_e2e.py +++ b/codex-lens/tests/test_hybrid_search_e2e.py @@ -869,3 +869,47 @@ class TestHybridSearchAdaptiveWeights: ) as rerank_mock: engine_on.search(Path("dummy.db"), "query", enable_vector=True) assert rerank_mock.call_count == 1 + + def test_cross_encoder_reranking_enabled(self, tmp_path): + """Cross-encoder stage runs only when explicitly enabled via config.""" + from unittest.mock import patch + + results_map = { + "exact": [SearchResult(path="a.py", score=10.0, excerpt="a")], + "fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")], + "vector": [SearchResult(path="c.py", score=0.9, excerpt="c")], + } + + class DummyEmbedder: + def embed(self, texts): + if isinstance(texts, str): + texts = [texts] + return [[1.0, 0.0] for _ in texts] + + class DummyReranker: + def score_pairs(self, pairs, batch_size=32): + return [0.0 for _ in pairs] + + config = Config( + data_dir=tmp_path / "ce", + enable_reranking=True, + enable_cross_encoder_rerank=True, + reranker_top_k=10, + ) + engine = HybridSearchEngine(config=config, embedder=DummyEmbedder()) + + with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch( + "codexlens.search.hybrid_search.rerank_results", + side_effect=lambda q, r, e, top_k=50: r, + ) as rerank_mock, patch.object( + HybridSearchEngine, + "_get_cross_encoder_reranker", + return_value=DummyReranker(), + ) as get_ce_mock, patch( + "codexlens.search.hybrid_search.cross_encoder_rerank", + side_effect=lambda q, r, ce, top_k=50: r, + ) as ce_mock: + engine.search(Path("dummy.db"), "query", enable_vector=True) + assert rerank_mock.call_count == 1 + assert get_ce_mock.call_count == 1 + assert ce_mock.call_count == 1 diff --git a/codex-lens/tests/test_merkle_detection.py b/codex-lens/tests/test_merkle_detection.py new file mode 100644 index 00000000..e4afdccd --- /dev/null +++ b/codex-lens/tests/test_merkle_detection.py @@ -0,0 +1,100 @@ +import time +from pathlib import Path + +from codexlens.config import Config +from codexlens.storage.dir_index import DirIndexStore + + +def _make_merkle_config(tmp_path: Path) -> Config: + data_dir = tmp_path / "data" + return Config( + data_dir=data_dir, + venv_path=data_dir / "venv", + enable_merkle_detection=True, + ) + + +class TestMerkleDetection: + def test_needs_reindex_touch_updates_mtime(self, tmp_path: Path) -> None: + config = _make_merkle_config(tmp_path) + source_dir = tmp_path / "src" + source_dir.mkdir(parents=True, exist_ok=True) + + file_path = source_dir / "a.py" + file_path.write_text("print('hi')\n", encoding="utf-8") + original_content = file_path.read_text(encoding="utf-8") + + index_db = tmp_path / "_index.db" + with DirIndexStore(index_db, config=config) as store: + store.add_file( + name=file_path.name, + full_path=file_path, + content=original_content, + language="python", + symbols=[], + ) + + stored_mtime_before = store.get_file_mtime(file_path) + assert stored_mtime_before is not None + + # Touch file without changing content + time.sleep(0.02) + file_path.write_text(original_content, encoding="utf-8") + + assert store.needs_reindex(file_path) is False + + stored_mtime_after = store.get_file_mtime(file_path) + assert stored_mtime_after is not None + assert stored_mtime_after != stored_mtime_before + + current_mtime = file_path.stat().st_mtime + assert abs(stored_mtime_after - current_mtime) <= 0.001 + + def test_parent_root_changes_when_child_changes(self, tmp_path: Path) -> None: + config = _make_merkle_config(tmp_path) + + source_root = tmp_path / "project" + child_dir = source_root / "child" + child_dir.mkdir(parents=True, exist_ok=True) + + child_file = child_dir / "child.py" + child_file.write_text("x = 1\n", encoding="utf-8") + + child_db = tmp_path / "child_index.db" + parent_db = tmp_path / "parent_index.db" + + with DirIndexStore(child_db, config=config) as child_store: + child_store.add_file( + name=child_file.name, + full_path=child_file, + content=child_file.read_text(encoding="utf-8"), + language="python", + symbols=[], + ) + child_root_1 = child_store.update_merkle_root() + assert child_root_1 + + with DirIndexStore(parent_db, config=config) as parent_store: + parent_store.register_subdir(name="child", index_path=child_db, files_count=1) + parent_root_1 = parent_store.update_merkle_root() + assert parent_root_1 + + time.sleep(0.02) + child_file.write_text("x = 2\n", encoding="utf-8") + + with DirIndexStore(child_db, config=config) as child_store: + child_store.add_file( + name=child_file.name, + full_path=child_file, + content=child_file.read_text(encoding="utf-8"), + language="python", + symbols=[], + ) + child_root_2 = child_store.update_merkle_root() + assert child_root_2 + assert child_root_2 != child_root_1 + + with DirIndexStore(parent_db, config=config) as parent_store: + parent_root_2 = parent_store.update_merkle_root() + assert parent_root_2 + assert parent_root_2 != parent_root_1 diff --git a/codex-lens/tests/test_performance_optimizations.py b/codex-lens/tests/test_performance_optimizations.py index 4f44876a..3bbe334e 100644 --- a/codex-lens/tests/test_performance_optimizations.py +++ b/codex-lens/tests/test_performance_optimizations.py @@ -1,9 +1,11 @@ -"""Tests for performance optimizations in CodexLens storage. +"""Tests for performance optimizations in CodexLens. This module tests the following optimizations: 1. Normalized keywords search (migration_001) 2. Optimized path lookup in registry 3. Prefix-mode symbol search +4. Graph expansion neighbor precompute overhead (<20%) +5. Cross-encoder reranking latency (<200ms) """ import json @@ -479,3 +481,113 @@ class TestPerformanceComparison: print(f" Substring: {substring_time*1000:.3f}ms ({len(substring_results)} results)") print(f" Ratio: {prefix_time/substring_time:.2f}x") print(f" Note: Performance benefits appear with 1000+ symbols") + + +class TestPerformanceBenchmarks: + """Benchmark-style assertions for key performance requirements.""" + + def test_graph_expansion_indexing_overhead_under_20_percent(self, temp_index_db, tmp_path): + """Graph neighbor precompute adds <20% overhead versus indexing baseline.""" + from codexlens.entities import CodeRelationship, RelationshipType, Symbol + from codexlens.storage.index_tree import _compute_graph_neighbors + + store = temp_index_db + + file_count = 60 + symbols_per_file = 8 + + start = time.perf_counter() + for file_idx in range(file_count): + file_path = tmp_path / f"graph_{file_idx}.py" + lines = [] + for sym_idx in range(symbols_per_file): + lines.append(f"def func_{file_idx}_{sym_idx}():") + lines.append(f" return {sym_idx}") + lines.append("") + content = "\n".join(lines) + + symbols = [ + Symbol( + name=f"func_{file_idx}_{sym_idx}", + kind="function", + range=(sym_idx * 3 + 1, sym_idx * 3 + 2), + file=str(file_path), + ) + for sym_idx in range(symbols_per_file) + ] + + relationships = [ + CodeRelationship( + source_symbol=f"func_{file_idx}_{sym_idx}", + target_symbol=f"func_{file_idx}_{sym_idx + 1}", + relationship_type=RelationshipType.CALL, + source_file=str(file_path), + target_file=None, + source_line=sym_idx * 3 + 2, + ) + for sym_idx in range(symbols_per_file - 1) + ] + + store.add_file( + name=file_path.name, + full_path=file_path, + content=content, + language="python", + symbols=symbols, + relationships=relationships, + ) + baseline_time = time.perf_counter() - start + + durations = [] + for _ in range(3): + start = time.perf_counter() + _compute_graph_neighbors(store) + durations.append(time.perf_counter() - start) + graph_time = min(durations) + + # Sanity-check that the benchmark exercised graph neighbor generation. + conn = store._get_connection() + neighbor_count = conn.execute( + "SELECT COUNT(*) as c FROM graph_neighbors" + ).fetchone()["c"] + assert neighbor_count > 0 + + assert baseline_time > 0.0 + overhead_ratio = graph_time / baseline_time + assert overhead_ratio < 0.2, ( + f"Graph neighbor precompute overhead too high: {overhead_ratio:.2%} " + f"(baseline={baseline_time:.3f}s, graph={graph_time:.3f}s)" + ) + + def test_cross_encoder_reranking_latency_under_200ms(self): + """Cross-encoder rerank step completes under 200ms (excluding model load).""" + from codexlens.entities import SearchResult + from codexlens.search.ranking import cross_encoder_rerank + + query = "find function" + results = [ + SearchResult( + path=f"file_{idx}.py", + score=1.0 / (idx + 1), + excerpt=f"def func_{idx}():\n return {idx}", + symbol_name=f"func_{idx}", + symbol_kind="function", + ) + for idx in range(50) + ] + + class DummyReranker: + def score_pairs(self, pairs, batch_size=32): + _ = batch_size + # Return deterministic pseudo-logits to exercise sigmoid normalization. + return [float(i) for i in range(len(pairs))] + + reranker = DummyReranker() + + start = time.perf_counter() + reranked = cross_encoder_rerank(query, results, reranker, top_k=50, batch_size=32) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + + assert len(reranked) == len(results) + assert any(r.metadata.get("cross_encoder_reranked") for r in reranked[:50]) + assert elapsed_ms < 200.0, f"Cross-encoder rerank too slow: {elapsed_ms:.1f}ms" diff --git a/codex-lens/tests/test_schema_cleanup_migration.py b/codex-lens/tests/test_schema_cleanup_migration.py index 0eaef27c..e7848f33 100644 --- a/codex-lens/tests/test_schema_cleanup_migration.py +++ b/codex-lens/tests/test_schema_cleanup_migration.py @@ -19,7 +19,7 @@ from codexlens.entities import Symbol class TestSchemaCleanupMigration: - """Test schema cleanup migration (v4 -> v5).""" + """Test schema cleanup migration (v4 -> latest).""" def test_migration_from_v4_to_v5(self): """Test that migration successfully removes deprecated fields.""" @@ -129,10 +129,12 @@ class TestSchemaCleanupMigration: # Now initialize store - this should trigger migration store.initialize() - # Verify schema version is now 5 + # Verify schema version is now the latest conn = store._get_connection() version_row = conn.execute("PRAGMA user_version").fetchone() - assert version_row[0] == 5, f"Expected schema version 5, got {version_row[0]}" + assert version_row[0] == DirIndexStore.SCHEMA_VERSION, ( + f"Expected schema version {DirIndexStore.SCHEMA_VERSION}, got {version_row[0]}" + ) # Check that deprecated columns are removed # 1. Check semantic_metadata doesn't have keywords column @@ -166,7 +168,7 @@ class TestSchemaCleanupMigration: store.close() def test_new_database_has_clean_schema(self): - """Test that new databases are created with clean schema (v5).""" + """Test that new databases are created with clean schema (latest).""" with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "_index.db" store = DirIndexStore(db_path) @@ -174,9 +176,9 @@ class TestSchemaCleanupMigration: conn = store._get_connection() - # Verify schema version is 5 + # Verify schema version is the latest version_row = conn.execute("PRAGMA user_version").fetchone() - assert version_row[0] == 5 + assert version_row[0] == DirIndexStore.SCHEMA_VERSION # Check that new schema doesn't have deprecated columns cursor = conn.execute("PRAGMA table_info(semantic_metadata)") diff --git a/codex-lens/tests/test_search_comprehensive.py b/codex-lens/tests/test_search_comprehensive.py index f26f256c..dcde8e9a 100644 --- a/codex-lens/tests/test_search_comprehensive.py +++ b/codex-lens/tests/test_search_comprehensive.py @@ -582,6 +582,7 @@ class TestChainSearchResult: ) assert result.query == "test" assert result.results == [] + assert result.related_results == [] assert result.symbols == [] assert result.stats.dirs_searched == 0 diff --git a/codex-lens/tests/test_search_full_coverage.py b/codex-lens/tests/test_search_full_coverage.py index 6e3e8143..1de3c350 100644 --- a/codex-lens/tests/test_search_full_coverage.py +++ b/codex-lens/tests/test_search_full_coverage.py @@ -1173,6 +1173,7 @@ class TestChainSearchResultExtended: assert result.query == "test query" assert len(result.results) == 1 assert len(result.symbols) == 1 + assert result.related_results == [] assert result.stats.dirs_searched == 5 def test_result_with_empty_collections(self): @@ -1186,5 +1187,6 @@ class TestChainSearchResultExtended: assert result.query == "no matches" assert result.results == [] + assert result.related_results == [] assert result.symbols == [] assert result.stats.dirs_searched == 0 diff --git a/codex-lens/tests/test_treesitter_parser.py b/codex-lens/tests/test_treesitter_parser.py index c631040f..62303fc5 100644 --- a/codex-lens/tests/test_treesitter_parser.py +++ b/codex-lens/tests/test_treesitter_parser.py @@ -110,6 +110,37 @@ class DataProcessor: assert result is not None assert len(result.symbols) == 0 + def test_extracts_relationships_with_alias_resolution(self): + parser = TreeSitterSymbolParser("python") + code = """ +import os.path as osp +from math import sqrt as sq + +class Base: + pass + +class Child(Base): + pass + +def main(): + osp.join("a", "b") + sq(4) +""" + result = parser.parse(code, Path("test.py")) + + assert result is not None + + rels = [r for r in result.relationships if r.source_symbol == "main"] + targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"} + assert "os.path.join" in targets + assert "math.sqrt" in targets + + inherits = [ + r for r in result.relationships + if r.source_symbol == "Child" and r.relationship_type.value == "inherits" + ] + assert any(r.target_symbol == "Base" for r in inherits) + @pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed") class TestTreeSitterJavaScriptParser: @@ -175,6 +206,22 @@ export const arrowFunc = () => {} assert "exported" in names assert "arrowFunc" in names + def test_extracts_relationships_with_import_alias(self): + parser = TreeSitterSymbolParser("javascript") + code = """ +import { readFile as rf } from "fs"; + +function main() { + rf("a"); +} +""" + result = parser.parse(code, Path("test.js")) + + assert result is not None + rels = [r for r in result.relationships if r.source_symbol == "main"] + targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"} + assert "fs.readFile" in targets + @pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed") class TestTreeSitterTypeScriptParser: