Files
Claude-Code-Workflow/codex-lens/src/codexlens/search/association_tree/data_structures.py
catlog22 261c98549d feat: Implement association tree for LSP-based code relationship discovery
- Add `association_tree` module with components for building and processing call association trees using LSP call hierarchy capabilities.
- Introduce `AssociationTreeBuilder` for constructing call trees from seed locations with depth-first expansion.
- Create data structures: `TreeNode`, `CallTree`, and `UniqueNode` for representing nodes and relationships in the call tree.
- Implement `ResultDeduplicator` to extract unique nodes from call trees and assign relevance scores based on depth, frequency, and kind.
- Add unit tests for `AssociationTreeBuilder` and `ResultDeduplicator` to ensure functionality and correctness.
2026-01-20 22:09:04 +08:00

192 lines
6.1 KiB
Python

"""Data structures for association tree building.
Defines the core data classes for representing call hierarchy trees and
deduplicated results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
@dataclass
class TreeNode:
"""Node in the call association tree.
Represents a single function/method in the tree, including its position
in the hierarchy and relationships.
Attributes:
item: LSP CallHierarchyItem containing symbol information
depth: Distance from the root node (seed) - 0 for roots
children: List of child nodes (functions called by this node)
parents: List of parent nodes (functions that call this node)
is_cycle: Whether this node creates a circular reference
path_from_root: Path (list of node IDs) from root to this node
"""
item: CallHierarchyItem
depth: int = 0
children: List[TreeNode] = field(default_factory=list)
parents: List[TreeNode] = field(default_factory=list)
is_cycle: bool = False
path_from_root: List[str] = field(default_factory=list)
@property
def node_id(self) -> str:
"""Unique identifier for this node."""
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
def __hash__(self) -> int:
"""Hash based on node ID."""
return hash(self.node_id)
def __eq__(self, other: object) -> bool:
"""Equality based on node ID."""
if not isinstance(other, TreeNode):
return False
return self.node_id == other.node_id
def __repr__(self) -> str:
"""String representation of the node."""
cycle_marker = " [CYCLE]" if self.is_cycle else ""
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
@dataclass
class CallTree:
"""Complete call tree structure built from seeds.
Contains all nodes discovered through recursive expansion and
the relationships between them.
Attributes:
roots: List of root nodes (seed symbols)
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
node_list: Flat list of all nodes in tree order
edges: List of (from_node_id, to_node_id) tuples representing calls
depth_reached: Maximum depth achieved in expansion
"""
roots: List[TreeNode] = field(default_factory=list)
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
node_list: List[TreeNode] = field(default_factory=list)
edges: List[tuple[str, str]] = field(default_factory=list)
depth_reached: int = 0
def add_node(self, node: TreeNode) -> None:
"""Add a node to the tree.
Args:
node: TreeNode to add
"""
if node.node_id not in self.all_nodes:
self.all_nodes[node.node_id] = node
self.node_list.append(node)
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
"""Add an edge between two nodes.
Args:
from_node: Source node
to_node: Target node
"""
edge = (from_node.node_id, to_node.node_id)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[TreeNode]:
"""Get a node by ID.
Args:
node_id: Node identifier
Returns:
TreeNode if found, None otherwise
"""
return self.all_nodes.get(node_id)
def __len__(self) -> int:
"""Return total number of nodes in tree."""
return len(self.all_nodes)
def __repr__(self) -> str:
"""String representation of the tree."""
return (
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
f"depth={self.depth_reached})"
)
@dataclass
class UniqueNode:
"""Deduplicated unique code symbol from the tree.
Represents a single unique code location that may appear multiple times
in the tree under different contexts. Contains aggregated information
about all occurrences.
Attributes:
file_path: Absolute path to the file
name: Symbol name (function, method, class, etc.)
kind: Symbol kind (function, method, class, etc.)
range: Code range in the file
min_depth: Minimum depth at which this node appears in the tree
occurrences: Number of times this node appears in the tree
paths: List of paths from roots to this node
context_nodes: Related nodes from the tree
score: Composite relevance score (higher is better)
"""
file_path: str
name: str
kind: str
range: Range
min_depth: int = 0
occurrences: int = 1
paths: List[List[str]] = field(default_factory=list)
context_nodes: List[str] = field(default_factory=list)
score: float = 0.0
@property
def node_key(self) -> tuple[str, int, int]:
"""Unique key for deduplication.
Uses (file_path, start_line, end_line) as the unique identifier
for this symbol across all occurrences.
"""
return (
self.file_path,
self.range.start_line,
self.range.end_line,
)
def add_path(self, path: List[str]) -> None:
"""Add a path from root to this node.
Args:
path: List of node IDs from root to this node
"""
if path not in self.paths:
self.paths.append(path)
def __hash__(self) -> int:
"""Hash based on node key."""
return hash(self.node_key)
def __eq__(self, other: object) -> bool:
"""Equality based on node key."""
if not isinstance(other, UniqueNode):
return False
return self.node_key == other.node_key
def __repr__(self) -> str:
"""String representation of the unique node."""
return (
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
)