mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat: Implement association tree for LSP-based code relationship discovery
- Add `association_tree` module with components for building and processing call association trees using LSP call hierarchy capabilities. - Introduce `AssociationTreeBuilder` for constructing call trees from seed locations with depth-first expansion. - Create data structures: `TreeNode`, `CallTree`, and `UniqueNode` for representing nodes and relationships in the call tree. - Implement `ResultDeduplicator` to extract unique nodes from call trees and assign relevance scores based on depth, frequency, and kind. - Add unit tests for `AssociationTreeBuilder` and `ResultDeduplicator` to ensure functionality and correctness.
This commit is contained in:
400
codex-lens/tests/test_association_tree.py
Normal file
400
codex-lens/tests/test_association_tree.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Unit tests for association tree building and deduplication.
|
||||
|
||||
Tests the AssociationTreeBuilder and ResultDeduplicator components using
|
||||
mocked LSP responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
|
||||
from codexlens.search.association_tree import (
|
||||
AssociationTreeBuilder,
|
||||
CallTree,
|
||||
ResultDeduplicator,
|
||||
TreeNode,
|
||||
UniqueNode,
|
||||
)
|
||||
|
||||
|
||||
class MockLspManager:
|
||||
"""Mock LSP manager for testing."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize mock with empty responses."""
|
||||
self.call_hierarchy_items: Dict[str, List[Dict]] = {}
|
||||
self.incoming_calls: Dict[str, List[Dict]] = {}
|
||||
self.outgoing_calls: Dict[str, List[Dict]] = {}
|
||||
|
||||
async def get_call_hierarchy_items(
|
||||
self, file_path: str, line: int, character: int
|
||||
) -> List[Dict]:
|
||||
"""Mock get_call_hierarchy_items."""
|
||||
key = f"{file_path}:{line}:{character}"
|
||||
return self.call_hierarchy_items.get(key, [])
|
||||
|
||||
async def get_incoming_calls(self, item: Dict[str, Any]) -> List[Dict]:
|
||||
"""Mock get_incoming_calls."""
|
||||
name = item.get("name", "")
|
||||
return self.incoming_calls.get(name, [])
|
||||
|
||||
async def get_outgoing_calls(self, item: Dict[str, Any]) -> List[Dict]:
|
||||
"""Mock get_outgoing_calls."""
|
||||
name = item.get("name", "")
|
||||
return self.outgoing_calls.get(name, [])
|
||||
|
||||
|
||||
def create_mock_item(
|
||||
name: str,
|
||||
file_path: str,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
kind: str = "function",
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a mock CallHierarchyItem dict.
|
||||
|
||||
Args:
|
||||
name: Symbol name
|
||||
file_path: File path
|
||||
start_line: Start line (0-based for LSP)
|
||||
end_line: End line (0-based for LSP)
|
||||
kind: Symbol kind
|
||||
|
||||
Returns:
|
||||
LSP CallHierarchyItem dict
|
||||
"""
|
||||
return {
|
||||
"name": name,
|
||||
"kind": kind,
|
||||
"uri": f"file:///{file_path}",
|
||||
"range": {
|
||||
"start": {"line": start_line, "character": 0},
|
||||
"end": {"line": end_line, "character": 0},
|
||||
},
|
||||
"detail": f"def {name}(...)",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_tree_building():
|
||||
"""Test building a simple tree with one root and one callee."""
|
||||
mock_lsp = MockLspManager()
|
||||
|
||||
# Root function
|
||||
root_item = create_mock_item("main", "test.py", 10, 15)
|
||||
|
||||
# Callee function
|
||||
callee_item = create_mock_item("helper", "test.py", 20, 25)
|
||||
|
||||
# Setup mock responses
|
||||
mock_lsp.call_hierarchy_items["test.py:11:1"] = [root_item]
|
||||
mock_lsp.outgoing_calls["main"] = [{"to": callee_item}]
|
||||
mock_lsp.incoming_calls["main"] = []
|
||||
mock_lsp.outgoing_calls["helper"] = []
|
||||
mock_lsp.incoming_calls["helper"] = []
|
||||
|
||||
# Build tree
|
||||
builder = AssociationTreeBuilder(mock_lsp)
|
||||
tree = await builder.build_tree(
|
||||
seed_file_path="test.py",
|
||||
seed_line=11,
|
||||
seed_character=1,
|
||||
max_depth=2,
|
||||
expand_callers=False,
|
||||
expand_callees=True,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(tree.roots) == 1
|
||||
assert tree.roots[0].item.name == "main"
|
||||
assert len(tree.roots[0].children) == 1
|
||||
assert tree.roots[0].children[0].item.name == "helper"
|
||||
assert len(tree.all_nodes) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tree_with_cycle_detection():
|
||||
"""Test that cycles are properly detected and marked."""
|
||||
mock_lsp = MockLspManager()
|
||||
|
||||
# Create circular reference: A -> B -> A
|
||||
item_a = create_mock_item("func_a", "test.py", 10, 15)
|
||||
item_b = create_mock_item("func_b", "test.py", 20, 25)
|
||||
|
||||
# Setup mock responses
|
||||
mock_lsp.call_hierarchy_items["test.py:11:1"] = [item_a]
|
||||
mock_lsp.outgoing_calls["func_a"] = [{"to": item_b}]
|
||||
mock_lsp.outgoing_calls["func_b"] = [{"to": item_a}] # Cycle
|
||||
mock_lsp.incoming_calls["func_a"] = []
|
||||
mock_lsp.incoming_calls["func_b"] = []
|
||||
|
||||
# Build tree
|
||||
builder = AssociationTreeBuilder(mock_lsp)
|
||||
tree = await builder.build_tree(
|
||||
seed_file_path="test.py",
|
||||
seed_line=11,
|
||||
seed_character=1,
|
||||
max_depth=5,
|
||||
expand_callers=False,
|
||||
expand_callees=True,
|
||||
)
|
||||
|
||||
# Should have 2 unique nodes (func_a and func_b)
|
||||
assert len(tree.all_nodes) == 2
|
||||
|
||||
# func_b should have a cycle child pointing back to func_a
|
||||
func_b_node = None
|
||||
for node in tree.node_list:
|
||||
if node.item.name == "func_b":
|
||||
func_b_node = node
|
||||
break
|
||||
|
||||
assert func_b_node is not None
|
||||
assert len(func_b_node.children) == 1
|
||||
assert func_b_node.children[0].is_cycle
|
||||
assert func_b_node.children[0].item.name == "func_a"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_depth_limit():
|
||||
"""Test that expansion stops at max_depth."""
|
||||
mock_lsp = MockLspManager()
|
||||
|
||||
# Chain: A -> B -> C -> D
|
||||
items = {
|
||||
"A": create_mock_item("func_a", "test.py", 10, 15),
|
||||
"B": create_mock_item("func_b", "test.py", 20, 25),
|
||||
"C": create_mock_item("func_c", "test.py", 30, 35),
|
||||
"D": create_mock_item("func_d", "test.py", 40, 45),
|
||||
}
|
||||
|
||||
mock_lsp.call_hierarchy_items["test.py:11:1"] = [items["A"]]
|
||||
mock_lsp.outgoing_calls["func_a"] = [{"to": items["B"]}]
|
||||
mock_lsp.outgoing_calls["func_b"] = [{"to": items["C"]}]
|
||||
mock_lsp.outgoing_calls["func_c"] = [{"to": items["D"]}]
|
||||
mock_lsp.outgoing_calls["func_d"] = []
|
||||
|
||||
for name in ["func_a", "func_b", "func_c", "func_d"]:
|
||||
mock_lsp.incoming_calls[name] = []
|
||||
|
||||
# Build tree with max_depth=2
|
||||
builder = AssociationTreeBuilder(mock_lsp)
|
||||
tree = await builder.build_tree(
|
||||
seed_file_path="test.py",
|
||||
seed_line=11,
|
||||
max_depth=2,
|
||||
expand_callers=False,
|
||||
expand_callees=True,
|
||||
)
|
||||
|
||||
# Should only have nodes A, B, C (depths 0, 1, 2)
|
||||
# D should not be included (would be depth 3)
|
||||
assert len(tree.all_nodes) == 3
|
||||
node_names = {node.item.name for node in tree.node_list}
|
||||
assert "func_a" in node_names
|
||||
assert "func_b" in node_names
|
||||
assert "func_c" in node_names
|
||||
assert "func_d" not in node_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_tree():
|
||||
"""Test building tree when no call hierarchy items found."""
|
||||
mock_lsp = MockLspManager()
|
||||
|
||||
# No items configured
|
||||
builder = AssociationTreeBuilder(mock_lsp)
|
||||
tree = await builder.build_tree(
|
||||
seed_file_path="test.py",
|
||||
seed_line=11,
|
||||
max_depth=2,
|
||||
)
|
||||
|
||||
# Should have empty tree
|
||||
assert len(tree.roots) == 0
|
||||
assert len(tree.all_nodes) == 0
|
||||
|
||||
|
||||
def test_deduplication_basic():
|
||||
"""Test basic deduplication of tree nodes."""
|
||||
# Create test tree with duplicate nodes
|
||||
tree = CallTree()
|
||||
|
||||
# Same function appearing at different depths via different paths
|
||||
# This simulates the real scenario where a function appears multiple times
|
||||
# in a call tree (e.g., reached from different callers)
|
||||
item_a1 = CallHierarchyItem(
|
||||
name="func_a",
|
||||
kind="function",
|
||||
file_path="test.py",
|
||||
range=Range(10, 0, 15, 0),
|
||||
)
|
||||
item_a2 = CallHierarchyItem(
|
||||
name="func_a",
|
||||
kind="function",
|
||||
file_path="test.py",
|
||||
range=Range(10, 0, 15, 0), # Same range
|
||||
)
|
||||
|
||||
node1 = TreeNode(item=item_a1, depth=0, path_from_root=["node1"])
|
||||
node2 = TreeNode(item=item_a2, depth=2, path_from_root=["root", "mid", "node2"])
|
||||
|
||||
# Manually add to node_list to simulate same symbol from different paths
|
||||
tree.node_list.append(node1)
|
||||
tree.node_list.append(node2)
|
||||
|
||||
# Different function
|
||||
item_b = CallHierarchyItem(
|
||||
name="func_b",
|
||||
kind="function",
|
||||
file_path="test.py",
|
||||
range=Range(20, 0, 25, 0),
|
||||
)
|
||||
node3 = TreeNode(item=item_b, depth=1, path_from_root=["root", "node3"])
|
||||
tree.node_list.append(node3)
|
||||
|
||||
# Deduplicate
|
||||
deduplicator = ResultDeduplicator()
|
||||
unique_nodes = deduplicator.deduplicate(tree)
|
||||
|
||||
# Should have 2 unique nodes (func_a merged, func_b separate)
|
||||
assert len(unique_nodes) == 2
|
||||
|
||||
# func_a should have occurrences=2 and min_depth=0
|
||||
func_a_node = next(n for n in unique_nodes if n.name == "func_a")
|
||||
assert func_a_node.occurrences == 2
|
||||
assert func_a_node.min_depth == 0
|
||||
|
||||
# func_b should have occurrences=1 and min_depth=1
|
||||
func_b_node = next(n for n in unique_nodes if n.name == "func_b")
|
||||
assert func_b_node.occurrences == 1
|
||||
assert func_b_node.min_depth == 1
|
||||
|
||||
|
||||
def test_deduplication_scoring():
|
||||
"""Test that scoring prioritizes depth and frequency correctly."""
|
||||
tree = CallTree()
|
||||
|
||||
# Create nodes with different characteristics
|
||||
# Node at depth 0 (root)
|
||||
item1 = CallHierarchyItem(
|
||||
name="root_func",
|
||||
kind="function",
|
||||
file_path="test.py",
|
||||
range=Range(10, 0, 15, 0),
|
||||
)
|
||||
node1 = TreeNode(item=item1, depth=0)
|
||||
tree.add_node(node1)
|
||||
|
||||
# Node at depth 5 (deep)
|
||||
item2 = CallHierarchyItem(
|
||||
name="deep_func",
|
||||
kind="function",
|
||||
file_path="test.py",
|
||||
range=Range(20, 0, 25, 0),
|
||||
)
|
||||
node2 = TreeNode(item=item2, depth=5)
|
||||
tree.add_node(node2)
|
||||
|
||||
# Deduplicate and score
|
||||
deduplicator = ResultDeduplicator()
|
||||
unique_nodes = deduplicator.deduplicate(tree)
|
||||
|
||||
# Root node should score higher than deep node
|
||||
root_node = next(n for n in unique_nodes if n.name == "root_func")
|
||||
deep_node = next(n for n in unique_nodes if n.name == "deep_func")
|
||||
|
||||
assert root_node.score > deep_node.score
|
||||
|
||||
|
||||
def test_deduplication_max_results():
|
||||
"""Test that max_results limit works correctly."""
|
||||
tree = CallTree()
|
||||
|
||||
# Create 5 unique nodes
|
||||
for i in range(5):
|
||||
item = CallHierarchyItem(
|
||||
name=f"func_{i}",
|
||||
kind="function",
|
||||
file_path="test.py",
|
||||
range=Range(i * 10, 0, i * 10 + 5, 0),
|
||||
)
|
||||
node = TreeNode(item=item, depth=i)
|
||||
tree.add_node(node)
|
||||
|
||||
# Deduplicate with max_results=3
|
||||
deduplicator = ResultDeduplicator()
|
||||
unique_nodes = deduplicator.deduplicate(tree, max_results=3)
|
||||
|
||||
# Should only return 3 nodes
|
||||
assert len(unique_nodes) == 3
|
||||
|
||||
|
||||
def test_filter_by_kind():
|
||||
"""Test filtering unique nodes by symbol kind."""
|
||||
# Create unique nodes with different kinds
|
||||
nodes = [
|
||||
UniqueNode(
|
||||
file_path="test.py",
|
||||
name="func1",
|
||||
kind="function",
|
||||
range=Range(10, 0, 15, 0),
|
||||
),
|
||||
UniqueNode(
|
||||
file_path="test.py",
|
||||
name="cls1",
|
||||
kind="class",
|
||||
range=Range(20, 0, 30, 0),
|
||||
),
|
||||
UniqueNode(
|
||||
file_path="test.py",
|
||||
name="var1",
|
||||
kind="variable",
|
||||
range=Range(40, 0, 40, 10),
|
||||
),
|
||||
]
|
||||
|
||||
deduplicator = ResultDeduplicator()
|
||||
|
||||
# Filter for functions only
|
||||
filtered = deduplicator.filter_by_kind(nodes, ["function"])
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0].name == "func1"
|
||||
|
||||
# Filter for functions and classes
|
||||
filtered = deduplicator.filter_by_kind(nodes, ["function", "class"])
|
||||
assert len(filtered) == 2
|
||||
|
||||
|
||||
def test_to_dict_list():
|
||||
"""Test conversion of unique nodes to dict list."""
|
||||
nodes = [
|
||||
UniqueNode(
|
||||
file_path="test.py",
|
||||
name="func1",
|
||||
kind="function",
|
||||
range=Range(10, 0, 15, 0),
|
||||
min_depth=0,
|
||||
occurrences=2,
|
||||
score=0.85,
|
||||
),
|
||||
]
|
||||
|
||||
deduplicator = ResultDeduplicator()
|
||||
dict_list = deduplicator.to_dict_list(nodes)
|
||||
|
||||
assert len(dict_list) == 1
|
||||
assert dict_list[0]["name"] == "func1"
|
||||
assert dict_list[0]["kind"] == "function"
|
||||
assert dict_list[0]["min_depth"] == 0
|
||||
assert dict_list[0]["occurrences"] == 2
|
||||
assert dict_list[0]["score"] == 0.85
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user