Add graph expansion and cross-encoder reranking features

- Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors.
- Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring.
- Created migrations to establish necessary database tables for relationships and graph neighbors.
- Developed tests for graph expansion functionality, ensuring related results are populated correctly.
- Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead.
- Updated schema cleanup tests to reflect changes in versioning and deprecated fields.
- Added new test cases for Treesitter parser to validate relationship extraction with alias resolution.
This commit is contained in:
catlog22
2025-12-31 16:58:59 +08:00
parent 4bde13e83a
commit 31a45f1f30
27 changed files with 2566 additions and 97 deletions

View File

@@ -49,6 +49,12 @@ semantic-directml = [
"onnxruntime-directml>=1.15.0", # DirectML support
]
# Cross-encoder reranking (second-stage, optional)
# Install with: pip install codexlens[reranker]
reranker = [
"sentence-transformers>=2.2",
]
# Encoding detection for non-UTF8 files
encoding = [
"chardet>=5.0",

View File

@@ -407,7 +407,7 @@ def search(
registry.initialize()
mapper = PathMapper()
engine = ChainSearchEngine(registry, mapper)
engine = ChainSearchEngine(registry, mapper, config=config)
# Auto-detect mode if set to "auto"
actual_mode = mode
@@ -550,7 +550,7 @@ def symbol(
registry.initialize()
mapper = PathMapper()
engine = ChainSearchEngine(registry, mapper)
engine = ChainSearchEngine(registry, mapper, config=config)
options = SearchOptions(depth=depth, total_limit=limit)
syms = engine.search_symbols(name, search_path, kind=kind, options=options)

View File

@@ -105,12 +105,22 @@ class Config:
# Indexing/search optimizations
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
# Graph expansion (search-time, uses precomputed neighbors)
enable_graph_expansion: bool = False
graph_expansion_depth: int = 2
# Optional search reranking (disabled by default)
enable_reranking: bool = False
reranking_top_k: int = 50
symbol_boost_factor: float = 1.5
# Optional cross-encoder reranking (second stage, requires codexlens[reranker])
enable_cross_encoder_rerank: bool = False
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker_top_k: int = 50
# Multi-endpoint configuration for litellm backend
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]

View File

@@ -58,6 +58,7 @@ class IndexedFile(BaseModel):
language: str = Field(..., min_length=1)
symbols: List[Symbol] = Field(default_factory=list)
chunks: List[SemanticChunk] = Field(default_factory=list)
relationships: List["CodeRelationship"] = Field(default_factory=list)
@field_validator("path", "language")
@classmethod
@@ -70,7 +71,7 @@ class IndexedFile(BaseModel):
class RelationshipType(str, Enum):
"""Types of code relationships."""
CALL = "call"
CALL = "calls"
INHERITS = "inherits"
IMPORTS = "imports"

View File

@@ -4,6 +4,11 @@ import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
try:
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
except Exception: # pragma: no cover - optional dependency / platform variance
TreeSitterSymbolParser = None # type: ignore[assignment]
class SymbolExtractor:
"""Extract symbols and relationships from source code using regex patterns."""
@@ -118,7 +123,7 @@ class SymbolExtractor:
patterns = self.PATTERNS[lang]
symbols = []
relationships = []
relationships: List[Dict] = []
lines = content.split('\n')
current_scope = None
@@ -141,33 +146,62 @@ class SymbolExtractor:
})
current_scope = name
# Extract imports
if 'import' in patterns:
match = re.search(patterns['import'], line)
if match:
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
if import_target and current_scope:
relationships.append({
'source_scope': current_scope,
'target': import_target.strip(),
'type': 'imports',
'file_path': str(file_path),
'line': line_num,
})
if TreeSitterSymbolParser is not None:
try:
ts_parser = TreeSitterSymbolParser(lang, file_path)
if ts_parser.is_available():
indexed = ts_parser.parse(content, file_path)
if indexed is not None and indexed.relationships:
relationships = [
{
"source_scope": r.source_symbol,
"target": r.target_symbol,
"type": r.relationship_type.value,
"file_path": str(file_path),
"line": r.source_line,
}
for r in indexed.relationships
]
except Exception:
relationships = []
# Extract function calls (simplified)
if 'call' in patterns and current_scope:
for match in re.finditer(patterns['call'], line):
call_name = match.group(1)
# Skip common keywords and the current function
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
relationships.append({
'source_scope': current_scope,
'target': call_name,
'type': 'calls',
'file_path': str(file_path),
'line': line_num,
})
# Regex fallback for relationships (when tree-sitter is unavailable)
if not relationships:
current_scope = None
for line_num, line in enumerate(lines, 1):
for kind in ['function', 'class']:
if kind in patterns:
match = re.search(patterns[kind], line)
if match:
current_scope = match.group(1)
# Extract imports
if 'import' in patterns:
match = re.search(patterns['import'], line)
if match:
import_target = match.group(1) or match.group(2) if match.lastindex >= 2 else match.group(1)
if import_target and current_scope:
relationships.append({
'source_scope': current_scope,
'target': import_target.strip(),
'type': 'imports',
'file_path': str(file_path),
'line': line_num,
})
# Extract function calls (simplified)
if 'call' in patterns and current_scope:
for match in re.finditer(patterns['call'], line):
call_name = match.group(1)
# Skip common keywords and the current function
if call_name not in ['if', 'for', 'while', 'return', 'print', 'len', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', current_scope]:
relationships.append({
'source_scope': current_scope,
'target': call_name,
'type': 'calls',
'file_path': str(file_path),
'line': line_num,
})
return symbols, relationships

View File

@@ -13,7 +13,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Protocol
from codexlens.config import Config
from codexlens.entities import IndexedFile, Symbol
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
@@ -30,36 +30,39 @@ class SimpleRegexParser:
if self.language_id in {"python", "javascript", "typescript"}:
ts_parser = TreeSitterSymbolParser(self.language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
)
indexed = ts_parser.parse(text, path)
if indexed is not None:
return indexed
# Fallback to regex parsing
if self.language_id == "python":
symbols = _parse_python_symbols_regex(text)
relationships = _parse_python_relationships_regex(text, path)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols_regex(text)
relationships = _parse_js_ts_relationships_regex(text, path)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
relationships = []
elif self.language_id == "go":
symbols = _parse_go_symbols(text)
relationships = []
elif self.language_id == "markdown":
symbols = _parse_markdown_symbols(text)
relationships = []
elif self.language_id == "text":
symbols = _parse_text_symbols(text)
relationships = []
else:
symbols = _parse_generic_symbols(text)
relationships = []
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
@@ -78,6 +81,9 @@ class ParserFactory:
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
@@ -127,12 +133,81 @@ def _parse_python_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
current_scope = def_match.group(1)
continue
if current_scope is None:
continue
import_match = _PY_IMPORT_RE.search(line)
if import_match:
import_target = import_match.group(1) or import_match.group(2)
if import_target:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_target.strip(),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _PY_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {
"if",
"for",
"while",
"return",
"print",
"len",
"str",
"int",
"float",
"list",
"dict",
"set",
"tuple",
current_scope,
}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
)
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
@@ -174,6 +249,61 @@ def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _JS_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
func_match = _JS_FUNC_RE.match(line)
if func_match:
current_scope = func_match.group(1)
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
current_scope = arrow_match.group(1)
continue
if current_scope is None:
continue
import_match = _JS_IMPORT_RE.search(line)
if import_match:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_match.group(1),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _JS_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {current_scope}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
@@ -253,4 +383,3 @@ def _parse_text_symbols(text: str) -> List[Symbol]:
# Text files don't have structured symbols, return empty list
# The file content will still be indexed for FTS search
return []

View File

