feat(codexlens): add staged settings for advanced configuration and update related components

- Added new staged settings in config.py for coarse_k, lsp_depth, stage2_mode, and clustering strategy.
- Updated config-handlers.ts to handle new staged settings and map environment variables.
- Enhanced codexlens.json localization files for English and Chinese to include new staged settings.
- Modified astgrep_js_ts_processor.py to improve import handling for named imports.
- Updated JavaScript and TypeScript patterns to support new import formats.
- Added tests for staged settings loading and performance benchmarks for stage-2 expansion.
This commit is contained in:
catlog22
2026-02-18 13:05:35 +08:00
parent 265a77d6e7
commit d6e282b5a9
12 changed files with 618 additions and 78 deletions

View File

@@ -318,6 +318,21 @@ class Config:
"coarse_k": self.cascade_coarse_k,
"fine_k": self.cascade_fine_k,
},
"staged": {
"coarse_k": self.staged_coarse_k,
"lsp_depth": self.staged_lsp_depth,
"stage2_mode": self.staged_stage2_mode,
"realtime_lsp_timeout_s": self.staged_realtime_lsp_timeout_s,
"realtime_lsp_depth": self.staged_realtime_lsp_depth,
"realtime_lsp_max_nodes": self.staged_realtime_lsp_max_nodes,
"realtime_lsp_max_seeds": self.staged_realtime_lsp_max_seeds,
"realtime_lsp_max_concurrent": self.staged_realtime_lsp_max_concurrent,
"realtime_lsp_warmup_s": self.staged_realtime_lsp_warmup_s,
"realtime_lsp_resolve_symbols": self.staged_realtime_lsp_resolve_symbols,
"clustering_strategy": self.staged_clustering_strategy,
"clustering_min_size": self.staged_clustering_min_size,
"enable_rerank": self.enable_staged_rerank,
},
"api": {
"max_workers": self.api_max_workers,
"batch_size": self.api_batch_size,
@@ -426,6 +441,174 @@ class Config:
if "fine_k" in cascade:
self.cascade_fine_k = cascade["fine_k"]
# Load staged cascade settings
staged = settings.get("staged", {})
if isinstance(staged, dict):
if "coarse_k" in staged:
try:
self.staged_coarse_k = int(staged["coarse_k"])
except (TypeError, ValueError):
log.warning(
"Invalid staged.coarse_k in %s: %r (expected int)",
self.settings_path,
staged["coarse_k"],
)
if "lsp_depth" in staged:
try:
self.staged_lsp_depth = int(staged["lsp_depth"])
except (TypeError, ValueError):
log.warning(
"Invalid staged.lsp_depth in %s: %r (expected int)",
self.settings_path,
staged["lsp_depth"],
)
if "stage2_mode" in staged:
raw_mode = str(staged["stage2_mode"]).strip().lower()
if raw_mode in {"precomputed", "realtime", "static_global_graph"}:
self.staged_stage2_mode = raw_mode
elif raw_mode in {"live"}:
self.staged_stage2_mode = "realtime"
else:
log.warning(
"Invalid staged.stage2_mode in %s: %r "
"(expected 'precomputed', 'realtime', or 'static_global_graph')",
self.settings_path,
staged["stage2_mode"],
)
if "realtime_lsp_timeout_s" in staged:
try:
self.staged_realtime_lsp_timeout_s = float(
staged["realtime_lsp_timeout_s"]
)
except (TypeError, ValueError):
log.warning(
"Invalid staged.realtime_lsp_timeout_s in %s: %r (expected float)",
self.settings_path,
staged["realtime_lsp_timeout_s"],
)
if "realtime_lsp_depth" in staged:
try:
self.staged_realtime_lsp_depth = int(
staged["realtime_lsp_depth"]
)
except (TypeError, ValueError):
log.warning(
"Invalid staged.realtime_lsp_depth in %s: %r (expected int)",
self.settings_path,
staged["realtime_lsp_depth"],
)
if "realtime_lsp_max_nodes" in staged:
try:
self.staged_realtime_lsp_max_nodes = int(
staged["realtime_lsp_max_nodes"]
)
except (TypeError, ValueError):
log.warning(
"Invalid staged.realtime_lsp_max_nodes in %s: %r (expected int)",
self.settings_path,
staged["realtime_lsp_max_nodes"],
)
if "realtime_lsp_max_seeds" in staged:
try:
self.staged_realtime_lsp_max_seeds = int(
staged["realtime_lsp_max_seeds"]
)
except (TypeError, ValueError):
log.warning(
"Invalid staged.realtime_lsp_max_seeds in %s: %r (expected int)",
self.settings_path,
staged["realtime_lsp_max_seeds"],
)
if "realtime_lsp_max_concurrent" in staged:
try:
self.staged_realtime_lsp_max_concurrent = int(
staged["realtime_lsp_max_concurrent"]
)
except (TypeError, ValueError):
log.warning(
"Invalid staged.realtime_lsp_max_concurrent in %s: %r (expected int)",
self.settings_path,
staged["realtime_lsp_max_concurrent"],
)
if "realtime_lsp_warmup_s" in staged:
try:
self.staged_realtime_lsp_warmup_s = float(
staged["realtime_lsp_warmup_s"]
)
except (TypeError, ValueError):
log.warning(
"Invalid staged.realtime_lsp_warmup_s in %s: %r (expected float)",
self.settings_path,
staged["realtime_lsp_warmup_s"],
)
if "realtime_lsp_resolve_symbols" in staged:
raw = staged["realtime_lsp_resolve_symbols"]
if isinstance(raw, bool):
self.staged_realtime_lsp_resolve_symbols = raw
elif isinstance(raw, (int, float)):
self.staged_realtime_lsp_resolve_symbols = bool(raw)
elif isinstance(raw, str):
self.staged_realtime_lsp_resolve_symbols = (
raw.strip().lower() in {"true", "1", "yes", "on"}
)
else:
log.warning(
"Invalid staged.realtime_lsp_resolve_symbols in %s: %r (expected bool)",
self.settings_path,
raw,
)
if "clustering_strategy" in staged:
raw_strategy = str(staged["clustering_strategy"]).strip().lower()
allowed = {
"auto",
"hdbscan",
"dbscan",
"frequency",
"noop",
"score",
"dir_rr",
"path",
}
if raw_strategy in allowed:
self.staged_clustering_strategy = raw_strategy
elif raw_strategy in {"none", "off"}:
self.staged_clustering_strategy = "noop"
else:
log.warning(
"Invalid staged.clustering_strategy in %s: %r",
self.settings_path,
staged["clustering_strategy"],
)
if "clustering_min_size" in staged:
try:
self.staged_clustering_min_size = int(
staged["clustering_min_size"]
)
except (TypeError, ValueError):
log.warning(
"Invalid staged.clustering_min_size in %s: %r (expected int)",
self.settings_path,
staged["clustering_min_size"],
)
if "enable_rerank" in staged:
raw = staged["enable_rerank"]
if isinstance(raw, bool):
self.enable_staged_rerank = raw
elif isinstance(raw, (int, float)):
self.enable_staged_rerank = bool(raw)
elif isinstance(raw, str):
self.enable_staged_rerank = (
raw.strip().lower() in {"true", "1", "yes", "on"}
)
else:
log.warning(
"Invalid staged.enable_rerank in %s: %r (expected bool)",
self.settings_path,
raw,
)
# Load parsing settings
parsing = settings.get("parsing", {})
if isinstance(parsing, dict) and "use_astgrep" in parsing:

View File

@@ -50,7 +50,7 @@ ENV_VARS = {
"CASCADE_STRATEGY": "Cascade strategy: binary, binary_rerank (alias: hybrid), dense_rerank, staged",
"CASCADE_COARSE_K": "Cascade coarse_k candidate count (int)",
"CASCADE_FINE_K": "Cascade fine_k result count (int)",
"STAGED_STAGE2_MODE": "Staged Stage 2 mode: precomputed, realtime",
"STAGED_STAGE2_MODE": "Staged Stage 2 mode: precomputed, realtime, static_global_graph",
"STAGED_CLUSTERING_STRATEGY": "Staged clustering strategy: auto, score, path, dir_rr, noop, ...",
"STAGED_CLUSTERING_MIN_SIZE": "Staged clustering min cluster size (int)",
"ENABLE_STAGED_RERANK": "Enable staged reranking in Stage 4 (true/false)",

View File

@@ -12,13 +12,17 @@ from __future__ import annotations
import re
from pathlib import Path
from typing import Callable, Iterable, List, Optional, Sequence, Set, Tuple
from typing import Callable, List, Optional, Sequence, Set, Tuple
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType
from codexlens.parsers.astgrep_processor import BaseAstGrepProcessor
_IDENT_RE = re.compile(r"^[A-Za-z_$][A-Za-z0-9_$]*$")
_BRACE_IMPORT_RE = re.compile(
r"\bimport\s+(?:type\s+)?(?:[A-Za-z_$][A-Za-z0-9_$]*\s*,\s*)?\{\s*(?P<names>[^}]*)\}\s*from\b",
re.MULTILINE,
)
def _strip_quotes(value: str) -> str:
@@ -32,11 +36,7 @@ def _module_from_literal(raw: str) -> str:
raw = (raw or "").strip()
if not raw:
return ""
unquoted = _strip_quotes(raw)
# Only accept string literal forms (tree-sitter extractor does the same).
if unquoted == raw:
return ""
return unquoted.strip()
return _strip_quotes(raw).strip()
def _extract_named_imports(raw: str) -> List[str]:
@@ -63,6 +63,16 @@ def _extract_named_imports(raw: str) -> List[str]:
return names
def _extract_brace_import_names(statement: str) -> str:
statement = (statement or "").strip()
if not statement:
return ""
match = _BRACE_IMPORT_RE.search(statement)
if not match:
return ""
return (match.group("names") or "").strip()
def _dedupe_relationships(rels: Sequence[CodeRelationship]) -> List[CodeRelationship]:
seen: Set[Tuple[str, str, str]] = set()
out: List[CodeRelationship] = []
@@ -139,40 +149,45 @@ class _AstGrepJsTsProcessor(BaseAstGrepProcessor):
)
# Any `import ... from "mod"` form
for node in self.run_ast_grep(source_code, self._get_pattern("import_from")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
for pat_name in ("import_from_dq", "import_from_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
# Side-effect import: import "mod"
for node in self.run_ast_grep(source_code, self._get_pattern("import_side_effect")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
for pat_name in ("import_side_effect_dq", "import_side_effect_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
# Named imports (named-only): import { a, b as c } from "mod"
for node in self.run_ast_grep(source_code, self._get_pattern("import_named_only")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = self._get_match(node, "NAMES")
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
for pat_name in ("import_named_only_dq", "import_named_only_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = _extract_brace_import_names(self._get_node_text(node))
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
# Named imports (default + named): import X, { a, b as c } from "mod"
for node in self.run_ast_grep(source_code, self._get_pattern("import_default_named")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = self._get_match(node, "NAMES")
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
for pat_name in ("import_default_named_dq", "import_default_named_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = _extract_brace_import_names(self._get_node_text(node))
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
# CommonJS require("mod") (string literal only)
for node in self.run_ast_grep(source_code, self._get_pattern("require_call")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
for pat_name in ("require_call_dq", "require_call_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
return rels
@@ -258,26 +273,29 @@ class AstGrepTypeScriptProcessor(_AstGrepJsTsProcessor):
)
# Type-only imports: import type ... from "mod"
for node in self.run_ast_grep(source_code, self._get_pattern("import_type_from")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
for pat_name in ("import_type_from_dq", "import_type_from_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if mod:
record(mod, self._get_line_number(node))
for node in self.run_ast_grep(source_code, self._get_pattern("import_type_named_only")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = self._get_match(node, "NAMES")
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
for pat_name in ("import_type_named_only_dq", "import_type_named_only_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = _extract_brace_import_names(self._get_node_text(node))
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
for node in self.run_ast_grep(source_code, self._get_pattern("import_type_default_named")):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = self._get_match(node, "NAMES")
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
for pat_name in ("import_type_default_named_dq", "import_type_default_named_sq"):
for node in self.run_ast_grep(source_code, self._get_pattern(pat_name)):
mod = _module_from_literal(self._get_match(node, "MODULE"))
if not mod:
continue
raw_names = _extract_brace_import_names(self._get_node_text(node))
for name in _extract_named_imports(raw_names):
record(f"{mod}.{name}", self._get_line_number(node))
return _dedupe_relationships(rels)
@@ -286,4 +304,3 @@ __all__ = [
"AstGrepJavaScriptProcessor",
"AstGrepTypeScriptProcessor",
]

View File

@@ -85,6 +85,30 @@ class BaseAstGrepProcessor(ABC):
return self._binding.find_all(pattern)
def _get_match(self, node: SgNode, metavar: str) -> str: # type: ignore[valid-type]
"""Extract matched metavariable value from node (best-effort)."""
if self._binding is None or node is None:
return ""
return self._binding._get_match(node, metavar)
def _get_line_number(self, node: SgNode) -> int: # type: ignore[valid-type]
"""Get 1-based starting line number of a node (best-effort)."""
if self._binding is None or node is None:
return 0
return self._binding._get_line_number(node)
def _get_line_range(self, node: SgNode) -> Tuple[int, int]: # type: ignore[valid-type]
"""Get (start_line, end_line) range of a node (best-effort)."""
if self._binding is None or node is None:
return (0, 0)
return self._binding._get_line_range(node)
def _get_node_text(self, node: SgNode) -> str: # type: ignore[valid-type]
"""Get the full text of a node (best-effort)."""
if self._binding is None or node is None:
return ""
return self._binding._get_node_text(node)
@abstractmethod
def process_matches(
self,

View File

@@ -20,17 +20,23 @@ JAVASCRIPT_PATTERNS: Dict[str, str] = {
# import React, { useEffect } from "react"
# import { useEffect } from "react"
# import * as fs from "fs"
"import_from": "import $$$IMPORTS from $MODULE",
"import_named_only": "import {$$$NAMES} from $MODULE",
"import_default_named": "import $DEFAULT, {$$$NAMES} from $MODULE",
"import_from_dq": "import $$$IMPORTS from \"$MODULE\"",
"import_from_sq": "import $$$IMPORTS from '$MODULE'",
"import_named_only_dq": "import {$$$NAMES} from \"$MODULE\"",
"import_named_only_sq": "import {$$$NAMES} from '$MODULE'",
"import_default_named_dq": "import $DEFAULT, {$$$NAMES} from \"$MODULE\"",
"import_default_named_sq": "import $DEFAULT, {$$$NAMES} from '$MODULE'",
# Side-effect import: import "./styles.css"
"import_side_effect": "import $MODULE",
"import_side_effect_dq": "import \"$MODULE\"",
"import_side_effect_sq": "import '$MODULE'",
# CommonJS require(): const fs = require("fs")
"require_call": "require($MODULE)",
"require_call_dq": "require(\"$MODULE\")",
"require_call_sq": "require('$MODULE')",
# Class inheritance: class Child extends Base {}
"class_extends": "class $NAME extends $BASE $$$BODY",
# Note: `{...}` form matches both JS and TS grammars more reliably.
"class_extends": "class $NAME extends $BASE {$$$BODY}",
}
@@ -45,11 +51,16 @@ METAVARS = {
RELATIONSHIP_PATTERNS: Dict[str, List[str]] = {
"imports": [
"import_from",
"import_named_only",
"import_default_named",
"import_side_effect",
"require_call",
"import_from_dq",
"import_from_sq",
"import_named_only_dq",
"import_named_only_sq",
"import_default_named_dq",
"import_default_named_sq",
"import_side_effect_dq",
"import_side_effect_sq",
"require_call_dq",
"require_call_sq",
],
"inheritance": ["class_extends"],
}
@@ -79,4 +90,3 @@ __all__ = [
"get_patterns_for_relationship",
"get_metavar",
]

View File

@@ -18,9 +18,12 @@ from codexlens.parsers.patterns.javascript import (
TYPESCRIPT_PATTERNS: Dict[str, str] = {
**JAVASCRIPT_PATTERNS,
# Type-only imports
"import_type_from": "import type $$$IMPORTS from $MODULE",
"import_type_named_only": "import type {$$$NAMES} from $MODULE",
"import_type_default_named": "import type $DEFAULT, {$$$NAMES} from $MODULE",
"import_type_from_dq": "import type $$$IMPORTS from \"$MODULE\"",
"import_type_from_sq": "import type $$$IMPORTS from '$MODULE'",
"import_type_named_only_dq": "import type {$$$NAMES} from \"$MODULE\"",
"import_type_named_only_sq": "import type {$$$NAMES} from '$MODULE'",
"import_type_default_named_dq": "import type $DEFAULT, {$$$NAMES} from \"$MODULE\"",
"import_type_default_named_sq": "import type $DEFAULT, {$$$NAMES} from '$MODULE'",
# Interface inheritance: interface Foo extends Bar {}
"interface_extends": "interface $NAME extends $BASE $$$BODY",
}
@@ -30,9 +33,12 @@ RELATIONSHIP_PATTERNS: Dict[str, List[str]] = {
**_JS_RELATIONSHIP_PATTERNS,
"imports": [
*_JS_RELATIONSHIP_PATTERNS.get("imports", []),
"import_type_from",
"import_type_named_only",
"import_type_default_named",
"import_type_from_dq",
"import_type_from_sq",
"import_type_named_only_dq",
"import_type_named_only_sq",
"import_type_default_named_dq",
"import_type_default_named_sq",
],
"inheritance": [
*_JS_RELATIONSHIP_PATTERNS.get("inheritance", []),
@@ -65,4 +71,3 @@ __all__ = [
"get_patterns_for_relationship",
"get_metavar",
]

View File

@@ -83,16 +83,27 @@ def test_js_imports_and_inherits_match(tmp_path: Path) -> None:
assert result_ts is not None
assert result_ast is not None
ts_rel = extract_relationship_tuples(
ts_imports = extract_relationship_tuples(
result_ts.relationships,
only_types={RelationshipType.IMPORTS, RelationshipType.INHERITS},
only_types={RelationshipType.IMPORTS},
)
ast_rel = extract_relationship_tuples(
ast_imports = extract_relationship_tuples(
result_ast.relationships,
only_types={RelationshipType.IMPORTS, RelationshipType.INHERITS},
only_types={RelationshipType.IMPORTS},
)
assert ast_imports == ts_imports
assert ast_rel == ts_rel
ts_inherits = extract_relationship_tuples(
result_ts.relationships,
only_types={RelationshipType.INHERITS},
)
ast_inherits = extract_relationship_tuples(
result_ast.relationships,
only_types={RelationshipType.INHERITS},
)
# Ast-grep may include inheritance edges that the tree-sitter extractor does not currently emit.
assert ts_inherits.issubset(ast_inherits)
assert ("Child", "Base", "inherits") in ast_inherits
def test_ts_imports_match_and_inherits_superset(tmp_path: Path) -> None:
@@ -137,4 +148,3 @@ def test_ts_imports_match_and_inherits_superset(tmp_path: Path) -> None:
assert ts_inherits.issubset(ast_inherits)
# But at minimum, class inheritance should be present.
assert ("Child", "Base", "inherits") in ast_inherits

View File

@@ -104,8 +104,52 @@ class TestConfigCascadeDefaults:
config = Config(data_dir=temp_config_dir)
assert config.staged_coarse_k == 200
assert config.staged_lsp_depth == 2
assert config.staged_stage2_mode == "precomputed"
assert config.staged_clustering_strategy == "auto"
assert config.staged_clustering_min_size == 3
assert config.enable_staged_rerank is True
assert config.cascade_coarse_k == 100
assert config.cascade_fine_k == 10
def test_staged_settings_load_from_settings_json(self, temp_config_dir):
"""load_settings should load staged.* settings when present."""
config = Config(data_dir=temp_config_dir)
settings = {
"staged": {
"coarse_k": 250,
"lsp_depth": 3,
"stage2_mode": "static_global_graph",
"realtime_lsp_timeout_s": 11.0,
"realtime_lsp_depth": 2,
"realtime_lsp_max_nodes": 42,
"realtime_lsp_max_seeds": 2,
"realtime_lsp_max_concurrent": 4,
"realtime_lsp_warmup_s": 0.5,
"realtime_lsp_resolve_symbols": True,
"clustering_strategy": "path",
"clustering_min_size": 7,
"enable_rerank": False,
}
}
settings_path = config.settings_path
settings_path.parent.mkdir(parents=True, exist_ok=True)
with open(settings_path, "w", encoding="utf-8") as f:
json.dump(settings, f)
with patch.object(config, "_apply_env_overrides"):
config.load_settings()
assert config.staged_coarse_k == 250
assert config.staged_lsp_depth == 3
assert config.staged_stage2_mode == "static_global_graph"
assert config.staged_realtime_lsp_timeout_s == 11.0
assert config.staged_realtime_lsp_depth == 2
assert config.staged_realtime_lsp_max_nodes == 42
assert config.staged_realtime_lsp_max_seeds == 2
assert config.staged_realtime_lsp_max_concurrent == 4
assert config.staged_realtime_lsp_warmup_s == 0.5
assert config.staged_realtime_lsp_resolve_symbols is True
assert config.staged_clustering_strategy == "path"
assert config.staged_clustering_min_size == 7
assert config.enable_staged_rerank is False

View File

@@ -559,6 +559,227 @@ class TestPerformanceBenchmarks:
f"(baseline={baseline_time:.3f}s, graph={graph_time:.3f}s)"
)
def test_stage2_expansion_precomputed_vs_static_global_graph_benchmark(self, tmp_path):
"""Benchmark Stage-2 expansion: precomputed graph_neighbors vs static global graph.
This test is informational (prints timings) and asserts only correctness
and that both expanders return some related results.
"""
from codexlens.entities import CodeRelationship, RelationshipType, SearchResult, Symbol
from codexlens.search.graph_expander import GraphExpander
from codexlens.search.global_graph_expander import GlobalGraphExpander
from codexlens.storage.dir_index import DirIndexStore
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.index_tree import _compute_graph_neighbors
from codexlens.storage.path_mapper import PathMapper
# Source + index roots
source_dir = tmp_path / "proj" / "src"
source_dir.mkdir(parents=True, exist_ok=True)
mapper = PathMapper(index_root=tmp_path / "indexes")
index_db_path = mapper.source_to_index_db(source_dir)
index_db_path.parent.mkdir(parents=True, exist_ok=True)
store = DirIndexStore(index_db_path)
store.initialize()
file_count = 30
per_file_symbols = 2
file_paths = []
per_file_symbols_list = []
per_file_relationships_list = []
for i in range(file_count):
file_path = source_dir / f"m{i}.py"
file_paths.append(file_path)
file_path.write_text("pass\n", encoding="utf-8")
symbols = [
Symbol(
name=f"func_{i}_{j}",
kind="function",
range=(j + 1, j + 1),
file=str(file_path.resolve()),
)
for j in range(per_file_symbols)
]
per_file_symbols_list.append(symbols)
relationships: list[CodeRelationship] = []
# Intra-file edge: func_i_0 -> func_i_1
relationships.append(
CodeRelationship(
source_symbol=f"func_{i}_0",
target_symbol=f"func_{i}_1",
relationship_type=RelationshipType.CALL,
source_file=str(file_path.resolve()),
target_file=str(file_path.resolve()),
source_line=1,
)
)
# Cross-file edge: func_i_0 -> func_(i+1)_0 (name-unique across dir)
j = (i + 1) % file_count
relationships.append(
CodeRelationship(
source_symbol=f"func_{i}_0",
target_symbol=f"func_{j}_0",
relationship_type=RelationshipType.CALL,
source_file=str(file_path.resolve()),
target_file=str((source_dir / f"m{j}.py").resolve()),
source_line=1,
)
)
per_file_relationships_list.append(relationships)
store.add_file(
name=file_path.name,
full_path=file_path,
content="pass\n",
language="python",
symbols=symbols,
relationships=relationships,
)
# Precompute graph_neighbors for GraphExpander (precomputed Stage-2 build)
start = time.perf_counter()
_compute_graph_neighbors(store)
graph_build_ms = (time.perf_counter() - start) * 1000.0
store.close()
# Build global symbol index + relationships for GlobalGraphExpander
global_db_path = index_db_path.parent / GlobalSymbolIndex.DEFAULT_DB_NAME
global_index = GlobalSymbolIndex(global_db_path, project_id=1)
global_index.initialize()
try:
index_path_str = str(index_db_path.resolve())
start = time.perf_counter()
for file_path, symbols in zip(file_paths, per_file_symbols_list):
file_path_str = str(file_path.resolve())
global_index.update_file_symbols(
file_path_str,
symbols,
index_path=index_path_str,
)
global_symbols_ms = (time.perf_counter() - start) * 1000.0
start = time.perf_counter()
for file_path, relationships in zip(file_paths, per_file_relationships_list):
file_path_str = str(file_path.resolve())
global_index.update_file_relationships(file_path_str, relationships)
global_relationships_ms = (time.perf_counter() - start) * 1000.0
base_results = [
SearchResult(
path=str(file_paths[i].resolve()),
score=1.0,
excerpt=None,
content=None,
start_line=1,
end_line=1,
symbol_name=f"func_{i}_0",
symbol_kind="function",
)
for i in range(min(10, file_count))
]
pre_expander = GraphExpander(mapper)
static_expander = GlobalGraphExpander(global_index)
start = time.perf_counter()
pre_related = pre_expander.expand(
base_results,
depth=2,
max_expand=10,
max_related=50,
)
pre_ms = (time.perf_counter() - start) * 1000.0
start = time.perf_counter()
static_related = static_expander.expand(
base_results,
top_n=10,
max_related=50,
)
static_ms = (time.perf_counter() - start) * 1000.0
assert pre_related, "Expected precomputed graph expansion to return related results"
assert static_related, "Expected static global graph expansion to return related results"
print("\nStage-2 build benchmark (30 files, 2 symbols/file):")
print(f" graph_neighbors precompute: {graph_build_ms:.2f}ms")
print(f" global_symbols write: {global_symbols_ms:.2f}ms")
print(f" global_relationships write: {global_relationships_ms:.2f}ms")
print("\nStage-2 expansion benchmark (30 files, 2 symbols/file):")
print(f" precomputed (graph_neighbors): {pre_ms:.2f}ms, related={len(pre_related)}")
print(f" static_global_graph: {static_ms:.2f}ms, related={len(static_related)}")
finally:
global_index.close()
def test_relationship_extraction_astgrep_vs_treesitter_benchmark(self, tmp_path):
"""Informational benchmark: relationship extraction via ast-grep vs tree-sitter.
Skips when optional parser dependencies are unavailable.
"""
import textwrap
from codexlens.config import Config
from codexlens.parsers.astgrep_processor import is_astgrep_processor_available
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
if not is_astgrep_processor_available():
pytest.skip("ast-grep processor unavailable (optional dependency)")
code = textwrap.dedent(
"""
import os
from typing import List
class Base:
pass
class Child(Base):
def method(self) -> List[str]:
return [os.path.join("a", "b")]
"""
).lstrip()
file_path = tmp_path / "sample.py"
file_path.write_text(code, encoding="utf-8")
cfg_ts = Config(data_dir=tmp_path / "cfg_ts")
cfg_ts.use_astgrep = False
ts_parser = TreeSitterSymbolParser("python", file_path, config=cfg_ts)
if not ts_parser.is_available():
pytest.skip("tree-sitter python binding unavailable")
cfg_ag = Config(data_dir=tmp_path / "cfg_ag")
cfg_ag.use_astgrep = True
ag_parser = TreeSitterSymbolParser("python", file_path, config=cfg_ag)
if getattr(ag_parser, "_astgrep_processor", None) is None:
pytest.skip("ast-grep processor failed to initialize")
def _bench(parser: TreeSitterSymbolParser) -> tuple[float, int]:
durations = []
rel_counts = []
for _ in range(3):
start = time.perf_counter()
indexed = parser.parse(code, file_path)
durations.append(time.perf_counter() - start)
rel_counts.append(0 if indexed is None else len(indexed.relationships))
return min(durations) * 1000.0, max(rel_counts)
ts_ms, ts_rels = _bench(ts_parser)
ag_ms, ag_rels = _bench(ag_parser)
assert ts_rels > 0, "Expected relationships extracted via tree-sitter"
assert ag_rels > 0, "Expected relationships extracted via ast-grep"
print("\nRelationship extraction benchmark (python, 1 file):")
print(f" tree-sitter: {ts_ms:.2f}ms, rels={ts_rels}")
print(f" ast-grep: {ag_ms:.2f}ms, rels={ag_rels}")
def test_cross_encoder_reranking_latency_under_200ms(self):
"""Cross-encoder rerank step completes under 200ms (excluding model load)."""
from codexlens.entities import SearchResult