mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
603 lines
19 KiB
Python
603 lines
19 KiB
Python
"""Core data structures for the hybrid search system.
|
|
|
|
This module defines the fundamental data structures used throughout the
|
|
hybrid search pipeline, including code symbol representations, association
|
|
graphs, and clustered search results.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
import networkx as nx
|
|
|
|
|
|
@dataclass
|
|
class Range:
|
|
"""Position range within a source file.
|
|
|
|
Attributes:
|
|
start_line: Starting line number (0-based).
|
|
start_character: Starting character offset within the line.
|
|
end_line: Ending line number (0-based).
|
|
end_character: Ending character offset within the line.
|
|
"""
|
|
|
|
start_line: int
|
|
start_character: int
|
|
end_line: int
|
|
end_character: int
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Validate range values."""
|
|
if self.start_line < 0:
|
|
raise ValueError("start_line must be >= 0")
|
|
if self.start_character < 0:
|
|
raise ValueError("start_character must be >= 0")
|
|
if self.end_line < 0:
|
|
raise ValueError("end_line must be >= 0")
|
|
if self.end_character < 0:
|
|
raise ValueError("end_character must be >= 0")
|
|
if self.end_line < self.start_line:
|
|
raise ValueError("end_line must be >= start_line")
|
|
if self.end_line == self.start_line and self.end_character < self.start_character:
|
|
raise ValueError("end_character must be >= start_character on the same line")
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization."""
|
|
return {
|
|
"start": {"line": self.start_line, "character": self.start_character},
|
|
"end": {"line": self.end_line, "character": self.end_character},
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> Range:
|
|
"""Create Range from dictionary representation."""
|
|
return cls(
|
|
start_line=data["start"]["line"],
|
|
start_character=data["start"]["character"],
|
|
end_line=data["end"]["line"],
|
|
end_character=data["end"]["character"],
|
|
)
|
|
|
|
@classmethod
|
|
def from_lsp_range(cls, lsp_range: Dict[str, Any]) -> Range:
|
|
"""Create Range from LSP Range object.
|
|
|
|
LSP Range format:
|
|
{"start": {"line": int, "character": int},
|
|
"end": {"line": int, "character": int}}
|
|
"""
|
|
return cls(
|
|
start_line=lsp_range["start"]["line"],
|
|
start_character=lsp_range["start"]["character"],
|
|
end_line=lsp_range["end"]["line"],
|
|
end_character=lsp_range["end"]["character"],
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CallHierarchyItem:
|
|
"""LSP CallHierarchyItem for representing callers/callees.
|
|
|
|
Attributes:
|
|
name: Symbol name (function, method, class name).
|
|
kind: Symbol kind (function, method, class, etc.).
|
|
file_path: Absolute file path where the symbol is defined.
|
|
range: Position range in the source file.
|
|
detail: Optional additional detail about the symbol.
|
|
"""
|
|
|
|
name: str
|
|
kind: str
|
|
file_path: str
|
|
range: Range
|
|
detail: Optional[str] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization."""
|
|
result: Dict[str, Any] = {
|
|
"name": self.name,
|
|
"kind": self.kind,
|
|
"file_path": self.file_path,
|
|
"range": self.range.to_dict(),
|
|
}
|
|
if self.detail:
|
|
result["detail"] = self.detail
|
|
return result
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
|
"""Create CallHierarchyItem from dictionary representation."""
|
|
return cls(
|
|
name=data["name"],
|
|
kind=data["kind"],
|
|
file_path=data["file_path"],
|
|
range=Range.from_dict(data["range"]),
|
|
detail=data.get("detail"),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CodeSymbolNode:
|
|
"""Graph node representing a code symbol.
|
|
|
|
Attributes:
|
|
id: Unique identifier in format 'file_path:name:line'.
|
|
name: Symbol name (function, class, variable name).
|
|
kind: Symbol kind (function, class, method, variable, etc.).
|
|
file_path: Absolute file path where symbol is defined.
|
|
range: Start/end position in the source file.
|
|
embedding: Optional vector embedding for semantic search.
|
|
raw_code: Raw source code of the symbol.
|
|
docstring: Documentation string (if available).
|
|
score: Ranking score (used during reranking).
|
|
"""
|
|
|
|
id: str
|
|
name: str
|
|
kind: str
|
|
file_path: str
|
|
range: Range
|
|
embedding: Optional[List[float]] = None
|
|
raw_code: str = ""
|
|
docstring: str = ""
|
|
score: float = 0.0
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Validate required fields."""
|
|
if not self.id:
|
|
raise ValueError("id cannot be empty")
|
|
if not self.name:
|
|
raise ValueError("name cannot be empty")
|
|
if not self.kind:
|
|
raise ValueError("kind cannot be empty")
|
|
if not self.file_path:
|
|
raise ValueError("file_path cannot be empty")
|
|
|
|
def __hash__(self) -> int:
|
|
"""Hash based on unique ID."""
|
|
return hash(self.id)
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
"""Equality based on unique ID."""
|
|
if not isinstance(other, CodeSymbolNode):
|
|
return False
|
|
return self.id == other.id
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization."""
|
|
result: Dict[str, Any] = {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"kind": self.kind,
|
|
"file_path": self.file_path,
|
|
"range": self.range.to_dict(),
|
|
"score": self.score,
|
|
}
|
|
if self.raw_code:
|
|
result["raw_code"] = self.raw_code
|
|
if self.docstring:
|
|
result["docstring"] = self.docstring
|
|
# Exclude embedding from serialization (too large for JSON responses)
|
|
return result
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> CodeSymbolNode:
|
|
"""Create CodeSymbolNode from dictionary representation."""
|
|
return cls(
|
|
id=data["id"],
|
|
name=data["name"],
|
|
kind=data["kind"],
|
|
file_path=data["file_path"],
|
|
range=Range.from_dict(data["range"]),
|
|
embedding=data.get("embedding"),
|
|
raw_code=data.get("raw_code", ""),
|
|
docstring=data.get("docstring", ""),
|
|
score=data.get("score", 0.0),
|
|
)
|
|
|
|
@classmethod
|
|
def from_lsp_location(
|
|
cls,
|
|
uri: str,
|
|
name: str,
|
|
kind: str,
|
|
lsp_range: Dict[str, Any],
|
|
raw_code: str = "",
|
|
docstring: str = "",
|
|
) -> CodeSymbolNode:
|
|
"""Create CodeSymbolNode from LSP location data.
|
|
|
|
Args:
|
|
uri: File URI (file:// prefix will be stripped).
|
|
name: Symbol name.
|
|
kind: Symbol kind.
|
|
lsp_range: LSP Range object.
|
|
raw_code: Optional raw source code.
|
|
docstring: Optional documentation string.
|
|
|
|
Returns:
|
|
New CodeSymbolNode instance.
|
|
"""
|
|
# Strip file:// prefix if present
|
|
file_path = uri
|
|
if file_path.startswith("file://"):
|
|
file_path = file_path[7:]
|
|
# Handle Windows paths (file:///C:/...)
|
|
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
|
file_path = file_path[1:]
|
|
|
|
range_obj = Range.from_lsp_range(lsp_range)
|
|
symbol_id = f"{file_path}:{name}:{range_obj.start_line}"
|
|
|
|
return cls(
|
|
id=symbol_id,
|
|
name=name,
|
|
kind=kind,
|
|
file_path=file_path,
|
|
range=range_obj,
|
|
raw_code=raw_code,
|
|
docstring=docstring,
|
|
)
|
|
|
|
@classmethod
|
|
def create_id(cls, file_path: str, name: str, line: int) -> str:
|
|
"""Generate a unique symbol ID.
|
|
|
|
Args:
|
|
file_path: Absolute file path.
|
|
name: Symbol name.
|
|
line: Start line number.
|
|
|
|
Returns:
|
|
Unique ID string in format 'file_path:name:line'.
|
|
"""
|
|
return f"{file_path}:{name}:{line}"
|
|
|
|
|
|
@dataclass
|
|
class CodeAssociationGraph:
|
|
"""Graph of code relationships between symbols.
|
|
|
|
This graph represents the association between code symbols discovered
|
|
through LSP queries (references, call hierarchy, etc.).
|
|
|
|
Attributes:
|
|
nodes: Dictionary mapping symbol IDs to CodeSymbolNode objects.
|
|
edges: List of (from_id, to_id, relationship_type) tuples.
|
|
relationship_type: 'calls', 'references', 'inherits', 'imports'.
|
|
"""
|
|
|
|
nodes: Dict[str, CodeSymbolNode] = field(default_factory=dict)
|
|
edges: List[Tuple[str, str, str]] = field(default_factory=list)
|
|
|
|
def add_node(self, node: CodeSymbolNode) -> None:
|
|
"""Add a node to the graph.
|
|
|
|
Args:
|
|
node: CodeSymbolNode to add. If a node with the same ID exists,
|
|
it will be replaced.
|
|
"""
|
|
self.nodes[node.id] = node
|
|
|
|
def add_edge(self, from_id: str, to_id: str, rel_type: str) -> None:
|
|
"""Add an edge to the graph.
|
|
|
|
Args:
|
|
from_id: Source node ID.
|
|
to_id: Target node ID.
|
|
rel_type: Relationship type ('calls', 'references', 'inherits', 'imports').
|
|
|
|
Raises:
|
|
ValueError: If from_id or to_id not in graph nodes.
|
|
"""
|
|
if from_id not in self.nodes:
|
|
raise ValueError(f"Source node '{from_id}' not found in graph")
|
|
if to_id not in self.nodes:
|
|
raise ValueError(f"Target node '{to_id}' not found in graph")
|
|
|
|
edge = (from_id, to_id, rel_type)
|
|
if edge not in self.edges:
|
|
self.edges.append(edge)
|
|
|
|
def add_edge_unchecked(self, from_id: str, to_id: str, rel_type: str) -> None:
|
|
"""Add an edge without validating node existence.
|
|
|
|
Use this method during bulk graph construction where nodes may be
|
|
added after edges, or when performance is critical.
|
|
|
|
Args:
|
|
from_id: Source node ID.
|
|
to_id: Target node ID.
|
|
rel_type: Relationship type.
|
|
"""
|
|
edge = (from_id, to_id, rel_type)
|
|
if edge not in self.edges:
|
|
self.edges.append(edge)
|
|
|
|
def get_node(self, node_id: str) -> Optional[CodeSymbolNode]:
|
|
"""Get a node by ID.
|
|
|
|
Args:
|
|
node_id: Node ID to look up.
|
|
|
|
Returns:
|
|
CodeSymbolNode if found, None otherwise.
|
|
"""
|
|
return self.nodes.get(node_id)
|
|
|
|
def get_neighbors(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
|
"""Get neighboring nodes connected by outgoing edges.
|
|
|
|
Args:
|
|
node_id: Node ID to find neighbors for.
|
|
rel_type: Optional filter by relationship type.
|
|
|
|
Returns:
|
|
List of neighboring CodeSymbolNode objects.
|
|
"""
|
|
neighbors = []
|
|
for from_id, to_id, edge_rel in self.edges:
|
|
if from_id == node_id:
|
|
if rel_type is None or edge_rel == rel_type:
|
|
node = self.nodes.get(to_id)
|
|
if node:
|
|
neighbors.append(node)
|
|
return neighbors
|
|
|
|
def get_incoming(self, node_id: str, rel_type: Optional[str] = None) -> List[CodeSymbolNode]:
|
|
"""Get nodes connected by incoming edges.
|
|
|
|
Args:
|
|
node_id: Node ID to find incoming connections for.
|
|
rel_type: Optional filter by relationship type.
|
|
|
|
Returns:
|
|
List of CodeSymbolNode objects with edges pointing to node_id.
|
|
"""
|
|
incoming = []
|
|
for from_id, to_id, edge_rel in self.edges:
|
|
if to_id == node_id:
|
|
if rel_type is None or edge_rel == rel_type:
|
|
node = self.nodes.get(from_id)
|
|
if node:
|
|
incoming.append(node)
|
|
return incoming
|
|
|
|
def to_networkx(self) -> "nx.DiGraph":
|
|
"""Convert to NetworkX DiGraph for graph algorithms.
|
|
|
|
Returns:
|
|
NetworkX directed graph with nodes and edges.
|
|
|
|
Raises:
|
|
ImportError: If networkx is not installed.
|
|
"""
|
|
try:
|
|
import networkx as nx
|
|
except ImportError:
|
|
raise ImportError(
|
|
"networkx is required for graph algorithms. "
|
|
"Install with: pip install networkx"
|
|
)
|
|
|
|
graph = nx.DiGraph()
|
|
|
|
# Add nodes with attributes
|
|
for node_id, node in self.nodes.items():
|
|
graph.add_node(
|
|
node_id,
|
|
name=node.name,
|
|
kind=node.kind,
|
|
file_path=node.file_path,
|
|
score=node.score,
|
|
)
|
|
|
|
# Add edges with relationship type
|
|
for from_id, to_id, rel_type in self.edges:
|
|
graph.add_edge(from_id, to_id, relationship=rel_type)
|
|
|
|
return graph
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization.
|
|
|
|
Returns:
|
|
Dictionary with 'nodes' and 'edges' keys.
|
|
"""
|
|
return {
|
|
"nodes": {node_id: node.to_dict() for node_id, node in self.nodes.items()},
|
|
"edges": [
|
|
{"from": from_id, "to": to_id, "relationship": rel_type}
|
|
for from_id, to_id, rel_type in self.edges
|
|
],
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> CodeAssociationGraph:
|
|
"""Create CodeAssociationGraph from dictionary representation.
|
|
|
|
Args:
|
|
data: Dictionary with 'nodes' and 'edges' keys.
|
|
|
|
Returns:
|
|
New CodeAssociationGraph instance.
|
|
"""
|
|
graph = cls()
|
|
|
|
# Load nodes
|
|
for node_id, node_data in data.get("nodes", {}).items():
|
|
graph.nodes[node_id] = CodeSymbolNode.from_dict(node_data)
|
|
|
|
# Load edges
|
|
for edge_data in data.get("edges", []):
|
|
graph.edges.append((
|
|
edge_data["from"],
|
|
edge_data["to"],
|
|
edge_data["relationship"],
|
|
))
|
|
|
|
return graph
|
|
|
|
def __len__(self) -> int:
|
|
"""Return the number of nodes in the graph."""
|
|
return len(self.nodes)
|
|
|
|
|
|
@dataclass
|
|
class SearchResultCluster:
|
|
"""Clustered search result containing related code symbols.
|
|
|
|
Search results are grouped into clusters based on graph community
|
|
detection or embedding similarity. Each cluster represents a
|
|
conceptually related group of code symbols.
|
|
|
|
Attributes:
|
|
cluster_id: Unique cluster identifier.
|
|
score: Cluster relevance score (max of symbol scores).
|
|
title: Human-readable cluster title/summary.
|
|
symbols: List of CodeSymbolNode in this cluster.
|
|
metadata: Additional cluster metadata.
|
|
"""
|
|
|
|
cluster_id: str
|
|
score: float
|
|
title: str
|
|
symbols: List[CodeSymbolNode] = field(default_factory=list)
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Validate cluster fields."""
|
|
if not self.cluster_id:
|
|
raise ValueError("cluster_id cannot be empty")
|
|
if self.score < 0:
|
|
raise ValueError("score must be >= 0")
|
|
|
|
def add_symbol(self, symbol: CodeSymbolNode) -> None:
|
|
"""Add a symbol to the cluster.
|
|
|
|
Args:
|
|
symbol: CodeSymbolNode to add.
|
|
"""
|
|
self.symbols.append(symbol)
|
|
|
|
def get_top_symbols(self, n: int = 5) -> List[CodeSymbolNode]:
|
|
"""Get top N symbols by score.
|
|
|
|
Args:
|
|
n: Number of symbols to return.
|
|
|
|
Returns:
|
|
List of top N CodeSymbolNode objects sorted by score descending.
|
|
"""
|
|
sorted_symbols = sorted(self.symbols, key=lambda s: s.score, reverse=True)
|
|
return sorted_symbols[:n]
|
|
|
|
def update_score(self) -> None:
|
|
"""Update cluster score to max of symbol scores."""
|
|
if self.symbols:
|
|
self.score = max(s.score for s in self.symbols)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization.
|
|
|
|
Returns:
|
|
Dictionary representation of the cluster.
|
|
"""
|
|
return {
|
|
"cluster_id": self.cluster_id,
|
|
"score": self.score,
|
|
"title": self.title,
|
|
"symbols": [s.to_dict() for s in self.symbols],
|
|
"metadata": self.metadata,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> SearchResultCluster:
|
|
"""Create SearchResultCluster from dictionary representation.
|
|
|
|
Args:
|
|
data: Dictionary with cluster data.
|
|
|
|
Returns:
|
|
New SearchResultCluster instance.
|
|
"""
|
|
return cls(
|
|
cluster_id=data["cluster_id"],
|
|
score=data["score"],
|
|
title=data["title"],
|
|
symbols=[CodeSymbolNode.from_dict(s) for s in data.get("symbols", [])],
|
|
metadata=data.get("metadata", {}),
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
"""Return the number of symbols in the cluster."""
|
|
return len(self.symbols)
|
|
|
|
|
|
@dataclass
|
|
class CallHierarchyItem:
|
|
"""LSP CallHierarchyItem for representing callers/callees.
|
|
|
|
Attributes:
|
|
name: Symbol name (function, method, etc.).
|
|
kind: Symbol kind (function, method, etc.).
|
|
file_path: Absolute file path.
|
|
range: Position range in the file.
|
|
detail: Optional additional detail (e.g., signature).
|
|
"""
|
|
|
|
name: str
|
|
kind: str
|
|
file_path: str
|
|
range: Range
|
|
detail: Optional[str] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for JSON serialization."""
|
|
result: Dict[str, Any] = {
|
|
"name": self.name,
|
|
"kind": self.kind,
|
|
"file_path": self.file_path,
|
|
"range": self.range.to_dict(),
|
|
}
|
|
if self.detail:
|
|
result["detail"] = self.detail
|
|
return result
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
|
"""Create CallHierarchyItem from dictionary representation."""
|
|
return cls(
|
|
name=data.get("name", "unknown"),
|
|
kind=data.get("kind", "unknown"),
|
|
file_path=data.get("file_path", data.get("uri", "")),
|
|
range=Range.from_dict(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
|
detail=data.get("detail"),
|
|
)
|
|
|
|
@classmethod
|
|
def from_lsp(cls, data: Dict[str, Any]) -> "CallHierarchyItem":
|
|
"""Create CallHierarchyItem from LSP response format.
|
|
|
|
LSP uses 0-based line numbers and 'character' instead of 'char'.
|
|
"""
|
|
uri = data.get("uri", data.get("file_path", ""))
|
|
# Strip file:// prefix
|
|
file_path = uri
|
|
if file_path.startswith("file://"):
|
|
file_path = file_path[7:]
|
|
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
|
|
file_path = file_path[1:]
|
|
|
|
return cls(
|
|
name=data.get("name", "unknown"),
|
|
kind=str(data.get("kind", "unknown")),
|
|
file_path=file_path,
|
|
range=Range.from_lsp_range(data.get("range", {"start": {"line": 0, "character": 0}, "end": {"line": 0, "character": 0}})),
|
|
detail=data.get("detail"),
|
|
)
|