@@ -1,12 +1,17 @@
"""Tree-sitter based parser for CodexLens.
Provides precise AST-level parsing with fallback to regex-based parsing.
Provides precise AST-level parsing via tree-sitter.
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
return `None`; callers should use a regex-based fallback such as
`codexlens.parsers.factory.SimpleRegexParser`.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional
try:
from tree_sitter import Language as TreeSitterLanguage
@@ -19,7 +24,7 @@ except ImportError:
TreeSitterParser = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import IndexedFile, Symbol
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@@ -85,6 +90,16 @@ class TreeSitterSymbolParser:
"""
return self._parser is not None and self._language is not None
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
if not self.is_available() or self._parser is None:
return None
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
return source_bytes, tree.root_node
except Exception:
return None
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
"""Parse source code and extract symbols without creating IndexedFile.
@@ -95,17 +110,15 @@ class TreeSitterSymbolParser:
Returns:
List of symbols if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
return self._extract_symbols(source_bytes, root)
except Exception:
# Gracefully handle parsing errors
# Gracefully handle extraction errors
return None
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
@@ -118,19 +131,21 @@ class TreeSitterSymbolParser:
Returns:
IndexedFile if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
symbols = self.parse_symbols(text)
if symbols is None:
return None
symbols = self._extract_symbols(source_bytes, root)
relationships = self._extract_relationships(source_bytes, root, path)
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
except Exception:
# Gracefully handle parsing errors
@@ -153,6 +168,465 @@ class TreeSitterSymbolParser:
else:
return []
def _extract_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, path)
if self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, path)
return []
def _extract_python_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"self", "cls"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "class_definition" and pushed_scope:
superclasses = node.child_by_field_name("superclasses")
if superclasses is not None:
for child in superclasses.children:
dotted = self._python_expression_to_dotted(source_bytes, child)
if not dotted:
continue
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type in {"import_statement", "import_from_statement"}:
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
if node.type == "call":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _extract_js_ts_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"this", "super"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if pushed_scope:
superclass = node.child_by_field_name("superclass")
if superclass is not None:
dotted = self._js_expression_to_dotted(source_bytes, superclass)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type in {"identifier", "property_identifier"}
and value_node.type == "arrow_function"
):
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name and scope_name != "constructor":
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"import_declaration", "import_statement"}:
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
# Best-effort support for CommonJS require() imports:
# const fs = require("fs")
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type == "identifier"
and value_node.type == "call_expression"
):
callee = value_node.child_by_field_name("function")
args = value_node.child_by_field_name("arguments")
if (
callee is not None
and self._node_text(source_bytes, callee).strip() == "require"
and args is not None
):
module_name = self._js_first_string_argument(source_bytes, args)
if module_name:
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
record_import(module_name, self._node_start_line(node))
if node.type == "call_expression":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _node_start_line(self, node: TreeSitterNode) -> int:
return node.start_point[0] + 1
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
dotted = (dotted or "").strip()
if not dotted:
return ""
base, sep, rest = dotted.partition(".")
resolved_base = aliases.get(base, base)
if not rest:
return resolved_base
if resolved_base and rest:
return f"{resolved_base}.{rest}"
return resolved_base
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"identifier", "dotted_name"}:
return self._node_text(source_bytes, node).strip()
if node.type == "attribute":
obj = node.child_by_field_name("object")
attr = node.child_by_field_name("attribute")
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
if obj_text and attr_text:
return f"{obj_text}.{attr_text}"
return obj_text or attr_text
return ""
def _python_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
if node.type == "import_statement":
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
module_name = self._node_text(source_bytes, name_node).strip()
if not module_name:
continue
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else module_name.split(".", 1)[0]
)
if bound_name:
aliases[bound_name] = module_name
targets.append(module_name)
elif child.type == "dotted_name":
module_name = self._node_text(source_bytes, child).strip()
if not module_name:
continue
bound_name = module_name.split(".", 1)[0]
if bound_name:
aliases[bound_name] = bound_name
targets.append(module_name)
if node.type == "import_from_statement":
module_name = ""
module_node = node.child_by_field_name("module_name")
if module_node is None:
for child in node.children:
if child.type == "dotted_name":
module_node = child
break
if module_node is not None:
module_name = self._node_text(source_bytes, module_node).strip()
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
imported_name = self._node_text(source_bytes, name_node).strip()
if not imported_name or imported_name == "*":
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported_name
)
if bound_name:
aliases[bound_name] = target
targets.append(target)
elif child.type == "identifier":
imported_name = self._node_text(source_bytes, child).strip()
if not imported_name or imported_name in {"from", "import", "*"}:
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
aliases[imported_name] = target
targets.append(target)
return aliases, targets
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"this", "super"}:
return node.type
if node.type in {"identifier", "property_identifier"}:
return self._node_text(source_bytes, node).strip()
if node.type == "member_expression":
obj = node.child_by_field_name("object")
prop = node.child_by_field_name("property")
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
if obj_text and prop_text:
return f"{obj_text}.{prop_text}"
return obj_text or prop_text
return ""
def _js_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
module_name = ""
source_node = node.child_by_field_name("source")
if source_node is not None:
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
if module_name:
targets.append(module_name)
for child in node.children:
if child.type == "import_clause":
for clause_child in child.children:
if clause_child.type == "identifier":
# Default import: import React from "react"
local = self._node_text(source_bytes, clause_child).strip()
if local and module_name:
aliases[local] = module_name
if clause_child.type == "namespace_import":
# Namespace import: import * as fs from "fs"
name_node = clause_child.child_by_field_name("name")
if name_node is not None and module_name:
local = self._node_text(source_bytes, name_node).strip()
if local:
aliases[local] = module_name
if clause_child.type == "named_imports":
for spec in clause_child.children:
if spec.type != "import_specifier":
continue
name_node = spec.child_by_field_name("name")
alias_node = spec.child_by_field_name("alias")
if name_node is None:
continue
imported = self._node_text(source_bytes, name_node).strip()
if not imported:
continue
local = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported
)
if local and module_name:
aliases[local] = f"{module_name}.{imported}"
targets.append(f"{module_name}.{imported}")
return aliases, targets
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
for child in args_node.children:
if child.type == "string":
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
return ""
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract Python symbols from AST.

View File

@@ -83,6 +83,7 @@ class ChainSearchResult:
Attributes:
query: Original search query
results: List of SearchResult objects
related_results: Expanded results from graph neighbors (optional)
symbols: List of Symbol objects (if include_symbols=True)
stats: SearchStats with execution metrics
"""
@@ -90,6 +91,7 @@ class ChainSearchResult:
results: List[SearchResult]
symbols: List[Symbol]
stats: SearchStats
related_results: List[SearchResult] = field(default_factory=list)
class ChainSearchEngine:
@@ -236,13 +238,26 @@ class ChainSearchEngine:
index_paths, query, None, options.total_limit
)
# Optional: graph expansion using precomputed neighbors
related_results: List[SearchResult] = []
if self._config is not None and getattr(self._config, "enable_graph_expansion", False):
try:
from codexlens.search.enrichment import SearchEnrichmentPipeline
pipeline = SearchEnrichmentPipeline(self.mapper, config=self._config)
related_results = pipeline.expand_related_results(final_results)
except Exception as exc:
self.logger.debug("Graph expansion failed: %s", exc)
related_results = []
stats.time_ms = (time.time() - start_time) * 1000
return ChainSearchResult(
query=query,
results=final_results,
symbols=symbols,
stats=stats
stats=stats,
related_results=related_results,
)
def search_files_only(self, query: str,

View File

@@ -4,6 +4,11 @@ import sqlite3
from pathlib import Path
from typing import List, Dict, Any, Optional
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.graph_expander import GraphExpander
from codexlens.storage.path_mapper import PathMapper
class RelationshipEnricher:
"""Enriches search results with code graph relationships."""
@@ -148,3 +153,19 @@ class RelationshipEnricher:
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
class SearchEnrichmentPipeline:
"""Search post-processing pipeline (optional enrichments)."""
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
self._config = config
self._graph_expander = GraphExpander(mapper, config=config)
def expand_related_results(self, results: List[SearchResult]) -> List[SearchResult]:
"""Expand base results with related symbols when enabled in config."""
if self._config is None or not getattr(self._config, "enable_graph_expansion", False):
return []
depth = int(getattr(self._config, "graph_expansion_depth", 2) or 2)
return self._graph_expander.expand(results, depth=depth)

View File

@@ -0,0 +1,264 @@
"""Graph expansion for search results using precomputed neighbors.
Expands top search results with related symbol definitions by traversing
precomputed N-hop neighbors stored in the per-directory index databases.
"""
from __future__ import annotations
import logging
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.storage.path_mapper import PathMapper
logger = logging.getLogger(__name__)
def _result_key(result: SearchResult) -> Tuple[str, Optional[str], Optional[int], Optional[int]]:
return (result.path, result.symbol_name, result.start_line, result.end_line)
def _slice_content_block(content: str, start_line: Optional[int], end_line: Optional[int]) -> Optional[str]:
if content is None:
return None
if start_line is None or end_line is None:
return None
if start_line < 1 or end_line < start_line:
return None
lines = content.splitlines()
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
if start_idx >= len(lines):
return None
return "\n".join(lines[start_idx:end_idx])
class GraphExpander:
"""Expands SearchResult lists with related symbols from the code graph."""
def __init__(self, mapper: PathMapper, *, config: Optional[Config] = None) -> None:
self._mapper = mapper
self._config = config
self._logger = logging.getLogger(__name__)
def expand(
self,
results: Sequence[SearchResult],
*,
depth: Optional[int] = None,
max_expand: int = 10,
max_related: int = 50,
) -> List[SearchResult]:
"""Expand top results with related symbols.
Args:
results: Base ranked results.
depth: Maximum relationship depth to include (defaults to Config or 2).
max_expand: Only expand the top-N base results to bound cost.
max_related: Maximum related results to return.
Returns:
A list of related SearchResult objects with relationship_depth metadata.
"""
if not results:
return []
configured_depth = getattr(self._config, "graph_expansion_depth", 2) if self._config else 2
max_depth = int(depth if depth is not None else configured_depth)
if max_depth <= 0:
return []
max_depth = min(max_depth, 2)
expand_count = max(0, int(max_expand))
related_limit = max(0, int(max_related))
if expand_count == 0 or related_limit == 0:
return []
seen = {_result_key(r) for r in results}
related_results: List[SearchResult] = []
conn_cache: Dict[Path, sqlite3.Connection] = {}
try:
for base in list(results)[:expand_count]:
if len(related_results) >= related_limit:
break
if not base.symbol_name or not base.path:
continue
index_path = self._mapper.source_to_index_db(Path(base.path).parent)
conn = conn_cache.get(index_path)
if conn is None:
conn = self._connect_readonly(index_path)
if conn is None:
continue
conn_cache[index_path] = conn
source_ids = self._resolve_source_symbol_ids(
conn,
file_path=base.path,
symbol_name=base.symbol_name,
symbol_kind=base.symbol_kind,
)
if not source_ids:
continue
for source_id in source_ids:
neighbors = self._get_neighbors(conn, source_id, max_depth=max_depth, limit=related_limit)
for neighbor_id, rel_depth in neighbors:
if len(related_results) >= related_limit:
break
row = self._get_symbol_details(conn, neighbor_id)
if row is None:
continue
path = str(row["full_path"])
symbol_name = str(row["name"])
symbol_kind = str(row["kind"])
start_line = int(row["start_line"]) if row["start_line"] is not None else None
end_line = int(row["end_line"]) if row["end_line"] is not None else None
content_block = _slice_content_block(
str(row["content"]) if row["content"] is not None else "",
start_line,
end_line,
)
score = float(base.score) * (0.5 ** int(rel_depth))
candidate = SearchResult(
path=path,
score=max(0.0, score),
excerpt=None,
content=content_block,
start_line=start_line,
end_line=end_line,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
metadata={"relationship_depth": int(rel_depth)},
)
key = _result_key(candidate)
if key in seen:
continue
seen.add(key)
related_results.append(candidate)
finally:
for conn in conn_cache.values():
try:
conn.close()
except Exception:
pass
return related_results
def _connect_readonly(self, index_path: Path) -> Optional[sqlite3.Connection]:
try:
if not index_path.exists() or index_path.stat().st_size == 0:
return None
except OSError:
return None
try:
conn = sqlite3.connect(f"file:{index_path}?mode=ro", uri=True, check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
except Exception as exc:
self._logger.debug("GraphExpander failed to open %s: %s", index_path, exc)
return None
def _resolve_source_symbol_ids(
self,
conn: sqlite3.Connection,
*,
file_path: str,
symbol_name: str,
symbol_kind: Optional[str],
) -> List[int]:
try:
if symbol_kind:
rows = conn.execute(
"""
SELECT s.id
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE f.full_path = ? AND s.name = ? AND s.kind = ?
""",
(file_path, symbol_name, symbol_kind),
).fetchall()
else:
rows = conn.execute(
"""
SELECT s.id
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE f.full_path = ? AND s.name = ?
""",
(file_path, symbol_name),
).fetchall()
except sqlite3.Error:
return []
ids: List[int] = []
for row in rows:
try:
ids.append(int(row["id"]))
except Exception:
continue
return ids
def _get_neighbors(
self,
conn: sqlite3.Connection,
source_symbol_id: int,
*,
max_depth: int,
limit: int,
) -> List[Tuple[int, int]]:
try:
rows = conn.execute(
"""
SELECT neighbor_symbol_id, relationship_depth
FROM graph_neighbors
WHERE source_symbol_id = ? AND relationship_depth <= ?
ORDER BY relationship_depth ASC, neighbor_symbol_id ASC
LIMIT ?
""",
(int(source_symbol_id), int(max_depth), int(limit)),
).fetchall()
except sqlite3.Error:
return []
neighbors: List[Tuple[int, int]] = []
for row in rows:
try:
neighbors.append((int(row["neighbor_symbol_id"]), int(row["relationship_depth"])))
except Exception:
continue
return neighbors
def _get_symbol_details(self, conn: sqlite3.Connection, symbol_id: int) -> Optional[sqlite3.Row]:
try:
return conn.execute(
"""
SELECT
s.id,
s.name,
s.kind,
s.start_line,
s.end_line,
f.full_path,
f.content
FROM symbols s
JOIN files f ON f.id = s.file_id
WHERE s.id = ?
""",
(int(symbol_id),),
).fetchone()
except sqlite3.Error:
return None

View File

@@ -34,6 +34,7 @@ from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.ranking import (
apply_symbol_boost,
cross_encoder_rerank,
get_rrf_weights,
reciprocal_rank_fusion,
rerank_results,
@@ -77,6 +78,7 @@ class HybridSearchEngine:
self.weights = weights or self.DEFAULT_WEIGHTS.copy()
self._config = config
self.embedder = embedder
self.reranker: Any = None
def search(
self,
@@ -112,6 +114,14 @@ class HybridSearchEngine:
>>> for r in results[:5]:
... print(f"{r.path}: {r.score:.3f}")
"""
# Defensive: avoid creating/locking an index database when callers pass
# an empty placeholder file (common in tests and misconfigured callers).
try:
if index_path.exists() and index_path.stat().st_size == 0:
return []
except OSError:
return []
# Determine which backends to use
backends = {}
@@ -180,9 +190,30 @@ class HybridSearchEngine:
query,
fused_results[:100],
self.embedder,
top_k=self._config.reranking_top_k,
top_k=(
100
if self._config.enable_cross_encoder_rerank
else self._config.reranking_top_k
),
)
# Optional: cross-encoder reranking as a second stage
if (
self._config is not None
and self._config.enable_reranking
and self._config.enable_cross_encoder_rerank
):
with timer("cross_encoder_rerank", self.logger):
if self.reranker is None:
self.reranker = self._get_cross_encoder_reranker()
if self.reranker is not None:
fused_results = cross_encoder_rerank(
query,
fused_results,
self.reranker,
top_k=self._config.reranker_top_k,
)
# Apply final limit
return fused_results[:limit]
@@ -222,6 +253,27 @@ class HybridSearchEngine:
)
return None
def _get_cross_encoder_reranker(self) -> Any:
if self._config is None:
return None
try:
from codexlens.semantic.reranker import CrossEncoderReranker, check_cross_encoder_available
except Exception as exc:
self.logger.debug("Cross-encoder reranker unavailable: %s", exc)
return None
ok, err = check_cross_encoder_available()
if not ok:
self.logger.debug("Cross-encoder reranker unavailable: %s", err)
return None
try:
return CrossEncoderReranker(model_name=self._config.reranker_model)
except Exception as exc:
self.logger.debug("Failed to initialize cross-encoder reranker: %s", exc)
return None
def _search_parallel(
self,
index_path: Path,

View File

@@ -379,6 +379,117 @@ def rerank_results(
return reranked_results
def cross_encoder_rerank(
query: str,
results: List[SearchResult],
reranker: Any,
top_k: int = 50,
batch_size: int = 32,
) -> List[SearchResult]:
"""Second-stage reranking using a cross-encoder model.
This function is dependency-agnostic: callers can pass any object that exposes
a compatible `score_pairs(pairs, batch_size=...)` method.
"""
if not results:
return []
if reranker is None or top_k <= 0:
return results
rerank_count = min(int(top_k), len(results))
def text_for_pair(r: SearchResult) -> str:
if r.excerpt and r.excerpt.strip():
return r.excerpt
if r.content and r.content.strip():
return r.content
if r.chunk and r.chunk.content and r.chunk.content.strip():
return r.chunk.content
return r.symbol_name or r.path
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
try:
if hasattr(reranker, "score_pairs"):
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
elif hasattr(reranker, "predict"):
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
else:
return results
except Exception:
return results
if not raw_scores or len(raw_scores) != rerank_count:
return results
scores = [float(s) for s in raw_scores]
min_s = min(scores)
max_s = max(scores)
def sigmoid(x: float) -> float:
# Clamp to keep exp() stable.
x = max(-50.0, min(50.0, x))
return 1.0 / (1.0 + math.exp(-x))
if 0.0 <= min_s and max_s <= 1.0:
probs = scores
else:
probs = [sigmoid(s) for s in scores]
reranked_results: List[SearchResult] = []
for idx, result in enumerate(results):
if idx < rerank_count:
prev_score = float(result.score)
ce_score = scores[idx]
ce_prob = probs[idx]
combined_score = 0.5 * prev_score + 0.5 * ce_prob
reranked_results.append(
SearchResult(
path=result.path,
score=combined_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={
**result.metadata,
"pre_cross_encoder_score": prev_score,
"cross_encoder_score": ce_score,
"cross_encoder_prob": ce_prob,
"cross_encoder_reranked": True,
},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
else:
reranked_results.append(
SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
reranked_results.sort(key=lambda r: r.score, reverse=True)
return reranked_results
def normalize_bm25_score(score: float) -> float:
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.

View File

@@ -0,0 +1,86 @@
"""Optional cross-encoder reranker for second-stage search ranking.
Install with: pip install codexlens[reranker]
"""
from __future__ import annotations
import logging
import threading
from typing import List, Sequence, Tuple
logger = logging.getLogger(__name__)
try:
from sentence_transformers import CrossEncoder as _CrossEncoder
CROSS_ENCODER_AVAILABLE = True
_import_error: str | None = None
except ImportError as exc: # pragma: no cover - optional dependency
_CrossEncoder = None # type: ignore[assignment]
CROSS_ENCODER_AVAILABLE = False
_import_error = str(exc)
def check_cross_encoder_available() -> tuple[bool, str | None]:
if CROSS_ENCODER_AVAILABLE:
return True, None
return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]"
class CrossEncoderReranker:
"""Cross-encoder reranker with lazy model loading."""
def __init__(self, model_name: str, *, device: str | None = None) -> None:
self.model_name = (model_name or "").strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.device = (device or "").strip() or None
self._model = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None:
return
ok, err = check_cross_encoder_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None:
return
try:
if self.device:
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
else:
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
except Exception as exc:
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
raise
def score_pairs(
self,
pairs: Sequence[Tuple[str, str]],
*,
batch_size: int = 32,
) -> List[float]:
"""Score (query, doc) pairs using the cross-encoder.
Returns:
List of scores (one per pair) in the model's native scale (usually logits).
"""
if not pairs:
return []
self._load_model()
if self._model is None: # pragma: no cover - defensive
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
return [float(s) for s in scores]

View File

@@ -10,15 +10,17 @@ Each directory maintains its own _index.db with:
from __future__ import annotations
import logging
import hashlib
import re
import sqlite3
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from codexlens.config import Config
from codexlens.entities import SearchResult, Symbol
from codexlens.entities import CodeRelationship, SearchResult, Symbol
from codexlens.errors import StorageError
from codexlens.storage.global_index import GlobalSymbolIndex
@@ -60,7 +62,7 @@ class DirIndexStore:
# Schema version for migration tracking
# Increment this when schema changes require migration
SCHEMA_VERSION = 5
SCHEMA_VERSION = 8
def __init__(
self,
@@ -150,6 +152,21 @@ class DirIndexStore:
from codexlens.storage.migrations.migration_005_cleanup_unused_fields import upgrade
upgrade(conn)
# Migration v5 -> v6: Ensure relationship tables/indexes exist
if from_version < 6:
from codexlens.storage.migrations.migration_006_enhance_relationships import upgrade
upgrade(conn)
# Migration v6 -> v7: Add graph neighbor cache for search expansion
if from_version < 7:
from codexlens.storage.migrations.migration_007_add_graph_neighbors import upgrade
upgrade(conn)
# Migration v7 -> v8: Add Merkle hashes for incremental change detection
if from_version < 8:
from codexlens.storage.migrations.migration_008_add_merkle_hashes import upgrade
upgrade(conn)
def close(self) -> None:
"""Close database connection."""
with self._lock:
@@ -179,6 +196,7 @@ class DirIndexStore:
content: str,
language: str,
symbols: Optional[List[Symbol]] = None,
relationships: Optional[List[CodeRelationship]] = None,
) -> int:
"""Add or update a file in the current directory index.
@@ -188,6 +206,7 @@ class DirIndexStore:
content: File content for indexing
language: Programming language identifier
symbols: List of Symbol objects from the file
relationships: Optional list of CodeRelationship edges from this file
Returns:
Database file_id
@@ -240,6 +259,8 @@ class DirIndexStore:
symbol_rows,
)
self._save_merkle_hash(conn, file_id=file_id, content=content)
self._save_relationships(conn, file_id=file_id, relationships=relationships)
conn.commit()
self._maybe_update_global_symbols(full_path_str, symbols or [])
return file_id
@@ -248,6 +269,96 @@ class DirIndexStore:
conn.rollback()
raise StorageError(f"Failed to add file {name}: {exc}") from exc
def save_relationships(self, file_id: int, relationships: List[CodeRelationship]) -> None:
"""Save relationships for an already-indexed file.
Args:
file_id: Database file id
relationships: Relationship edges to persist
"""
if not relationships:
return
with self._lock:
conn = self._get_connection()
self._save_relationships(conn, file_id=file_id, relationships=relationships)
conn.commit()
def _save_relationships(
self,
conn: sqlite3.Connection,
file_id: int,
relationships: Optional[List[CodeRelationship]],
) -> None:
if not relationships:
return
rows = conn.execute(
"SELECT id, name FROM symbols WHERE file_id=? ORDER BY start_line, id",
(file_id,),
).fetchall()
name_to_id: Dict[str, int] = {}
for row in rows:
name = row["name"]
if name not in name_to_id:
name_to_id[name] = int(row["id"])
if not name_to_id:
return
rel_rows: List[Tuple[int, str, str, int, Optional[str]]] = []
seen: set[tuple[int, str, str, int, Optional[str]]] = set()
for rel in relationships:
source_id = name_to_id.get(rel.source_symbol)
if source_id is None:
continue
target = (rel.target_symbol or "").strip()
if not target:
continue
rel_type = rel.relationship_type.value
source_line = int(rel.source_line)
key = (source_id, target, rel_type, source_line, rel.target_file)
if key in seen:
continue
seen.add(key)
rel_rows.append((source_id, target, rel_type, source_line, rel.target_file))
if not rel_rows:
return
conn.executemany(
"""
INSERT INTO code_relationships(
source_symbol_id, target_qualified_name,
relationship_type, source_line, target_file
)
VALUES(?, ?, ?, ?, ?)
""",
rel_rows,
)
def _save_merkle_hash(self, conn: sqlite3.Connection, file_id: int, content: str) -> None:
"""Upsert a SHA-256 content hash for the given file_id (best-effort)."""
try:
digest = hashlib.sha256(content.encode("utf-8", errors="ignore")).hexdigest()
now = time.time()
conn.execute(
"""
INSERT INTO merkle_hashes(file_id, sha256, updated_at)
VALUES(?, ?, ?)
ON CONFLICT(file_id) DO UPDATE SET
sha256=excluded.sha256,
updated_at=excluded.updated_at
""",
(file_id, digest, now),
)
except sqlite3.Error:
return
def add_files_batch(
self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]]
) -> int:
@@ -312,6 +423,8 @@ class DirIndexStore:
symbol_rows,
)
self._save_merkle_hash(conn, file_id=file_id, content=content)
conn.commit()
return count
@@ -395,9 +508,13 @@ class DirIndexStore:
return float(row["mtime"]) if row and row["mtime"] else None
def needs_reindex(self, full_path: str | Path) -> bool:
"""Check if a file needs reindexing based on mtime comparison.
"""Check if a file needs reindexing.
Uses 1ms tolerance to handle filesystem timestamp precision variations.
Default behavior uses mtime comparison (with 1ms tolerance).
When `Config.enable_merkle_detection` is enabled and Merkle metadata is
available, uses SHA-256 content hash comparison (with mtime as a fast
path to avoid hashing unchanged files).
Args:
full_path: Complete source file path
@@ -415,16 +532,154 @@ class DirIndexStore:
except OSError:
return False # Can't read file stats, skip
# Get stored mtime from database
stored_mtime = self.get_file_mtime(full_path_obj)
MTIME_TOLERANCE = 0.001
# File not in index, needs indexing
if stored_mtime is None:
# Fast path: mtime-only mode (default / backward-compatible)
if self._config is None or not getattr(self._config, "enable_merkle_detection", False):
stored_mtime = self.get_file_mtime(full_path_obj)
if stored_mtime is None:
return True
return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE
full_path_str = str(full_path_obj)
# Hash-based change detection (best-effort, falls back to mtime when metadata missing)
with self._lock:
conn = self._get_connection()
try:
row = conn.execute(
"""
SELECT f.id AS file_id, f.mtime AS mtime, mh.sha256 AS sha256
FROM files f
LEFT JOIN merkle_hashes mh ON mh.file_id = f.id
WHERE f.full_path=?
""",
(full_path_str,),
).fetchone()
except sqlite3.Error:
row = None
if row is None:
return True
# Compare with 1ms tolerance for floating point precision
MTIME_TOLERANCE = 0.001
return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE
stored_mtime = float(row["mtime"]) if row["mtime"] else None
stored_hash = row["sha256"] if row["sha256"] else None
file_id = int(row["file_id"])
# Missing Merkle data: fall back to mtime
if stored_hash is None:
if stored_mtime is None:
return True
return abs(current_mtime - stored_mtime) > MTIME_TOLERANCE
# If mtime is unchanged within tolerance, assume unchanged without hashing.
if stored_mtime is not None and abs(current_mtime - stored_mtime) <= MTIME_TOLERANCE:
return False
try:
current_text = full_path_obj.read_text(encoding="utf-8", errors="ignore")
current_hash = hashlib.sha256(current_text.encode("utf-8", errors="ignore")).hexdigest()
except OSError:
return False
if current_hash == stored_hash:
# Content unchanged, but mtime drifted: update stored mtime to avoid repeated hashing.
with self._lock:
conn = self._get_connection()
conn.execute("UPDATE files SET mtime=? WHERE id=?", (current_mtime, file_id))
conn.commit()
return False
return True
def get_merkle_root_hash(self) -> Optional[str]:
"""Return the stored Merkle root hash for this directory index (if present)."""
with self._lock:
conn = self._get_connection()
try:
row = conn.execute(
"SELECT root_hash FROM merkle_state WHERE id=1"
).fetchone()
except sqlite3.Error:
return None
return row["root_hash"] if row and row["root_hash"] else None
def update_merkle_root(self) -> Optional[str]:
"""Compute and persist the Merkle root hash for this directory index.
The root hash includes:
- Direct file hashes from `merkle_hashes`
- Direct subdirectory root hashes (read from child `_index.db` files)
"""
if self._config is None or not getattr(self._config, "enable_merkle_detection", False):
return None
with self._lock:
conn = self._get_connection()
try:
file_rows = conn.execute(
"""
SELECT f.name AS name, mh.sha256 AS sha256
FROM files f
LEFT JOIN merkle_hashes mh ON mh.file_id = f.id
ORDER BY f.name
"""
).fetchall()
subdir_rows = conn.execute(
"SELECT name, index_path FROM subdirs ORDER BY name"
).fetchall()
except sqlite3.Error as exc:
self.logger.debug("Failed to compute merkle root: %s", exc)
return None
items: List[str] = []
for row in file_rows:
name = row["name"]
sha = (row["sha256"] or "").strip()
items.append(f"f:{name}:{sha}")
def read_child_root(index_path: str) -> str:
try:
with sqlite3.connect(index_path) as child_conn:
child_conn.row_factory = sqlite3.Row
child_row = child_conn.execute(
"SELECT root_hash FROM merkle_state WHERE id=1"
).fetchone()
return child_row["root_hash"] if child_row and child_row["root_hash"] else ""
except Exception:
return ""
for row in subdir_rows:
name = row["name"]
index_path = row["index_path"]
child_hash = read_child_root(index_path) if index_path else ""
items.append(f"d:{name}:{child_hash}")
root_hash = hashlib.sha256("\n".join(items).encode("utf-8", errors="ignore")).hexdigest()
now = time.time()
with self._lock:
conn = self._get_connection()
try:
conn.execute(
"""
INSERT INTO merkle_state(id, root_hash, updated_at)
VALUES(1, ?, ?)
ON CONFLICT(id) DO UPDATE SET
root_hash=excluded.root_hash,
updated_at=excluded.updated_at
""",
(root_hash, now),
)
conn.commit()
except sqlite3.Error as exc:
self.logger.debug("Failed to persist merkle root: %s", exc)
return None
return root_hash
def add_file_incremental(
self,
@@ -433,6 +688,7 @@ class DirIndexStore:
content: str,
language: str,
symbols: Optional[List[Symbol]] = None,
relationships: Optional[List[CodeRelationship]] = None,
) -> Optional[int]:
"""Add or update a file only if it has changed (incremental indexing).
@@ -444,6 +700,7 @@ class DirIndexStore:
content: File content for indexing
language: Programming language identifier
symbols: List of Symbol objects from the file
relationships: Optional list of CodeRelationship edges from this file
Returns:
Database file_id if indexed, None if skipped (unchanged)
@@ -456,7 +713,7 @@ class DirIndexStore:
return None # Skip unchanged file
# File changed or new, perform full indexing
return self.add_file(name, full_path, content, language, symbols)
return self.add_file(name, full_path, content, language, symbols, relationships)
def cleanup_deleted_files(self, source_dir: Path) -> int:
"""Remove indexed files that no longer exist in the source directory.
@@ -1767,6 +2024,39 @@ class DirIndexStore:
"""
)
# Precomputed graph neighbors cache for search expansion (v7)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS graph_neighbors (
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
relationship_depth INTEGER NOT NULL,
PRIMARY KEY (source_symbol_id, neighbor_symbol_id)
)
"""
)
# Merkle hashes for incremental change detection (v8)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS merkle_hashes (
file_id INTEGER PRIMARY KEY REFERENCES files(id) ON DELETE CASCADE,
sha256 TEXT NOT NULL,
updated_at REAL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS merkle_state (
id INTEGER PRIMARY KEY CHECK (id = 1),
root_hash TEXT,
updated_at REAL
)
"""
)
# Indexes (v5: removed idx_symbols_type)
conn.execute("CREATE INDEX IF NOT EXISTS idx_files_name ON files(name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_files_path ON files(full_path)")
@@ -1780,6 +2070,14 @@ class DirIndexStore:
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)")
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth "
"ON graph_neighbors(source_symbol_id, relationship_depth)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor "
"ON graph_neighbors(neighbor_symbol_id)"
)
except sqlite3.DatabaseError as exc:
raise StorageError(f"Failed to create schema: {exc}") from exc

View File

@@ -8,11 +8,13 @@ from __future__ import annotations
import logging
import os
import re
import sqlite3
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple
from codexlens.config import Config
from codexlens.parsers.factory import ParserFactory
@@ -247,6 +249,9 @@ class IndexTreeBuilder:
try:
with DirIndexStore(result.index_path, config=self.config, global_index=global_index) as store:
deleted_count = store.cleanup_deleted_files(result.source_path)
if deleted_count > 0:
_compute_graph_neighbors(store, logger=self.logger)
store.update_merkle_root()
total_deleted += deleted_count
if deleted_count > 0:
self.logger.debug("Removed %d deleted files from %s", deleted_count, result.source_path)
@@ -575,6 +580,7 @@ class IndexTreeBuilder:
content=text,
language=language_id,
symbols=indexed_file.symbols,
relationships=indexed_file.relationships,
)
files_count += 1
@@ -584,6 +590,9 @@ class IndexTreeBuilder:
self.logger.debug("Failed to index %s: %s", file_path, exc)
continue
if files_count > 0:
_compute_graph_neighbors(store, logger=self.logger)
# Get list of subdirectories
subdirs = [
d.name
@@ -593,6 +602,7 @@ class IndexTreeBuilder:
and not d.name.startswith(".")
]
store.update_merkle_root()
store.close()
if global_index is not None:
global_index.close()
@@ -654,31 +664,29 @@ class IndexTreeBuilder:
parent_index_db = self.mapper.source_to_index_db(parent_path)
try:
store = DirIndexStore(parent_index_db)
store.initialize()
with DirIndexStore(parent_index_db, config=self.config) as store:
for result in all_results:
# Only register direct children (parent is one level up)
if result.source_path.parent != parent_path:
continue
for result in all_results:
# Only register direct children (parent is one level up)
if result.source_path.parent != parent_path:
continue
if result.error:
continue
if result.error:
continue
# Register subdirectory link
store.register_subdir(
name=result.source_path.name,
index_path=result.index_path,
files_count=result.files_count,
direct_files=result.files_count,
)
self.logger.debug(
"Linked %s to parent %s",
result.source_path.name,
parent_path,
)
# Register subdirectory link
store.register_subdir(
name=result.source_path.name,
index_path=result.index_path,
files_count=result.files_count,
direct_files=result.files_count,
)
self.logger.debug(
"Linked %s to parent %s",
result.source_path.name,
parent_path,
)
store.close()
store.update_merkle_root()
except Exception as exc:
self.logger.error(
@@ -726,6 +734,164 @@ class IndexTreeBuilder:
return files
def _normalize_relationship_target(target: str) -> str:
"""Best-effort normalization of a relationship target into a local symbol name."""
target = (target or "").strip()
if not target:
return ""
# Drop trailing call parentheses when present (e.g., "foo()" -> "foo").
if target.endswith("()"):
target = target[:-2]
# Keep the leaf identifier for common qualified formats.
for sep in ("::", ".", "#"):
if sep in target:
target = target.split(sep)[-1]
# Strip non-identifier suffix/prefix noise.
target = re.sub(r"^[^A-Za-z0-9_]+", "", target)
target = re.sub(r"[^A-Za-z0-9_]+$", "", target)
return target
def _compute_graph_neighbors(
store: DirIndexStore,
*,
max_depth: int = 2,
logger: Optional[logging.Logger] = None,
) -> None:
"""Compute and persist N-hop neighbors for all symbols in a directory index."""
if max_depth <= 0:
return
log = logger or logging.getLogger(__name__)
with store._lock:
conn = store._get_connection()
conn.row_factory = sqlite3.Row
# Ensure schema exists even for older databases pinned to the same user_version.
try:
from codexlens.storage.migrations.migration_007_add_graph_neighbors import upgrade
upgrade(conn)
except Exception as exc:
log.debug("Graph neighbor schema ensure failed: %s", exc)
cursor = conn.cursor()
try:
cursor.execute("DELETE FROM graph_neighbors")
except sqlite3.Error:
# Table missing or schema mismatch; skip gracefully.
return
try:
symbol_rows = cursor.execute(
"SELECT id, file_id, name FROM symbols"
).fetchall()
rel_rows = cursor.execute(
"SELECT source_symbol_id, target_qualified_name FROM code_relationships"
).fetchall()
except sqlite3.Error:
return
if not symbol_rows or not rel_rows:
try:
conn.commit()
except sqlite3.Error:
pass
return
symbol_file_by_id: Dict[int, int] = {}
symbols_by_file_and_name: Dict[Tuple[int, str], List[int]] = {}
symbols_by_name: Dict[str, List[int]] = {}
for row in symbol_rows:
symbol_id = int(row["id"])
file_id = int(row["file_id"])
name = str(row["name"])
symbol_file_by_id[symbol_id] = file_id
symbols_by_file_and_name.setdefault((file_id, name), []).append(symbol_id)
symbols_by_name.setdefault(name, []).append(symbol_id)
adjacency: Dict[int, Set[int]] = {}
for row in rel_rows:
source_id = int(row["source_symbol_id"])
target_raw = str(row["target_qualified_name"] or "")
target_name = _normalize_relationship_target(target_raw)
if not target_name:
continue
source_file_id = symbol_file_by_id.get(source_id)
if source_file_id is None:
continue
candidate_ids = symbols_by_file_and_name.get((source_file_id, target_name))
if not candidate_ids:
global_candidates = symbols_by_name.get(target_name, [])
# Only resolve cross-file by name when unambiguous.
candidate_ids = global_candidates if len(global_candidates) == 1 else []
for target_id in candidate_ids:
if target_id == source_id:
continue
adjacency.setdefault(source_id, set()).add(target_id)
adjacency.setdefault(target_id, set()).add(source_id)
if not adjacency:
try:
conn.commit()
except sqlite3.Error:
pass
return
insert_rows: List[Tuple[int, int, int]] = []
max_depth = min(int(max_depth), 2)
for source_id, first_hop in adjacency.items():
if not first_hop:
continue
for neighbor_id in first_hop:
insert_rows.append((source_id, neighbor_id, 1))
if max_depth < 2:
continue
second_hop: Set[int] = set()
for neighbor_id in first_hop:
second_hop.update(adjacency.get(neighbor_id, set()))
second_hop.discard(source_id)
second_hop.difference_update(first_hop)
for neighbor_id in second_hop:
insert_rows.append((source_id, neighbor_id, 2))
if not insert_rows:
try:
conn.commit()
except sqlite3.Error:
pass
return
try:
cursor.executemany(
"""
INSERT INTO graph_neighbors(
source_symbol_id, neighbor_symbol_id, relationship_depth
)
VALUES(?, ?, ?)
""",
insert_rows,
)
conn.commit()
except sqlite3.Error:
return
# === Worker Function for ProcessPoolExecutor ===
@@ -795,6 +961,7 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
content=text,
language=language_id,
symbols=indexed_file.symbols,
relationships=indexed_file.relationships,
)
files_count += 1
@@ -803,6 +970,9 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
except Exception:
continue
if files_count > 0:
_compute_graph_neighbors(store)
# Get subdirectories
ignore_dirs = {
".git",
@@ -821,6 +991,7 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
if d.is_dir() and d.name not in ignore_dirs and not d.name.startswith(".")
]
store.update_merkle_root()
store.close()
if global_index is not None:
global_index.close()

View File

@@ -0,0 +1,136 @@
"""Merkle tree utilities for change detection.
This module provides a generic, file-system based Merkle tree implementation
that can be used to efficiently diff directory states.
"""
from __future__ import annotations
import hashlib
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional
def sha256_bytes(data: bytes) -> str:
return hashlib.sha256(data).hexdigest()
def sha256_text(text: str) -> str:
return sha256_bytes(text.encode("utf-8", errors="ignore"))
@dataclass
class MerkleNode:
"""A Merkle node representing either a file (leaf) or directory (internal)."""
name: str
rel_path: str
hash: str
is_dir: bool
children: Dict[str, "MerkleNode"] = field(default_factory=dict)
def iter_files(self) -> Iterable["MerkleNode"]:
if not self.is_dir:
yield self
return
for child in self.children.values():
yield from child.iter_files()
@dataclass
class MerkleTree:
"""Merkle tree for a directory snapshot."""
root: MerkleNode
@classmethod
def build_from_directory(cls, root_dir: Path) -> "MerkleTree":
root_dir = Path(root_dir).resolve()
node = cls._build_node(root_dir, base=root_dir)
return cls(root=node)
@classmethod
def _build_node(cls, path: Path, *, base: Path) -> MerkleNode:
if path.is_file():
rel = str(path.relative_to(base)).replace("\\", "/")
return MerkleNode(
name=path.name,
rel_path=rel,
hash=sha256_bytes(path.read_bytes()),
is_dir=False,
)
if not path.is_dir():
rel = str(path.relative_to(base)).replace("\\", "/")
return MerkleNode(name=path.name, rel_path=rel, hash="", is_dir=False)
children: Dict[str, MerkleNode] = {}
for child in sorted(path.iterdir(), key=lambda p: p.name):
child_node = cls._build_node(child, base=base)
children[child_node.name] = child_node
items = [
f"{'d' if n.is_dir else 'f'}:{name}:{n.hash}"
for name, n in sorted(children.items(), key=lambda kv: kv[0])
]
dir_hash = sha256_text("\n".join(items))
rel_path = "." if path == base else str(path.relative_to(base)).replace("\\", "/")
return MerkleNode(
name="." if path == base else path.name,
rel_path=rel_path,
hash=dir_hash,
is_dir=True,
children=children,
)
@staticmethod
def find_changed_files(old: Optional["MerkleTree"], new: Optional["MerkleTree"]) -> List[str]:
"""Find changed/added/removed files between two trees.
Returns:
List of relative file paths (POSIX-style separators).
"""
if old is None and new is None:
return []
if old is None:
return sorted({n.rel_path for n in new.root.iter_files()}) # type: ignore[union-attr]
if new is None:
return sorted({n.rel_path for n in old.root.iter_files()})
changed: set[str] = set()
def walk(old_node: Optional[MerkleNode], new_node: Optional[MerkleNode]) -> None:
if old_node is None and new_node is None:
return
if old_node is None and new_node is not None:
changed.update(n.rel_path for n in new_node.iter_files())
return
if new_node is None and old_node is not None:
changed.update(n.rel_path for n in old_node.iter_files())
return
assert old_node is not None and new_node is not None
if old_node.hash == new_node.hash:
return
if not old_node.is_dir and not new_node.is_dir:
changed.add(new_node.rel_path)
return
if old_node.is_dir != new_node.is_dir:
changed.update(n.rel_path for n in old_node.iter_files())
changed.update(n.rel_path for n in new_node.iter_files())
return
names = set(old_node.children.keys()) | set(new_node.children.keys())
for name in names:
walk(old_node.children.get(name), new_node.children.get(name))
walk(old.root, new.root)
return sorted(changed)

View File

@@ -0,0 +1,37 @@
"""
Migration 006: Ensure relationship tables and indexes exist.
This migration is intentionally idempotent. It creates the `code_relationships`
table (used for graph visualization) and its indexes if missing.
"""
from __future__ import annotations
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection) -> None:
cursor = db_conn.cursor()
log.info("Ensuring code_relationships table exists...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL REFERENCES symbols (id) ON DELETE CASCADE,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
)
"""
)
log.info("Ensuring relationship indexes exist...")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_rel_type ON code_relationships(relationship_type)")

