Add tests and implement functionality for staged cascade search and LSP expansion

- Introduced a new JSON file for verbose output of the Codex Lens search results.
- Added unit tests for binary search functionality in `test_stage1_binary_search_uses_chunk_lines.py`.
- Implemented regression tests for staged cascade Stage 2 expansion depth in `test_staged_cascade_lsp_depth.py`.
- Created unit tests for staged cascade Stage 2 realtime LSP graph expansion in `test_staged_cascade_realtime_lsp.py`.
- Enhanced the ChainSearchEngine to respect configuration settings for staged LSP depth and improve search accuracy.
This commit is contained in:
catlog22
2026-02-08 21:54:42 +08:00
parent 166211dcd4
commit b9b2932f50
20 changed files with 1882 additions and 283 deletions

View File

@@ -455,6 +455,12 @@ def search(
hidden=True,
help="[Advanced] Cascade strategy for --method cascade."
),
staged_stage2_mode: Optional[str] = typer.Option(
None,
"--staged-stage2-mode",
hidden=True,
help="[Advanced] Stage 2 expansion mode for cascade strategy 'staged': precomputed | realtime.",
),
# Hidden deprecated parameter for backward compatibility
mode: Optional[str] = typer.Option(None, "--mode", hidden=True, help="[DEPRECATED] Use --method instead."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
@@ -545,7 +551,7 @@ def search(
# Validate cascade_strategy if provided (for advanced users)
if internal_cascade_strategy is not None:
valid_strategies = ["binary", "hybrid", "binary_rerank", "dense_rerank"]
valid_strategies = ["binary", "hybrid", "binary_rerank", "dense_rerank", "staged"]
if internal_cascade_strategy not in valid_strategies:
if json_mode:
print_json(success=False, error=f"Invalid cascade strategy: {internal_cascade_strategy}. Must be one of: {', '.join(valid_strategies)}")
@@ -606,6 +612,18 @@ def search(
engine = ChainSearchEngine(registry, mapper, config=config)
# Optional staged cascade overrides (only meaningful for cascade strategy 'staged')
if staged_stage2_mode is not None:
stage2 = staged_stage2_mode.strip().lower()
if stage2 not in {"precomputed", "realtime"}:
msg = "Invalid --staged-stage2-mode. Must be: precomputed | realtime."
if json_mode:
print_json(success=False, error=msg)
else:
console.print(f"[red]{msg}[/red]")
raise typer.Exit(code=1)
config.staged_stage2_mode = stage2
# Map method to SearchOptions flags
# fts: FTS-only search (optionally with fuzzy)
# vector: Pure vector semantic search
@@ -986,6 +1004,103 @@ def status(
registry.close()
@app.command(name="lsp-status")
def lsp_status(
path: Path = typer.Option(Path("."), "--path", "-p", help="Workspace root for LSP probing."),
probe_file: Optional[Path] = typer.Option(
None,
"--probe-file",
help="Optional file path to probe (starts the matching language server and prints capabilities).",
),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
"""Show standalone LSP configuration and optionally probe a language server.
This exercises the existing LSP server selection/startup path in StandaloneLspManager.
"""
_configure_logging(verbose, json_mode)
import asyncio
import shutil
from codexlens.lsp.standalone_manager import StandaloneLspManager
workspace_root = path.expanduser().resolve()
probe_path = probe_file.expanduser().resolve() if probe_file is not None else None
async def _run():
manager = StandaloneLspManager(workspace_root=str(workspace_root))
await manager.start()
servers = []
for language_id, cfg in sorted(manager._configs.items()): # type: ignore[attr-defined]
cmd0 = cfg.command[0] if cfg.command else None
servers.append(
{
"language_id": language_id,
"display_name": cfg.display_name,
"extensions": list(cfg.extensions),
"command": list(cfg.command),
"command_available": bool(shutil.which(cmd0)) if cmd0 else False,
}
)
probe = None
if probe_path is not None:
state = await manager._get_server(str(probe_path))
if state is None:
probe = {
"file": str(probe_path),
"ok": False,
"error": "No language server configured/available for this file.",
}
else:
probe = {
"file": str(probe_path),
"ok": True,
"language_id": state.config.language_id,
"display_name": state.config.display_name,
"initialized": bool(state.initialized),
"capabilities": state.capabilities,
}
await manager.stop()
return {"workspace_root": str(workspace_root), "servers": servers, "probe": probe}
try:
payload = asyncio.run(_run())
except Exception as exc:
if json_mode:
print_json(success=False, error=f"LSP status failed: {exc}")
else:
console.print(f"[red]LSP status failed:[/red] {exc}")
raise typer.Exit(code=1)
if json_mode:
print_json(success=True, result=payload)
return
console.print("[bold]CodexLens LSP Status[/bold]")
console.print(f" Workspace: {payload['workspace_root']}")
console.print("\n[bold]Configured Servers:[/bold]")
for s in payload["servers"]:
ok = "" if s["command_available"] else ""
console.print(f" {ok} {s['display_name']} ({s['language_id']}) -> {s['command'][0] if s['command'] else ''}")
console.print(f" Extensions: {', '.join(s['extensions'])}")
if payload["probe"] is not None:
probe = payload["probe"]
console.print("\n[bold]Probe:[/bold]")
if not probe.get("ok"):
console.print(f"{probe.get('file')}")
console.print(f" {probe.get('error')}")
else:
console.print(f"{probe.get('file')}")
console.print(f" Server: {probe.get('display_name')} ({probe.get('language_id')})")
console.print(f" Initialized: {probe.get('initialized')}")
@app.command()
def projects(
action: str = typer.Argument("list", help="Action: list, show, remove"),
@@ -3962,4 +4077,3 @@ def index_migrate_deprecated(
json_mode=json_mode,
verbose=verbose,
)

View File

@@ -145,6 +145,11 @@ class Config:
# Staged cascade search configuration (4-stage pipeline)
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
staged_lsp_depth: int = 2 # LSP relationship expansion depth in Stage 2
staged_stage2_mode: str = "precomputed" # "precomputed" (graph_neighbors) | "realtime" (LSP)
staged_realtime_lsp_timeout_s: float = 10.0 # Max time budget for realtime LSP expansion
staged_realtime_lsp_max_nodes: int = 100 # Node cap for realtime graph expansion
staged_realtime_lsp_warmup_s: float = 2.0 # Wait for server analysis after opening seed docs
staged_realtime_lsp_resolve_symbols: bool = False # If True, resolves symbol names via documentSymbol (slower)
staged_clustering_strategy: str = "auto" # "auto", "hdbscan", "dbscan", "frequency", "noop"
staged_clustering_min_size: int = 3 # Minimum cluster size for Stage 3 grouping
enable_staged_rerank: bool = True # Enable optional cross-encoder reranking in Stage 4

View File

@@ -20,6 +20,7 @@ from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from urllib.parse import unquote
if TYPE_CHECKING:
from codexlens.lsp.standalone_manager import StandaloneLspManager
@@ -62,12 +63,14 @@ class Location:
"""
# Handle VSCode URI format (file:///path/to/file)
uri = data.get("uri", data.get("file_path", ""))
if uri.startswith("file:///"):
# Windows: file:///C:/path -> C:/path
# Unix: file:///path -> /path
file_path = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
elif uri.startswith("file://"):
file_path = uri[7:]
if uri.startswith("file://"):
# Strip scheme and decode percent-encoding (e.g. file:///d%3A/...).
# Keep behavior compatible with both Windows and Unix paths.
raw = unquote(uri[7:]) # keep leading slash for Unix paths
# Windows: file:///C:/... or file:///c%3A/... -> C:/...
if raw.startswith("/") and len(raw) > 2 and raw[2] == ":":
raw = raw[1:]
file_path = raw
else:
file_path = uri

View File

@@ -28,6 +28,7 @@ class LspGraphBuilder:
max_depth: int = 2,
max_nodes: int = 100,
max_concurrent: int = 10,
resolve_symbols: bool = True,
):
"""Initialize GraphBuilder.
@@ -35,10 +36,12 @@ class LspGraphBuilder:
max_depth: Maximum depth for BFS expansion from seeds.
max_nodes: Maximum number of nodes in the graph.
max_concurrent: Maximum concurrent LSP requests.
resolve_symbols: If False, skip documentSymbol lookups and create lightweight nodes.
"""
self.max_depth = max_depth
self.max_nodes = max_nodes
self.max_concurrent = max_concurrent
self.resolve_symbols = resolve_symbols
# Cache for document symbols per file (avoids per-location hover queries)
self._document_symbols_cache: Dict[str, List[Dict[str, Any]]] = {}
@@ -276,9 +279,11 @@ class LspGraphBuilder:
start_line = location.line
# Try to find symbol info from cached document symbols (fast)
symbol_info = await self._get_symbol_at_location(
file_path, start_line, lsp_bridge
)
symbol_info = None
if self.resolve_symbols:
symbol_info = await self._get_symbol_at_location(
file_path, start_line, lsp_bridge
)
if symbol_info:
name = symbol_info.get("name", f"symbol_L{start_line}")

View File

@@ -1094,15 +1094,15 @@ class ChainSearchEngine:
metadata = chunk.get("metadata")
symbol_name = None
symbol_kind = None
start_line = None
end_line = None
start_line = chunk.get("start_line")
end_line = chunk.get("end_line")
if metadata:
try:
meta_dict = json.loads(metadata) if isinstance(metadata, str) else metadata
symbol_name = meta_dict.get("symbol_name")
symbol_kind = meta_dict.get("symbol_kind")
start_line = meta_dict.get("start_line")
end_line = meta_dict.get("end_line")
start_line = meta_dict.get("start_line", start_line)
end_line = meta_dict.get("end_line", end_line)
except Exception:
pass
@@ -1130,10 +1130,11 @@ class ChainSearchEngine:
coarse_results: List[SearchResult],
index_root: Optional[Path],
) -> List[SearchResult]:
"""Stage 2: LSP-based graph expansion using GraphExpander.
"""Stage 2: LSP/graph expansion for staged cascade.
Expands coarse results with related symbols (definitions, references,
callers, callees) using precomputed graph neighbors.
Supports two modes via Config.staged_stage2_mode:
- "precomputed" (default): GraphExpander over per-dir `graph_neighbors` table
- "realtime": on-demand graph expansion via live LSP servers (LspBridge + LspGraphBuilder)
Args:
coarse_results: Results from Stage 1 binary search
@@ -1146,44 +1147,14 @@ class ChainSearchEngine:
return coarse_results
try:
from codexlens.search.graph_expander import GraphExpander
# Get expansion depth from config
depth = 2
mode = "precomputed"
if self._config is not None:
depth = getattr(self._config, "graph_expansion_depth", 2)
mode = (getattr(self._config, "staged_stage2_mode", "precomputed") or "precomputed").strip().lower()
expander = GraphExpander(self.mapper, config=self._config)
if mode in {"realtime", "live"}:
return self._stage2_realtime_lsp_expand(coarse_results, index_root=index_root)
# Expand top results (limit expansion to avoid explosion)
max_expand = min(10, len(coarse_results))
max_related = 50
related_results = expander.expand(
coarse_results,
depth=depth,
max_expand=max_expand,
max_related=max_related,
)
if related_results:
self.logger.debug(
"Stage 2 expanded %d base results to %d related symbols",
len(coarse_results), len(related_results)
)
# Combine: original results + related results
# Keep original results first (higher relevance)
combined = list(coarse_results)
seen_keys = {(r.path, r.symbol_name, r.start_line) for r in coarse_results}
for related in related_results:
key = (related.path, related.symbol_name, related.start_line)
if key not in seen_keys:
seen_keys.add(key)
combined.append(related)
return combined
return self._stage2_precomputed_graph_expand(coarse_results, index_root=index_root)
except ImportError as exc:
self.logger.debug("GraphExpander not available: %s", exc)
@@ -1192,6 +1163,238 @@ class ChainSearchEngine:
self.logger.debug("Stage 2 LSP expansion failed: %s", exc)
return coarse_results
def _stage2_precomputed_graph_expand(
self,
coarse_results: List[SearchResult],
*,
index_root: Path,
) -> List[SearchResult]:
"""Stage 2 (precomputed): expand using GraphExpander over `graph_neighbors`."""
from codexlens.search.graph_expander import GraphExpander
depth = 2
if self._config is not None:
depth = getattr(
self._config,
"staged_lsp_depth",
getattr(self._config, "graph_expansion_depth", 2),
)
try:
depth = int(depth)
except Exception:
depth = 2
expander = GraphExpander(self.mapper, config=self._config)
max_expand = min(10, len(coarse_results))
max_related = 50
related_results = expander.expand(
coarse_results,
depth=depth,
max_expand=max_expand,
max_related=max_related,
)
if related_results:
self.logger.debug(
"Stage 2 (precomputed) expanded %d base results to %d related symbols",
len(coarse_results), len(related_results)
)
return self._combine_stage2_results(coarse_results, related_results)
def _stage2_realtime_lsp_expand(
self,
coarse_results: List[SearchResult],
*,
index_root: Path,
) -> List[SearchResult]:
"""Stage 2 (realtime): compute expansion graph via live LSP servers."""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from codexlens.hybrid_search.data_structures import CodeSymbolNode, Range
from codexlens.lsp import LspBridge, LspGraphBuilder
max_depth = 2
timeout_s = 10.0
max_nodes = 100
warmup_s = 2.0
resolve_symbols = False
if self._config is not None:
max_depth = int(getattr(self._config, "staged_lsp_depth", 2) or 2)
timeout_s = float(getattr(self._config, "staged_realtime_lsp_timeout_s", 10.0) or 10.0)
max_nodes = int(getattr(self._config, "staged_realtime_lsp_max_nodes", 100) or 100)
warmup_s = float(getattr(self._config, "staged_realtime_lsp_warmup_s", 2.0) or 0.0)
resolve_symbols = bool(getattr(self._config, "staged_realtime_lsp_resolve_symbols", False))
try:
source_root = self.mapper.index_to_source(index_root)
except Exception:
source_root = Path(coarse_results[0].path).resolve().parent
workspace_root = self._find_lsp_workspace_root(source_root)
max_expand = min(10, len(coarse_results))
seed_nodes: List[CodeSymbolNode] = []
seed_ids: set[str] = set()
for seed in list(coarse_results)[:max_expand]:
if not seed.path:
continue
name = seed.symbol_name or Path(seed.path).stem
kind = seed.symbol_kind or "unknown"
start_line = int(seed.start_line or 1)
end_line = int(seed.end_line or start_line)
start_character = 1
try:
if seed.symbol_name and start_line >= 1:
line_text = Path(seed.path).read_text(encoding="utf-8", errors="ignore").splitlines()[start_line - 1]
idx = line_text.find(seed.symbol_name)
if idx >= 0:
start_character = idx + 1 # 1-based for StandaloneLspManager
except Exception:
start_character = 1
node_id = f"{seed.path}:{name}:{start_line}"
seed_ids.add(node_id)
seed_nodes.append(
CodeSymbolNode(
id=node_id,
name=name,
kind=kind,
file_path=seed.path,
range=Range(
start_line=start_line,
start_character=start_character,
end_line=end_line,
end_character=1,
),
raw_code=seed.content or "",
docstring=seed.excerpt or "",
)
)
if not seed_nodes:
return coarse_results
async def expand_graph():
async with LspBridge(workspace_root=str(workspace_root), timeout=timeout_s) as bridge:
# Warm up analysis: open seed docs and wait a bit so references/call hierarchy are populated.
if warmup_s > 0:
for seed in seed_nodes[:3]:
try:
await bridge.get_document_symbols(seed.file_path)
except Exception:
continue
try:
await asyncio.sleep(min(warmup_s, max(0.0, timeout_s - 0.5)))
except Exception:
pass
builder = LspGraphBuilder(
max_depth=max_depth,
max_nodes=max_nodes,
resolve_symbols=resolve_symbols,
)
return await builder.build_from_seeds(seed_nodes, bridge)
def run_coro_blocking():
return asyncio.run(asyncio.wait_for(expand_graph(), timeout=timeout_s))
try:
try:
asyncio.get_running_loop()
has_running_loop = True
except RuntimeError:
has_running_loop = False
if has_running_loop:
with ThreadPoolExecutor(max_workers=1) as executor:
graph = executor.submit(run_coro_blocking).result(timeout=timeout_s + 1.0)
else:
graph = run_coro_blocking()
except Exception as exc:
self.logger.debug("Stage 2 (realtime) expansion failed: %s", exc)
return coarse_results
related_results: List[SearchResult] = []
for node_id, node in getattr(graph, "nodes", {}).items():
if node_id in seed_ids or getattr(node, "id", "") in seed_ids:
continue
try:
start_line = int(getattr(node.range, "start_line", 1) or 1)
end_line = int(getattr(node.range, "end_line", start_line) or start_line)
except Exception:
start_line, end_line = 1, 1
related_results.append(
SearchResult(
path=node.file_path,
score=0.5,
excerpt=None,
content=getattr(node, "raw_code", "") or None,
symbol_name=node.name,
symbol_kind=node.kind,
start_line=start_line,
end_line=end_line,
metadata={"stage2_mode": "realtime", "lsp_node_id": node_id},
)
)
if related_results:
self.logger.debug(
"Stage 2 (realtime) expanded %d base results to %d related symbols",
len(coarse_results), len(related_results)
)
return self._combine_stage2_results(coarse_results, related_results)
def _combine_stage2_results(
self,
coarse_results: List[SearchResult],
related_results: List[SearchResult],
) -> List[SearchResult]:
combined = list(coarse_results)
seen_keys = {(r.path, r.symbol_name, r.start_line) for r in coarse_results}
for related in related_results:
key = (related.path, related.symbol_name, related.start_line)
if key not in seen_keys:
seen_keys.add(key)
combined.append(related)
return combined
def _find_lsp_workspace_root(self, start_path: Path) -> Path:
"""Best-effort workspace root selection for LSP initialization.
Many language servers (e.g. Pyright) use workspace-relative include/exclude
patterns, so using a deep subdir (like "src") as root can break reference
and call-hierarchy queries.
"""
start = Path(start_path).resolve()
if start.is_file():
start = start.parent
# Prefer an explicit LSP config file in the workspace.
for current in [start, *list(start.parents)]:
try:
if (current / "lsp-servers.json").is_file():
return current
except OSError:
continue
# Fallback heuristics for project root markers.
for current in [start, *list(start.parents)]:
try:
if (current / ".git").exists() or (current / "pyproject.toml").is_file():
return current
except OSError:
continue
return start
def _stage3_cluster_prune(
self,
expanded_results: List[SearchResult],