feat: Implement association tree for LSP-based code relationship discovery

- Add `association_tree` module with components for building and processing call association trees using LSP call hierarchy capabilities.
- Introduce `AssociationTreeBuilder` for constructing call trees from seed locations with depth-first expansion.
- Create data structures: `TreeNode`, `CallTree`, and `UniqueNode` for representing nodes and relationships in the call tree.
- Implement `ResultDeduplicator` to extract unique nodes from call trees and assign relevance scores based on depth, frequency, and kind.
- Add unit tests for `AssociationTreeBuilder` and `ResultDeduplicator` to ensure functionality and correctness.
This commit is contained in:
catlog22
2026-01-20 22:09:04 +08:00
parent b85d9b9eb1
commit 261c98549d
21 changed files with 2826 additions and 94 deletions

View File

@@ -1053,7 +1053,38 @@ class StandaloneLspManager:
return []
return result
async def get_outgoing_calls(
self,
item: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Get outgoing calls for a call hierarchy item.
Args:
item: CallHierarchyItem from get_call_hierarchy_items
Returns:
List of CallHierarchyOutgoingCall dicts
"""
# Determine language from item's uri
uri = item.get("uri", "")
file_path = uri.replace("file:///", "").replace("file://", "")
state = await self._get_server(file_path)
if not state:
return []
result = await self._send_request(
state,
"callHierarchy/outgoingCalls",
{"item": item},
)
if not result or not isinstance(result, list):
return []
return result
async def __aenter__(self) -> "StandaloneLspManager":
"""Async context manager entry."""
await self.start()

View File

@@ -0,0 +1,257 @@
# Association Tree Quick Start
## Installation
No additional dependencies needed - uses existing CodexLens LSP infrastructure.
## Basic Usage
### 1. Import Components
```python
from codexlens.lsp.standalone_manager import StandaloneLspManager
from codexlens.search.association_tree import (
AssociationTreeBuilder,
ResultDeduplicator,
)
```
### 2. Build a Tree
```python
import asyncio
async def build_tree_example():
# Initialize LSP manager
async with StandaloneLspManager(workspace_root="/path/to/project") as lsp:
# Create builder
builder = AssociationTreeBuilder(lsp, timeout=5.0)
# Build tree from seed location
tree = await builder.build_tree(
seed_file_path="src/main.py",
seed_line=42, # 1-based line number
seed_character=1, # 1-based character position
max_depth=5, # Maximum recursion depth
expand_callers=True, # Find who calls this
expand_callees=True, # Find what this calls
)
return tree
tree = asyncio.run(build_tree_example())
print(f"Found {len(tree.all_nodes)} unique nodes")
```
### 3. Deduplicate and Score
```python
# Create deduplicator
deduplicator = ResultDeduplicator(
depth_weight=0.4, # Weight for depth score (0-1)
frequency_weight=0.3, # Weight for frequency score (0-1)
kind_weight=0.3, # Weight for symbol kind score (0-1)
)
# Extract unique nodes
unique_nodes = deduplicator.deduplicate(tree, max_results=20)
# Print results
for node in unique_nodes:
print(f"{node.name} @ {node.file_path}:{node.range.start_line}")
print(f" Score: {node.score:.2f}, Depth: {node.min_depth}, Occurs: {node.occurrences}")
```
### 4. Filter Results
```python
# Filter by symbol kind
functions = deduplicator.filter_by_kind(unique_nodes, ["function", "method"])
# Filter by file pattern
core_modules = deduplicator.filter_by_file(unique_nodes, ["src/core/"])
# Convert to JSON
json_data = deduplicator.to_dict_list(unique_nodes)
```
## Common Patterns
### Pattern 1: Find All Callers
```python
tree = await builder.build_tree(
seed_file_path=target_file,
seed_line=target_line,
max_depth=3,
expand_callers=True, # Only expand callers
expand_callees=False, # Don't expand callees
)
```
### Pattern 2: Find Call Chain
```python
tree = await builder.build_tree(
seed_file_path=entry_point,
seed_line=main_line,
max_depth=10,
expand_callers=False, # Don't expand callers
expand_callees=True, # Only expand callees (call chain)
)
```
### Pattern 3: Full Relationship Map
```python
tree = await builder.build_tree(
seed_file_path=target_file,
seed_line=target_line,
max_depth=5,
expand_callers=True, # Expand both directions
expand_callees=True,
)
```
## Configuration Tips
### Max Depth Guidelines
- **Depth 1-2**: Direct callers/callees only (fast, focused)
- **Depth 3-5**: Good balance of coverage and performance (recommended)
- **Depth 6-10**: Deep exploration (slower, may hit cycles)
### Timeout Settings
```python
builder = AssociationTreeBuilder(
lsp,
timeout=5.0, # 5 seconds per LSP request
)
# For slower language servers
builder = AssociationTreeBuilder(lsp, timeout=10.0)
```
### Score Weight Tuning
```python
# Emphasize proximity to seed
deduplicator = ResultDeduplicator(
depth_weight=0.7, # High weight for depth
frequency_weight=0.2,
kind_weight=0.1,
)
# Emphasize frequently-called functions
deduplicator = ResultDeduplicator(
depth_weight=0.2,
frequency_weight=0.7, # High weight for frequency
kind_weight=0.1,
)
```
## Error Handling
```python
try:
tree = await builder.build_tree(...)
if not tree.all_nodes:
print("No call hierarchy found - LSP may not support this file type")
except asyncio.TimeoutError:
print("LSP request timed out - try increasing timeout")
except Exception as e:
print(f"Error building tree: {e}")
```
## Performance Optimization
### 1. Limit Depth
```python
# Fast: max_depth=3
tree = await builder.build_tree(..., max_depth=3)
```
### 2. Filter Early
```python
# Get all nodes
unique_nodes = deduplicator.deduplicate(tree)
# Filter to relevant kinds immediately
functions = deduplicator.filter_by_kind(unique_nodes, ["function", "method"])
```
### 3. Use Timeouts
```python
# Set aggressive timeouts for fast iteration
builder = AssociationTreeBuilder(lsp, timeout=3.0)
```
## Common Issues
### Issue: Empty Tree Returned
**Causes**:
- File not supported by LSP server
- No call hierarchy at that position
- Position is not on a function/method
**Solutions**:
- Verify LSP server supports the language
- Check that position is on a function definition
- Try different seed locations
### Issue: Timeout Errors
**Causes**:
- LSP server slow or overloaded
- Network/connection issues
- Max depth too high
**Solutions**:
- Increase timeout value
- Reduce max_depth
- Check LSP server health
### Issue: Cycle Detected
**Behavior**: Cycles are automatically detected and marked
**Example**:
```python
for node in tree.node_list:
if node.is_cycle:
print(f"Cycle detected at {node.item.name}")
```
## Testing
Run the test suite:
```bash
# All tests
pytest tests/test_association_tree.py -v
# Specific test
pytest tests/test_association_tree.py::test_simple_tree_building -v
```
## Demo Script
Run the demo:
```bash
python examples/association_tree_demo.py
```
## Further Reading
- [Full Documentation](README.md)
- [Implementation Summary](../../ASSOCIATION_TREE_IMPLEMENTATION.md)
- [LSP Manager Documentation](../../lsp/standalone_manager.py)

View File

@@ -0,0 +1,188 @@
# Association Tree Module
LSP-based code relationship discovery using call hierarchy.
## Overview
This module provides components for building and analyzing call relationship trees using Language Server Protocol (LSP) call hierarchy capabilities. It consists of three main components:
1. **Data Structures** (`data_structures.py`) - Core data classes
2. **Association Tree Builder** (`builder.py`) - Tree construction via LSP
3. **Result Deduplicator** (`deduplicator.py`) - Node extraction and scoring
## Components
### 1. Data Structures
**TreeNode**: Represents a single node in the call tree.
- Contains LSP CallHierarchyItem
- Tracks depth, parents, children
- Detects and marks cycles
**CallTree**: Complete tree structure with roots and edges.
- Stores all discovered nodes
- Tracks edges (call relationships)
- Provides lookup by node_id
**UniqueNode**: Deduplicated code symbol with metadata.
- Aggregates multiple occurrences
- Tracks minimum depth
- Contains relevance score
### 2. AssociationTreeBuilder
Builds call trees using LSP call hierarchy:
**Strategy**:
- Depth-first recursive expansion
- Supports expanding callers (incoming calls) and callees (outgoing calls)
- Detects and marks circular references
- Respects max_depth limit
**Key Features**:
- Async/await for concurrent LSP requests
- Timeout handling (5s per node)
- Graceful error handling
- Cycle detection via visited set
### 3. ResultDeduplicator
Extracts unique nodes from trees and assigns scores:
**Scoring Factors**:
- **Depth** (40%): Shallower = more relevant
- **Frequency** (30%): More occurrences = more important
- **Kind** (30%): function/method > class > variable
**Features**:
- Merges duplicate nodes by (file_path, start_line, end_line)
- Tracks all paths to each node
- Supports filtering by kind or file pattern
- Configurable score weights
## Usage Example
```python
import asyncio
from codexlens.lsp.standalone_manager import StandaloneLspManager
from codexlens.search.association_tree import (
AssociationTreeBuilder,
ResultDeduplicator,
)
async def main():
# Initialize LSP manager
async with StandaloneLspManager(workspace_root="/path/to/project") as lsp:
# Create tree builder
builder = AssociationTreeBuilder(lsp, timeout=5.0)
# Build tree from seed location
tree = await builder.build_tree(
seed_file_path="src/main.py",
seed_line=42,
seed_character=1,
max_depth=5,
expand_callers=True, # Find who calls this
expand_callees=True, # Find what this calls
)
print(f"Tree: {tree}")
print(f" Roots: {len(tree.roots)}")
print(f" Total nodes: {len(tree.all_nodes)}")
print(f" Edges: {len(tree.edges)}")
# Deduplicate and score
deduplicator = ResultDeduplicator(
depth_weight=0.4,
frequency_weight=0.3,
kind_weight=0.3,
)
unique_nodes = deduplicator.deduplicate(tree, max_results=20)
print(f"\nTop unique nodes:")
for node in unique_nodes[:10]:
print(f" {node.name} ({node.file_path}:{node.range.start_line})")
print(f" Depth: {node.min_depth}, Occurrences: {node.occurrences}, Score: {node.score:.2f}")
# Filter by kind
functions_only = deduplicator.filter_by_kind(unique_nodes, ["function", "method"])
print(f"\nFunctions/methods: {len(functions_only)}")
asyncio.run(main())
```
## Integration with Hybrid Search
The association tree can be integrated with the hybrid search engine:
```python
from codexlens.search.hybrid_search import HybridSearchEngine
async def search_with_association_tree(query: str):
# 1. Get seed results from vector search
search_engine = HybridSearchEngine()
seed_results = await search_engine.search(query, limit=5)
# 2. Build association trees from top results
builder = AssociationTreeBuilder(lsp_manager)
trees = []
for result in seed_results:
tree = await builder.build_tree(
seed_file_path=result.file_path,
seed_line=result.line,
max_depth=3,
)
trees.append(tree)
# 3. Merge and deduplicate
merged_tree = merge_trees(trees) # Custom merge logic
deduplicator = ResultDeduplicator()
unique_nodes = deduplicator.deduplicate(merged_tree, max_results=50)
# 4. Convert to search results
final_results = convert_to_search_results(unique_nodes)
return final_results
```
## Testing
Run the test suite:
```bash
pytest tests/test_association_tree.py -v
```
Test coverage includes:
- Simple tree building
- Cycle detection
- Max depth limits
- Empty trees
- Deduplication logic
- Scoring algorithms
- Filtering operations
## Performance Considerations
1. **LSP Timeouts**: Set appropriate timeout values (default 5s)
2. **Max Depth**: Limit depth to avoid exponential expansion (recommended: 3-5)
3. **Caching**: LSP manager caches open documents
4. **Parallel Expansion**: Incoming/outgoing calls fetched in parallel
## Error Handling
The builder gracefully handles:
- LSP timeout errors (logs warning, continues)
- Missing call hierarchy support (returns empty)
- Network/connection failures (skips node)
- Invalid LSP responses (logs error, skips)
## Future Enhancements
- [ ] Multi-root tree building from multiple seeds
- [ ] Custom scoring functions
- [ ] Graph visualization export
- [ ] Incremental tree updates
- [ ] Cross-file relationship analysis

View File

@@ -0,0 +1,21 @@
"""Association tree module for LSP-based code relationship discovery.
This module provides components for building and processing call association trees
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from .builder import AssociationTreeBuilder
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
from .deduplicator import ResultDeduplicator
__all__ = [
"AssociationTreeBuilder",
"CallTree",
"TreeNode",
"UniqueNode",
"ResultDeduplicator",
]

View File

@@ -0,0 +1,439 @@
"""Association tree builder using LSP call hierarchy.
Builds call relationship trees by recursively expanding from seed locations
using Language Server Protocol (LSP) call hierarchy capabilities.
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Dict, List, Optional, Set
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
from codexlens.lsp.standalone_manager import StandaloneLspManager
from .data_structures import CallTree, TreeNode
logger = logging.getLogger(__name__)
class AssociationTreeBuilder:
"""Builds association trees from seed locations using LSP call hierarchy.
Uses depth-first recursive expansion to build a tree of code relationships
starting from seed locations (typically from vector search results).
Strategy:
- Start from seed locations (vector search results)
- For each seed, get call hierarchy items via LSP
- Recursively expand incoming calls (callers) if expand_callers=True
- Recursively expand outgoing calls (callees) if expand_callees=True
- Track visited nodes to prevent cycles
- Stop at max_depth or when no more relations found
Attributes:
lsp_manager: StandaloneLspManager for LSP communication
visited: Set of visited node IDs to prevent cycles
timeout: Timeout for individual LSP requests (seconds)
"""
def __init__(
self,
lsp_manager: StandaloneLspManager,
timeout: float = 5.0,
):
"""Initialize AssociationTreeBuilder.
Args:
lsp_manager: StandaloneLspManager instance for LSP communication
timeout: Timeout for individual LSP requests in seconds
"""
self.lsp_manager = lsp_manager
self.timeout = timeout
self.visited: Set[str] = set()
async def build_tree(
self,
seed_file_path: str,
seed_line: int,
seed_character: int = 1,
max_depth: int = 5,
expand_callers: bool = True,
expand_callees: bool = True,
) -> CallTree:
"""Build call tree from a single seed location.
Args:
seed_file_path: Path to the seed file
seed_line: Line number of the seed symbol (1-based)
seed_character: Character position (1-based, default 1)
max_depth: Maximum recursion depth (default 5)
expand_callers: Whether to expand incoming calls (callers)
expand_callees: Whether to expand outgoing calls (callees)
Returns:
CallTree containing all discovered nodes and relationships
"""
tree = CallTree()
self.visited.clear()
# Get call hierarchy items for the seed position
try:
hierarchy_items = await asyncio.wait_for(
self.lsp_manager.get_call_hierarchy_items(
file_path=seed_file_path,
line=seed_line,
character=seed_character,
),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.warning(
"Timeout getting call hierarchy items for %s:%d",
seed_file_path,
seed_line,
)
return tree
except Exception as e:
logger.error(
"Error getting call hierarchy items for %s:%d: %s",
seed_file_path,
seed_line,
e,
)
return tree
if not hierarchy_items:
logger.debug(
"No call hierarchy items found for %s:%d",
seed_file_path,
seed_line,
)
return tree
# Create root nodes from hierarchy items
for item_dict in hierarchy_items:
# Convert LSP dict to CallHierarchyItem
item = self._dict_to_call_hierarchy_item(item_dict)
if not item:
continue
root_node = TreeNode(
item=item,
depth=0,
path_from_root=[self._create_node_id(item)],
)
tree.roots.append(root_node)
tree.add_node(root_node)
# Mark as visited
self.visited.add(root_node.node_id)
# Recursively expand the tree
await self._expand_node(
node=root_node,
node_dict=item_dict,
tree=tree,
current_depth=0,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
tree.depth_reached = max_depth
return tree
async def _expand_node(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Recursively expand a node by fetching its callers and callees.
Args:
node: TreeNode to expand
node_dict: LSP CallHierarchyItem dict (for LSP requests)
tree: CallTree to add discovered nodes to
current_depth: Current recursion depth
max_depth: Maximum allowed depth
expand_callers: Whether to expand incoming calls
expand_callees: Whether to expand outgoing calls
"""
# Stop if max depth reached
if current_depth >= max_depth:
return
# Prepare tasks for parallel expansion
tasks = []
if expand_callers:
tasks.append(
self._expand_incoming_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
if expand_callees:
tasks.append(
self._expand_outgoing_calls(
node=node,
node_dict=node_dict,
tree=tree,
current_depth=current_depth,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
)
# Execute expansions in parallel
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _expand_incoming_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand incoming calls (callers) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to continue expanding callers
expand_callees: Whether to expand callees
"""
try:
incoming_calls = await asyncio.wait_for(
self.lsp_manager.get_incoming_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting incoming calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting incoming calls for %s: %s", node.node_id, e)
return
if not incoming_calls:
return
# Process each incoming call
for call_dict in incoming_calls:
caller_dict = call_dict.get("from")
if not caller_dict:
continue
# Convert to CallHierarchyItem
caller_item = self._dict_to_call_hierarchy_item(caller_dict)
if not caller_item:
continue
caller_id = self._create_node_id(caller_item)
# Check for cycles
if caller_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [caller_id],
)
node.parents.append(cycle_node)
continue
# Create new caller node
caller_node = TreeNode(
item=caller_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [caller_id],
)
# Add to tree
tree.add_node(caller_node)
tree.add_edge(caller_node, node)
# Update relationships
node.parents.append(caller_node)
caller_node.children.append(node)
# Mark as visited
self.visited.add(caller_id)
# Recursively expand the caller
await self._expand_node(
node=caller_node,
node_dict=caller_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
async def _expand_outgoing_calls(
self,
node: TreeNode,
node_dict: Dict,
tree: CallTree,
current_depth: int,
max_depth: int,
expand_callers: bool,
expand_callees: bool,
) -> None:
"""Expand outgoing calls (callees) for a node.
Args:
node: TreeNode being expanded
node_dict: LSP dict for the node
tree: CallTree to add nodes to
current_depth: Current depth
max_depth: Maximum depth
expand_callers: Whether to expand callers
expand_callees: Whether to continue expanding callees
"""
try:
outgoing_calls = await asyncio.wait_for(
self.lsp_manager.get_outgoing_calls(item=node_dict),
timeout=self.timeout,
)
except asyncio.TimeoutError:
logger.debug("Timeout getting outgoing calls for %s", node.node_id)
return
except Exception as e:
logger.debug("Error getting outgoing calls for %s: %s", node.node_id, e)
return
if not outgoing_calls:
return
# Process each outgoing call
for call_dict in outgoing_calls:
callee_dict = call_dict.get("to")
if not callee_dict:
continue
# Convert to CallHierarchyItem
callee_item = self._dict_to_call_hierarchy_item(callee_dict)
if not callee_item:
continue
callee_id = self._create_node_id(callee_item)
# Check for cycles
if callee_id in self.visited:
# Create cycle marker node
cycle_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
is_cycle=True,
path_from_root=node.path_from_root + [callee_id],
)
node.children.append(cycle_node)
continue
# Create new callee node
callee_node = TreeNode(
item=callee_item,
depth=current_depth + 1,
path_from_root=node.path_from_root + [callee_id],
)
# Add to tree
tree.add_node(callee_node)
tree.add_edge(node, callee_node)
# Update relationships
node.children.append(callee_node)
callee_node.parents.append(node)
# Mark as visited
self.visited.add(callee_id)
# Recursively expand the callee
await self._expand_node(
node=callee_node,
node_dict=callee_dict,
tree=tree,
current_depth=current_depth + 1,
max_depth=max_depth,
expand_callers=expand_callers,
expand_callees=expand_callees,
)
def _dict_to_call_hierarchy_item(
self, item_dict: Dict
) -> Optional[CallHierarchyItem]:
"""Convert LSP dict to CallHierarchyItem.
Args:
item_dict: LSP CallHierarchyItem dictionary
Returns:
CallHierarchyItem or None if conversion fails
"""
try:
# Extract URI and convert to file path
uri = item_dict.get("uri", "")
file_path = uri.replace("file:///", "").replace("file://", "")
# Handle Windows paths (file:///C:/...)
if len(file_path) > 2 and file_path[0] == "/" and file_path[2] == ":":
file_path = file_path[1:]
# Extract range
range_dict = item_dict.get("range", {})
start = range_dict.get("start", {})
end = range_dict.get("end", {})
# Create Range (convert from 0-based to 1-based)
item_range = Range(
start_line=start.get("line", 0) + 1,
start_character=start.get("character", 0) + 1,
end_line=end.get("line", 0) + 1,
end_character=end.get("character", 0) + 1,
)
return CallHierarchyItem(
name=item_dict.get("name", "unknown"),
kind=str(item_dict.get("kind", "unknown")),
file_path=file_path,
range=item_range,
detail=item_dict.get("detail"),
)
except Exception as e:
logger.debug("Failed to convert dict to CallHierarchyItem: %s", e)
return None
def _create_node_id(self, item: CallHierarchyItem) -> str:
"""Create unique node ID from CallHierarchyItem.
Args:
item: CallHierarchyItem
Returns:
Unique node ID string
"""
return f"{item.file_path}:{item.name}:{item.range.start_line}"

View File

@@ -0,0 +1,191 @@
"""Data structures for association tree building.
Defines the core data classes for representing call hierarchy trees and
deduplicated results.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
@dataclass
class TreeNode:
"""Node in the call association tree.
Represents a single function/method in the tree, including its position
in the hierarchy and relationships.
Attributes:
item: LSP CallHierarchyItem containing symbol information
depth: Distance from the root node (seed) - 0 for roots
children: List of child nodes (functions called by this node)
parents: List of parent nodes (functions that call this node)
is_cycle: Whether this node creates a circular reference
path_from_root: Path (list of node IDs) from root to this node
"""
item: CallHierarchyItem
depth: int = 0
children: List[TreeNode] = field(default_factory=list)
parents: List[TreeNode] = field(default_factory=list)
is_cycle: bool = False
path_from_root: List[str] = field(default_factory=list)
@property
def node_id(self) -> str:
"""Unique identifier for this node."""
return f"{self.item.file_path}:{self.item.name}:{self.item.range.start_line}"
def __hash__(self) -> int:
"""Hash based on node ID."""
return hash(self.node_id)
def __eq__(self, other: object) -> bool:
"""Equality based on node ID."""
if not isinstance(other, TreeNode):
return False
return self.node_id == other.node_id
def __repr__(self) -> str:
"""String representation of the node."""
cycle_marker = " [CYCLE]" if self.is_cycle else ""
return f"TreeNode({self.item.name}@{self.item.file_path}:{self.item.range.start_line}){cycle_marker}"
@dataclass
class CallTree:
"""Complete call tree structure built from seeds.
Contains all nodes discovered through recursive expansion and
the relationships between them.
Attributes:
roots: List of root nodes (seed symbols)
all_nodes: Dictionary mapping node_id -> TreeNode for quick lookup
node_list: Flat list of all nodes in tree order
edges: List of (from_node_id, to_node_id) tuples representing calls
depth_reached: Maximum depth achieved in expansion
"""
roots: List[TreeNode] = field(default_factory=list)
all_nodes: Dict[str, TreeNode] = field(default_factory=dict)
node_list: List[TreeNode] = field(default_factory=list)
edges: List[tuple[str, str]] = field(default_factory=list)
depth_reached: int = 0
def add_node(self, node: TreeNode) -> None:
"""Add a node to the tree.
Args:
node: TreeNode to add
"""
if node.node_id not in self.all_nodes:
self.all_nodes[node.node_id] = node
self.node_list.append(node)
def add_edge(self, from_node: TreeNode, to_node: TreeNode) -> None:
"""Add an edge between two nodes.
Args:
from_node: Source node
to_node: Target node
"""
edge = (from_node.node_id, to_node.node_id)
if edge not in self.edges:
self.edges.append(edge)
def get_node(self, node_id: str) -> Optional[TreeNode]:
"""Get a node by ID.
Args:
node_id: Node identifier
Returns:
TreeNode if found, None otherwise
"""
return self.all_nodes.get(node_id)
def __len__(self) -> int:
"""Return total number of nodes in tree."""
return len(self.all_nodes)
def __repr__(self) -> str:
"""String representation of the tree."""
return (
f"CallTree(roots={len(self.roots)}, nodes={len(self.all_nodes)}, "
f"depth={self.depth_reached})"
)
@dataclass
class UniqueNode:
"""Deduplicated unique code symbol from the tree.
Represents a single unique code location that may appear multiple times
in the tree under different contexts. Contains aggregated information
about all occurrences.
Attributes:
file_path: Absolute path to the file
name: Symbol name (function, method, class, etc.)
kind: Symbol kind (function, method, class, etc.)
range: Code range in the file
min_depth: Minimum depth at which this node appears in the tree
occurrences: Number of times this node appears in the tree
paths: List of paths from roots to this node
context_nodes: Related nodes from the tree
score: Composite relevance score (higher is better)
"""
file_path: str
name: str
kind: str
range: Range
min_depth: int = 0
occurrences: int = 1
paths: List[List[str]] = field(default_factory=list)
context_nodes: List[str] = field(default_factory=list)
score: float = 0.0
@property
def node_key(self) -> tuple[str, int, int]:
"""Unique key for deduplication.
Uses (file_path, start_line, end_line) as the unique identifier
for this symbol across all occurrences.
"""
return (
self.file_path,
self.range.start_line,
self.range.end_line,
)
def add_path(self, path: List[str]) -> None:
"""Add a path from root to this node.
Args:
path: List of node IDs from root to this node
"""
if path not in self.paths:
self.paths.append(path)
def __hash__(self) -> int:
"""Hash based on node key."""
return hash(self.node_key)
def __eq__(self, other: object) -> bool:
"""Equality based on node key."""
if not isinstance(other, UniqueNode):
return False
return self.node_key == other.node_key
def __repr__(self) -> str:
"""String representation of the unique node."""
return (
f"UniqueNode({self.name}@{self.file_path}:{self.range.start_line}, "
f"depth={self.min_depth}, occ={self.occurrences}, score={self.score:.2f})"
)

View File

@@ -0,0 +1,301 @@
"""Result deduplication for association tree nodes.
Provides functionality to extract unique nodes from a call tree and assign
relevance scores based on various factors.
"""
from __future__ import annotations
import logging
from typing import Dict, List, Optional
from .data_structures import (
CallTree,
TreeNode,
UniqueNode,
)
logger = logging.getLogger(__name__)
# Symbol kind weights for scoring (higher = more relevant)
KIND_WEIGHTS: Dict[str, float] = {
# Functions and methods are primary targets
"function": 1.0,
"method": 1.0,
"12": 1.0, # LSP SymbolKind.Function
"6": 1.0, # LSP SymbolKind.Method
# Classes are important but secondary
"class": 0.8,
"5": 0.8, # LSP SymbolKind.Class
# Interfaces and types
"interface": 0.7,
"11": 0.7, # LSP SymbolKind.Interface
"type": 0.6,
# Constructors
"constructor": 0.9,
"9": 0.9, # LSP SymbolKind.Constructor
# Variables and constants
"variable": 0.4,
"13": 0.4, # LSP SymbolKind.Variable
"constant": 0.5,
"14": 0.5, # LSP SymbolKind.Constant
# Default for unknown kinds
"unknown": 0.3,
}
class ResultDeduplicator:
"""Extracts and scores unique nodes from call trees.
Processes a CallTree to extract unique code locations, merging duplicates
and assigning relevance scores based on:
- Depth: Shallower nodes (closer to seeds) score higher
- Frequency: Nodes appearing multiple times score higher
- Kind: Function/method > class > variable
Attributes:
depth_weight: Weight for depth factor in scoring (default 0.4)
frequency_weight: Weight for frequency factor (default 0.3)
kind_weight: Weight for symbol kind factor (default 0.3)
max_depth_penalty: Maximum depth before full penalty applied
"""
def __init__(
self,
depth_weight: float = 0.4,
frequency_weight: float = 0.3,
kind_weight: float = 0.3,
max_depth_penalty: int = 10,
):
"""Initialize ResultDeduplicator.
Args:
depth_weight: Weight for depth factor (0.0-1.0)
frequency_weight: Weight for frequency factor (0.0-1.0)
kind_weight: Weight for symbol kind factor (0.0-1.0)
max_depth_penalty: Depth at which score becomes 0 for depth factor
"""
self.depth_weight = depth_weight
self.frequency_weight = frequency_weight
self.kind_weight = kind_weight
self.max_depth_penalty = max_depth_penalty
def deduplicate(
self,
tree: CallTree,
max_results: Optional[int] = None,
) -> List[UniqueNode]:
"""Extract unique nodes from the call tree.
Traverses the tree, groups nodes by their unique key (file_path,
start_line, end_line), and merges duplicate occurrences.
Args:
tree: CallTree to process
max_results: Maximum number of results to return (None = all)
Returns:
List of UniqueNode objects, sorted by score descending
"""
if not tree.node_list:
return []
# Group nodes by unique key
unique_map: Dict[tuple, UniqueNode] = {}
for node in tree.node_list:
if node.is_cycle:
# Skip cycle markers - they point to already-counted nodes
continue
key = self._get_node_key(node)
if key in unique_map:
# Update existing unique node
unique_node = unique_map[key]
unique_node.occurrences += 1
unique_node.min_depth = min(unique_node.min_depth, node.depth)
unique_node.add_path(node.path_from_root)
# Collect context from relationships
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
else:
# Create new unique node
unique_node = UniqueNode(
file_path=node.item.file_path,
name=node.item.name,
kind=node.item.kind,
range=node.item.range,
min_depth=node.depth,
occurrences=1,
paths=[node.path_from_root.copy()],
context_nodes=[],
score=0.0,
)
# Collect initial context
for parent in node.parents:
if not parent.is_cycle:
unique_node.context_nodes.append(parent.node_id)
for child in node.children:
if not child.is_cycle:
unique_node.context_nodes.append(child.node_id)
unique_map[key] = unique_node
# Calculate scores for all unique nodes
unique_nodes = list(unique_map.values())
# Find max frequency for normalization
max_frequency = max((n.occurrences for n in unique_nodes), default=1)
for node in unique_nodes:
node.score = self._score_node(node, max_frequency)
# Sort by score descending
unique_nodes.sort(key=lambda n: n.score, reverse=True)
# Apply max_results limit
if max_results is not None and max_results > 0:
unique_nodes = unique_nodes[:max_results]
logger.debug(
"Deduplicated %d tree nodes to %d unique nodes",
len(tree.node_list),
len(unique_nodes),
)
return unique_nodes
def _score_node(
self,
node: UniqueNode,
max_frequency: int,
) -> float:
"""Calculate composite score for a unique node.
Score = depth_weight * depth_score +
frequency_weight * frequency_score +
kind_weight * kind_score
Args:
node: UniqueNode to score
max_frequency: Maximum occurrence count for normalization
Returns:
Composite score between 0.0 and 1.0
"""
# Depth score: closer to root = higher score
# Score of 1.0 at depth 0, decreasing to 0.0 at max_depth_penalty
depth_score = max(
0.0,
1.0 - (node.min_depth / self.max_depth_penalty),
)
# Frequency score: more occurrences = higher score
frequency_score = node.occurrences / max_frequency if max_frequency > 0 else 0.0
# Kind score: function/method > class > variable
kind_str = str(node.kind).lower()
kind_score = KIND_WEIGHTS.get(kind_str, KIND_WEIGHTS["unknown"])
# Composite score
score = (
self.depth_weight * depth_score
+ self.frequency_weight * frequency_score
+ self.kind_weight * kind_score
)
return score
def _get_node_key(self, node: TreeNode) -> tuple:
"""Get unique key for a tree node.
Uses (file_path, start_line, end_line) as the unique identifier.
Args:
node: TreeNode
Returns:
Tuple key for deduplication
"""
return (
node.item.file_path,
node.item.range.start_line,
node.item.range.end_line,
)
def filter_by_kind(
self,
nodes: List[UniqueNode],
kinds: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by symbol kind.
Args:
nodes: List of UniqueNode to filter
kinds: List of allowed kinds (e.g., ["function", "method"])
Returns:
Filtered list of UniqueNode
"""
kinds_lower = [k.lower() for k in kinds]
return [
node
for node in nodes
if str(node.kind).lower() in kinds_lower
]
def filter_by_file(
self,
nodes: List[UniqueNode],
file_patterns: List[str],
) -> List[UniqueNode]:
"""Filter unique nodes by file path patterns.
Args:
nodes: List of UniqueNode to filter
file_patterns: List of path substrings to match
Returns:
Filtered list of UniqueNode
"""
return [
node
for node in nodes
if any(pattern in node.file_path for pattern in file_patterns)
]
def to_dict_list(self, nodes: List[UniqueNode]) -> List[Dict]:
"""Convert list of UniqueNode to JSON-serializable dicts.
Args:
nodes: List of UniqueNode
Returns:
List of dictionaries
"""
return [
{
"file_path": node.file_path,
"name": node.name,
"kind": node.kind,
"range": {
"start_line": node.range.start_line,
"start_character": node.range.start_character,
"end_line": node.range.end_line,
"end_character": node.range.end_character,
},
"min_depth": node.min_depth,
"occurrences": node.occurrences,
"path_count": len(node.paths),
"score": round(node.score, 4),
}
for node in nodes
]