View File

@@ -0,0 +1,47 @@
"""
Migration 007: Add precomputed graph neighbor table for search expansion.
Adds:
- graph_neighbors: cached N-hop neighbors between symbols (keyed by symbol ids)
This table is derived data (a cache) and is safe to rebuild at any time.
The migration is intentionally idempotent.
"""
from __future__ import annotations
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection) -> None:
cursor = db_conn.cursor()
log.info("Creating graph_neighbors table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS graph_neighbors (
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
neighbor_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
relationship_depth INTEGER NOT NULL,
PRIMARY KEY (source_symbol_id, neighbor_symbol_id)
)
"""
)
log.info("Creating indexes for graph_neighbors...")
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_source_depth
ON graph_neighbors(source_symbol_id, relationship_depth)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_graph_neighbors_neighbor
ON graph_neighbors(neighbor_symbol_id)
"""
)

View File

@@ -0,0 +1,81 @@
"""
Migration 008: Add Merkle hash tables for content-based incremental indexing.
Adds:
- merkle_hashes: per-file SHA-256 hashes (keyed by file_id)
- merkle_state: directory-level root hash (single row, id=1)
Backfills merkle_hashes using the existing `files.content` column when available.
"""
from __future__ import annotations
import hashlib
import logging
import time
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection) -> None:
cursor = db_conn.cursor()
log.info("Creating merkle_hashes table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS merkle_hashes (
file_id INTEGER PRIMARY KEY REFERENCES files(id) ON DELETE CASCADE,
sha256 TEXT NOT NULL,
updated_at REAL
)
"""
)
log.info("Creating merkle_state table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS merkle_state (
id INTEGER PRIMARY KEY CHECK (id = 1),
root_hash TEXT,
updated_at REAL
)
"""
)
# Backfill file hashes from stored content (best-effort).
try:
rows = cursor.execute("SELECT id, content FROM files").fetchall()
except Exception as exc:
log.warning("Unable to backfill merkle hashes (files table missing?): %s", exc)
return
now = time.time()
inserts: list[tuple[int, str, float]] = []
for row in rows:
file_id = int(row[0])
content = row[1]
if content is None:
continue
try:
digest = hashlib.sha256(str(content).encode("utf-8", errors="ignore")).hexdigest()
inserts.append((file_id, digest, now))
except Exception:
continue
if not inserts:
return
log.info("Backfilling %d file hashes...", len(inserts))
cursor.executemany(
"""
INSERT INTO merkle_hashes(file_id, sha256, updated_at)
VALUES(?, ?, ?)
ON CONFLICT(file_id) DO UPDATE SET
sha256=excluded.sha256,
updated_at=excluded.updated_at
""",
inserts,
)

