mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
refactor: 移除图索引功能,修复内存泄露,优化嵌入生成
主要更改: 1. 移除图索引功能 (graph indexing) - 删除 graph_analyzer.py 及相关迁移文件 - 移除 CLI 的 graph 命令和 --enrich 标志 - 清理 chain_search.py 中的图查询方法 (370行) - 删除相关测试文件 2. 修复嵌入生成内存问题 - 重构 generate_embeddings.py 使用流式批处理 - 改用 embedding_manager 的内存安全实现 - 文件从 548 行精简到 259 行 (52.7% 减少) 3. 修复内存泄露 - chain_search.py: quick_search 使用 with 语句管理 ChainSearchEngine - embedding_manager.py: 使用 with 语句管理 VectorStore - vector_store.py: 添加暴力搜索内存警告 4. 代码清理 - 移除 Symbol 模型的 token_count 和 symbol_type 字段 - 清理相关测试用例 测试: 760 passed, 7 skipped 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,644 +0,0 @@
|
||||
"""Unit tests for ChainSearchEngine.
|
||||
|
||||
Tests the graph query methods (search_callers, search_callees, search_inheritance)
|
||||
with mocked SQLiteStore dependency to test logic in isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock, patch, call
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from codexlens.search.chain_search import (
|
||||
ChainSearchEngine,
|
||||
SearchOptions,
|
||||
SearchStats,
|
||||
ChainSearchResult,
|
||||
)
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
from codexlens.storage.registry import RegistryStore, DirMapping
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry():
|
||||
"""Create a mock RegistryStore."""
|
||||
registry = Mock(spec=RegistryStore)
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mapper():
|
||||
"""Create a mock PathMapper."""
|
||||
mapper = Mock(spec=PathMapper)
|
||||
return mapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine(mock_registry, mock_mapper):
|
||||
"""Create a ChainSearchEngine with mocked dependencies."""
|
||||
return ChainSearchEngine(mock_registry, mock_mapper, max_workers=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_index_path():
|
||||
"""Sample index database path."""
|
||||
return Path("/test/project/_index.db")
|
||||
|
||||
|
||||
class TestChainSearchEngineCallers:
|
||||
"""Tests for search_callers method."""
|
||||
|
||||
def test_search_callers_returns_relationships(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callers returns caller relationships."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
# Mock finding the start index
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
# Mock collect_index_paths to return single index
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
# Mock the parallel search to return caller data
|
||||
expected_callers = [
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 42,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=expected_callers):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "caller_function"
|
||||
assert result[0]["target_symbol"] == "my_function"
|
||||
assert result[0]["relationship_type"] == "calls"
|
||||
assert result[0]["source_line"] == 42
|
||||
|
||||
def test_search_callers_empty_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callers handles no results gracefully."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "nonexistent_function"
|
||||
|
||||
# Mock finding the start index
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
# Mock collect_index_paths
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
# Mock empty results
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=[]):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_search_callers_no_index_found(self, search_engine, mock_registry):
|
||||
"""Test that search_callers returns empty list when no index found."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
# Mock no index found
|
||||
mock_registry.find_nearest_index.return_value = None
|
||||
|
||||
with patch.object(search_engine, '_find_start_index', return_value=None):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_search_callers_uses_options(self, search_engine, mock_registry, mock_mapper, sample_index_path):
|
||||
"""Test that search_callers respects SearchOptions."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
options = SearchOptions(depth=1, total_limit=50)
|
||||
|
||||
# Configure mapper to return a path that exists
|
||||
mock_mapper.source_to_index_db.return_value = sample_index_path
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]) as mock_collect:
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=[]) as mock_search:
|
||||
# Patch Path.exists to return True so the exact match is found
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
# Execute
|
||||
search_engine.search_callers(target_symbol, source_path, options)
|
||||
|
||||
# Assert that depth was passed to collect_index_paths
|
||||
mock_collect.assert_called_once_with(sample_index_path, 1)
|
||||
# Assert that total_limit was passed to parallel search
|
||||
mock_search.assert_called_once_with([sample_index_path], target_symbol, 50)
|
||||
|
||||
|
||||
class TestChainSearchEngineCallees:
|
||||
"""Tests for search_callees method."""
|
||||
|
||||
def test_search_callees_returns_relationships(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callees returns callee relationships."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
source_symbol = "caller_function"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
expected_callees = [
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "callee_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 15,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
|
||||
# Execute
|
||||
result = search_engine.search_callees(source_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "caller_function"
|
||||
assert result[0]["target_symbol"] == "callee_function"
|
||||
assert result[0]["source_line"] == 15
|
||||
|
||||
def test_search_callees_filters_by_file(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callees correctly handles file-specific queries."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
source_symbol = "MyClass.method"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
# Multiple callees from same source symbol
|
||||
expected_callees = [
|
||||
{
|
||||
"source_symbol": "MyClass.method",
|
||||
"target_symbol": "helper_a",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/utils.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "MyClass.method",
|
||||
"target_symbol": "helper_b",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 20,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/utils.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
|
||||
# Execute
|
||||
result = search_engine.search_callees(source_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0]["target_symbol"] == "helper_a"
|
||||
assert result[1]["target_symbol"] == "helper_b"
|
||||
|
||||
def test_search_callees_empty_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callees handles no callees gracefully."""
|
||||
source_path = Path("/test/project")
|
||||
source_symbol = "leaf_function"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
with patch.object(search_engine, '_search_callees_parallel', return_value=[]):
|
||||
# Execute
|
||||
result = search_engine.search_callees(source_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestChainSearchEngineInheritance:
|
||||
"""Tests for search_inheritance method."""
|
||||
|
||||
def test_search_inheritance_returns_inherits_relationships(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_inheritance returns inheritance relationships."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
class_name = "BaseClass"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
expected_inheritance = [
|
||||
{
|
||||
"source_symbol": "DerivedClass",
|
||||
"target_symbol": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 5,
|
||||
"source_file": "/test/project/derived.py",
|
||||
"target_file": "/test/project/base.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
|
||||
# Execute
|
||||
result = search_engine.search_inheritance(class_name, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "DerivedClass"
|
||||
assert result[0]["target_symbol"] == "BaseClass"
|
||||
assert result[0]["relationship_type"] == "inherits"
|
||||
|
||||
def test_search_inheritance_multiple_subclasses(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test inheritance search with multiple derived classes."""
|
||||
source_path = Path("/test/project")
|
||||
class_name = "BaseClass"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
expected_inheritance = [
|
||||
{
|
||||
"source_symbol": "DerivedClassA",
|
||||
"target_symbol": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 5,
|
||||
"source_file": "/test/project/derived_a.py",
|
||||
"target_file": "/test/project/base.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "DerivedClassB",
|
||||
"target_symbol": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/project/derived_b.py",
|
||||
"target_file": "/test/project/base.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
|
||||
# Execute
|
||||
result = search_engine.search_inheritance(class_name, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0]["source_symbol"] == "DerivedClassA"
|
||||
assert result[1]["source_symbol"] == "DerivedClassB"
|
||||
|
||||
def test_search_inheritance_empty_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test inheritance search with no subclasses found."""
|
||||
source_path = Path("/test/project")
|
||||
class_name = "FinalClass"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=[]):
|
||||
# Execute
|
||||
result = search_engine.search_inheritance(class_name, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestChainSearchEngineParallelSearch:
|
||||
"""Tests for parallel search aggregation."""
|
||||
|
||||
def test_parallel_search_aggregates_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that parallel search aggregates results from multiple indexes."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
index_path_1 = Path("/test/project/_index.db")
|
||||
index_path_2 = Path("/test/project/subdir/_index.db")
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=index_path_1,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
|
||||
# Mock parallel search results from multiple indexes
|
||||
callers_from_multiple = [
|
||||
{
|
||||
"source_symbol": "caller_in_root",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/project/root.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "caller_in_subdir",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 20,
|
||||
"source_file": "/test/project/subdir/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=callers_from_multiple):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert results from both indexes are included
|
||||
assert len(result) == 2
|
||||
assert any(r["source_file"] == "/test/project/root.py" for r in result)
|
||||
assert any(r["source_file"] == "/test/project/subdir/module.py" for r in result)
|
||||
|
||||
def test_parallel_search_deduplicates_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that parallel search deduplicates results by (source_file, source_line)."""
|
||||
# Note: This test verifies the behavior of _search_callers_parallel deduplication
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
index_path_1 = Path("/test/project/_index.db")
|
||||
index_path_2 = Path("/test/project/_index.db") # Same index (simulates duplicate)
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=index_path_1,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
|
||||
# Mock duplicate results from same location
|
||||
duplicate_callers = [
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 42,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 42,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=duplicate_callers):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert: even with duplicates in input, output may contain both
|
||||
# (actual deduplication happens in _search_callers_parallel)
|
||||
assert len(result) >= 1
|
||||
|
||||
|
||||
class TestChainSearchEngineContextManager:
|
||||
"""Tests for context manager functionality."""
|
||||
|
||||
def test_context_manager_closes_executor(self, mock_registry, mock_mapper):
|
||||
"""Test that context manager properly closes executor."""
|
||||
with ChainSearchEngine(mock_registry, mock_mapper) as engine:
|
||||
# Force executor creation
|
||||
engine._get_executor()
|
||||
assert engine._executor is not None
|
||||
|
||||
# Executor should be closed after exiting context
|
||||
assert engine._executor is None
|
||||
|
||||
def test_close_method_shuts_down_executor(self, search_engine):
|
||||
"""Test that close() method shuts down executor."""
|
||||
# Create executor
|
||||
search_engine._get_executor()
|
||||
assert search_engine._executor is not None
|
||||
|
||||
# Close
|
||||
search_engine.close()
|
||||
assert search_engine._executor is None
|
||||
|
||||
|
||||
class TestSearchCallersSingle:
|
||||
"""Tests for _search_callers_single internal method."""
|
||||
|
||||
def test_search_callers_single_queries_store(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callers_single queries SQLiteStore correctly."""
|
||||
target_symbol = "my_function"
|
||||
|
||||
# Mock SQLiteStore
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
mock_store_instance = MockStore.return_value.__enter__.return_value
|
||||
mock_store_instance.query_relationships_by_target.return_value = [
|
||||
{
|
||||
"source_symbol": "caller",
|
||||
"target_symbol": target_symbol,
|
||||
"relationship_type": "calls",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/file.py",
|
||||
"target_file": "/test/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callers_single(sample_index_path, target_symbol)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "caller"
|
||||
mock_store_instance.query_relationships_by_target.assert_called_once_with(target_symbol)
|
||||
|
||||
def test_search_callers_single_handles_errors(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callers_single returns empty list on error."""
|
||||
target_symbol = "my_function"
|
||||
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
MockStore.return_value.__enter__.side_effect = Exception("Database error")
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callers_single(sample_index_path, target_symbol)
|
||||
|
||||
# Assert - should return empty list, not raise exception
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestSearchCalleesSingle:
|
||||
"""Tests for _search_callees_single internal method."""
|
||||
|
||||
def test_search_callees_single_queries_database(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callees_single queries SQLiteStore correctly."""
|
||||
source_symbol = "caller_function"
|
||||
|
||||
# Mock SQLiteStore
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
mock_store_instance = MagicMock()
|
||||
MockStore.return_value.__enter__.return_value = mock_store_instance
|
||||
|
||||
# Mock execute_query to return relationship data (using new public API)
|
||||
mock_store_instance.execute_query.return_value = [
|
||||
{
|
||||
"source_symbol": source_symbol,
|
||||
"target_symbol": "callee_function",
|
||||
"relationship_type": "call",
|
||||
"source_line": 15,
|
||||
"source_file": "/test/module.py",
|
||||
"target_file": "/test/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callees_single(sample_index_path, source_symbol)
|
||||
|
||||
# Assert - verify execute_query was called (public API)
|
||||
assert mock_store_instance.execute_query.called
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == source_symbol
|
||||
assert result[0]["target_symbol"] == "callee_function"
|
||||
|
||||
def test_search_callees_single_handles_errors(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callees_single returns empty list on error."""
|
||||
source_symbol = "caller_function"
|
||||
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
MockStore.return_value.__enter__.side_effect = Exception("DB error")
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callees_single(sample_index_path, source_symbol)
|
||||
|
||||
# Assert - should return empty list, not raise exception
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestSearchInheritanceSingle:
|
||||
"""Tests for _search_inheritance_single internal method."""
|
||||
|
||||
def test_search_inheritance_single_queries_database(self, search_engine, sample_index_path):
|
||||
"""Test that _search_inheritance_single queries SQLiteStore correctly."""
|
||||
class_name = "BaseClass"
|
||||
|
||||
# Mock SQLiteStore
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
mock_store_instance = MagicMock()
|
||||
MockStore.return_value.__enter__.return_value = mock_store_instance
|
||||
|
||||
# Mock execute_query to return relationship data (using new public API)
|
||||
mock_store_instance.execute_query.return_value = [
|
||||
{
|
||||
"source_symbol": "DerivedClass",
|
||||
"target_qualified_name": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 5,
|
||||
"source_file": "/test/derived.py",
|
||||
"target_file": "/test/base.py",
|
||||
}
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_inheritance_single(sample_index_path, class_name)
|
||||
|
||||
# Assert
|
||||
assert mock_store_instance.execute_query.called
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "DerivedClass"
|
||||
assert result[0]["relationship_type"] == "inherits"
|
||||
|
||||
# Verify execute_query was called with 'inherits' filter
|
||||
call_args = mock_store_instance.execute_query.call_args
|
||||
sql_query = call_args[0][0]
|
||||
assert "relationship_type = 'inherits'" in sql_query
|
||||
|
||||
def test_search_inheritance_single_handles_errors(self, search_engine, sample_index_path):
|
||||
"""Test that _search_inheritance_single returns empty list on error."""
|
||||
class_name = "BaseClass"
|
||||
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
MockStore.return_value.__enter__.side_effect = Exception("DB error")
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_inheritance_single(sample_index_path, class_name)
|
||||
|
||||
# Assert - should return empty list, not raise exception
|
||||
assert result == []
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Tests for CLI search command with --enrich flag."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
from codexlens.cli.commands import app
|
||||
|
||||
|
||||
class TestCLISearchEnrich:
|
||||
"""Test CLI search command with --enrich flag integration."""
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create CLI test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_search_with_enrich_flag_help(self, runner):
|
||||
"""Test --enrich flag is documented in help."""
|
||||
result = runner.invoke(app, ["search", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "--enrich" in result.output
|
||||
assert "relationships" in result.output.lower() or "graph" in result.output.lower()
|
||||
|
||||
def test_search_with_enrich_flag_accepted(self, runner):
|
||||
"""Test --enrich flag is accepted by the CLI."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich"])
|
||||
# Should not show 'unknown option' error
|
||||
assert "No such option" not in result.output
|
||||
assert "error: unrecognized" not in result.output.lower()
|
||||
|
||||
def test_search_without_enrich_flag(self, runner):
|
||||
"""Test search without --enrich flag has no relationships."""
|
||||
result = runner.invoke(app, ["search", "test", "--json"])
|
||||
# Even without an index, JSON should be attempted
|
||||
if result.exit_code == 0:
|
||||
try:
|
||||
data = json.loads(result.output)
|
||||
# If we get results, they should not have enriched=true
|
||||
if data.get("success") and "result" in data:
|
||||
assert data["result"].get("enriched", False) is False
|
||||
except json.JSONDecodeError:
|
||||
pass # Not JSON output, that's fine for error cases
|
||||
|
||||
def test_search_enrich_json_output_structure(self, runner):
|
||||
"""Test JSON output structure includes enriched flag."""
|
||||
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
|
||||
# If we get valid JSON output, check structure
|
||||
if result.exit_code == 0:
|
||||
try:
|
||||
data = json.loads(result.output)
|
||||
if data.get("success") and "result" in data:
|
||||
# enriched field should exist
|
||||
assert "enriched" in data["result"]
|
||||
except json.JSONDecodeError:
|
||||
pass # Not JSON output
|
||||
|
||||
def test_search_enrich_with_mode(self, runner):
|
||||
"""Test --enrich works with different search modes."""
|
||||
modes = ["exact", "fuzzy", "hybrid"]
|
||||
for mode in modes:
|
||||
result = runner.invoke(
|
||||
app, ["search", "test", "--mode", mode, "--enrich"]
|
||||
)
|
||||
# Should not show validation errors
|
||||
assert "Invalid" not in result.output
|
||||
|
||||
|
||||
class TestEnrichFlagBehavior:
|
||||
"""Test behavioral aspects of --enrich flag."""
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create CLI test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_enrich_failure_does_not_break_search(self, runner):
|
||||
"""Test that enrichment failure doesn't prevent search from returning results."""
|
||||
# Even without proper index, search should not crash due to enrich
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
|
||||
# Should not have unhandled exception
|
||||
assert "Traceback" not in result.output
|
||||
|
||||
def test_enrich_flag_with_files_only(self, runner):
|
||||
"""Test --enrich is accepted with --files-only mode."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--files-only"])
|
||||
# Should not show option conflict error
|
||||
assert "conflict" not in result.output.lower()
|
||||
|
||||
def test_enrich_flag_with_limit(self, runner):
|
||||
"""Test --enrich works with --limit parameter."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--limit", "5"])
|
||||
# Should not show validation error
|
||||
assert "Invalid" not in result.output
|
||||
|
||||
|
||||
class TestEnrichOutputFormat:
|
||||
"""Test output format with --enrich flag."""
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create CLI test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_enrich_verbose_shows_status(self, runner):
|
||||
"""Test verbose mode shows enrichment status."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
|
||||
# Verbose mode may show enrichment info or warnings
|
||||
# Just ensure it doesn't crash
|
||||
assert result.exit_code in [0, 1] # 0 = success, 1 = no index
|
||||
|
||||
def test_json_output_has_enriched_field(self, runner):
|
||||
"""Test JSON output always has enriched field when --enrich used."""
|
||||
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
|
||||
if result.exit_code == 0:
|
||||
try:
|
||||
data = json.loads(result.output)
|
||||
if data.get("success"):
|
||||
result_data = data.get("result", {})
|
||||
assert "enriched" in result_data
|
||||
assert isinstance(result_data["enriched"], bool)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
@@ -204,8 +204,6 @@ class TestEntitySerialization:
|
||||
"kind": "function",
|
||||
"range": (1, 10),
|
||||
"file": None,
|
||||
"token_count": None,
|
||||
"symbol_type": None,
|
||||
}
|
||||
|
||||
def test_indexed_file_model_dump(self):
|
||||
|
||||
@@ -1,436 +0,0 @@
|
||||
"""Tests for GraphAnalyzer - code relationship extraction."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.semantic.graph_analyzer import GraphAnalyzer
|
||||
|
||||
|
||||
TREE_SITTER_PYTHON_AVAILABLE = True
|
||||
try:
|
||||
import tree_sitter_python # type: ignore[import-not-found] # noqa: F401
|
||||
except Exception:
|
||||
TREE_SITTER_PYTHON_AVAILABLE = False
|
||||
|
||||
|
||||
TREE_SITTER_JS_AVAILABLE = True
|
||||
try:
|
||||
import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401
|
||||
except Exception:
|
||||
TREE_SITTER_JS_AVAILABLE = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
class TestPythonGraphAnalyzer:
|
||||
"""Tests for Python relationship extraction."""
|
||||
|
||||
def test_simple_function_call(self):
|
||||
"""Test extraction of simple function call."""
|
||||
code = """def helper():
|
||||
pass
|
||||
|
||||
def main():
|
||||
helper()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find main -> helper call
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "main"
|
||||
assert rel.target_symbol == "helper"
|
||||
assert rel.relationship_type == "call"
|
||||
assert rel.source_line == 5
|
||||
|
||||
def test_multiple_calls_in_function(self):
|
||||
"""Test extraction of multiple calls from same function."""
|
||||
code = """def foo():
|
||||
pass
|
||||
|
||||
def bar():
|
||||
pass
|
||||
|
||||
def main():
|
||||
foo()
|
||||
bar()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find main -> foo and main -> bar
|
||||
assert len(relationships) == 2
|
||||
targets = {rel.target_symbol for rel in relationships}
|
||||
assert targets == {"foo", "bar"}
|
||||
assert all(rel.source_symbol == "main" for rel in relationships)
|
||||
|
||||
def test_nested_function_calls(self):
|
||||
"""Test extraction of calls from nested functions."""
|
||||
code = """def inner_helper():
|
||||
pass
|
||||
|
||||
def outer():
|
||||
def inner():
|
||||
inner_helper()
|
||||
inner()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find outer.inner -> inner_helper and outer -> inner (with fully qualified names)
|
||||
assert len(relationships) == 2
|
||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
assert ("outer.inner", "inner_helper") in call_pairs
|
||||
assert ("outer", "inner") in call_pairs
|
||||
|
||||
def test_method_call_in_class(self):
|
||||
"""Test extraction of method calls within class."""
|
||||
code = """class Calculator:
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
def compute(self, x, y):
|
||||
result = self.add(x, y)
|
||||
return result
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find Calculator.compute -> add (with fully qualified source)
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "Calculator.compute"
|
||||
assert rel.target_symbol == "add"
|
||||
|
||||
def test_module_level_call(self):
|
||||
"""Test extraction of module-level function calls."""
|
||||
code = """def setup():
|
||||
pass
|
||||
|
||||
setup()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find <module> -> setup
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "<module>"
|
||||
assert rel.target_symbol == "setup"
|
||||
|
||||
def test_async_function_call(self):
|
||||
"""Test extraction of calls involving async functions."""
|
||||
code = """async def fetch_data():
|
||||
pass
|
||||
|
||||
async def process():
|
||||
await fetch_data()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find process -> fetch_data
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "process"
|
||||
assert rel.target_symbol == "fetch_data"
|
||||
|
||||
def test_complex_python_file(self):
|
||||
"""Test extraction from realistic Python file with multiple patterns."""
|
||||
code = """class DataProcessor:
|
||||
def __init__(self):
|
||||
self.data = []
|
||||
|
||||
def load(self, filename):
|
||||
self.data = read_file(filename)
|
||||
|
||||
def process(self):
|
||||
self.validate()
|
||||
self.transform()
|
||||
|
||||
def validate(self):
|
||||
pass
|
||||
|
||||
def transform(self):
|
||||
pass
|
||||
|
||||
def read_file(filename):
|
||||
pass
|
||||
|
||||
def main():
|
||||
processor = DataProcessor()
|
||||
processor.load("data.txt")
|
||||
processor.process()
|
||||
|
||||
main()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Extract call pairs
|
||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
|
||||
# Expected relationships (with fully qualified source symbols for methods)
|
||||
expected = {
|
||||
("DataProcessor.load", "read_file"),
|
||||
("DataProcessor.process", "validate"),
|
||||
("DataProcessor.process", "transform"),
|
||||
("main", "DataProcessor"),
|
||||
("main", "load"),
|
||||
("main", "process"),
|
||||
("<module>", "main"),
|
||||
}
|
||||
|
||||
# Should find all expected relationships
|
||||
assert call_pairs >= expected
|
||||
|
||||
def test_empty_file(self):
|
||||
"""Test handling of empty file."""
|
||||
code = ""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
assert len(relationships) == 0
|
||||
|
||||
def test_file_with_no_calls(self):
|
||||
"""Test handling of file with definitions but no calls."""
|
||||
code = """def func1():
|
||||
pass
|
||||
|
||||
def func2():
|
||||
pass
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
assert len(relationships) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed")
|
||||
class TestJavaScriptGraphAnalyzer:
|
||||
"""Tests for JavaScript relationship extraction."""
|
||||
|
||||
def test_simple_function_call(self):
|
||||
"""Test extraction of simple JavaScript function call."""
|
||||
code = """function helper() {}
|
||||
|
||||
function main() {
|
||||
helper();
|
||||
}
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Should find main -> helper call
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "main"
|
||||
assert rel.target_symbol == "helper"
|
||||
assert rel.relationship_type == "call"
|
||||
|
||||
def test_arrow_function_call(self):
|
||||
"""Test extraction of calls from arrow functions."""
|
||||
code = """const helper = () => {};
|
||||
|
||||
const main = () => {
|
||||
helper();
|
||||
};
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Should find main -> helper call
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "main"
|
||||
assert rel.target_symbol == "helper"
|
||||
|
||||
def test_class_method_call(self):
|
||||
"""Test extraction of method calls in JavaScript class."""
|
||||
code = """class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
compute(x, y) {
|
||||
return this.add(x, y);
|
||||
}
|
||||
}
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Should find Calculator.compute -> add (with fully qualified source)
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "Calculator.compute"
|
||||
assert rel.target_symbol == "add"
|
||||
|
||||
def test_complex_javascript_file(self):
|
||||
"""Test extraction from realistic JavaScript file."""
|
||||
code = """function readFile(filename) {
|
||||
return "";
|
||||
}
|
||||
|
||||
class DataProcessor {
|
||||
constructor() {
|
||||
this.data = [];
|
||||
}
|
||||
|
||||
load(filename) {
|
||||
this.data = readFile(filename);
|
||||
}
|
||||
|
||||
process() {
|
||||
this.validate();
|
||||
this.transform();
|
||||
}
|
||||
|
||||
validate() {}
|
||||
|
||||
transform() {}
|
||||
}
|
||||
|
||||
function main() {
|
||||
const processor = new DataProcessor();
|
||||
processor.load("data.txt");
|
||||
processor.process();
|
||||
}
|
||||
|
||||
main();
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Extract call pairs
|
||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
|
||||
# Expected relationships (with fully qualified source symbols for methods)
|
||||
# Note: constructor calls like "new DataProcessor()" are not tracked
|
||||
expected = {
|
||||
("DataProcessor.load", "readFile"),
|
||||
("DataProcessor.process", "validate"),
|
||||
("DataProcessor.process", "transform"),
|
||||
("main", "load"),
|
||||
("main", "process"),
|
||||
("<module>", "main"),
|
||||
}
|
||||
|
||||
# Should find all expected relationships
|
||||
assert call_pairs >= expected
|
||||
|
||||
|
||||
class TestGraphAnalyzerEdgeCases:
|
||||
"""Edge case tests for GraphAnalyzer."""
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_unavailable_language(self):
|
||||
"""Test handling of unsupported language."""
|
||||
code = "some code"
|
||||
analyzer = GraphAnalyzer("rust")
|
||||
relationships = analyzer.analyze_file(code, Path("test.rs"))
|
||||
assert len(relationships) == 0
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_malformed_python_code(self):
|
||||
"""Test handling of malformed Python code."""
|
||||
code = "def broken(\n pass"
|
||||
analyzer = GraphAnalyzer("python")
|
||||
# Should not crash
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
assert isinstance(relationships, list)
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_file_path_in_relationship(self):
|
||||
"""Test that file path is correctly set in relationships."""
|
||||
code = """def foo():
|
||||
pass
|
||||
|
||||
def bar():
|
||||
foo()
|
||||
"""
|
||||
test_path = Path("test.py")
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, test_path)
|
||||
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_file == str(test_path.resolve())
|
||||
assert rel.target_file is None # Intra-file
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_performance_large_file(self):
|
||||
"""Test performance on larger file (1000 lines)."""
|
||||
import time
|
||||
|
||||
# Generate file with many functions and calls
|
||||
lines = []
|
||||
for i in range(100):
|
||||
lines.append(f"def func_{i}():")
|
||||
if i > 0:
|
||||
lines.append(f" func_{i-1}()")
|
||||
else:
|
||||
lines.append(" pass")
|
||||
|
||||
code = "\n".join(lines)
|
||||
|
||||
analyzer = GraphAnalyzer("python")
|
||||
start_time = time.time()
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Should complete in under 500ms
|
||||
assert elapsed_ms < 500
|
||||
|
||||
# Should find 99 calls (func_1 -> func_0, func_2 -> func_1, ...)
|
||||
assert len(relationships) == 99
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_call_accuracy_rate(self):
|
||||
"""Test >95% accuracy on known call graph."""
|
||||
code = """def a(): pass
|
||||
def b(): pass
|
||||
def c(): pass
|
||||
def d(): pass
|
||||
def e(): pass
|
||||
|
||||
def test1():
|
||||
a()
|
||||
b()
|
||||
|
||||
def test2():
|
||||
c()
|
||||
d()
|
||||
|
||||
def test3():
|
||||
e()
|
||||
|
||||
def main():
|
||||
test1()
|
||||
test2()
|
||||
test3()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Expected calls: test1->a, test1->b, test2->c, test2->d, test3->e, main->test1, main->test2, main->test3
|
||||
expected_calls = {
|
||||
("test1", "a"),
|
||||
("test1", "b"),
|
||||
("test2", "c"),
|
||||
("test2", "d"),
|
||||
("test3", "e"),
|
||||
("main", "test1"),
|
||||
("main", "test2"),
|
||||
("main", "test3"),
|
||||
}
|
||||
|
||||
found_calls = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
|
||||
# Calculate accuracy
|
||||
correct = len(expected_calls & found_calls)
|
||||
total = len(expected_calls)
|
||||
accuracy = (correct / total) * 100 if total > 0 else 0
|
||||
|
||||
# Should have >95% accuracy
|
||||
assert accuracy >= 95.0
|
||||
assert correct == total # Should be 100% for this simple case
|
||||
@@ -1,392 +0,0 @@
|
||||
"""End-to-end tests for graph search CLI commands."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typer.testing import CliRunner
|
||||
import pytest
|
||||
|
||||
from codexlens.cli.commands import app
|
||||
from codexlens.storage.sqlite_store import SQLiteStore
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.entities import IndexedFile, Symbol, CodeRelationship
|
||||
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project():
|
||||
"""Create a temporary project with indexed code and relationships."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir) / "test_project"
|
||||
project_root.mkdir()
|
||||
|
||||
# Create test Python files
|
||||
(project_root / "main.py").write_text("""
|
||||
def main():
|
||||
result = calculate(5, 3)
|
||||
print(result)
|
||||
|
||||
def calculate(a, b):
|
||||
return add(a, b)
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
""")
|
||||
|
||||
(project_root / "utils.py").write_text("""
|
||||
class BaseClass:
|
||||
def method(self):
|
||||
pass
|
||||
|
||||
class DerivedClass(BaseClass):
|
||||
def method(self):
|
||||
super().method()
|
||||
helper()
|
||||
|
||||
def helper():
|
||||
return True
|
||||
""")
|
||||
|
||||
# Create a custom index directory for graph testing
|
||||
# Skip the standard init to avoid schema conflicts
|
||||
mapper = PathMapper()
|
||||
index_root = mapper.source_to_index_dir(project_root)
|
||||
index_root.mkdir(parents=True, exist_ok=True)
|
||||
test_db = index_root / "_index.db"
|
||||
|
||||
# Register project manually
|
||||
registry = RegistryStore()
|
||||
registry.initialize()
|
||||
project_info = registry.register_project(
|
||||
source_root=project_root,
|
||||
index_root=index_root
|
||||
)
|
||||
registry.register_dir(
|
||||
project_id=project_info.id,
|
||||
source_path=project_root,
|
||||
index_path=test_db,
|
||||
depth=0,
|
||||
files_count=2
|
||||
)
|
||||
|
||||
# Initialize the store with proper SQLiteStore schema and add files
|
||||
with SQLiteStore(test_db) as store:
|
||||
# Read and add files to the store
|
||||
main_content = (project_root / "main.py").read_text()
|
||||
utils_content = (project_root / "utils.py").read_text()
|
||||
|
||||
main_indexed = IndexedFile(
|
||||
path=str(project_root / "main.py"),
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="main", kind="function", range=(2, 4)),
|
||||
Symbol(name="calculate", kind="function", range=(6, 7)),
|
||||
Symbol(name="add", kind="function", range=(9, 10))
|
||||
]
|
||||
)
|
||||
utils_indexed = IndexedFile(
|
||||
path=str(project_root / "utils.py"),
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="BaseClass", kind="class", range=(2, 4)),
|
||||
Symbol(name="DerivedClass", kind="class", range=(6, 9)),
|
||||
Symbol(name="helper", kind="function", range=(11, 12))
|
||||
]
|
||||
)
|
||||
|
||||
store.add_file(main_indexed, main_content)
|
||||
store.add_file(utils_indexed, utils_content)
|
||||
|
||||
with SQLiteStore(test_db) as store:
|
||||
# Add relationships for main.py
|
||||
main_file = project_root / "main.py"
|
||||
relationships_main = [
|
||||
CodeRelationship(
|
||||
source_symbol="main",
|
||||
target_symbol="calculate",
|
||||
relationship_type="call",
|
||||
source_file=str(main_file),
|
||||
source_line=3,
|
||||
target_file=str(main_file)
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="calculate",
|
||||
target_symbol="add",
|
||||
relationship_type="call",
|
||||
source_file=str(main_file),
|
||||
source_line=7,
|
||||
target_file=str(main_file)
|
||||
),
|
||||
]
|
||||
store.add_relationships(main_file, relationships_main)
|
||||
|
||||
# Add relationships for utils.py
|
||||
utils_file = project_root / "utils.py"
|
||||
relationships_utils = [
|
||||
CodeRelationship(
|
||||
source_symbol="DerivedClass",
|
||||
target_symbol="BaseClass",
|
||||
relationship_type="inherits",
|
||||
source_file=str(utils_file),
|
||||
source_line=6, # DerivedClass is defined on line 6
|
||||
target_file=str(utils_file)
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="DerivedClass.method",
|
||||
target_symbol="helper",
|
||||
relationship_type="call",
|
||||
source_file=str(utils_file),
|
||||
source_line=8,
|
||||
target_file=str(utils_file)
|
||||
),
|
||||
]
|
||||
store.add_relationships(utils_file, relationships_utils)
|
||||
|
||||
registry.close()
|
||||
|
||||
yield project_root
|
||||
|
||||
|
||||
class TestGraphCallers:
|
||||
"""Test callers query type."""
|
||||
|
||||
def test_find_callers_basic(self, temp_project):
|
||||
"""Test finding functions that call a given function."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "calculate" in result.stdout
|
||||
assert "Callers of 'add'" in result.stdout
|
||||
|
||||
def test_find_callers_json_mode(self, temp_project):
|
||||
"""Test callers query with JSON output."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--json"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "success" in result.stdout
|
||||
assert "relationships" in result.stdout
|
||||
|
||||
def test_find_callers_no_results(self, temp_project):
|
||||
"""Test callers query when no callers exist."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"nonexistent_function",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No callers found" in result.stdout or "0 found" in result.stdout
|
||||
|
||||
|
||||
class TestGraphCallees:
|
||||
"""Test callees query type."""
|
||||
|
||||
def test_find_callees_basic(self, temp_project):
|
||||
"""Test finding functions called by a given function."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callees",
|
||||
"main",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "calculate" in result.stdout
|
||||
assert "Callees of 'main'" in result.stdout
|
||||
|
||||
def test_find_callees_chain(self, temp_project):
|
||||
"""Test finding callees in a call chain."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callees",
|
||||
"calculate",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "add" in result.stdout
|
||||
|
||||
def test_find_callees_json_mode(self, temp_project):
|
||||
"""Test callees query with JSON output."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callees",
|
||||
"main",
|
||||
"--path", str(temp_project),
|
||||
"--json"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "success" in result.stdout
|
||||
|
||||
|
||||
class TestGraphInheritance:
|
||||
"""Test inheritance query type."""
|
||||
|
||||
def test_find_inheritance_basic(self, temp_project):
|
||||
"""Test finding inheritance relationships."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"inheritance",
|
||||
"BaseClass",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "DerivedClass" in result.stdout
|
||||
assert "Inheritance relationships" in result.stdout
|
||||
|
||||
def test_find_inheritance_derived(self, temp_project):
|
||||
"""Test finding inheritance from derived class perspective."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"inheritance",
|
||||
"DerivedClass",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "BaseClass" in result.stdout
|
||||
|
||||
def test_find_inheritance_json_mode(self, temp_project):
|
||||
"""Test inheritance query with JSON output."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"inheritance",
|
||||
"BaseClass",
|
||||
"--path", str(temp_project),
|
||||
"--json"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "success" in result.stdout
|
||||
|
||||
|
||||
class TestGraphValidation:
|
||||
"""Test query validation and error handling."""
|
||||
|
||||
def test_invalid_query_type(self, temp_project):
|
||||
"""Test error handling for invalid query type."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"invalid_type",
|
||||
"symbol",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Invalid query type" in result.stdout
|
||||
|
||||
def test_invalid_path(self):
|
||||
"""Test error handling for non-existent path."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"symbol",
|
||||
"--path", "/nonexistent/path"
|
||||
])
|
||||
|
||||
# Should handle gracefully (may exit with error or return empty results)
|
||||
assert result.exit_code in [0, 1]
|
||||
|
||||
|
||||
class TestGraphPerformance:
|
||||
"""Test graph query performance requirements."""
|
||||
|
||||
def test_query_response_time(self, temp_project):
|
||||
"""Verify graph queries complete in under 1 second."""
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert elapsed < 1.0, f"Query took {elapsed:.2f}s, expected <1s"
|
||||
|
||||
def test_multiple_query_types(self, temp_project):
|
||||
"""Test all three query types complete successfully."""
|
||||
import time
|
||||
|
||||
queries = [
|
||||
("callers", "add"),
|
||||
("callees", "main"),
|
||||
("inheritance", "BaseClass")
|
||||
]
|
||||
|
||||
total_start = time.time()
|
||||
|
||||
for query_type, symbol in queries:
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
query_type,
|
||||
symbol,
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
assert result.exit_code == 0
|
||||
|
||||
total_elapsed = time.time() - total_start
|
||||
assert total_elapsed < 3.0, f"All queries took {total_elapsed:.2f}s, expected <3s"
|
||||
|
||||
|
||||
class TestGraphOptions:
|
||||
"""Test graph command options."""
|
||||
|
||||
def test_limit_option(self, temp_project):
|
||||
"""Test limit option works correctly."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--limit", "1"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_depth_option(self, temp_project):
|
||||
"""Test depth option works correctly."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--depth", "0"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_verbose_option(self, temp_project):
|
||||
"""Test verbose option works correctly."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--verbose"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,355 +0,0 @@
|
||||
"""Tests for code relationship storage."""
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, Symbol
|
||||
from codexlens.storage.migration_manager import MigrationManager
|
||||
from codexlens.storage.sqlite_store import SQLiteStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create a temporary database for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
yield db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(temp_db):
|
||||
"""Create a SQLiteStore with migrations applied."""
|
||||
store = SQLiteStore(temp_db)
|
||||
store.initialize()
|
||||
|
||||
# Manually apply migration_003 (code_relationships table)
|
||||
conn = store._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT,
|
||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield store
|
||||
|
||||
# Cleanup
|
||||
store.close()
|
||||
|
||||
|
||||
def test_relationship_table_created(store):
|
||||
"""Test that the code_relationships table is created by migration."""
|
||||
conn = store._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check table exists
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='code_relationships'"
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
assert result is not None, "code_relationships table should exist"
|
||||
|
||||
# Check indexes exist
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='code_relationships'"
|
||||
)
|
||||
indexes = [row[0] for row in cursor.fetchall()]
|
||||
assert "idx_relationships_source" in indexes
|
||||
assert "idx_relationships_target" in indexes
|
||||
assert "idx_relationships_type" in indexes
|
||||
|
||||
|
||||
def test_add_relationships(store):
|
||||
"""Test storing code relationships."""
|
||||
# First add a file with symbols
|
||||
indexed_file = IndexedFile(
|
||||
path=str(Path(__file__).parent / "sample.py"),
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="foo", kind="function", range=(1, 5)),
|
||||
Symbol(name="bar", kind="function", range=(7, 10)),
|
||||
]
|
||||
)
|
||||
|
||||
content = """def foo():
|
||||
bar()
|
||||
baz()
|
||||
|
||||
def bar():
|
||||
print("hello")
|
||||
"""
|
||||
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
# Add relationships
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=indexed_file.path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="baz",
|
||||
relationship_type="call",
|
||||
source_file=indexed_file.path,
|
||||
target_file=None,
|
||||
source_line=3
|
||||
),
|
||||
]
|
||||
|
||||
store.add_relationships(indexed_file.path, relationships)
|
||||
|
||||
# Verify relationships were stored
|
||||
conn = store._get_connection()
|
||||
count = conn.execute("SELECT COUNT(*) FROM code_relationships").fetchone()[0]
|
||||
assert count == 2, "Should have stored 2 relationships"
|
||||
|
||||
|
||||
def test_query_relationships_by_target(store):
|
||||
"""Test querying relationships by target symbol (find callers)."""
|
||||
# Setup: Add file and relationships
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
# Content: Line 1-2: foo(), Line 4-5: bar(), Line 7-8: main()
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="foo", kind="function", range=(1, 2)),
|
||||
Symbol(name="bar", kind="function", range=(4, 5)),
|
||||
Symbol(name="main", kind="function", range=(7, 8)),
|
||||
]
|
||||
)
|
||||
|
||||
content = "def foo():\n bar()\n\ndef bar():\n pass\n\ndef main():\n bar()\n"
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2 # Call inside foo (line 2)
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="main",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=8 # Call inside main (line 8)
|
||||
),
|
||||
]
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Query: Find all callers of "bar"
|
||||
callers = store.query_relationships_by_target("bar")
|
||||
|
||||
assert len(callers) == 2, "Should find 2 callers of bar"
|
||||
assert any(r["source_symbol"] == "foo" for r in callers)
|
||||
assert any(r["source_symbol"] == "main" for r in callers)
|
||||
assert all(r["target_symbol"] == "bar" for r in callers)
|
||||
assert all(r["relationship_type"] == "call" for r in callers)
|
||||
|
||||
|
||||
def test_query_relationships_by_source(store):
|
||||
"""Test querying relationships by source symbol (find callees)."""
|
||||
# Setup
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="foo", kind="function", range=(1, 6)),
|
||||
]
|
||||
)
|
||||
|
||||
content = "def foo():\n bar()\n baz()\n qux()\n"
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="baz",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=3
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="qux",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=4
|
||||
),
|
||||
]
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Query: Find all functions called by foo
|
||||
callees = store.query_relationships_by_source("foo", file_path)
|
||||
|
||||
assert len(callees) == 3, "Should find 3 functions called by foo"
|
||||
targets = {r["target_symbol"] for r in callees}
|
||||
assert targets == {"bar", "baz", "qux"}
|
||||
assert all(r["source_symbol"] == "foo" for r in callees)
|
||||
|
||||
|
||||
def test_query_performance(store):
|
||||
"""Test that relationship queries execute within performance threshold."""
|
||||
import time
|
||||
|
||||
# Setup: Create a file with many relationships
|
||||
file_path = str(Path(__file__).parent / "large_file.py")
|
||||
symbols = [Symbol(name=f"func_{i}", kind="function", range=(i*10+1, i*10+5)) for i in range(100)]
|
||||
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=symbols
|
||||
)
|
||||
|
||||
content = "\n".join([f"def func_{i}():\n pass\n" for i in range(100)])
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
# Create many relationships
|
||||
relationships = []
|
||||
for i in range(100):
|
||||
for j in range(10):
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=f"func_{i}",
|
||||
target_symbol=f"target_{j}",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=i*10 + 1
|
||||
)
|
||||
)
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Query and measure time
|
||||
start = time.time()
|
||||
results = store.query_relationships_by_target("target_5")
|
||||
elapsed_ms = (time.time() - start) * 1000
|
||||
|
||||
assert len(results) == 100, "Should find 100 callers"
|
||||
assert elapsed_ms < 50, f"Query took {elapsed_ms:.1f}ms, should be <50ms"
|
||||
|
||||
|
||||
def test_stats_includes_relationships(store):
|
||||
"""Test that stats() includes relationship count."""
|
||||
# Add a file with relationships
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[Symbol(name="foo", kind="function", range=(1, 5))]
|
||||
)
|
||||
|
||||
store.add_file(indexed_file, "def foo():\n bar()\n")
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
)
|
||||
]
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Check stats
|
||||
stats = store.stats()
|
||||
|
||||
assert "relationships" in stats
|
||||
assert stats["relationships"] == 1
|
||||
assert stats["files"] == 1
|
||||
assert stats["symbols"] == 1
|
||||
|
||||
|
||||
def test_update_relationships_on_file_reindex(store):
|
||||
"""Test that relationships are updated when file is re-indexed."""
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
|
||||
# Initial index
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[Symbol(name="foo", kind="function", range=(1, 3))]
|
||||
)
|
||||
store.add_file(indexed_file, "def foo():\n bar()\n")
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
)
|
||||
]
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Re-index with different relationships
|
||||
new_relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="baz",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
)
|
||||
]
|
||||
store.add_relationships(file_path, new_relationships)
|
||||
|
||||
# Verify old relationships are replaced
|
||||
all_rels = store.query_relationships_by_source("foo", file_path)
|
||||
assert len(all_rels) == 1
|
||||
assert all_rels[0]["target_symbol"] == "baz"
|
||||
@@ -188,60 +188,3 @@ class TestTokenCountPerformance:
|
||||
# Precomputed should be at least 10% faster
|
||||
speedup = ((computed_time - precomputed_time) / computed_time) * 100
|
||||
assert speedup >= 10.0, f"Speedup {speedup:.2f}% < 10% (computed={computed_time:.4f}s, precomputed={precomputed_time:.4f}s)"
|
||||
|
||||
|
||||
class TestSymbolEntityTokenCount:
|
||||
"""Tests for Symbol entity token_count field."""
|
||||
|
||||
def test_symbol_with_token_count(self):
|
||||
"""Test creating Symbol with token_count."""
|
||||
symbol = Symbol(
|
||||
name="test_func",
|
||||
kind="function",
|
||||
range=(1, 10),
|
||||
token_count=42
|
||||
)
|
||||
|
||||
assert symbol.token_count == 42
|
||||
|
||||
def test_symbol_without_token_count(self):
|
||||
"""Test creating Symbol without token_count (defaults to None)."""
|
||||
symbol = Symbol(
|
||||
name="test_func",
|
||||
kind="function",
|
||||
range=(1, 10)
|
||||
)
|
||||
|
||||
assert symbol.token_count is None
|
||||
|
||||
def test_symbol_with_symbol_type(self):
|
||||
"""Test creating Symbol with symbol_type."""
|
||||
symbol = Symbol(
|
||||
name="TestClass",
|
||||
kind="class",
|
||||
range=(1, 20),
|
||||
symbol_type="class_definition"
|
||||
)
|
||||
|
||||
assert symbol.symbol_type == "class_definition"
|
||||
|
||||
def test_symbol_token_count_validation(self):
|
||||
"""Test that negative token counts are rejected."""
|
||||
with pytest.raises(ValueError, match="token_count must be >= 0"):
|
||||
Symbol(
|
||||
name="test",
|
||||
kind="function",
|
||||
range=(1, 2),
|
||||
token_count=-1
|
||||
)
|
||||
|
||||
def test_symbol_zero_token_count(self):
|
||||
"""Test that zero token count is allowed."""
|
||||
symbol = Symbol(
|
||||
name="empty",
|
||||
kind="function",
|
||||
range=(1, 1),
|
||||
token_count=0
|
||||
)
|
||||
|
||||
assert symbol.token_count == 0
|
||||
|
||||
Reference in New Issue
Block a user