Files
Claude-Code-Workflow/codex-lens/tests/test_association_tree.py

401 lines
12 KiB
Python

"""Unit tests for association tree building and deduplication.
Tests the AssociationTreeBuilder and ResultDeduplicator components using
mocked LSP responses.
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from codexlens.hybrid_search.data_structures import CallHierarchyItem, Range
from codexlens.search.association_tree import (
AssociationTreeBuilder,
CallTree,
ResultDeduplicator,
TreeNode,
UniqueNode,
)
class MockLspManager:
"""Mock LSP manager for testing."""
def __init__(self):
"""Initialize mock with empty responses."""
self.call_hierarchy_items: Dict[str, List[Dict]] = {}
self.incoming_calls: Dict[str, List[Dict]] = {}
self.outgoing_calls: Dict[str, List[Dict]] = {}
async def get_call_hierarchy_items(
self, file_path: str, line: int, character: int, wait_for_analysis: float = 0.0
) -> List[Dict]:
"""Mock get_call_hierarchy_items."""
key = f"{file_path}:{line}:{character}"
return self.call_hierarchy_items.get(key, [])
async def get_incoming_calls(self, item: Dict[str, Any]) -> List[Dict]:
"""Mock get_incoming_calls."""
name = item.get("name", "")
return self.incoming_calls.get(name, [])
async def get_outgoing_calls(self, item: Dict[str, Any]) -> List[Dict]:
"""Mock get_outgoing_calls."""
name = item.get("name", "")
return self.outgoing_calls.get(name, [])
def create_mock_item(
name: str,
file_path: str,
start_line: int,
end_line: int,
kind: str = "function",
) -> Dict[str, Any]:
"""Create a mock CallHierarchyItem dict.
Args:
name: Symbol name
file_path: File path
start_line: Start line (0-based for LSP)
end_line: End line (0-based for LSP)
kind: Symbol kind
Returns:
LSP CallHierarchyItem dict
"""
return {
"name": name,
"kind": kind,
"uri": f"file:///{file_path}",
"range": {
"start": {"line": start_line, "character": 0},
"end": {"line": end_line, "character": 0},
},
"detail": f"def {name}(...)",
}
@pytest.mark.asyncio
async def test_simple_tree_building():
"""Test building a simple tree with one root and one callee."""
mock_lsp = MockLspManager()
# Root function
root_item = create_mock_item("main", "test.py", 10, 15)
# Callee function
callee_item = create_mock_item("helper", "test.py", 20, 25)
# Setup mock responses
mock_lsp.call_hierarchy_items["test.py:11:1"] = [root_item]
mock_lsp.outgoing_calls["main"] = [{"to": callee_item}]
mock_lsp.incoming_calls["main"] = []
mock_lsp.outgoing_calls["helper"] = []
mock_lsp.incoming_calls["helper"] = []
# Build tree
builder = AssociationTreeBuilder(mock_lsp)
tree = await builder.build_tree(
seed_file_path="test.py",
seed_line=11,
seed_character=1,
max_depth=2,
expand_callers=False,
expand_callees=True,
)
# Assertions
assert len(tree.roots) == 1
assert tree.roots[0].item.name == "main"
assert len(tree.roots[0].children) == 1
assert tree.roots[0].children[0].item.name == "helper"
assert len(tree.all_nodes) == 2
@pytest.mark.asyncio
async def test_tree_with_cycle_detection():
"""Test that cycles are properly detected and marked."""
mock_lsp = MockLspManager()
# Create circular reference: A -> B -> A
item_a = create_mock_item("func_a", "test.py", 10, 15)
item_b = create_mock_item("func_b", "test.py", 20, 25)
# Setup mock responses
mock_lsp.call_hierarchy_items["test.py:11:1"] = [item_a]
mock_lsp.outgoing_calls["func_a"] = [{"to": item_b}]
mock_lsp.outgoing_calls["func_b"] = [{"to": item_a}] # Cycle
mock_lsp.incoming_calls["func_a"] = []
mock_lsp.incoming_calls["func_b"] = []
# Build tree
builder = AssociationTreeBuilder(mock_lsp)
tree = await builder.build_tree(
seed_file_path="test.py",
seed_line=11,
seed_character=1,
max_depth=5,
expand_callers=False,
expand_callees=True,
)
# Should have 2 unique nodes (func_a and func_b)
assert len(tree.all_nodes) == 2
# func_b should have a cycle child pointing back to func_a
func_b_node = None
for node in tree.node_list:
if node.item.name == "func_b":
func_b_node = node
break
assert func_b_node is not None
assert len(func_b_node.children) == 1
assert func_b_node.children[0].is_cycle
assert func_b_node.children[0].item.name == "func_a"
@pytest.mark.asyncio
async def test_max_depth_limit():
"""Test that expansion stops at max_depth."""
mock_lsp = MockLspManager()
# Chain: A -> B -> C -> D
items = {
"A": create_mock_item("func_a", "test.py", 10, 15),
"B": create_mock_item("func_b", "test.py", 20, 25),
"C": create_mock_item("func_c", "test.py", 30, 35),
"D": create_mock_item("func_d", "test.py", 40, 45),
}
mock_lsp.call_hierarchy_items["test.py:11:1"] = [items["A"]]
mock_lsp.outgoing_calls["func_a"] = [{"to": items["B"]}]
mock_lsp.outgoing_calls["func_b"] = [{"to": items["C"]}]
mock_lsp.outgoing_calls["func_c"] = [{"to": items["D"]}]
mock_lsp.outgoing_calls["func_d"] = []
for name in ["func_a", "func_b", "func_c", "func_d"]:
mock_lsp.incoming_calls[name] = []
# Build tree with max_depth=2
builder = AssociationTreeBuilder(mock_lsp)
tree = await builder.build_tree(
seed_file_path="test.py",
seed_line=11,
max_depth=2,
expand_callers=False,
expand_callees=True,
)
# Should only have nodes A, B, C (depths 0, 1, 2)
# D should not be included (would be depth 3)
assert len(tree.all_nodes) == 3
node_names = {node.item.name for node in tree.node_list}
assert "func_a" in node_names
assert "func_b" in node_names
assert "func_c" in node_names
assert "func_d" not in node_names
@pytest.mark.asyncio
async def test_empty_tree():
"""Test building tree when no call hierarchy items found."""
mock_lsp = MockLspManager()
# No items configured
builder = AssociationTreeBuilder(mock_lsp)
tree = await builder.build_tree(
seed_file_path="test.py",
seed_line=11,
max_depth=2,
)
# Should have empty tree
assert len(tree.roots) == 0
assert len(tree.all_nodes) == 0
def test_deduplication_basic():
"""Test basic deduplication of tree nodes."""
# Create test tree with duplicate nodes
tree = CallTree()
# Same function appearing at different depths via different paths
# This simulates the real scenario where a function appears multiple times
# in a call tree (e.g., reached from different callers)
item_a1 = CallHierarchyItem(
name="func_a",
kind="function",
file_path="test.py",
range=Range(10, 0, 15, 0),
)
item_a2 = CallHierarchyItem(
name="func_a",
kind="function",
file_path="test.py",
range=Range(10, 0, 15, 0), # Same range
)
node1 = TreeNode(item=item_a1, depth=0, path_from_root=["node1"])
node2 = TreeNode(item=item_a2, depth=2, path_from_root=["root", "mid", "node2"])
# Manually add to node_list to simulate same symbol from different paths
tree.node_list.append(node1)
tree.node_list.append(node2)
# Different function
item_b = CallHierarchyItem(
name="func_b",
kind="function",
file_path="test.py",
range=Range(20, 0, 25, 0),
)
node3 = TreeNode(item=item_b, depth=1, path_from_root=["root", "node3"])
tree.node_list.append(node3)
# Deduplicate
deduplicator = ResultDeduplicator()
unique_nodes = deduplicator.deduplicate(tree)
# Should have 2 unique nodes (func_a merged, func_b separate)
assert len(unique_nodes) == 2
# func_a should have occurrences=2 and min_depth=0
func_a_node = next(n for n in unique_nodes if n.name == "func_a")
assert func_a_node.occurrences == 2
assert func_a_node.min_depth == 0
# func_b should have occurrences=1 and min_depth=1
func_b_node = next(n for n in unique_nodes if n.name == "func_b")
assert func_b_node.occurrences == 1
assert func_b_node.min_depth == 1
def test_deduplication_scoring():
"""Test that scoring prioritizes depth and frequency correctly."""
tree = CallTree()
# Create nodes with different characteristics
# Node at depth 0 (root)
item1 = CallHierarchyItem(
name="root_func",
kind="function",
file_path="test.py",
range=Range(10, 0, 15, 0),
)
node1 = TreeNode(item=item1, depth=0)
tree.add_node(node1)
# Node at depth 5 (deep)
item2 = CallHierarchyItem(
name="deep_func",
kind="function",
file_path="test.py",
range=Range(20, 0, 25, 0),
)
node2 = TreeNode(item=item2, depth=5)
tree.add_node(node2)
# Deduplicate and score
deduplicator = ResultDeduplicator()
unique_nodes = deduplicator.deduplicate(tree)
# Root node should score higher than deep node
root_node = next(n for n in unique_nodes if n.name == "root_func")
deep_node = next(n for n in unique_nodes if n.name == "deep_func")
assert root_node.score > deep_node.score
def test_deduplication_max_results():
"""Test that max_results limit works correctly."""
tree = CallTree()
# Create 5 unique nodes
for i in range(5):
item = CallHierarchyItem(
name=f"func_{i}",
kind="function",
file_path="test.py",
range=Range(i * 10, 0, i * 10 + 5, 0),
)
node = TreeNode(item=item, depth=i)
tree.add_node(node)
# Deduplicate with max_results=3
deduplicator = ResultDeduplicator()
unique_nodes = deduplicator.deduplicate(tree, max_results=3)
# Should only return 3 nodes
assert len(unique_nodes) == 3
def test_filter_by_kind():
"""Test filtering unique nodes by symbol kind."""
# Create unique nodes with different kinds
nodes = [
UniqueNode(
file_path="test.py",
name="func1",
kind="function",
range=Range(10, 0, 15, 0),
),
UniqueNode(
file_path="test.py",
name="cls1",
kind="class",
range=Range(20, 0, 30, 0),
),
UniqueNode(
file_path="test.py",
name="var1",
kind="variable",
range=Range(40, 0, 40, 10),
),
]
deduplicator = ResultDeduplicator()
# Filter for functions only
filtered = deduplicator.filter_by_kind(nodes, ["function"])
assert len(filtered) == 1
assert filtered[0].name == "func1"
# Filter for functions and classes
filtered = deduplicator.filter_by_kind(nodes, ["function", "class"])
assert len(filtered) == 2
def test_to_dict_list():
"""Test conversion of unique nodes to dict list."""
nodes = [
UniqueNode(
file_path="test.py",
name="func1",
kind="function",
range=Range(10, 0, 15, 0),
min_depth=0,
occurrences=2,
score=0.85,
),
]
deduplicator = ResultDeduplicator()
dict_list = deduplicator.to_dict_list(nodes)
assert len(dict_list) == 1
assert dict_list[0]["name"] == "func1"
assert dict_list[0]["kind"] == "function"
assert dict_list[0]["min_depth"] == 0
assert dict_list[0]["occurrences"] == 2
assert dict_list[0]["score"] == 0.85
if __name__ == "__main__":
pytest.main([__file__, "-v"])