View File

@@ -0,0 +1,188 @@
import sqlite3
import tempfile
from pathlib import Path
import pytest
from codexlens.config import Config
from codexlens.entities import CodeRelationship, RelationshipType, SearchResult, Symbol
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
from codexlens.search.graph_expander import GraphExpander
from codexlens.storage.dir_index import DirIndexStore
from codexlens.storage.index_tree import _compute_graph_neighbors
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
@pytest.fixture()
def temp_paths() -> Path:
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
root = Path(tmpdir.name)
yield root
try:
tmpdir.cleanup()
except (PermissionError, OSError):
pass
def _create_index_with_neighbors(root: Path) -> tuple[PathMapper, Path, Path]:
project_root = root / "project"
project_root.mkdir(parents=True, exist_ok=True)
index_root = root / "indexes"
mapper = PathMapper(index_root=index_root)
index_db_path = mapper.source_to_index_db(project_root)
index_db_path.parent.mkdir(parents=True, exist_ok=True)
content = "\n".join(
[
"def a():",
" b()",
"",
"def b():",
" c()",
"",
"def c():",
" return 1",
"",
]
)
file_path = project_root / "module.py"
file_path.write_text(content, encoding="utf-8")
symbols = [
Symbol(name="a", kind="function", range=(1, 2), file=str(file_path)),
Symbol(name="b", kind="function", range=(4, 5), file=str(file_path)),
Symbol(name="c", kind="function", range=(7, 8), file=str(file_path)),
]
relationships = [
CodeRelationship(
source_symbol="a",
target_symbol="b",
relationship_type=RelationshipType.CALL,
source_file=str(file_path),
target_file=None,
source_line=2,
),
CodeRelationship(
source_symbol="b",
target_symbol="c",
relationship_type=RelationshipType.CALL,
source_file=str(file_path),
target_file=None,
source_line=5,
),
]
config = Config(data_dir=root / "data")
store = DirIndexStore(index_db_path, config=config)
store.initialize()
store.add_file(
name=file_path.name,
full_path=file_path,
content=content,
language="python",
symbols=symbols,
relationships=relationships,
)
_compute_graph_neighbors(store)
store.close()
return mapper, project_root, file_path
def test_graph_neighbors_precomputed_two_hop(temp_paths: Path) -> None:
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
index_db_path = mapper.source_to_index_db(project_root)
conn = sqlite3.connect(str(index_db_path))
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"""
SELECT s1.name AS source_name, s2.name AS neighbor_name, gn.relationship_depth
FROM graph_neighbors gn
JOIN symbols s1 ON s1.id = gn.source_symbol_id
JOIN symbols s2 ON s2.id = gn.neighbor_symbol_id
ORDER BY source_name, neighbor_name, relationship_depth
"""
).fetchall()
finally:
conn.close()
triples = {(r["source_name"], r["neighbor_name"], int(r["relationship_depth"])) for r in rows}
assert ("a", "b", 1) in triples
assert ("a", "c", 2) in triples
assert ("b", "c", 1) in triples
assert ("c", "b", 1) in triples
assert file_path.exists()
def test_graph_expander_returns_related_results_with_depth_metadata(temp_paths: Path) -> None:
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
_ = project_root
expander = GraphExpander(mapper, config=Config(data_dir=temp_paths / "data", graph_expansion_depth=2))
base = SearchResult(
path=str(file_path.resolve()),
score=1.0,
excerpt="",
content=None,
start_line=1,
end_line=2,
symbol_name="a",
symbol_kind="function",
)
related = expander.expand([base], depth=2, max_expand=1, max_related=10)
depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in related}
assert depth_by_symbol.get("b") == 1
assert depth_by_symbol.get("c") == 2
def test_chain_search_populates_related_results_when_enabled(temp_paths: Path) -> None:
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
_ = file_path
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
config = Config(
data_dir=temp_paths / "data",
enable_graph_expansion=True,
graph_expansion_depth=2,
)
engine = ChainSearchEngine(registry, mapper, config=config)
try:
options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False)
result = engine.search("b", project_root, options)
assert result.results
assert result.results[0].symbol_name == "a"
depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in result.related_results}
assert depth_by_symbol.get("b") == 1
assert depth_by_symbol.get("c") == 2
finally:
engine.close()
def test_chain_search_related_results_empty_when_disabled(temp_paths: Path) -> None:
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
_ = file_path
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
config = Config(
data_dir=temp_paths / "data",
enable_graph_expansion=False,
)
engine = ChainSearchEngine(registry, mapper, config=config)
try:
options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False)
result = engine.search("b", project_root, options)
assert result.related_results == []
finally:
engine.close()

