Add graph expansion and cross-encoder reranking features

- Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors.
- Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring.
- Created migrations to establish necessary database tables for relationships and graph neighbors.
- Developed tests for graph expansion functionality, ensuring related results are populated correctly.
- Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead.
- Updated schema cleanup tests to reflect changes in versioning and deprecated fields.
- Added new test cases for Treesitter parser to validate relationship extraction with alias resolution.
This commit is contained in:
catlog22
2025-12-31 16:58:59 +08:00
parent 4bde13e83a
commit 31a45f1f30
27 changed files with 2566 additions and 97 deletions

View File

@@ -13,7 +13,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Protocol
from codexlens.config import Config
from codexlens.entities import IndexedFile, Symbol
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
@@ -30,36 +30,39 @@ class SimpleRegexParser:
if self.language_id in {"python", "javascript", "typescript"}:
ts_parser = TreeSitterSymbolParser(self.language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
)
indexed = ts_parser.parse(text, path)
if indexed is not None:
return indexed
# Fallback to regex parsing
if self.language_id == "python":
symbols = _parse_python_symbols_regex(text)
relationships = _parse_python_relationships_regex(text, path)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols_regex(text)
relationships = _parse_js_ts_relationships_regex(text, path)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
relationships = []
elif self.language_id == "go":
symbols = _parse_go_symbols(text)
relationships = []
elif self.language_id == "markdown":
symbols = _parse_markdown_symbols(text)
relationships = []
elif self.language_id == "text":
symbols = _parse_text_symbols(text)
relationships = []
else:
symbols = _parse_generic_symbols(text)
relationships = []
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
@@ -78,6 +81,9 @@ class ParserFactory:
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
@@ -127,12 +133,81 @@ def _parse_python_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
current_scope = def_match.group(1)
continue
if current_scope is None:
continue
import_match = _PY_IMPORT_RE.search(line)
if import_match:
import_target = import_match.group(1) or import_match.group(2)
if import_target:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_target.strip(),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _PY_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {
"if",
"for",
"while",
"return",
"print",
"len",
"str",
"int",
"float",
"list",
"dict",
"set",
"tuple",
current_scope,
}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
)
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
@@ -174,6 +249,61 @@ def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _JS_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
func_match = _JS_FUNC_RE.match(line)
if func_match:
current_scope = func_match.group(1)
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
current_scope = arrow_match.group(1)
continue
if current_scope is None:
continue
import_match = _JS_IMPORT_RE.search(line)
if import_match:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_match.group(1),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _JS_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {current_scope}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
@@ -253,4 +383,3 @@ def _parse_text_symbols(text: str) -> List[Symbol]:
# Text files don't have structured symbols, return empty list
# The file content will still be indexed for FTS search
return []

View File

