Files

603 lines
19 KiB
Python

"""Core data structures for the hybrid search system.
This module defines the fundamental data structures used throughout the
hybrid search pipeline, including code symbol representations, association
graphs, and clustered search results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
import networkx as nx
@dataclass
class Range:
"""Position range within a source file.
Attributes:
start_line: Starting line number (0-based).
start_character: Starting character offset within the line.
end_line: Ending line number (0-based).
end_character: Ending character offset within the line.
"""
start_line: int
start_character: int
end_line: int
end_character: int
def __post_init__(self) -> None:
"""Validate range values."""
if self.start_line < 0:
raise ValueError("start_line must be >= 0")
if self.start_character < 0:
raise ValueError("start_character must be >= 0")
if self.end_line < 0:
raise ValueError("end_line must be >= 0")
if self.end_character < 0:
raise ValueError("end_character must be >= 0")
if self.end_line < self.start_line:
raise ValueError("end_line must be >= start_line")
if self.end_line == self.start_line and self.end_character < self.start_character:
raise ValueError("end_character must be >= start_character on the same line")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
"start": {"line": self.start_line, "character": self.start_character},
"end": {"line": self.end_line, "character": self.end_character},
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Range:
"""Create Range from dictionary representation."""
return cls(
start_line=data["start"]["line"],
start_character=data["start"]["character"],
end_line=data["end"]["line"],
end_character=data["end"]["character"],
)
@classmethod
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
"""Create Range from LSP Range object.
LSP Range format:
{"start": {"line": int, "character": int},
"end": {"line": int, "character": int}}
"""
return cls(
start_line=lsp_range["start"]["line"],
start_character=lsp_range["start"]["character"],
end_line=lsp_range["end"]["line"],
end_character=lsp_range["end"]["character"],
)
@dataclass
class CallHierarchyItem:
"""LSP CallHierarchyItem for representing callers/callees.
Attributes:
name: Symbol name (function, method, class name).
kind: Symbol kind (function, method, class, etc.).
file_path: Absolute file path where the symbol is defined.
range: Position range in the source file.
detail: Optional additional detail about the symbol.
"""
name: str
kind: str
file_path: str
range: Range
detail: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
}
if self.detail:
result["detail"] = self.detail
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from dictionary representation."""
return cls(
name=data["name"],
kind=data["kind"],
file_path=data["file_path"],
range=Range.from_dict(data["range"]),
detail=data.get("detail"),
)
@dataclass
class CodeSymbolNode:
"""Graph node representing a code symbol.
Attributes:
id: Unique identifier in format 'file_path:name:line'.
name: Symbol name (function, class, variable name).
kind: Symbol kind (function, class, method, variable, etc.).
file_path: Absolute file path where symbol is defined.
range: Start/end position in the source file.
embedding: Optional vector embedding for semantic search.
raw_code: Raw source code of the symbol.
docstring: Documentation string (if available).
score: Ranking score (used during reranking).
"""
id: str
name: str
kind: str
file_path: str
range: Range
embedding: Optional[List[float]] = None
raw_code: str = ""
docstring: str = ""
score: float = 0.0
def __post_init__(self) -> None:
"""Validate required fields."""
if not self.id:
raise ValueError("id cannot be empty")
if not self.name:
raise ValueError("name cannot be empty")
if not self.kind:
raise ValueError("kind cannot be empty")
if not self.file_path:
raise ValueError("file_path cannot be empty")
def __hash__(self) -> int:
"""Hash based on unique ID."""
return hash(self.id)
def __eq__(self, other: object) -> bool:
"""Equality based on unique ID."""
if not isinstance(other, CodeSymbolNode):
return False
return self.id == other.id
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"id": self.id,
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
"score": self.score,
}
if self.raw_code:
result["raw_code"] = self.raw_code
if self.docstring:
result["docstring"] = self.docstring
# Exclude embedding from serialization (too large for JSON responses)
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
"""Create CodeSymbolNode from dictionary representation."""
return cls(
id=data["id"],
name=data["name"],
kind=data["kind"],
file_path=data["file_path"],
range=Range.from_dict(data["range"]),
embedding=data.get("embedding"),
raw_code=data.get("raw_code", ""),
docstring=data.get("docstring", ""),
score=data.get("score", 0.0),
)
@classmethod
def from_lsp_location(
cls,
uri: str,
name: str,
kind: str,
lsp_range: Dict[str, Any],
raw_code: str = "",
docstring: str = "",
) -> CodeSymbolNode:
"""Create CodeSymbolNode from LSP location data.
Args:
uri: File URI (file:// prefix will be stripped).
name: Symbol name.
kind: Symbol kind.
lsp_range: LSP Range object.
raw_code: Optional raw source code.
docstring: Optional documentation string.
Returns:
New CodeSymbolNode instance.
"""
# Strip file:// prefix if present
file_path = uri
if file_path.startswith("file://"):
file_path = file_path[7:]
# Handle Windows paths (file:///C:/...)
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
range_obj = Range.from_lsp_range(lsp_range)
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
return cls(
id=symbol_id,
name=name,
kind=kind,
file_path=file_path,
range=range_obj,
raw_code=raw_code,
docstring=docstring,
)
@classmethod
def create_id(cls, file_path: str, name: str, line: int) -> str:
"""Generate a unique symbol ID.
Args:
file_path: Absolute file path.
name: Symbol name.
line: Start line number.
Returns:
Unique ID string in format 'file_path:name:line'.
"""
return f"{file_path}:{name}:{line}"
@dataclass
class CodeAssociationGraph:
"""Graph of code relationships between symbols.
This graph represents the association between code symbols discovered
through LSP queries (references, call hierarchy, etc.).
Attributes:
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
edges: List of (from_id, to_id, relationship_type) tuples.
relationship_type: 'calls', 'references', 'inherits', 'imports'.
"""
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
edges: List[Tuple[str, str, str]] = field(default_factory=list)
def add_node(self, node: CodeSymbolNode) -> None:
"""Add a node to the graph.
Args:
node: CodeSymbolNode to add. If a node with the same ID exists,
it will be replaced.
"""
self.nodes[node.id] = node
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
"""Add an edge to the graph.
Args:
from_id: Source node ID.
to_id: Target node ID.
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
Raises:
ValueError: If from_id or to_id not in graph nodes.
"""
if from_id not in self.nodes:
raise ValueError(f"Source node '{from_id}' not found in graph")
if to_id not in self.nodes:
raise ValueError(f"Target node '{to_id}' not found in graph")
edge = (from_id, to_id, rel_type)
if edge not in self.edges:
self.edges.append(edge)
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
"""Add an edge without validating node existence.
Use this method during bulk graph construction where nodes may be
added after edges, or when performance is critical.
Args:
from_id: Source node ID.
to_id: Target node ID.
rel_type: Relationship type.
"""
edge = (from_id, to_id, rel_type)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
"""Get a node by ID.
Args:
node_id: Node ID to look up.
Returns:
CodeSymbolNode if found, None otherwise.
"""
return self.nodes.get(node_id)
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
"""Get neighboring nodes connected by outgoing edges.
Args:
node_id: Node ID to find neighbors for.
rel_type: Optional filter by relationship type.
Returns:
List of neighboring CodeSymbolNode objects.
"""
neighbors = []
for from_id, to_id, edge_rel in self.edges:
if from_id == node_id:
if rel_type is None or edge_rel == rel_type:
node = self.nodes.get(to_id)
if node:
neighbors.append(node)
return neighbors
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
"""Get nodes connected by incoming edges.
Args:
node_id: Node ID to find incoming connections for.
rel_type: Optional filter by relationship type.
Returns:
List of CodeSymbolNode objects with edges pointing to node_id.
"""
incoming = []
for from_id, to_id, edge_rel in self.edges:
if to_id == node_id:
if rel_type is None or edge_rel == rel_type:
node = self.nodes.get(from_id)
if node:
incoming.append(node)
return incoming
def to_networkx(self) -> "nx.DiGraph":
"""Convert to NetworkX DiGraph for graph algorithms.
Returns:
NetworkX directed graph with nodes and edges.
Raises:
ImportError: If networkx is not installed.
"""
try:
import networkx as nx
except ImportError:
raise ImportError(
"networkx is required for graph algorithms. "
"Install with: pip install networkx"
)
graph = nx.DiGraph()
# Add nodes with attributes
for node_id, node in self.nodes.items():
graph.add_node(
node_id,
name=node.name,
kind=node.kind,
file_path=node.file_path,
score=node.score,
)
# Add edges with relationship type
for from_id, to_id, rel_type in self.edges:
graph.add_edge(from_id, to_id, relationship=rel_type)
return graph
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization.
Returns:
Dictionary with 'nodes' and 'edges' keys.
"""
return {
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
"edges": [
{"from": from_id, "to": to_id, "relationship": rel_type}
for from_id, to_id, rel_type in self.edges
],
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
"""Create CodeAssociationGraph from dictionary representation.
Args:
data: Dictionary with 'nodes' and 'edges' keys.
Returns:
New CodeAssociationGraph instance.
"""
graph = cls()
# Load nodes
for node_id, node_data in data.get("nodes", {}).items():
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
# Load edges
for edge_data in data.get("edges", []):
graph.edges.append((
edge_data["from"],
edge_data["to"],
edge_data["relationship"],
))
return graph
def __len__(self) -> int:
"""Return the number of nodes in the graph."""
return len(self.nodes)
@dataclass
class SearchResultCluster:
"""Clustered search result containing related code symbols.
Search results are grouped into clusters based on graph community
detection or embedding similarity. Each cluster represents a
conceptually related group of code symbols.
Attributes:
cluster_id: Unique cluster identifier.
score: Cluster relevance score (max of symbol scores).
title: Human-readable cluster title/summary.
symbols: List of CodeSymbolNode in this cluster.
metadata: Additional cluster metadata.
"""
cluster_id: str
score: float
title: str
symbols: List[CodeSymbolNode] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate cluster fields."""
if not self.cluster_id:
raise ValueError("cluster_id cannot be empty")
if self.score < 0:
raise ValueError("score must be >= 0")
def add_symbol(self, symbol: CodeSymbolNode) -> None:
"""Add a symbol to the cluster.
Args:
symbol: CodeSymbolNode to add.
"""
self.symbols.append(symbol)
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
"""Get top N symbols by score.
Args:
n: Number of symbols to return.
Returns:
List of top N CodeSymbolNode objects sorted by score descending.
"""
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
return sorted_symbols[:n]
def update_score(self) -> None:
"""Update cluster score to max of symbol scores."""
if self.symbols:
self.score = max(s.score for s in self.symbols)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization.
Returns:
Dictionary representation of the cluster.
"""
return {
"cluster_id": self.cluster_id,
"score": self.score,
"title": self.title,
"symbols": [s.to_dict() for s in self.symbols],
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
"""Create SearchResultCluster from dictionary representation.
Args:
data: Dictionary with cluster data.
Returns:
New SearchResultCluster instance.
"""
return cls(
cluster_id=data["cluster_id"],
score=data["score"],
title=data["title"],
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
metadata=data.get("metadata", {}),
)
def __len__(self) -> int:
"""Return the number of symbols in the cluster."""
return len(self.symbols)
@dataclass
class CallHierarchyItem:
"""LSP CallHierarchyItem for representing callers/callees.
Attributes:
name: Symbol name (function, method, etc.).
kind: Symbol kind (function, method, etc.).
file_path: Absolute file path.
range: Position range in the file.
detail: Optional additional detail (e.g., signature).
"""
name: str
kind: str
file_path: str
range: Range
detail: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: Dict[str, Any] = {
"name": self.name,
"kind": self.kind,
"file_path": self.file_path,
"range": self.range.to_dict(),
}
if self.detail:
result["detail"] = self.detail
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from dictionary representation."""
return cls(
name=data.get("name", "unknown"),
kind=data.get("kind", "unknown"),
file_path=data.get("file_path", data.get("uri", "")),
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
detail=data.get("detail"),
)
@classmethod
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
"""Create CallHierarchyItem from LSP response format.
LSP uses 0-based line numbers and 'character' instead of 'char'.
"""
uri = data.get("uri", data.get("file_path", ""))
# Strip file:// prefix
file_path = uri
if file_path.startswith("file://"):
file_path = file_path[7:]
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
return cls(
name=data.get("name", "unknown"),
kind=str(data.get("kind", "unknown")),
file_path=file_path,
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
detail=data.get("detail"),
)