View File

@@ -869,3 +869,47 @@ class TestHybridSearchAdaptiveWeights:
) as rerank_mock:
engine_on.search(Path("dummy.db"), "query", enable_vector=True)
assert rerank_mock.call_count == 1
def test_cross_encoder_reranking_enabled(self, tmp_path):
"""Cross-encoder stage runs only when explicitly enabled via config."""
from unittest.mock import patch
results_map = {
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
}
class DummyEmbedder:
def embed(self, texts):
if isinstance(texts, str):
texts = [texts]
return [[1.0, 0.0] for _ in texts]
class DummyReranker:
def score_pairs(self, pairs, batch_size=32):
return [0.0 for _ in pairs]
config = Config(
data_dir=tmp_path / "ce",
enable_reranking=True,
enable_cross_encoder_rerank=True,
reranker_top_k=10,
)
engine = HybridSearchEngine(config=config, embedder=DummyEmbedder())
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
"codexlens.search.hybrid_search.rerank_results",
side_effect=lambda q, r, e, top_k=50: r,
) as rerank_mock, patch.object(
HybridSearchEngine,
"_get_cross_encoder_reranker",
return_value=DummyReranker(),
) as get_ce_mock, patch(
"codexlens.search.hybrid_search.cross_encoder_rerank",
side_effect=lambda q, r, ce, top_k=50: r,
) as ce_mock:
engine.search(Path("dummy.db"), "query", enable_vector=True)
assert rerank_mock.call_count == 1
assert get_ce_mock.call_count == 1
assert ce_mock.call_count == 1