@@ -1,12 +1,17 @@
"""Tree-sitter based parser for CodexLens.
Provides precise AST-level parsing with fallback to regex-based parsing.
Provides precise AST-level parsing via tree-sitter.
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
return `None`; callers should use a regex-based fallback such as
`codexlens.parsers.factory.SimpleRegexParser`.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional
try:
from tree_sitter import Language as TreeSitterLanguage
@@ -19,7 +24,7 @@ except ImportError:
TreeSitterParser = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import IndexedFile, Symbol
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@@ -85,6 +90,16 @@ class TreeSitterSymbolParser:
"""
return self._parser is not None and self._language is not None
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
if not self.is_available() or self._parser is None:
return None
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
return source_bytes, tree.root_node
except Exception:
return None
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
"""Parse source code and extract symbols without creating IndexedFile.
@@ -95,17 +110,15 @@ class TreeSitterSymbolParser:
Returns:
List of symbols if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
return self._extract_symbols(source_bytes, root)
except Exception:
# Gracefully handle parsing errors
# Gracefully handle extraction errors
return None
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
@@ -118,19 +131,21 @@ class TreeSitterSymbolParser:
Returns:
IndexedFile if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
symbols = self.parse_symbols(text)
if symbols is None:
return None
symbols = self._extract_symbols(source_bytes, root)
relationships = self._extract_relationships(source_bytes, root, path)
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
except Exception:
# Gracefully handle parsing errors
@@ -153,6 +168,465 @@ class TreeSitterSymbolParser:
else:
return []
def _extract_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, path)
if self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, path)
return []
def _extract_python_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"self", "cls"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "class_definition" and pushed_scope:
superclasses = node.child_by_field_name("superclasses")
if superclasses is not None:
for child in superclasses.children:
dotted = self._python_expression_to_dotted(source_bytes, child)
if not dotted:
continue
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type in {"import_statement", "import_from_statement"}:
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
if node.type == "call":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _extract_js_ts_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"this", "super"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if pushed_scope:
superclass = node.child_by_field_name("superclass")
if superclass is not None:
dotted = self._js_expression_to_dotted(source_bytes, superclass)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type in {"identifier", "property_identifier"}
and value_node.type == "arrow_function"
):
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name and scope_name != "constructor":
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"import_declaration", "import_statement"}:
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
# Best-effort support for CommonJS require() imports:
# const fs = require("fs")
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type == "identifier"
and value_node.type == "call_expression"
):
callee = value_node.child_by_field_name("function")
args = value_node.child_by_field_name("arguments")
if (
callee is not None
and self._node_text(source_bytes, callee).strip() == "require"
and args is not None
):
module_name = self._js_first_string_argument(source_bytes, args)
if module_name:
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
record_import(module_name, self._node_start_line(node))
if node.type == "call_expression":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _node_start_line(self, node: TreeSitterNode) -> int:
return node.start_point[0] + 1
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
dotted = (dotted or "").strip()
if not dotted:
return ""
base, sep, rest = dotted.partition(".")
resolved_base = aliases.get(base, base)
if not rest:
return resolved_base
if resolved_base and rest:
return f"{resolved_base}.{rest}"
return resolved_base
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"identifier", "dotted_name"}:
return self._node_text(source_bytes, node).strip()
if node.type == "attribute":
obj = node.child_by_field_name("object")
attr = node.child_by_field_name("attribute")
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
if obj_text and attr_text:
return f"{obj_text}.{attr_text}"
return obj_text or attr_text
return ""
def _python_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
if node.type == "import_statement":
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
module_name = self._node_text(source_bytes, name_node).strip()
if not module_name:
continue
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else module_name.split(".", 1)[0]
)
if bound_name:
aliases[bound_name] = module_name
targets.append(module_name)
elif child.type == "dotted_name":
module_name = self._node_text(source_bytes, child).strip()
if not module_name:
continue
bound_name = module_name.split(".", 1)[0]
if bound_name:
aliases[bound_name] = bound_name
targets.append(module_name)
if node.type == "import_from_statement":
module_name = ""
module_node = node.child_by_field_name("module_name")
if module_node is None:
for child in node.children:
if child.type == "dotted_name":
module_node = child
break
if module_node is not None:
module_name = self._node_text(source_bytes, module_node).strip()
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
imported_name = self._node_text(source_bytes, name_node).strip()
if not imported_name or imported_name == "*":
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported_name
)
if bound_name:
aliases[bound_name] = target
targets.append(target)
elif child.type == "identifier":
imported_name = self._node_text(source_bytes, child).strip()
if not imported_name or imported_name in {"from", "import", "*"}:
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
aliases[imported_name] = target
targets.append(target)
return aliases, targets
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"this", "super"}:
return node.type
if node.type in {"identifier", "property_identifier"}:
return self._node_text(source_bytes, node).strip()
if node.type == "member_expression":
obj = node.child_by_field_name("object")
prop = node.child_by_field_name("property")
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
if obj_text and prop_text:
return f"{obj_text}.{prop_text}"
return obj_text or prop_text
return ""
def _js_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
module_name = ""
source_node = node.child_by_field_name("source")
if source_node is not None:
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
if module_name:
targets.append(module_name)
for child in node.children:
if child.type == "import_clause":
for clause_child in child.children:
if clause_child.type == "identifier":
# Default import: import React from "react"
local = self._node_text(source_bytes, clause_child).strip()
if local and module_name:
aliases[local] = module_name
if clause_child.type == "namespace_import":
# Namespace import: import * as fs from "fs"
name_node = clause_child.child_by_field_name("name")
if name_node is not None and module_name:
local = self._node_text(source_bytes, name_node).strip()
if local:
aliases[local] = module_name
if clause_child.type == "named_imports":
for spec in clause_child.children:
if spec.type != "import_specifier":
continue
name_node = spec.child_by_field_name("name")
alias_node = spec.child_by_field_name("alias")
if name_node is None:
continue
imported = self._node_text(source_bytes, name_node).strip()
if not imported:
continue
local = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported
)
if local and module_name:
aliases[local] = f"{module_name}.{imported}"
targets.append(f"{module_name}.{imported}")
return aliases, targets
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
for child in args_node.children:
if child.type == "string":
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
return ""
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract Python symbols from AST.