Refactor code structure and remove redundant changes

This commit is contained in:
catlog22
2026-01-24 14:47:47 +08:00
parent cf5fecd66d
commit f2b0a5bbc9
113 changed files with 43217 additions and 235 deletions

View File

@@ -0,0 +1,21 @@
"""Association tree module for LSP-based code relationship discovery.
This module provides components for building and processing call association trees
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from .builder import AssociationTreeBuilder
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
from .deduplicator import ResultDeduplicator
__all__ = [
"AssociationTreeBuilder",
"CallTree",
"TreeNode",
"UniqueNode",
"ResultDeduplicator",
]

View File

@@ -0,0 +1,450 @@
"""Association tree builder using LSP call hierarchy.
Builds call relationship trees by recursively expanding from seed locations
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Dict, List, Optional, Set
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
from codexlens.lsp.standalone_manager import StandaloneLspManager
from .data_structures import CallTree, TreeNode
logger = logging.getLogger(__name__)
class AssociationTreeBuilder:
"""Builds association trees from seed locations using LSP call hierarchy.
Uses depth-first recursive expansion to build a tree of code relationships
starting from seed locations (typically from vector search results).
Strategy:
- Start from seed locations (vector search results)
- For each seed, get call hierarchy items via LSP
- Recursively expand incoming calls (callers) if expand_callers=True
- Recursively expand outgoing calls (callees) if expand_callees=True
- Track visited nodes to prevent cycles
- Stop at max_depth or when no more relations found
Attributes:
lsp_manager: StandaloneLspManager for LSP communication
visited: Set of visited node IDs to prevent cycles
timeout: Timeout for individual LSP requests (seconds)
"""
def __init__(
self,
lsp_manager: StandaloneLspManager,
timeout: float = 5.0,
analysis_wait: float = 2.0,
):
"""Initialize AssociationTreeBuilder.
Args:
lsp_manager: StandaloneLspManager instance for LSP communication
timeout: Timeout for individual LSP requests in seconds
analysis_wait: Time to wait for LSP analysis on first file (seconds)
"""
self.lsp_manager = lsp_manager
self.timeout = timeout
self.analysis_wait = analysis_wait
self.visited: Set[str] = set()
self._analyzed_files: Set[str] = set() # Track files already analyzed
async def build_tree(
self,
seed_file_path: str,
seed_line: int,
seed_character: int = 1,
max_depth: int = 5,
expand_callers: bool = True,
expand_callees: bool = True,
) -> CallTree:
"""Build call tree from a single seed location.
Args:
seed_file_path: Path to the seed file
seed_line: Line number of the seed symbol (1-based)
seed_character: Character position (1-based, default 1)
max_depth: Maximum recursion depth (default 5)
expand_callers: Whether to expand incoming calls (callers)
expand_callees: Whether to expand outgoing calls (callees)
Returns:
CallTree containing all discovered nodes and relationships
"""
tree = CallTree()
self.visited.clear()
# Determine wait time - only wait for analysis on first encounter of file
wait_time = 0.0
if seed_file_path not in self._analyzed_files:
wait_time = self.analysis_wait
self._analyzed_files.add(seed_file_path)
# Get call hierarchy items for the seed position
try:
hierarchy_items = await asyncio.wait_for(
self.lsp_manager.get_call_hierarchy_items(
file_path=seed_file_path,
line=seed_line,
character=seed_character,
wait_for_analysis=wait_time,
),
timeout=self.timeout + wait_time,
)
except asyncio.TimeoutError:
logger.warning(
"Timeout getting call hierarchy items for %s:%d",
seed_file_path,
seed_line,
)
return tree
except Exception as e:
logger.error(
"Error getting call hierarchy items for %s:%d: %s",
seed_file_path,
seed_line,
e,
)
return tree
if not hierarchy_items:
logger.debug(
"No call hierarchy items found for %s:%d",
seed_file_path,
seed_line,
)
return tree
# Create root nodes from hierarchy items
for item_dict in hierarchy_items:
# Convert LSP dict to CallHierarchyItem
item = self._dict_to_call_hierarchy_item(item_dict)
if not item:
continue
root_node = TreeNode(
item=item,
depth=0,
path_from_root=[self._create_node_id(item)],
)
tree.roots.append(root_node)
tree.add_node(root_node)
# Mark as visited
self.visited.add(root_node.node_id)
# Recursively expand the tree
await self._expand_node(
node=root_node,
node_dict=item_dict,
tree=tree,
current_depth=0,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
tree.depth_reached = max_depth
return tree
async def _expand_node(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Recursively expand a node by fetching its callers and callees.
Args:
node: TreeNode to expand
node_dict: LSP CallHierarchyItem dict (for LSP requests)
tree: CallTree to add discovered nodes to
current_depth: Current recursion depth
max_depth: Maximum allowed depth
expand_callers: Whether to expand incoming calls
expand_callees: Whether to expand outgoing calls
"""
# Stop if max depth reached
if current_depth >= max_depth:
return
# Prepare tasks for parallel expansion
tasks = []
if expand_callers:
tasks.append(
self._expand_incoming_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
if expand_callees:
tasks.append(
self._expand_outgoing_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
# Execute expansions in parallel
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _expand_incoming_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand incoming calls (callers) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to continue expanding callers
expand_callees: Whether to expand callees
"""
try:
incoming_calls = await asyncio.wait_for(
self.lsp_manager.get_incoming_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting incoming calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting incoming calls for %s: %s", node.node_id, e)
return
if not incoming_calls:
return
# Process each incoming call
for call_dict in incoming_calls:
caller_dict = call_dict.get("from")
if not caller_dict:
continue
# Convert to CallHierarchyItem
caller_item = self._dict_to_call_hierarchy_item(caller_dict)
if not caller_item:
continue
caller_id = self._create_node_id(caller_item)
# Check for cycles
if caller_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [caller_id],
)
node.parents.append(cycle_node)
continue
# Create new caller node
caller_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [caller_id],
)
# Add to tree
tree.add_node(caller_node)
tree.add_edge(caller_node, node)
# Update relationships
node.parents.append(caller_node)
caller_node.children.append(node)
# Mark as visited
self.visited.add(caller_id)
# Recursively expand the caller
await self._expand_node(
node=caller_node,
node_dict=caller_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
async def _expand_outgoing_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand outgoing calls (callees) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to expand callers
expand_callees: Whether to continue expanding callees
"""
try:
outgoing_calls = await asyncio.wait_for(
self.lsp_manager.get_outgoing_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting outgoing calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting outgoing calls for %s: %s", node.node_id, e)
return
if not outgoing_calls:
return
# Process each outgoing call
for call_dict in outgoing_calls:
callee_dict = call_dict.get("to")
if not callee_dict:
continue
# Convert to CallHierarchyItem
callee_item = self._dict_to_call_hierarchy_item(callee_dict)
if not callee_item:
continue
callee_id = self._create_node_id(callee_item)
# Check for cycles
if callee_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [callee_id],
)
node.children.append(cycle_node)
continue
# Create new callee node
callee_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [callee_id],
)
# Add to tree
tree.add_node(callee_node)
tree.add_edge(node, callee_node)
# Update relationships
node.children.append(callee_node)
callee_node.parents.append(node)
# Mark as visited
self.visited.add(callee_id)
# Recursively expand the callee
await self._expand_node(
node=callee_node,
node_dict=callee_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
def _dict_to_call_hierarchy_item(
self, item_dict: Dict
) -> Optional[CallHierarchyItem]:
"""Convert LSP dict to CallHierarchyItem.
Args:
item_dict: LSP CallHierarchyItem dictionary
Returns:
CallHierarchyItem or None if conversion fails
"""
try:
# Extract URI and convert to file path
uri = item_dict.get("uri", "")
file_path = uri.replace("file:///", "").replace("file://", "")
# Handle Windows paths (file:///C:/...)
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
# Extract range
range_dict = item_dict.get("range", {})
start = range_dict.get("start", {})
end = range_dict.get("end", {})
# Create Range (convert from 0-based to 1-based)
item_range = Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
)
return CallHierarchyItem(
name=item_dict.get("name", "unknown"),
kind=str(item_dict.get("kind", "unknown")),
file_path=file_path,
range=item_range,
detail=item_dict.get("detail"),
)
except Exception as e:
logger.debug("Failed to convert dict to CallHierarchyItem: %s", e)
return None
def _create_node_id(self, item: CallHierarchyItem) -> str:
"""Create unique node ID from CallHierarchyItem.
Args:
item: CallHierarchyItem
Returns:
Unique node ID string
"""
return f"{item.file_path}:{item.name}:{item.range.start_line}"

View File

@@ -0,0 +1,191 @@
"""Data structures for association tree building.
Defines the core data classes for representing call hierarchy trees and
deduplicated results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
@dataclass
class TreeNode:
"""Node in the call association tree.
Represents a single function/method in the tree, including its position
in the hierarchy and relationships.
Attributes:
item: LSP CallHierarchyItem containing symbol information
depth: Distance from the root node (seed) - 0 for roots
children: List of child nodes (functions called by this node)
parents: List of parent nodes (functions that call this node)
is_cycle: Whether this node creates a circular reference
path_from_root: Path (list of node IDs) from root to this node
"""
item: CallHierarchyItem
depth: int = 0
children: List[TreeNode] = field(default_factory=list)
parents: List[TreeNode] = field(default_factory=list)
is_cycle: bool = False
path_from_root: List[str] = field(default_factory=list)
@property
def node_id(self) -> str:
"""Unique identifier for this node."""
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
def __hash__(self) -> int:
"""Hash based on node ID."""
return hash(self.node_id)
def __eq__(self, other: object) -> bool:
"""Equality based on node ID."""
if not isinstance(other, TreeNode):
return False
return self.node_id == other.node_id
def __repr__(self) -> str:
"""String representation of the node."""
cycle_marker = " [CYCLE]" if self.is_cycle else ""
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
@dataclass
class CallTree:
"""Complete call tree structure built from seeds.
Contains all nodes discovered through recursive expansion and
the relationships between them.
Attributes:
roots: List of root nodes (seed symbols)
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
node_list: Flat list of all nodes in tree order
edges: List of (from_node_id, to_node_id) tuples representing calls
depth_reached: Maximum depth achieved in expansion
"""
roots: List[TreeNode] = field(default_factory=list)
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
node_list: List[TreeNode] = field(default_factory=list)
edges: List[tuple[str, str]] = field(default_factory=list)
depth_reached: int = 0
def add_node(self, node: TreeNode) -> None:
"""Add a node to the tree.
Args:
node: TreeNode to add
"""
if node.node_id not in self.all_nodes:
self.all_nodes[node.node_id] = node
self.node_list.append(node)
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
"""Add an edge between two nodes.
Args:
from_node: Source node
to_node: Target node
"""
edge = (from_node.node_id, to_node.node_id)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[TreeNode]:
"""Get a node by ID.
Args:
node_id: Node identifier
Returns:
TreeNode if found, None otherwise
"""
return self.all_nodes.get(node_id)
def __len__(self) -> int:
"""Return total number of nodes in tree."""
return len(self.all_nodes)
def __repr__(self) -> str:
"""String representation of the tree."""
return (
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
f"depth={self.depth_reached})"
)
@dataclass
class UniqueNode:
"""Deduplicated unique code symbol from the tree.
Represents a single unique code location that may appear multiple times
in the tree under different contexts. Contains aggregated information
about all occurrences.
Attributes:
file_path: Absolute path to the file
name: Symbol name (function, method, class, etc.)
kind: Symbol kind (function, method, class, etc.)
range: Code range in the file
min_depth: Minimum depth at which this node appears in the tree
occurrences: Number of times this node appears in the tree
paths: List of paths from roots to this node
context_nodes: Related nodes from the tree
score: Composite relevance score (higher is better)
"""
file_path: str
name: str
kind: str
range: Range
min_depth: int = 0
occurrences: int = 1
paths: List[List[str]] = field(default_factory=list)
context_nodes: List[str] = field(default_factory=list)
score: float = 0.0
@property
def node_key(self) -> tuple[str, int, int]:
"""Unique key for deduplication.
Uses (file_path, start_line, end_line) as the unique identifier
for this symbol across all occurrences.
"""
return (
self.file_path,
self.range.start_line,
self.range.end_line,
)
def add_path(self, path: List[str]) -> None:
"""Add a path from root to this node.
Args:
path: List of node IDs from root to this node
"""
if path not in self.paths:
self.paths.append(path)
def __hash__(self) -> int:
"""Hash based on node key."""
return hash(self.node_key)
def __eq__(self, other: object) -> bool:
"""Equality based on node key."""
if not isinstance(other, UniqueNode):
return False
return self.node_key == other.node_key
def __repr__(self) -> str:
"""String representation of the unique node."""
return (
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
)

View File

@@ -0,0 +1,301 @@
"""Result deduplication for association tree nodes.
Provides functionality to extract unique nodes from a call tree and assign
relevance scores based on various factors.
"""
from __future__ import annotations
import logging
from typing import Dict, List, Optional
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
logger = logging.getLogger(__name__)
# Symbol kind weights for scoring (higher = more relevant)
KIND_WEIGHTS: Dict[str, float] = {
# Functions and methods are primary targets
"function": 1.0,
"method": 1.0,
"12": 1.0, # LSP SymbolKind.Function
"6": 1.0, # LSP SymbolKind.Method
# Classes are important but secondary
"class": 0.8,
"5": 0.8, # LSP SymbolKind.Class
# Interfaces and types
"interface": 0.7,
"11": 0.7, # LSP SymbolKind.Interface
"type": 0.6,
# Constructors
"constructor": 0.9,
"9": 0.9, # LSP SymbolKind.Constructor
# Variables and constants
"variable": 0.4,
"13": 0.4, # LSP SymbolKind.Variable
"constant": 0.5,
"14": 0.5, # LSP SymbolKind.Constant
# Default for unknown kinds
"unknown": 0.3,
}
class ResultDeduplicator:
"""Extracts and scores unique nodes from call trees.
Processes a CallTree to extract unique code locations, merging duplicates
and assigning relevance scores based on:
- Depth: Shallower nodes (closer to seeds) score higher
- Frequency: Nodes appearing multiple times score higher
- Kind: Function/method > class > variable
Attributes:
depth_weight: Weight for depth factor in scoring (default 0.4)
frequency_weight: Weight for frequency factor (default 0.3)
kind_weight: Weight for symbol kind factor (default 0.3)
max_depth_penalty: Maximum depth before full penalty applied
"""
def __init__(
self,
depth_weight: float = 0.4,
frequency_weight: float = 0.3,
kind_weight: float = 0.3,
max_depth_penalty: int = 10,
):
"""Initialize ResultDeduplicator.
Args:
depth_weight: Weight for depth factor (0.0-1.0)
frequency_weight: Weight for frequency factor (0.0-1.0)
kind_weight: Weight for symbol kind factor (0.0-1.0)
max_depth_penalty: Depth at which score becomes 0 for depth factor
"""
self.depth_weight = depth_weight
self.frequency_weight = frequency_weight
self.kind_weight = kind_weight
self.max_depth_penalty = max_depth_penalty
def deduplicate(
self,
tree: CallTree,
max_results: Optional[int] = None,
) -> List[UniqueNode]:
"""Extract unique nodes from the call tree.
Traverses the tree, groups nodes by their unique key (file_path,
start_line, end_line), and merges duplicate occurrences.
Args:
tree: CallTree to process
max_results: Maximum number of results to return (None = all)
Returns:
List of UniqueNode objects, sorted by score descending
"""
if not tree.node_list:
return []
# Group nodes by unique key
unique_map: Dict[tuple, UniqueNode] = {}
for node in tree.node_list:
if node.is_cycle:
# Skip cycle markers - they point to already-counted nodes
continue
key = self._get_node_key(node)
if key in unique_map:
# Update existing unique node
unique_node = unique_map[key]
unique_node.occurrences += 1
unique_node.min_depth = min(unique_node.min_depth, node.depth)
unique_node.add_path(node.path_from_root)
# Collect context from relationships
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
else:
# Create new unique node
unique_node = UniqueNode(
file_path=node.item.file_path,
name=node.item.name,
kind=node.item.kind,
range=node.item.range,
min_depth=node.depth,
occurrences=1,
paths=[node.path_from_root.copy()],
context_nodes=[],
score=0.0,
)
# Collect initial context
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
unique_map[key] = unique_node
# Calculate scores for all unique nodes
unique_nodes = list(unique_map.values())
# Find max frequency for normalization
max_frequency = max((n.occurrences for n in unique_nodes), default=1)
for node in unique_nodes:
node.score = self._score_node(node, max_frequency)
# Sort by score descending
unique_nodes.sort(key=lambda n: n.score, reverse=True)
# Apply max_results limit
if max_results is not None and max_results > 0:
unique_nodes = unique_nodes[:max_results]
logger.debug(
"Deduplicated %d tree nodes to %d unique nodes",
len(tree.node_list),
len(unique_nodes),
)
return unique_nodes
def _score_node(
self,
node: UniqueNode,
max_frequency: int,
) -> float:
"""Calculate composite score for a unique node.
Score = depth_weight * depth_score +
frequency_weight * frequency_score +
kind_weight * kind_score
Args:
node: UniqueNode to score
max_frequency: Maximum occurrence count for normalization
Returns:
Composite score between 0.0 and 1.0
"""
# Depth score: closer to root = higher score
# Score of 1.0 at depth 0, decreasing to 0.0 at max_depth_penalty
depth_score = max(
0.0,
1.0 - (node.min_depth / self.max_depth_penalty),
)
# Frequency score: more occurrences = higher score
frequency_score = node.occurrences / max_frequency if max_frequency > 0 else 0.0
# Kind score: function/method > class > variable
kind_str = str(node.kind).lower()
kind_score = KIND_WEIGHTS.get(kind_str, KIND_WEIGHTS["unknown"])
# Composite score
score = (
self.depth_weight * depth_score
+ self.frequency_weight * frequency_score
+ self.kind_weight * kind_score
)
return score
def _get_node_key(self, node: TreeNode) -> tuple:
"""Get unique key for a tree node.
Uses (file_path, start_line, end_line) as the unique identifier.
Args:
node: TreeNode
Returns:
Tuple key for deduplication
"""
return (
node.item.file_path,
node.item.range.start_line,
node.item.range.end_line,
)
def filter_by_kind(
self,
nodes: List[UniqueNode],
kinds: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by symbol kind.
Args:
nodes: List of UniqueNode to filter
kinds: List of allowed kinds (e.g., ["function", "method"])
Returns:
Filtered list of UniqueNode
"""
kinds_lower = [k.lower() for k in kinds]
return [
node
for node in nodes
if str(node.kind).lower() in kinds_lower
]
def filter_by_file(
self,
nodes: List[UniqueNode],
file_patterns: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by file path patterns.
Args:
nodes: List of UniqueNode to filter
file_patterns: List of path substrings to match
Returns:
Filtered list of UniqueNode
"""
return [
node
for node in nodes
if any(pattern in node.file_path for pattern in file_patterns)
]
def to_dict_list(self, nodes: List[UniqueNode]) -> List[Dict]:
"""Convert list of UniqueNode to JSON-serializable dicts.
Args:
nodes: List of UniqueNode
Returns:
List of dictionaries
"""
return [
{
"file_path": node.file_path,
"name": node.name,
"kind": node.kind,
"range": {
"start_line": node.range.start_line,
"start_character": node.range.start_character,
"end_line": node.range.end_line,
"end_character": node.range.end_character,
},
"min_depth": node.min_depth,
"occurrences": node.occurrences,
"path_count": len(node.paths),
"score": round(node.score, 4),
}
for node in nodes
]