View File

@@ -0,0 +1,100 @@
import time
from pathlib import Path
from codexlens.config import Config
from codexlens.storage.dir_index import DirIndexStore
def _make_merkle_config(tmp_path: Path) -> Config:
data_dir = tmp_path / "data"
return Config(
data_dir=data_dir,
venv_path=data_dir / "venv",
enable_merkle_detection=True,
)
class TestMerkleDetection:
def test_needs_reindex_touch_updates_mtime(self, tmp_path: Path) -> None:
config = _make_merkle_config(tmp_path)
source_dir = tmp_path / "src"
source_dir.mkdir(parents=True, exist_ok=True)
file_path = source_dir / "a.py"
file_path.write_text("print('hi')\n", encoding="utf-8")
original_content = file_path.read_text(encoding="utf-8")
index_db = tmp_path / "_index.db"
with DirIndexStore(index_db, config=config) as store:
store.add_file(
name=file_path.name,
full_path=file_path,
content=original_content,
language="python",
symbols=[],
)
stored_mtime_before = store.get_file_mtime(file_path)
assert stored_mtime_before is not None
# Touch file without changing content
time.sleep(0.02)
file_path.write_text(original_content, encoding="utf-8")
assert store.needs_reindex(file_path) is False
stored_mtime_after = store.get_file_mtime(file_path)
assert stored_mtime_after is not None
assert stored_mtime_after != stored_mtime_before
current_mtime = file_path.stat().st_mtime
assert abs(stored_mtime_after - current_mtime) <= 0.001
def test_parent_root_changes_when_child_changes(self, tmp_path: Path) -> None:
config = _make_merkle_config(tmp_path)
source_root = tmp_path / "project"
child_dir = source_root / "child"
child_dir.mkdir(parents=True, exist_ok=True)
child_file = child_dir / "child.py"
child_file.write_text("x = 1\n", encoding="utf-8")
child_db = tmp_path / "child_index.db"
parent_db = tmp_path / "parent_index.db"
with DirIndexStore(child_db, config=config) as child_store:
child_store.add_file(
name=child_file.name,
full_path=child_file,
content=child_file.read_text(encoding="utf-8"),
language="python",
symbols=[],
)
child_root_1 = child_store.update_merkle_root()
assert child_root_1
with DirIndexStore(parent_db, config=config) as parent_store:
parent_store.register_subdir(name="child", index_path=child_db, files_count=1)
parent_root_1 = parent_store.update_merkle_root()
assert parent_root_1
time.sleep(0.02)
child_file.write_text("x = 2\n", encoding="utf-8")
with DirIndexStore(child_db, config=config) as child_store:
child_store.add_file(
name=child_file.name,
full_path=child_file,
content=child_file.read_text(encoding="utf-8"),
language="python",
symbols=[],
)
child_root_2 = child_store.update_merkle_root()
assert child_root_2
assert child_root_2 != child_root_1
with DirIndexStore(parent_db, config=config) as parent_store:
parent_root_2 = parent_store.update_merkle_root()
assert parent_root_2
assert parent_root_2 != parent_root_1

