mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-09 02:24:11 +08:00
Add unit tests for LspGraphBuilder class
- Implement comprehensive unit tests for the LspGraphBuilder class to validate its functionality in building code association graphs. - Tests cover various scenarios including single level graph expansion, max nodes and depth boundaries, concurrent expansion limits, document symbol caching, error handling during node expansion, and edge cases such as empty seed lists and self-referencing nodes. - Utilize pytest and asyncio for asynchronous testing and mocking of LspBridge methods.
This commit is contained in:
28
codex-lens/src/codexlens/hybrid_search/__init__.py
Normal file
28
codex-lens/src/codexlens/hybrid_search/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Hybrid Search data structures for CodexLens.
|
||||
|
||||
This module provides core data structures for hybrid search:
|
||||
- CodeSymbolNode: Graph node representing a code symbol
|
||||
- CodeAssociationGraph: Graph of code relationships
|
||||
- SearchResultCluster: Clustered search results
|
||||
- Range: Position range in source files
|
||||
- CallHierarchyItem: LSP call hierarchy item
|
||||
|
||||
Note: The search engine is in codexlens.search.hybrid_search
|
||||
LSP-based expansion is in codexlens.lsp module
|
||||
"""
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeAssociationGraph,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
SearchResultCluster,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallHierarchyItem",
|
||||
"CodeAssociationGraph",
|
||||
"CodeSymbolNode",
|
||||
"Range",
|
||||
"SearchResultCluster",
|
||||
]
|
||||
602
codex-lens/src/codexlens/hybrid_search/data_structures.py
Normal file
602
codex-lens/src/codexlens/hybrid_search/data_structures.py
Normal file
@@ -0,0 +1,602 @@
|
||||
"""Core data structures for the hybrid search system.
|
||||
|
||||
This module defines the fundamental data structures used throughout the
|
||||
hybrid search pipeline, including code symbol representations, association
|
||||
graphs, and clustered search results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import networkx as nx
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
"""Position range within a source file.
|
||||
|
||||
Attributes:
|
||||
start_line: Starting line number (0-based).
|
||||
start_character: Starting character offset within the line.
|
||||
end_line: Ending line number (0-based).
|
||||
end_character: Ending character offset within the line.
|
||||
"""
|
||||
|
||||
start_line: int
|
||||
start_character: int
|
||||
end_line: int
|
||||
end_character: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate range values."""
|
||||
if self.start_line < 0:
|
||||
raise ValueError("start_line must be >= 0")
|
||||
if self.start_character < 0:
|
||||
raise ValueError("start_character must be >= 0")
|
||||
if self.end_line < 0:
|
||||
raise ValueError("end_line must be >= 0")
|
||||
if self.end_character < 0:
|
||||
raise ValueError("end_character must be >= 0")
|
||||
if self.end_line < self.start_line:
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
if self.end_line == self.start_line and self.end_character < self.start_character:
|
||||
raise ValueError("end_character must be >= start_character on the same line")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"start": {"line": self.start_line, "character": self.start_character},
|
||||
"end": {"line": self.end_line, "character": self.end_character},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> Range:
|
||||
"""Create Range from dictionary representation."""
|
||||
return cls(
|
||||
start_line=data["start"]["line"],
|
||||
start_character=data["start"]["character"],
|
||||
end_line=data["end"]["line"],
|
||||
end_character=data["end"]["character"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
|
||||
"""Create Range from LSP Range object.
|
||||
|
||||
LSP Range format:
|
||||
{"start": {"line": int, "character": int},
|
||||
"end": {"line": int, "character": int}}
|
||||
"""
|
||||
return cls(
|
||||
start_line=lsp_range["start"]["line"],
|
||||
start_character=lsp_range["start"]["character"],
|
||||
end_line=lsp_range["end"]["line"],
|
||||
end_character=lsp_range["end"]["character"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallHierarchyItem:
|
||||
"""LSP CallHierarchyItem for representing callers/callees.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name (function, method, class name).
|
||||
kind: Symbol kind (function, method, class, etc.).
|
||||
file_path: Absolute file path where the symbol is defined.
|
||||
range: Position range in the source file.
|
||||
detail: Optional additional detail about the symbol.
|
||||
"""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
detail: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
}
|
||||
if self.detail:
|
||||
result["detail"] = self.detail
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from dictionary representation."""
|
||||
return cls(
|
||||
name=data["name"],
|
||||
kind=data["kind"],
|
||||
file_path=data["file_path"],
|
||||
range=Range.from_dict(data["range"]),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeSymbolNode:
|
||||
"""Graph node representing a code symbol.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier in format 'file_path:name:line'.
|
||||
name: Symbol name (function, class, variable name).
|
||||
kind: Symbol kind (function, class, method, variable, etc.).
|
||||
file_path: Absolute file path where symbol is defined.
|
||||
range: Start/end position in the source file.
|
||||
embedding: Optional vector embedding for semantic search.
|
||||
raw_code: Raw source code of the symbol.
|
||||
docstring: Documentation string (if available).
|
||||
score: Ranking score (used during reranking).
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
embedding: Optional[List[float]] = None
|
||||
raw_code: str = ""
|
||||
docstring: str = ""
|
||||
score: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate required fields."""
|
||||
if not self.id:
|
||||
raise ValueError("id cannot be empty")
|
||||
if not self.name:
|
||||
raise ValueError("name cannot be empty")
|
||||
if not self.kind:
|
||||
raise ValueError("kind cannot be empty")
|
||||
if not self.file_path:
|
||||
raise ValueError("file_path cannot be empty")
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on unique ID."""
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on unique ID."""
|
||||
if not isinstance(other, CodeSymbolNode):
|
||||
return False
|
||||
return self.id == other.id
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
"score": self.score,
|
||||
}
|
||||
if self.raw_code:
|
||||
result["raw_code"] = self.raw_code
|
||||
if self.docstring:
|
||||
result["docstring"] = self.docstring
|
||||
# Exclude embedding from serialization (too large for JSON responses)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
|
||||
"""Create CodeSymbolNode from dictionary representation."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
kind=data["kind"],
|
||||
file_path=data["file_path"],
|
||||
range=Range.from_dict(data["range"]),
|
||||
embedding=data.get("embedding"),
|
||||
raw_code=data.get("raw_code", ""),
|
||||
docstring=data.get("docstring", ""),
|
||||
score=data.get("score", 0.0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp_location(
|
||||
cls,
|
||||
uri: str,
|
||||
name: str,
|
||||
kind: str,
|
||||
lsp_range: Dict[str, Any],
|
||||
raw_code: str = "",
|
||||
docstring: str = "",
|
||||
) -> CodeSymbolNode:
|
||||
"""Create CodeSymbolNode from LSP location data.
|
||||
|
||||
Args:
|
||||
uri: File URI (file:// prefix will be stripped).
|
||||
name: Symbol name.
|
||||
kind: Symbol kind.
|
||||
lsp_range: LSP Range object.
|
||||
raw_code: Optional raw source code.
|
||||
docstring: Optional documentation string.
|
||||
|
||||
Returns:
|
||||
New CodeSymbolNode instance.
|
||||
"""
|
||||
# Strip file:// prefix if present
|
||||
file_path = uri
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
# Handle Windows paths (file:///C:/...)
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
range_obj = Range.from_lsp_range(lsp_range)
|
||||
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
|
||||
|
||||
return cls(
|
||||
id=symbol_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=range_obj,
|
||||
raw_code=raw_code,
|
||||
docstring=docstring,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_id(cls, file_path: str, name: str, line: int) -> str:
|
||||
"""Generate a unique symbol ID.
|
||||
|
||||
Args:
|
||||
file_path: Absolute file path.
|
||||
name: Symbol name.
|
||||
line: Start line number.
|
||||
|
||||
Returns:
|
||||
Unique ID string in format 'file_path:name:line'.
|
||||
"""
|
||||
return f"{file_path}:{name}:{line}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeAssociationGraph:
|
||||
"""Graph of code relationships between symbols.
|
||||
|
||||
This graph represents the association between code symbols discovered
|
||||
through LSP queries (references, call hierarchy, etc.).
|
||||
|
||||
Attributes:
|
||||
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
|
||||
edges: List of (from_id, to_id, relationship_type) tuples.
|
||||
relationship_type: 'calls', 'references', 'inherits', 'imports'.
|
||||
"""
|
||||
|
||||
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
|
||||
edges: List[Tuple[str, str, str]] = field(default_factory=list)
|
||||
|
||||
def add_node(self, node: CodeSymbolNode) -> None:
|
||||
"""Add a node to the graph.
|
||||
|
||||
Args:
|
||||
node: CodeSymbolNode to add. If a node with the same ID exists,
|
||||
it will be replaced.
|
||||
"""
|
||||
self.nodes[node.id] = node
|
||||
|
||||
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||
"""Add an edge to the graph.
|
||||
|
||||
Args:
|
||||
from_id: Source node ID.
|
||||
to_id: Target node ID.
|
||||
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
|
||||
|
||||
Raises:
|
||||
ValueError: If from_id or to_id not in graph nodes.
|
||||
"""
|
||||
if from_id not in self.nodes:
|
||||
raise ValueError(f"Source node '{from_id}' not found in graph")
|
||||
if to_id not in self.nodes:
|
||||
raise ValueError(f"Target node '{to_id}' not found in graph")
|
||||
|
||||
edge = (from_id, to_id, rel_type)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
|
||||
"""Add an edge without validating node existence.
|
||||
|
||||
Use this method during bulk graph construction where nodes may be
|
||||
added after edges, or when performance is critical.
|
||||
|
||||
Args:
|
||||
from_id: Source node ID.
|
||||
to_id: Target node ID.
|
||||
rel_type: Relationship type.
|
||||
"""
|
||||
edge = (from_id, to_id, rel_type)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
|
||||
"""Get a node by ID.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to look up.
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode if found, None otherwise.
|
||||
"""
|
||||
return self.nodes.get(node_id)
|
||||
|
||||
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||
"""Get neighboring nodes connected by outgoing edges.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to find neighbors for.
|
||||
rel_type: Optional filter by relationship type.
|
||||
|
||||
Returns:
|
||||
List of neighboring CodeSymbolNode objects.
|
||||
"""
|
||||
neighbors = []
|
||||
for from_id, to_id, edge_rel in self.edges:
|
||||
if from_id == node_id:
|
||||
if rel_type is None or edge_rel == rel_type:
|
||||
node = self.nodes.get(to_id)
|
||||
if node:
|
||||
neighbors.append(node)
|
||||
return neighbors
|
||||
|
||||
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
||||
"""Get nodes connected by incoming edges.
|
||||
|
||||
Args:
|
||||
node_id: Node ID to find incoming connections for.
|
||||
rel_type: Optional filter by relationship type.
|
||||
|
||||
Returns:
|
||||
List of CodeSymbolNode objects with edges pointing to node_id.
|
||||
"""
|
||||
incoming = []
|
||||
for from_id, to_id, edge_rel in self.edges:
|
||||
if to_id == node_id:
|
||||
if rel_type is None or edge_rel == rel_type:
|
||||
node = self.nodes.get(from_id)
|
||||
if node:
|
||||
incoming.append(node)
|
||||
return incoming
|
||||
|
||||
def to_networkx(self) -> "nx.DiGraph":
|
||||
"""Convert to NetworkX DiGraph for graph algorithms.
|
||||
|
||||
Returns:
|
||||
NetworkX directed graph with nodes and edges.
|
||||
|
||||
Raises:
|
||||
ImportError: If networkx is not installed.
|
||||
"""
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"networkx is required for graph algorithms. "
|
||||
"Install with: pip install networkx"
|
||||
)
|
||||
|
||||
graph = nx.DiGraph()
|
||||
|
||||
# Add nodes with attributes
|
||||
for node_id, node in self.nodes.items():
|
||||
graph.add_node(
|
||||
node_id,
|
||||
name=node.name,
|
||||
kind=node.kind,
|
||||
file_path=node.file_path,
|
||||
score=node.score,
|
||||
)
|
||||
|
||||
# Add edges with relationship type
|
||||
for from_id, to_id, rel_type in self.edges:
|
||||
graph.add_edge(from_id, to_id, relationship=rel_type)
|
||||
|
||||
return graph
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary with 'nodes' and 'edges' keys.
|
||||
"""
|
||||
return {
|
||||
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
||||
"edges": [
|
||||
{"from": from_id, "to": to_id, "relationship": rel_type}
|
||||
for from_id, to_id, rel_type in self.edges
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
|
||||
"""Create CodeAssociationGraph from dictionary representation.
|
||||
|
||||
Args:
|
||||
data: Dictionary with 'nodes' and 'edges' keys.
|
||||
|
||||
Returns:
|
||||
New CodeAssociationGraph instance.
|
||||
"""
|
||||
graph = cls()
|
||||
|
||||
# Load nodes
|
||||
for node_id, node_data in data.get("nodes", {}).items():
|
||||
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
|
||||
|
||||
# Load edges
|
||||
for edge_data in data.get("edges", []):
|
||||
graph.edges.append((
|
||||
edge_data["from"],
|
||||
edge_data["to"],
|
||||
edge_data["relationship"],
|
||||
))
|
||||
|
||||
return graph
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of nodes in the graph."""
|
||||
return len(self.nodes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResultCluster:
|
||||
"""Clustered search result containing related code symbols.
|
||||
|
||||
Search results are grouped into clusters based on graph community
|
||||
detection or embedding similarity. Each cluster represents a
|
||||
conceptually related group of code symbols.
|
||||
|
||||
Attributes:
|
||||
cluster_id: Unique cluster identifier.
|
||||
score: Cluster relevance score (max of symbol scores).
|
||||
title: Human-readable cluster title/summary.
|
||||
symbols: List of CodeSymbolNode in this cluster.
|
||||
metadata: Additional cluster metadata.
|
||||
"""
|
||||
|
||||
cluster_id: str
|
||||
score: float
|
||||
title: str
|
||||
symbols: List[CodeSymbolNode] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate cluster fields."""
|
||||
if not self.cluster_id:
|
||||
raise ValueError("cluster_id cannot be empty")
|
||||
if self.score < 0:
|
||||
raise ValueError("score must be >= 0")
|
||||
|
||||
def add_symbol(self, symbol: CodeSymbolNode) -> None:
|
||||
"""Add a symbol to the cluster.
|
||||
|
||||
Args:
|
||||
symbol: CodeSymbolNode to add.
|
||||
"""
|
||||
self.symbols.append(symbol)
|
||||
|
||||
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
|
||||
"""Get top N symbols by score.
|
||||
|
||||
Args:
|
||||
n: Number of symbols to return.
|
||||
|
||||
Returns:
|
||||
List of top N CodeSymbolNode objects sorted by score descending.
|
||||
"""
|
||||
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
|
||||
return sorted_symbols[:n]
|
||||
|
||||
def update_score(self) -> None:
|
||||
"""Update cluster score to max of symbol scores."""
|
||||
if self.symbols:
|
||||
self.score = max(s.score for s in self.symbols)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the cluster.
|
||||
"""
|
||||
return {
|
||||
"cluster_id": self.cluster_id,
|
||||
"score": self.score,
|
||||
"title": self.title,
|
||||
"symbols": [s.to_dict() for s in self.symbols],
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
|
||||
"""Create SearchResultCluster from dictionary representation.
|
||||
|
||||
Args:
|
||||
data: Dictionary with cluster data.
|
||||
|
||||
Returns:
|
||||
New SearchResultCluster instance.
|
||||
"""
|
||||
return cls(
|
||||
cluster_id=data["cluster_id"],
|
||||
score=data["score"],
|
||||
title=data["title"],
|
||||
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of symbols in the cluster."""
|
||||
return len(self.symbols)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallHierarchyItem:
|
||||
"""LSP CallHierarchyItem for representing callers/callees.
|
||||
|
||||
Attributes:
|
||||
name: Symbol name (function, method, etc.).
|
||||
kind: Symbol kind (function, method, etc.).
|
||||
file_path: Absolute file path.
|
||||
range: Position range in the file.
|
||||
detail: Optional additional detail (e.g., signature).
|
||||
"""
|
||||
|
||||
name: str
|
||||
kind: str
|
||||
file_path: str
|
||||
range: Range
|
||||
detail: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"kind": self.kind,
|
||||
"file_path": self.file_path,
|
||||
"range": self.range.to_dict(),
|
||||
}
|
||||
if self.detail:
|
||||
result["detail"] = self.detail
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from dictionary representation."""
|
||||
return cls(
|
||||
name=data.get("name", "unknown"),
|
||||
kind=data.get("kind", "unknown"),
|
||||
file_path=data.get("file_path", data.get("uri", "")),
|
||||
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
||||
"""Create CallHierarchyItem from LSP response format.
|
||||
|
||||
LSP uses 0-based line numbers and 'character' instead of 'char'.
|
||||
"""
|
||||
uri = data.get("uri", data.get("file_path", ""))
|
||||
# Strip file:// prefix
|
||||
file_path = uri
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
||||
file_path = file_path[1:]
|
||||
|
||||
return cls(
|
||||
name=data.get("name", "unknown"),
|
||||
kind=str(data.get("kind", "unknown")),
|
||||
file_path=file_path,
|
||||
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
||||
detail=data.get("detail"),
|
||||
)
|
||||
@@ -1,7 +1,34 @@
|
||||
"""codex-lens Language Server Protocol implementation."""
|
||||
"""LSP module for real-time language server integration.
|
||||
|
||||
from __future__ import annotations
|
||||
This module provides:
|
||||
- LspBridge: HTTP bridge to VSCode language servers
|
||||
- LspGraphBuilder: Build code association graphs via LSP
|
||||
- Location: Position in a source file
|
||||
|
||||
from codexlens.lsp.server import CodexLensLanguageServer, main
|
||||
Example:
|
||||
>>> from codexlens.lsp import LspBridge, LspGraphBuilder
|
||||
>>>
|
||||
>>> async with LspBridge() as bridge:
|
||||
... refs = await bridge.get_references(symbol)
|
||||
... graph = await LspGraphBuilder().build_from_seeds(seeds, bridge)
|
||||
"""
|
||||
|
||||
__all__ = ["CodexLensLanguageServer", "main"]
|
||||
from codexlens.lsp.lsp_bridge import (
|
||||
CacheEntry,
|
||||
Location,
|
||||
LspBridge,
|
||||
)
|
||||
from codexlens.lsp.lsp_graph_builder import (
|
||||
LspGraphBuilder,
|
||||
)
|
||||
|
||||
# Alias for backward compatibility
|
||||
GraphBuilder = LspGraphBuilder
|
||||
|
||||
__all__ = [
|
||||
"CacheEntry",
|
||||
"GraphBuilder",
|
||||
"Location",
|
||||
"LspBridge",
|
||||
"LspGraphBuilder",
|
||||
]
|
||||
|
||||
834
codex-lens/src/codexlens/lsp/lsp_bridge.py
Normal file
834
codex-lens/src/codexlens/lsp/lsp_bridge.py
Normal file
@@ -0,0 +1,834 @@
|
||||
"""LspBridge service for real-time LSP communication with caching.
|
||||
|
||||
This module provides a bridge to communicate with language servers either via:
|
||||
1. Standalone LSP Manager (direct subprocess communication - default)
|
||||
2. VSCode Bridge extension (HTTP-based, legacy mode)
|
||||
|
||||
Features:
|
||||
- Direct communication with language servers (no VSCode dependency)
|
||||
- Cache with TTL and file modification time invalidation
|
||||
- Graceful error handling with empty results on failure
|
||||
- Support for definition, references, hover, and call hierarchy
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
|
||||
# Check for optional dependencies
|
||||
try:
|
||||
import aiohttp
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Location:
|
||||
"""A location in a source file (LSP response format)."""
|
||||
|
||||
file_path: str
|
||||
line: int
|
||||
character: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"line": self.line,
|
||||
"character": self.character,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_lsp_response(cls, data: Dict[str, Any]) -> "Location":
|
||||
"""Create Location from LSP response format.
|
||||
|
||||
Handles both direct format and VSCode URI format.
|
||||
"""
|
||||
# Handle VSCode URI format (file:///path/to/file)
|
||||
uri = data.get("uri", data.get("file_path", ""))
|
||||
if uri.startswith("file:///"):
|
||||
# Windows: file:///C:/path -> C:/path
|
||||
# Unix: file:///path -> /path
|
||||
file_path = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||
elif uri.startswith("file://"):
|
||||
file_path = uri[7:]
|
||||
else:
|
||||
file_path = uri
|
||||
|
||||
# Get position from range or direct fields
|
||||
if "range" in data:
|
||||
range_data = data["range"]
|
||||
start = range_data.get("start", {})
|
||||
line = start.get("line", 0) + 1 # LSP is 0-based, convert to 1-based
|
||||
character = start.get("character", 0) + 1
|
||||
else:
|
||||
line = data.get("line", 1)
|
||||
character = data.get("character", 1)
|
||||
|
||||
return cls(file_path=file_path, line=line, character=character)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached LSP response with expiration metadata.
|
||||
|
||||
Attributes:
|
||||
data: The cached response data
|
||||
file_mtime: File modification time when cached (for invalidation)
|
||||
cached_at: Unix timestamp when entry was cached
|
||||
"""
|
||||
|
||||
data: Any
|
||||
file_mtime: float
|
||||
cached_at: float
|
||||
|
||||
|
||||
class LspBridge:
|
||||
"""Bridge for real-time LSP communication with language servers.
|
||||
|
||||
By default, uses StandaloneLspManager to directly spawn and communicate
|
||||
with language servers via JSON-RPC over stdio. No VSCode dependency required.
|
||||
|
||||
For legacy mode, can use VSCode Bridge HTTP server (set use_vscode_bridge=True).
|
||||
|
||||
Features:
|
||||
- Direct language server communication (default)
|
||||
- Response caching with TTL and file modification invalidation
|
||||
- Timeout handling
|
||||
- Graceful error handling returning empty results
|
||||
|
||||
Example:
|
||||
# Default: standalone mode (no VSCode needed)
|
||||
async with LspBridge() as bridge:
|
||||
refs = await bridge.get_references(symbol)
|
||||
definition = await bridge.get_definition(symbol)
|
||||
|
||||
# Legacy: VSCode Bridge mode
|
||||
async with LspBridge(use_vscode_bridge=True) as bridge:
|
||||
refs = await bridge.get_references(symbol)
|
||||
"""
|
||||
|
||||
DEFAULT_BRIDGE_URL = "http://127.0.0.1:3457"
|
||||
DEFAULT_TIMEOUT = 30.0 # seconds (increased for standalone mode)
|
||||
DEFAULT_CACHE_TTL = 300 # 5 minutes
|
||||
DEFAULT_MAX_CACHE_SIZE = 1000 # Maximum cache entries
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bridge_url: str = DEFAULT_BRIDGE_URL,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
cache_ttl: int = DEFAULT_CACHE_TTL,
|
||||
max_cache_size: int = DEFAULT_MAX_CACHE_SIZE,
|
||||
use_vscode_bridge: bool = False,
|
||||
workspace_root: Optional[str] = None,
|
||||
config_file: Optional[str] = None,
|
||||
):
|
||||
"""Initialize LspBridge.
|
||||
|
||||
Args:
|
||||
bridge_url: URL of the VSCode Bridge HTTP server (legacy mode only)
|
||||
timeout: Request timeout in seconds
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
max_cache_size: Maximum number of cache entries (LRU eviction)
|
||||
use_vscode_bridge: If True, use VSCode Bridge HTTP mode (requires aiohttp)
|
||||
workspace_root: Root directory for standalone LSP manager
|
||||
config_file: Path to lsp-servers.json configuration file
|
||||
"""
|
||||
self.bridge_url = bridge_url
|
||||
self.timeout = timeout
|
||||
self.cache_ttl = cache_ttl
|
||||
self.max_cache_size = max_cache_size
|
||||
self.use_vscode_bridge = use_vscode_bridge
|
||||
self.workspace_root = workspace_root
|
||||
self.config_file = config_file
|
||||
|
||||
self.cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
|
||||
# VSCode Bridge mode (legacy)
|
||||
self._session: Optional["aiohttp.ClientSession"] = None
|
||||
|
||||
# Standalone mode (default)
|
||||
self._manager: Optional["StandaloneLspManager"] = None
|
||||
self._manager_started = False
|
||||
|
||||
# Validate dependencies
|
||||
if use_vscode_bridge and not HAS_AIOHTTP:
|
||||
raise ImportError(
|
||||
"aiohttp is required for VSCode Bridge mode: pip install aiohttp"
|
||||
)
|
||||
|
||||
async def _ensure_manager(self) -> "StandaloneLspManager":
|
||||
"""Ensure standalone LSP manager is started."""
|
||||
if self._manager is None:
|
||||
from codexlens.lsp.standalone_manager import StandaloneLspManager
|
||||
self._manager = StandaloneLspManager(
|
||||
workspace_root=self.workspace_root,
|
||||
config_file=self.config_file,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if not self._manager_started:
|
||||
await self._manager.start()
|
||||
self._manager_started = True
|
||||
|
||||
return self._manager
|
||||
|
||||
async def _get_session(self) -> "aiohttp.ClientSession":
|
||||
"""Get or create the aiohttp session (VSCode Bridge mode only)."""
|
||||
if not HAS_AIOHTTP:
|
||||
raise ImportError("aiohttp required for VSCode Bridge mode")
|
||||
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close connections and cleanup resources."""
|
||||
# Close VSCode Bridge session
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
# Stop standalone manager
|
||||
if self._manager and self._manager_started:
|
||||
await self._manager.stop()
|
||||
self._manager_started = False
|
||||
|
||||
def _get_file_mtime(self, file_path: str) -> float:
|
||||
"""Get file modification time, or 0 if file doesn't exist."""
|
||||
try:
|
||||
return os.path.getmtime(file_path)
|
||||
except OSError:
|
||||
return 0.0
|
||||
|
||||
def _is_cached(self, cache_key: str, file_path: str) -> bool:
|
||||
"""Check if cache entry is valid.
|
||||
|
||||
Cache is invalid if:
|
||||
- Entry doesn't exist
|
||||
- TTL has expired
|
||||
- File has been modified since caching
|
||||
|
||||
Args:
|
||||
cache_key: The cache key to check
|
||||
file_path: Path to source file for mtime check
|
||||
|
||||
Returns:
|
||||
True if cache is valid and can be used
|
||||
"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
entry = self.cache[cache_key]
|
||||
now = time.time()
|
||||
|
||||
# Check TTL
|
||||
if now - entry.cached_at > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return False
|
||||
|
||||
# Check file modification time
|
||||
current_mtime = self._get_file_mtime(file_path)
|
||||
if current_mtime != entry.file_mtime:
|
||||
del self.cache[cache_key]
|
||||
return False
|
||||
|
||||
# Move to end on access (LRU behavior)
|
||||
self.cache.move_to_end(cache_key)
|
||||
return True
|
||||
|
||||
def _cache(self, key: str, file_path: str, data: Any) -> None:
|
||||
"""Store data in cache with LRU eviction.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
file_path: Path to source file (for mtime tracking)
|
||||
data: Data to cache
|
||||
"""
|
||||
# Remove oldest entries if at capacity
|
||||
while len(self.cache) >= self.max_cache_size:
|
||||
self.cache.popitem(last=False) # Remove oldest (FIFO order)
|
||||
|
||||
# Move to end if key exists (update access order)
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
self.cache[key] = CacheEntry(
|
||||
data=data,
|
||||
file_mtime=self._get_file_mtime(file_path),
|
||||
cached_at=time.time(),
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
self.cache.clear()
|
||||
|
||||
async def _request_vscode_bridge(self, action: str, params: Dict[str, Any]) -> Any:
|
||||
"""Make HTTP request to VSCode Bridge (legacy mode).
|
||||
|
||||
Args:
|
||||
action: The endpoint/action name (e.g., "get_definition")
|
||||
params: Request parameters
|
||||
|
||||
Returns:
|
||||
Response data on success, None on failure
|
||||
"""
|
||||
url = f"{self.bridge_url}/{action}"
|
||||
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.post(url, json=params) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
if data.get("success") is False:
|
||||
return None
|
||||
|
||||
return data.get("result")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_references(self, symbol: CodeSymbolNode) -> List[Location]:
|
||||
"""Get all references to a symbol via real-time LSP.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to find references for
|
||||
|
||||
Returns:
|
||||
List of Location objects where the symbol is referenced.
|
||||
Returns empty list on error or timeout.
|
||||
"""
|
||||
cache_key = f"refs:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
locations: List[Location] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_references", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
# Don't cache on connection error (result is None)
|
||||
if result is None:
|
||||
return locations
|
||||
|
||||
if isinstance(result, list):
|
||||
for item in result:
|
||||
try:
|
||||
locations.append(Location.from_lsp_response(item))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_references(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
for item in result:
|
||||
try:
|
||||
locations.append(Location.from_lsp_response(item))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
|
||||
self._cache(cache_key, symbol.file_path, locations)
|
||||
return locations
|
||||
|
||||
async def get_definition(self, symbol: CodeSymbolNode) -> Optional[Location]:
|
||||
"""Get symbol definition location.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to find definition for
|
||||
|
||||
Returns:
|
||||
Location of the definition, or None if not found
|
||||
"""
|
||||
cache_key = f"def:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
location: Optional[Location] = None
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_definition", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result:
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
try:
|
||||
location = Location.from_lsp_response(result[0])
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
elif isinstance(result, dict):
|
||||
try:
|
||||
location = Location.from_lsp_response(result)
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_definition(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
if result:
|
||||
try:
|
||||
location = Location.from_lsp_response(result)
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
|
||||
self._cache(cache_key, symbol.file_path, location)
|
||||
return location
|
||||
|
||||
async def get_call_hierarchy(self, symbol: CodeSymbolNode) -> List[CallHierarchyItem]:
|
||||
"""Get incoming/outgoing calls for a symbol.
|
||||
|
||||
If call hierarchy is not supported by the language server,
|
||||
falls back to using references.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to get call hierarchy for
|
||||
|
||||
Returns:
|
||||
List of CallHierarchyItem representing callers/callees.
|
||||
Returns empty list on error or if not supported.
|
||||
"""
|
||||
cache_key = f"calls:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
items: List[CallHierarchyItem] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_call_hierarchy", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result is None:
|
||||
# Fallback: use references
|
||||
refs = await self.get_references(symbol)
|
||||
for ref in refs:
|
||||
items.append(CallHierarchyItem(
|
||||
name=f"caller@{ref.line}",
|
||||
kind="reference",
|
||||
file_path=ref.file_path,
|
||||
range=Range(
|
||||
start_line=ref.line,
|
||||
start_character=ref.character,
|
||||
end_line=ref.line,
|
||||
end_character=ref.character,
|
||||
),
|
||||
detail="Inferred from reference",
|
||||
))
|
||||
elif isinstance(result, list):
|
||||
for item in result:
|
||||
try:
|
||||
range_data = item.get("range", {})
|
||||
start = range_data.get("start", {})
|
||||
end = range_data.get("end", {})
|
||||
|
||||
items.append(CallHierarchyItem(
|
||||
name=item.get("name", "unknown"),
|
||||
kind=item.get("kind", "unknown"),
|
||||
file_path=item.get("file_path", item.get("uri", "")),
|
||||
range=Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
),
|
||||
detail=item.get("detail"),
|
||||
))
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
|
||||
# Try to get call hierarchy items
|
||||
hierarchy_items = await manager.get_call_hierarchy_items(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
if hierarchy_items:
|
||||
# Get incoming calls for each item
|
||||
for h_item in hierarchy_items:
|
||||
incoming = await manager.get_incoming_calls(h_item)
|
||||
for call in incoming:
|
||||
from_item = call.get("from", {})
|
||||
range_data = from_item.get("range", {})
|
||||
start = range_data.get("start", {})
|
||||
end = range_data.get("end", {})
|
||||
|
||||
# Parse URI
|
||||
uri = from_item.get("uri", "")
|
||||
if uri.startswith("file:///"):
|
||||
fp = uri[8:] if uri[8:9].isalpha() and uri[9:10] == ":" else uri[7:]
|
||||
elif uri.startswith("file://"):
|
||||
fp = uri[7:]
|
||||
else:
|
||||
fp = uri
|
||||
|
||||
items.append(CallHierarchyItem(
|
||||
name=from_item.get("name", "unknown"),
|
||||
kind=str(from_item.get("kind", "unknown")),
|
||||
file_path=fp,
|
||||
range=Range(
|
||||
start_line=start.get("line", 0) + 1,
|
||||
start_character=start.get("character", 0) + 1,
|
||||
end_line=end.get("line", 0) + 1,
|
||||
end_character=end.get("character", 0) + 1,
|
||||
),
|
||||
detail=from_item.get("detail"),
|
||||
))
|
||||
else:
|
||||
# Fallback: use references
|
||||
refs = await self.get_references(symbol)
|
||||
for ref in refs:
|
||||
items.append(CallHierarchyItem(
|
||||
name=f"caller@{ref.line}",
|
||||
kind="reference",
|
||||
file_path=ref.file_path,
|
||||
range=Range(
|
||||
start_line=ref.line,
|
||||
start_character=ref.character,
|
||||
end_line=ref.line,
|
||||
end_character=ref.character,
|
||||
),
|
||||
detail="Inferred from reference",
|
||||
))
|
||||
|
||||
self._cache(cache_key, symbol.file_path, items)
|
||||
return items
|
||||
|
||||
async def get_document_symbols(self, file_path: str) -> List[Dict[str, Any]]:
|
||||
"""Get all symbols in a document (batch operation).
|
||||
|
||||
This is more efficient than individual hover queries when processing
|
||||
multiple locations in the same file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file
|
||||
|
||||
Returns:
|
||||
List of symbol dictionaries with name, kind, range, etc.
|
||||
Returns empty list on error or timeout.
|
||||
"""
|
||||
cache_key = f"symbols:{file_path}"
|
||||
|
||||
if self._is_cached(cache_key, file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
symbols: List[Dict[str, Any]] = []
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_document_symbols", {
|
||||
"file_path": file_path,
|
||||
})
|
||||
|
||||
if isinstance(result, list):
|
||||
symbols = self._flatten_document_symbols(result)
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
result = await manager.get_document_symbols(file_path)
|
||||
|
||||
if result:
|
||||
symbols = self._flatten_document_symbols(result)
|
||||
|
||||
self._cache(cache_key, file_path, symbols)
|
||||
return symbols
|
||||
|
||||
def _flatten_document_symbols(
|
||||
self, symbols: List[Dict[str, Any]], parent_name: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Flatten nested document symbols into a flat list.
|
||||
|
||||
Document symbols can be nested (e.g., methods inside classes).
|
||||
This flattens them for easier lookup by line number.
|
||||
|
||||
Args:
|
||||
symbols: List of symbol dictionaries (may be nested)
|
||||
parent_name: Name of parent symbol for qualification
|
||||
|
||||
Returns:
|
||||
Flat list of all symbols with their ranges
|
||||
"""
|
||||
flat: List[Dict[str, Any]] = []
|
||||
|
||||
for sym in symbols:
|
||||
# Add the symbol itself
|
||||
symbol_entry = {
|
||||
"name": sym.get("name", "unknown"),
|
||||
"kind": self._symbol_kind_to_string(sym.get("kind", 0)),
|
||||
"range": sym.get("range", sym.get("location", {}).get("range", {})),
|
||||
"selection_range": sym.get("selectionRange", {}),
|
||||
"detail": sym.get("detail", ""),
|
||||
"parent": parent_name,
|
||||
}
|
||||
flat.append(symbol_entry)
|
||||
|
||||
# Recursively process children
|
||||
children = sym.get("children", [])
|
||||
if children:
|
||||
qualified_name = sym.get("name", "")
|
||||
if parent_name:
|
||||
qualified_name = f"{parent_name}.{qualified_name}"
|
||||
flat.extend(self._flatten_document_symbols(children, qualified_name))
|
||||
|
||||
return flat
|
||||
|
||||
def _symbol_kind_to_string(self, kind: int) -> str:
|
||||
"""Convert LSP SymbolKind integer to string.
|
||||
|
||||
Args:
|
||||
kind: LSP SymbolKind enum value
|
||||
|
||||
Returns:
|
||||
Human-readable string representation
|
||||
"""
|
||||
# LSP SymbolKind enum (1-indexed)
|
||||
kinds = {
|
||||
1: "file",
|
||||
2: "module",
|
||||
3: "namespace",
|
||||
4: "package",
|
||||
5: "class",
|
||||
6: "method",
|
||||
7: "property",
|
||||
8: "field",
|
||||
9: "constructor",
|
||||
10: "enum",
|
||||
11: "interface",
|
||||
12: "function",
|
||||
13: "variable",
|
||||
14: "constant",
|
||||
15: "string",
|
||||
16: "number",
|
||||
17: "boolean",
|
||||
18: "array",
|
||||
19: "object",
|
||||
20: "key",
|
||||
21: "null",
|
||||
22: "enum_member",
|
||||
23: "struct",
|
||||
24: "event",
|
||||
25: "operator",
|
||||
26: "type_parameter",
|
||||
}
|
||||
return kinds.get(kind, "unknown")
|
||||
|
||||
async def get_hover(self, symbol: CodeSymbolNode) -> Optional[str]:
|
||||
"""Get hover documentation for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: The code symbol to get hover info for
|
||||
|
||||
Returns:
|
||||
Hover documentation as string, or None if not available
|
||||
"""
|
||||
cache_key = f"hover:{symbol.id}"
|
||||
|
||||
if self._is_cached(cache_key, symbol.file_path):
|
||||
return self.cache[cache_key].data
|
||||
|
||||
hover_text: Optional[str] = None
|
||||
|
||||
if self.use_vscode_bridge:
|
||||
# Legacy: VSCode Bridge HTTP mode
|
||||
result = await self._request_vscode_bridge("get_hover", {
|
||||
"file_path": symbol.file_path,
|
||||
"line": symbol.range.start_line,
|
||||
"character": symbol.range.start_character,
|
||||
})
|
||||
|
||||
if result:
|
||||
hover_text = self._parse_hover_result(result)
|
||||
else:
|
||||
# Default: Standalone mode
|
||||
manager = await self._ensure_manager()
|
||||
hover_text = await manager.get_hover(
|
||||
file_path=symbol.file_path,
|
||||
line=symbol.range.start_line,
|
||||
character=symbol.range.start_character,
|
||||
)
|
||||
|
||||
self._cache(cache_key, symbol.file_path, hover_text)
|
||||
return hover_text
|
||||
|
||||
def _parse_hover_result(self, result: Any) -> Optional[str]:
|
||||
"""Parse hover result into string."""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
elif isinstance(result, list):
|
||||
parts = []
|
||||
for item in result:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
value = item.get("value", item.get("contents", ""))
|
||||
if value:
|
||||
parts.append(str(value))
|
||||
return "\n\n".join(parts) if parts else None
|
||||
elif isinstance(result, dict):
|
||||
contents = result.get("contents", result.get("value", ""))
|
||||
if isinstance(contents, str):
|
||||
return contents
|
||||
elif isinstance(contents, list):
|
||||
parts = []
|
||||
for c in contents:
|
||||
if isinstance(c, str):
|
||||
parts.append(c)
|
||||
elif isinstance(c, dict):
|
||||
parts.append(str(c.get("value", "")))
|
||||
return "\n\n".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
async def __aenter__(self) -> "LspBridge":
|
||||
"""Async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit - close connections."""
|
||||
await self.close()
|
||||
|
||||
|
||||
# Simple test
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
async def test_lsp_bridge():
|
||||
"""Simple test of LspBridge functionality."""
|
||||
print("Testing LspBridge (Standalone Mode)...")
|
||||
print(f"Timeout: {LspBridge.DEFAULT_TIMEOUT}s")
|
||||
print(f"Cache TTL: {LspBridge.DEFAULT_CACHE_TTL}s")
|
||||
print()
|
||||
|
||||
# Create a test symbol pointing to this file
|
||||
test_file = os.path.abspath(__file__)
|
||||
test_symbol = CodeSymbolNode(
|
||||
id=f"{test_file}:LspBridge:96",
|
||||
name="LspBridge",
|
||||
kind="class",
|
||||
file_path=test_file,
|
||||
range=Range(
|
||||
start_line=96,
|
||||
start_character=1,
|
||||
end_line=200,
|
||||
end_character=1,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"Test symbol: {test_symbol.name} in {os.path.basename(test_symbol.file_path)}")
|
||||
print()
|
||||
|
||||
# Use standalone mode (default)
|
||||
async with LspBridge(
|
||||
workspace_root=str(Path(__file__).parent.parent.parent.parent),
|
||||
) as bridge:
|
||||
print("1. Testing get_document_symbols...")
|
||||
try:
|
||||
symbols = await bridge.get_document_symbols(test_file)
|
||||
print(f" Found {len(symbols)} symbols")
|
||||
for sym in symbols[:5]:
|
||||
print(f" - {sym.get('name')} ({sym.get('kind')})")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("2. Testing get_definition...")
|
||||
try:
|
||||
definition = await bridge.get_definition(test_symbol)
|
||||
if definition:
|
||||
print(f" Definition: {os.path.basename(definition.file_path)}:{definition.line}")
|
||||
else:
|
||||
print(" No definition found")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("3. Testing get_references...")
|
||||
try:
|
||||
refs = await bridge.get_references(test_symbol)
|
||||
print(f" Found {len(refs)} references")
|
||||
for ref in refs[:3]:
|
||||
print(f" - {os.path.basename(ref.file_path)}:{ref.line}")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("4. Testing get_hover...")
|
||||
try:
|
||||
hover = await bridge.get_hover(test_symbol)
|
||||
if hover:
|
||||
print(f" Hover: {hover[:100]}...")
|
||||
else:
|
||||
print(" No hover info found")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("5. Testing get_call_hierarchy...")
|
||||
try:
|
||||
calls = await bridge.get_call_hierarchy(test_symbol)
|
||||
print(f" Found {len(calls)} call hierarchy items")
|
||||
for call in calls[:3]:
|
||||
print(f" - {call.name} in {os.path.basename(call.file_path)}")
|
||||
except Exception as e:
|
||||
print(f" Error: {e}")
|
||||
|
||||
print()
|
||||
print("6. Testing cache...")
|
||||
print(f" Cache entries: {len(bridge.cache)}")
|
||||
for key in list(bridge.cache.keys())[:5]:
|
||||
print(f" - {key}")
|
||||
|
||||
print()
|
||||
print("Test complete!")
|
||||
|
||||
# Run the test
|
||||
# Note: On Windows, use default ProactorEventLoop (supports subprocess creation)
|
||||
|
||||
asyncio.run(test_lsp_bridge())
|
||||
375
codex-lens/src/codexlens/lsp/lsp_graph_builder.py
Normal file
375
codex-lens/src/codexlens/lsp/lsp_graph_builder.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Graph builder for code association graphs via LSP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from codexlens.hybrid_search.data_structures import (
|
||||
CallHierarchyItem,
|
||||
CodeAssociationGraph,
|
||||
CodeSymbolNode,
|
||||
Range,
|
||||
)
|
||||
from codexlens.lsp.lsp_bridge import (
|
||||
Location,
|
||||
LspBridge,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LspGraphBuilder:
|
||||
"""Builds code association graph by expanding from seed symbols using LSP."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_depth: int = 2,
|
||||
max_nodes: int = 100,
|
||||
max_concurrent: int = 10,
|
||||
):
|
||||
"""Initialize GraphBuilder.
|
||||
|
||||
Args:
|
||||
max_depth: Maximum depth for BFS expansion from seeds.
|
||||
max_nodes: Maximum number of nodes in the graph.
|
||||
max_concurrent: Maximum concurrent LSP requests.
|
||||
"""
|
||||
self.max_depth = max_depth
|
||||
self.max_nodes = max_nodes
|
||||
self.max_concurrent = max_concurrent
|
||||
# Cache for document symbols per file (avoids per-location hover queries)
|
||||
self._document_symbols_cache: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
async def build_from_seeds(
|
||||
self,
|
||||
seeds: List[CodeSymbolNode],
|
||||
lsp_bridge: LspBridge,
|
||||
) -> CodeAssociationGraph:
|
||||
"""Build association graph by BFS expansion from seeds.
|
||||
|
||||
For each seed:
|
||||
1. Get references via LSP
|
||||
2. Get call hierarchy via LSP
|
||||
3. Add nodes and edges to graph
|
||||
4. Continue expanding until max_depth or max_nodes reached
|
||||
|
||||
Args:
|
||||
seeds: Initial seed symbols to expand from.
|
||||
lsp_bridge: LSP bridge for querying language servers.
|
||||
|
||||
Returns:
|
||||
CodeAssociationGraph with expanded nodes and relationships.
|
||||
"""
|
||||
graph = CodeAssociationGraph()
|
||||
visited: Set[str] = set()
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
# Initialize queue with seeds at depth 0
|
||||
queue: List[Tuple[CodeSymbolNode, int]] = [(s, 0) for s in seeds]
|
||||
|
||||
# Add seed nodes to graph
|
||||
for seed in seeds:
|
||||
graph.add_node(seed)
|
||||
|
||||
# BFS expansion
|
||||
while queue and len(graph.nodes) < self.max_nodes:
|
||||
# Take a batch of nodes from queue
|
||||
batch_size = min(self.max_concurrent, len(queue))
|
||||
batch = queue[:batch_size]
|
||||
queue = queue[batch_size:]
|
||||
|
||||
# Expand nodes in parallel
|
||||
tasks = [
|
||||
self._expand_node(
|
||||
node, depth, graph, lsp_bridge, visited, semaphore
|
||||
)
|
||||
for node, depth in batch
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results and add new nodes to queue
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning("Error expanding node: %s", result)
|
||||
continue
|
||||
if result:
|
||||
# Add new nodes to queue if not at max depth
|
||||
for new_node, new_depth in result:
|
||||
if (
|
||||
new_depth <= self.max_depth
|
||||
and len(graph.nodes) < self.max_nodes
|
||||
):
|
||||
queue.append((new_node, new_depth))
|
||||
|
||||
return graph
|
||||
|
||||
async def _expand_node(
|
||||
self,
|
||||
node: CodeSymbolNode,
|
||||
depth: int,
|
||||
graph: CodeAssociationGraph,
|
||||
lsp_bridge: LspBridge,
|
||||
visited: Set[str],
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> List[Tuple[CodeSymbolNode, int]]:
|
||||
"""Expand a single node, return new nodes to process.
|
||||
|
||||
Args:
|
||||
node: Node to expand.
|
||||
depth: Current depth in BFS.
|
||||
graph: Graph to add nodes and edges to.
|
||||
lsp_bridge: LSP bridge for queries.
|
||||
visited: Set of visited node IDs.
|
||||
semaphore: Semaphore for concurrency control.
|
||||
|
||||
Returns:
|
||||
List of (new_node, new_depth) tuples to add to queue.
|
||||
"""
|
||||
# Skip if already visited or at max depth
|
||||
if node.id in visited:
|
||||
return []
|
||||
if depth > self.max_depth:
|
||||
return []
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
return []
|
||||
|
||||
visited.add(node.id)
|
||||
new_nodes: List[Tuple[CodeSymbolNode, int]] = []
|
||||
|
||||
async with semaphore:
|
||||
# Get relationships in parallel
|
||||
try:
|
||||
refs_task = lsp_bridge.get_references(node)
|
||||
calls_task = lsp_bridge.get_call_hierarchy(node)
|
||||
|
||||
refs, calls = await asyncio.gather(
|
||||
refs_task, calls_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle reference results
|
||||
if isinstance(refs, Exception):
|
||||
logger.debug(
|
||||
"Failed to get references for %s: %s", node.id, refs
|
||||
)
|
||||
refs = []
|
||||
|
||||
# Handle call hierarchy results
|
||||
if isinstance(calls, Exception):
|
||||
logger.debug(
|
||||
"Failed to get call hierarchy for %s: %s",
|
||||
node.id,
|
||||
calls,
|
||||
)
|
||||
calls = []
|
||||
|
||||
# Process references
|
||||
for ref in refs:
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
break
|
||||
|
||||
ref_node = await self._location_to_node(ref, lsp_bridge)
|
||||
if ref_node and ref_node.id != node.id:
|
||||
if ref_node.id not in graph.nodes:
|
||||
graph.add_node(ref_node)
|
||||
new_nodes.append((ref_node, depth + 1))
|
||||
# Use add_edge since both nodes should exist now
|
||||
graph.add_edge(node.id, ref_node.id, "references")
|
||||
|
||||
# Process call hierarchy (incoming calls)
|
||||
for call in calls:
|
||||
if len(graph.nodes) >= self.max_nodes:
|
||||
break
|
||||
|
||||
call_node = await self._call_hierarchy_to_node(
|
||||
call, lsp_bridge
|
||||
)
|
||||
if call_node and call_node.id != node.id:
|
||||
if call_node.id not in graph.nodes:
|
||||
graph.add_node(call_node)
|
||||
new_nodes.append((call_node, depth + 1))
|
||||
# Incoming call: call_node calls node
|
||||
graph.add_edge(call_node.id, node.id, "calls")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error during node expansion for %s: %s", node.id, e
|
||||
)
|
||||
|
||||
return new_nodes
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the document symbols cache.
|
||||
|
||||
Call this between searches to free memory and ensure fresh data.
|
||||
"""
|
||||
self._document_symbols_cache.clear()
|
||||
|
||||
async def _get_symbol_at_location(
|
||||
self,
|
||||
file_path: str,
|
||||
line: int,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Find symbol at location using cached document symbols.
|
||||
|
||||
This is much more efficient than individual hover queries because
|
||||
document symbols are fetched once per file and cached.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file.
|
||||
line: Line number (1-based).
|
||||
lsp_bridge: LSP bridge for fetching document symbols.
|
||||
|
||||
Returns:
|
||||
Symbol dictionary with name, kind, range, etc., or None if not found.
|
||||
"""
|
||||
# Get or fetch document symbols for this file
|
||||
if file_path not in self._document_symbols_cache:
|
||||
symbols = await lsp_bridge.get_document_symbols(file_path)
|
||||
self._document_symbols_cache[file_path] = symbols
|
||||
|
||||
symbols = self._document_symbols_cache[file_path]
|
||||
|
||||
# Find symbol containing this line (best match = smallest range)
|
||||
best_match: Optional[Dict[str, Any]] = None
|
||||
best_range_size = float("inf")
|
||||
|
||||
for symbol in symbols:
|
||||
sym_range = symbol.get("range", {})
|
||||
start = sym_range.get("start", {})
|
||||
end = sym_range.get("end", {})
|
||||
|
||||
# LSP ranges are 0-based, our line is 1-based
|
||||
start_line = start.get("line", 0) + 1
|
||||
end_line = end.get("line", 0) + 1
|
||||
|
||||
if start_line <= line <= end_line:
|
||||
range_size = end_line - start_line
|
||||
if range_size < best_range_size:
|
||||
best_match = symbol
|
||||
best_range_size = range_size
|
||||
|
||||
return best_match
|
||||
|
||||
async def _location_to_node(
|
||||
self,
|
||||
location: Location,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[CodeSymbolNode]:
|
||||
"""Convert LSP location to CodeSymbolNode.
|
||||
|
||||
Uses cached document symbols instead of individual hover queries
|
||||
for better performance.
|
||||
|
||||
Args:
|
||||
location: LSP location to convert.
|
||||
lsp_bridge: LSP bridge for additional queries.
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode or None if conversion fails.
|
||||
"""
|
||||
try:
|
||||
file_path = location.file_path
|
||||
start_line = location.line
|
||||
|
||||
# Try to find symbol info from cached document symbols (fast)
|
||||
symbol_info = await self._get_symbol_at_location(
|
||||
file_path, start_line, lsp_bridge
|
||||
)
|
||||
|
||||
if symbol_info:
|
||||
name = symbol_info.get("name", f"symbol_L{start_line}")
|
||||
kind = symbol_info.get("kind", "unknown")
|
||||
|
||||
# Extract range from symbol if available
|
||||
sym_range = symbol_info.get("range", {})
|
||||
start = sym_range.get("start", {})
|
||||
end = sym_range.get("end", {})
|
||||
|
||||
location_range = Range(
|
||||
start_line=start.get("line", start_line - 1) + 1,
|
||||
start_character=start.get("character", location.character - 1) + 1,
|
||||
end_line=end.get("line", start_line - 1) + 1,
|
||||
end_character=end.get("character", location.character - 1) + 1,
|
||||
)
|
||||
else:
|
||||
# Fallback to basic node without symbol info
|
||||
name = f"symbol_L{start_line}"
|
||||
kind = "unknown"
|
||||
location_range = Range(
|
||||
start_line=location.line,
|
||||
start_character=location.character,
|
||||
end_line=location.line,
|
||||
end_character=location.character,
|
||||
)
|
||||
|
||||
node_id = self._create_node_id(file_path, name, start_line)
|
||||
|
||||
return CodeSymbolNode(
|
||||
id=node_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=location_range,
|
||||
docstring="", # Skip hover for performance
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to convert location to node: %s", e)
|
||||
return None
|
||||
|
||||
async def _call_hierarchy_to_node(
|
||||
self,
|
||||
call_item: CallHierarchyItem,
|
||||
lsp_bridge: LspBridge,
|
||||
) -> Optional[CodeSymbolNode]:
|
||||
"""Convert CallHierarchyItem to CodeSymbolNode.
|
||||
|
||||
Args:
|
||||
call_item: Call hierarchy item to convert.
|
||||
lsp_bridge: LSP bridge (unused, kept for API consistency).
|
||||
|
||||
Returns:
|
||||
CodeSymbolNode or None if conversion fails.
|
||||
"""
|
||||
try:
|
||||
file_path = call_item.file_path
|
||||
name = call_item.name
|
||||
start_line = call_item.range.start_line
|
||||
# CallHierarchyItem.kind is already a string
|
||||
kind = call_item.kind
|
||||
|
||||
node_id = self._create_node_id(file_path, name, start_line)
|
||||
|
||||
return CodeSymbolNode(
|
||||
id=node_id,
|
||||
name=name,
|
||||
kind=kind,
|
||||
file_path=file_path,
|
||||
range=call_item.range,
|
||||
docstring=call_item.detail or "",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to convert call hierarchy item to node: %s", e
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_node_id(
|
||||
self, file_path: str, name: str, line: int
|
||||
) -> str:
|
||||
"""Create unique node ID.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
name: Symbol name.
|
||||
line: Line number (0-based).
|
||||
|
||||
Returns:
|
||||
Unique node ID string.
|
||||
"""
|
||||
return f"{file_path}:{name}:{line}"
|
||||
1049
codex-lens/src/codexlens/lsp/standalone_manager.py
Normal file
1049
codex-lens/src/codexlens/lsp/standalone_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -49,6 +49,13 @@ from codexlens.search.ranking import (
|
||||
)
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
|
||||
# Optional LSP imports (for real-time graph expansion)
|
||||
try:
|
||||
from codexlens.lsp import LspBridge, LspGraphBuilder
|
||||
HAS_LSP = True
|
||||
except ImportError:
|
||||
HAS_LSP = False
|
||||
|
||||
|
||||
# Three-way fusion weights (FTS + Vector + SPLADE)
|
||||
THREE_WAY_WEIGHTS = {
|
||||
@@ -113,6 +120,9 @@ class HybridSearchEngine:
|
||||
enable_vector: bool = False,
|
||||
pure_vector: bool = False,
|
||||
enable_splade: bool = False,
|
||||
enable_lsp_graph: bool = False,
|
||||
lsp_max_depth: int = 1,
|
||||
lsp_max_nodes: int = 20,
|
||||
) -> List[SearchResult]:
|
||||
"""Execute hybrid search with parallel retrieval and RRF fusion.
|
||||
|
||||
@@ -124,6 +134,9 @@ class HybridSearchEngine:
|
||||
enable_vector: Enable vector search (default False)
|
||||
pure_vector: If True, only use vector search without FTS fallback (default False)
|
||||
enable_splade: If True, force SPLADE sparse neural search (default False)
|
||||
enable_lsp_graph: If True, enable real-time LSP graph expansion (default False)
|
||||
lsp_max_depth: Maximum depth for LSP graph BFS expansion (default 1)
|
||||
lsp_max_nodes: Maximum nodes to collect in LSP graph (default 20)
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by fusion score
|
||||
@@ -140,6 +153,9 @@ class HybridSearchEngine:
|
||||
>>> # SPLADE sparse neural search
|
||||
>>> results = engine.search(Path("project/_index.db"), "auth flow",
|
||||
... enable_splade=True, enable_vector=True)
|
||||
>>> # With LSP graph expansion (real-time)
|
||||
>>> results = engine.search(Path("project/_index.db"), "auth flow",
|
||||
... enable_vector=True, enable_lsp_graph=True)
|
||||
>>> for r in results[:5]:
|
||||
... print(f"{r.path}: {r.score:.3f}")
|
||||
"""
|
||||
@@ -228,9 +244,21 @@ class HybridSearchEngine:
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
|
||||
# Add LSP graph expansion if requested and available
|
||||
if enable_lsp_graph and HAS_LSP:
|
||||
backends["lsp_graph"] = True
|
||||
elif enable_lsp_graph and not HAS_LSP:
|
||||
self.logger.warning(
|
||||
"LSP graph search requested but dependencies not available. "
|
||||
"Install: pip install aiohttp"
|
||||
)
|
||||
|
||||
# Execute parallel searches
|
||||
with timer("parallel_search_total", self.logger):
|
||||
results_map = self._search_parallel(index_path, query, backends, limit, vector_category)
|
||||
results_map = self._search_parallel(
|
||||
index_path, query, backends, limit, vector_category,
|
||||
lsp_max_depth, lsp_max_nodes
|
||||
)
|
||||
|
||||
# Provide helpful message if pure-vector mode returns no results
|
||||
if pure_vector and enable_vector and len(results_map.get("vector", [])) == 0:
|
||||
@@ -427,6 +455,8 @@ class HybridSearchEngine:
|
||||
backends: Dict[str, bool],
|
||||
limit: int,
|
||||
category: Optional[str] = None,
|
||||
lsp_max_depth: int = 1,
|
||||
lsp_max_nodes: int = 20,
|
||||
) -> Dict[str, List[SearchResult]]:
|
||||
"""Execute parallel searches across enabled backends.
|
||||
|
||||
@@ -436,6 +466,8 @@ class HybridSearchEngine:
|
||||
backends: Dictionary of backend name to enabled flag
|
||||
limit: Results limit per backend
|
||||
category: Optional category filter for vector search ('code' or 'doc')
|
||||
lsp_max_depth: Maximum depth for LSP graph BFS expansion (default 1)
|
||||
lsp_max_nodes: Maximum nodes to collect in LSP graph (default 20)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping source name to results list
|
||||
@@ -477,6 +509,14 @@ class HybridSearchEngine:
|
||||
)
|
||||
future_to_source[future] = "splade"
|
||||
|
||||
if backends.get("lsp_graph"):
|
||||
submit_times["lsp_graph"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
self._search_lsp_graph, index_path, query, limit,
|
||||
lsp_max_depth, lsp_max_nodes
|
||||
)
|
||||
future_to_source[future] = "lsp_graph"
|
||||
|
||||
# Collect results as they complete with timeout protection
|
||||
try:
|
||||
for future in as_completed(future_to_source, timeout=30.0):
|
||||
@@ -1211,7 +1251,159 @@ class HybridSearchEngine:
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.debug("SPLADE search error: %s", exc)
|
||||
return []
|
||||
|
||||
def _search_lsp_graph(
|
||||
self,
|
||||
index_path: Path,
|
||||
query: str,
|
||||
limit: int,
|
||||
max_depth: int = 1,
|
||||
max_nodes: int = 20,
|
||||
) -> List[SearchResult]:
|
||||
"""Execute LSP-based graph expansion search.
|
||||
|
||||
Uses real-time LSP to expand from seed results and find related code.
|
||||
This provides accurate, up-to-date code relationships.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
query: Natural language query string
|
||||
limit: Maximum results
|
||||
max_depth: Maximum depth for LSP graph BFS expansion (default 1)
|
||||
max_nodes: Maximum nodes to collect in LSP graph (default 20)
|
||||
|
||||
Returns:
|
||||
List of SearchResult from graph expansion
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
if not HAS_LSP:
|
||||
self.logger.debug("LSP dependencies not available")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Try multiple seed sources in priority order
|
||||
seeds = []
|
||||
seed_source = "none"
|
||||
|
||||
# 1. Try vector search first (best semantic match)
|
||||
seeds = self._search_vector(index_path, query, limit=3, category="code")
|
||||
if seeds:
|
||||
seed_source = "vector"
|
||||
|
||||
# 2. Fallback to SPLADE if vector returns nothing
|
||||
if not seeds:
|
||||
self.logger.debug("Vector search returned no seeds, trying SPLADE")
|
||||
seeds = self._search_splade(index_path, query, limit=3)
|
||||
if seeds:
|
||||
seed_source = "splade"
|
||||
|
||||
# 3. Fallback to exact FTS if SPLADE also fails
|
||||
if not seeds:
|
||||
self.logger.debug("SPLADE returned no seeds, trying exact FTS")
|
||||
seeds = self._search_exact(index_path, query, limit=3)
|
||||
if seeds:
|
||||
seed_source = "exact_fts"
|
||||
|
||||
# 4. No seeds available from any source
|
||||
if not seeds:
|
||||
self.logger.debug("No seed results available for LSP graph expansion")
|
||||
return []
|
||||
|
||||
self.logger.debug(
|
||||
"LSP graph expansion using %d seeds from %s",
|
||||
len(seeds),
|
||||
seed_source,
|
||||
)
|
||||
|
||||
# Convert SearchResult to CodeSymbolNode for LSP processing
|
||||
from codexlens.hybrid_search.data_structures import CodeSymbolNode, Range
|
||||
|
||||
seed_nodes = []
|
||||
for seed in seeds:
|
||||
try:
|
||||
node = CodeSymbolNode(
|
||||
id=f"{seed.path}:{seed.symbol_name or 'unknown'}:{seed.start_line or 0}",
|
||||
name=seed.symbol_name or "unknown",
|
||||
kind=seed.symbol_kind or "unknown",
|
||||
file_path=seed.path,
|
||||
range=Range(
|
||||
start_line=seed.start_line or 1,
|
||||
start_character=0,
|
||||
end_line=seed.end_line or seed.start_line or 1,
|
||||
end_character=0,
|
||||
),
|
||||
raw_code=seed.content or "",
|
||||
docstring=seed.excerpt or "",
|
||||
)
|
||||
seed_nodes.append(node)
|
||||
except Exception as e:
|
||||
self.logger.debug("Failed to create seed node: %s", e)
|
||||
continue
|
||||
|
||||
if not seed_nodes:
|
||||
return []
|
||||
|
||||
# Run async LSP expansion in sync context
|
||||
async def expand_graph():
|
||||
async with LspBridge() as bridge:
|
||||
builder = LspGraphBuilder(max_depth=max_depth, max_nodes=max_nodes)
|
||||
graph = await builder.build_from_seeds(seed_nodes, bridge)
|
||||
return graph
|
||||
|
||||
# Run the async code
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# Already in async context - use run_coroutine_threadsafe
|
||||
import concurrent.futures
|
||||
future = asyncio.run_coroutine_threadsafe(expand_graph(), loop)
|
||||
graph = future.result(timeout=5.0)
|
||||
else:
|
||||
graph = loop.run_until_complete(expand_graph())
|
||||
except RuntimeError:
|
||||
# No event loop - create new one
|
||||
graph = asyncio.run(expand_graph())
|
||||
|
||||
# Convert graph nodes to SearchResult
|
||||
# Create set of seed identifiers for fast lookup
|
||||
seed_ids = set()
|
||||
for seed in seeds:
|
||||
seed_id = f"{seed.path}:{seed.symbol_name or 'unknown'}:{seed.start_line or 0}"
|
||||
seed_ids.add(seed_id)
|
||||
|
||||
results = []
|
||||
for node_id, node in graph.nodes.items():
|
||||
# Skip seed nodes using ID comparison (already in other results)
|
||||
if node_id in seed_ids or node.id in seed_ids:
|
||||
continue
|
||||
|
||||
# Calculate score based on graph position
|
||||
# Nodes closer to seeds get higher scores
|
||||
depth = 1 # Simple heuristic, could be improved
|
||||
score = 0.8 / (1 + depth) # Score decreases with depth
|
||||
|
||||
results.append(SearchResult(
|
||||
path=node.file_path,
|
||||
score=score,
|
||||
excerpt=node.docstring[:200] if node.docstring else node.raw_code[:200] if node.raw_code else "",
|
||||
content=node.raw_code,
|
||||
symbol=None,
|
||||
metadata={"lsp_node_id": node_id, "lsp_kind": node.kind},
|
||||
start_line=node.range.start_line,
|
||||
end_line=node.range.end_line,
|
||||
symbol_name=node.name,
|
||||
symbol_kind=node.kind,
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
return results[:limit]
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.debug("LSP graph search error: %s", exc)
|
||||
return []
|
||||
|
||||
@@ -17,15 +17,17 @@ from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
# Default RRF weights for SPLADE-based hybrid search
|
||||
DEFAULT_WEIGHTS = {
|
||||
"splade": 0.4, # Replaces exact(0.3) + fuzzy(0.1)
|
||||
"vector": 0.6,
|
||||
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
}
|
||||
|
||||
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
||||
FTS_FALLBACK_WEIGHTS = {
|
||||
"exact": 0.3,
|
||||
"exact": 0.25,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.6,
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user