View File

@@ -1,9 +1,11 @@
"""Tests for performance optimizations in CodexLens storage.
"""Tests for performance optimizations in CodexLens.
This module tests the following optimizations:
1. Normalized keywords search (migration_001)
2. Optimized path lookup in registry
3. Prefix-mode symbol search
4. Graph expansion neighbor precompute overhead (<20%)
5. Cross-encoder reranking latency (<200ms)
"""
import json
@@ -479,3 +481,113 @@ class TestPerformanceComparison:
print(f" Substring: {substring_time*1000:.3f}ms ({len(substring_results)} results)")
print(f" Ratio: {prefix_time/substring_time:.2f}x")
print(f" Note: Performance benefits appear with 1000+ symbols")
class TestPerformanceBenchmarks:
"""Benchmark-style assertions for key performance requirements."""
def test_graph_expansion_indexing_overhead_under_20_percent(self, temp_index_db, tmp_path):
"""Graph neighbor precompute adds <20% overhead versus indexing baseline."""
from codexlens.entities import CodeRelationship, RelationshipType, Symbol
from codexlens.storage.index_tree import _compute_graph_neighbors
store = temp_index_db
file_count = 60
symbols_per_file = 8
start = time.perf_counter()
for file_idx in range(file_count):
file_path = tmp_path / f"graph_{file_idx}.py"
lines = []
for sym_idx in range(symbols_per_file):
lines.append(f"def func_{file_idx}_{sym_idx}():")
lines.append(f" return {sym_idx}")
lines.append("")
content = "\n".join(lines)
symbols = [
Symbol(
name=f"func_{file_idx}_{sym_idx}",
kind="function",
range=(sym_idx * 3 + 1, sym_idx * 3 + 2),
file=str(file_path),
)
for sym_idx in range(symbols_per_file)
]
relationships = [
CodeRelationship(
source_symbol=f"func_{file_idx}_{sym_idx}",
target_symbol=f"func_{file_idx}_{sym_idx + 1}",
relationship_type=RelationshipType.CALL,
source_file=str(file_path),
target_file=None,
source_line=sym_idx * 3 + 2,
)
for sym_idx in range(symbols_per_file - 1)
]
store.add_file(
name=file_path.name,
full_path=file_path,
content=content,
language="python",
symbols=symbols,
relationships=relationships,
)
baseline_time = time.perf_counter() - start
durations = []
for _ in range(3):
start = time.perf_counter()
_compute_graph_neighbors(store)
durations.append(time.perf_counter() - start)
graph_time = min(durations)
# Sanity-check that the benchmark exercised graph neighbor generation.
conn = store._get_connection()
neighbor_count = conn.execute(
"SELECT COUNT(*) as c FROM graph_neighbors"
).fetchone()["c"]
assert neighbor_count > 0
assert baseline_time > 0.0
overhead_ratio = graph_time / baseline_time
assert overhead_ratio < 0.2, (
f"Graph neighbor precompute overhead too high: {overhead_ratio:.2%} "
f"(baseline={baseline_time:.3f}s, graph={graph_time:.3f}s)"
)
def test_cross_encoder_reranking_latency_under_200ms(self):
"""Cross-encoder rerank step completes under 200ms (excluding model load)."""
from codexlens.entities import SearchResult
from codexlens.search.ranking import cross_encoder_rerank
query = "find function"
results = [
SearchResult(
path=f"file_{idx}.py",
score=1.0 / (idx + 1),
excerpt=f"def func_{idx}():\n return {idx}",
symbol_name=f"func_{idx}",
symbol_kind="function",
)
for idx in range(50)
]
class DummyReranker:
def score_pairs(self, pairs, batch_size=32):
_ = batch_size
# Return deterministic pseudo-logits to exercise sigmoid normalization.
return [float(i) for i in range(len(pairs))]
reranker = DummyReranker()
start = time.perf_counter()
reranked = cross_encoder_rerank(query, results, reranker, top_k=50, batch_size=32)
elapsed_ms = (time.perf_counter() - start) * 1000.0
assert len(reranked) == len(results)
assert any(r.metadata.get("cross_encoder_reranked") for r in reranked[:50])
assert elapsed_ms < 200.0, f"Cross-encoder rerank too slow: {elapsed_ms:.1f}ms"

View File

@@ -19,7 +19,7 @@ from codexlens.entities import Symbol
class TestSchemaCleanupMigration:
"""Test schema cleanup migration (v4 -> v5)."""
"""Test schema cleanup migration (v4 -> latest)."""
def test_migration_from_v4_to_v5(self):
"""Test that migration successfully removes deprecated fields."""
@@ -129,10 +129,12 @@ class TestSchemaCleanupMigration:
# Now initialize store - this should trigger migration
store.initialize()
# Verify schema version is now 5
# Verify schema version is now the latest
conn = store._get_connection()
version_row = conn.execute("PRAGMA user_version").fetchone()
assert version_row[0] == 5, f"Expected schema version 5, got {version_row[0]}"
assert version_row[0] == DirIndexStore.SCHEMA_VERSION, (
f"Expected schema version {DirIndexStore.SCHEMA_VERSION}, got {version_row[0]}"
)
# Check that deprecated columns are removed
# 1. Check semantic_metadata doesn't have keywords column
@@ -166,7 +168,7 @@ class TestSchemaCleanupMigration:
store.close()
def test_new_database_has_clean_schema(self):
"""Test that new databases are created with clean schema (v5)."""
"""Test that new databases are created with clean schema (latest)."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
store = DirIndexStore(db_path)
@@ -174,9 +176,9 @@ class TestSchemaCleanupMigration:
conn = store._get_connection()
# Verify schema version is 5
# Verify schema version is the latest
version_row = conn.execute("PRAGMA user_version").fetchone()
assert version_row[0] == 5
assert version_row[0] == DirIndexStore.SCHEMA_VERSION
# Check that new schema doesn't have deprecated columns
cursor = conn.execute("PRAGMA table_info(semantic_metadata)")

View File

@@ -582,6 +582,7 @@ class TestChainSearchResult:
)
assert result.query == "test"
assert result.results == []
assert result.related_results == []
assert result.symbols == []
assert result.stats.dirs_searched == 0

View File

@@ -1173,6 +1173,7 @@ class TestChainSearchResultExtended:
assert result.query == "test query"
assert len(result.results) == 1
assert len(result.symbols) == 1
assert result.related_results == []
assert result.stats.dirs_searched == 5
def test_result_with_empty_collections(self):
@@ -1186,5 +1187,6 @@ class TestChainSearchResultExtended:
assert result.query == "no matches"
assert result.results == []
assert result.related_results == []
assert result.symbols == []
assert result.stats.dirs_searched == 0

View File

@@ -110,6 +110,37 @@ class DataProcessor:
assert result is not None
assert len(result.symbols) == 0
def test_extracts_relationships_with_alias_resolution(self):
parser = TreeSitterSymbolParser("python")
code = """
import os.path as osp
from math import sqrt as sq
class Base:
pass
class Child(Base):
pass
def main():
osp.join("a", "b")
sq(4)
"""
result = parser.parse(code, Path("test.py"))
assert result is not None
rels = [r for r in result.relationships if r.source_symbol == "main"]
targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"}
assert "os.path.join" in targets
assert "math.sqrt" in targets
inherits = [
r for r in result.relationships
if r.source_symbol == "Child" and r.relationship_type.value == "inherits"
]
assert any(r.target_symbol == "Base" for r in inherits)
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
class TestTreeSitterJavaScriptParser:
@@ -175,6 +206,22 @@ export const arrowFunc = () => {}
assert "exported" in names
assert "arrowFunc" in names
def test_extracts_relationships_with_import_alias(self):
parser = TreeSitterSymbolParser("javascript")
code = """
import { readFile as rf } from "fs";
function main() {
rf("a");
}
"""
result = parser.parse(code, Path("test.js"))
assert result is not None
rels = [r for r in result.relationships if r.source_symbol == "main"]
targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"}
assert "fs.readFile" in targets
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
class TestTreeSitterTypeScriptParser: