feat(cli): 添加 --rule 选项支持模板自动发现

重构 ccw cli 模板系统:

- 新增 template-discovery.ts 模块,支持扁平化模板自动发现
- 添加 --rule <template> 选项,自动加载 protocol 和 template
- 模板目录从嵌套结构 (prompts/category/file.txt) 迁移到扁平结构 (prompts/category-function.txt)
- 更新所有 agent/command 文件,使用 $PROTO $TMPL 环境变量替代 $(cat ...) 模式
- 支持模糊匹配:--rule 02-review-architecture 可匹配 analysis-review-architecture.txt

其他更新:
- Dashboard: 添加 Claude Manager 和 Issue Manager 页面
- Codex-lens: 增强 chain_search 和 clustering 模块

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
catlog22
2026-01-17 19:20:24 +08:00
parent 1fae35c05d
commit f14418603a
137 changed files with 13125 additions and 301 deletions

View File

@@ -0,0 +1,282 @@
"""Tests for codexlens.api.references module."""
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codexlens.api.references import (
find_references,
_read_line_from_file,
_proximity_score,
_group_references_by_definition,
_transform_to_reference_result,
)
from codexlens.api.models import (
DefinitionResult,
ReferenceResult,
GroupedReferences,
)
class TestReadLineFromFile:
"""Tests for _read_line_from_file helper."""
def test_read_existing_line(self, tmp_path):
"""Test reading an existing line from a file."""
test_file = tmp_path / "test.py"
test_file.write_text("line 1\nline 2\nline 3\n")
assert _read_line_from_file(str(test_file), 1) == "line 1"
assert _read_line_from_file(str(test_file), 2) == "line 2"
assert _read_line_from_file(str(test_file), 3) == "line 3"
def test_read_nonexistent_line(self, tmp_path):
"""Test reading a line that doesn't exist."""
test_file = tmp_path / "test.py"
test_file.write_text("line 1\nline 2\n")
assert _read_line_from_file(str(test_file), 10) == ""
def test_read_nonexistent_file(self):
"""Test reading from a file that doesn't exist."""
assert _read_line_from_file("/nonexistent/path/file.py", 1) == ""
def test_strips_trailing_whitespace(self, tmp_path):
"""Test that trailing whitespace is stripped."""
test_file = tmp_path / "test.py"
test_file.write_text("line with spaces \n")
assert _read_line_from_file(str(test_file), 1) == "line with spaces"
class TestProximityScore:
"""Tests for _proximity_score helper."""
def test_same_file(self):
"""Same file should return highest score."""
score = _proximity_score("/a/b/c.py", "/a/b/c.py")
assert score == 1000
def test_same_directory(self):
"""Same directory should return 100."""
score = _proximity_score("/a/b/x.py", "/a/b/y.py")
assert score == 100
def test_different_directories(self):
"""Different directories should return common prefix length."""
score = _proximity_score("/a/b/c/x.py", "/a/b/d/y.py")
# Common path is /a/b
assert score > 0
def test_empty_paths(self):
"""Empty paths should return 0."""
assert _proximity_score("", "/a/b/c.py") == 0
assert _proximity_score("/a/b/c.py", "") == 0
assert _proximity_score("", "") == 0
class TestGroupReferencesByDefinition:
"""Tests for _group_references_by_definition helper."""
def test_single_definition(self):
"""Single definition should have all references."""
definition = DefinitionResult(
name="foo",
kind="function",
file_path="/a/b/c.py",
line=10,
end_line=20,
)
references = [
ReferenceResult(
file_path="/a/b/d.py",
line=5,
column=0,
context_line="foo()",
relationship="call",
),
ReferenceResult(
file_path="/a/x/y.py",
line=10,
column=0,
context_line="foo()",
relationship="call",
),
]
result = _group_references_by_definition([definition], references)
assert len(result) == 1
assert result[0].definition == definition
assert len(result[0].references) == 2
def test_multiple_definitions(self):
"""Multiple definitions should group by proximity."""
def1 = DefinitionResult(
name="foo",
kind="function",
file_path="/a/b/c.py",
line=10,
end_line=20,
)
def2 = DefinitionResult(
name="foo",
kind="function",
file_path="/x/y/z.py",
line=10,
end_line=20,
)
# Reference closer to def1
ref1 = ReferenceResult(
file_path="/a/b/d.py",
line=5,
column=0,
context_line="foo()",
relationship="call",
)
# Reference closer to def2
ref2 = ReferenceResult(
file_path="/x/y/w.py",
line=10,
column=0,
context_line="foo()",
relationship="call",
)
result = _group_references_by_definition(
[def1, def2], [ref1, ref2], include_definition=True
)
assert len(result) == 2
# Each definition should have the closer reference
def1_refs = [g for g in result if g.definition == def1][0].references
def2_refs = [g for g in result if g.definition == def2][0].references
assert any(r.file_path == "/a/b/d.py" for r in def1_refs)
assert any(r.file_path == "/x/y/w.py" for r in def2_refs)
def test_empty_definitions(self):
"""Empty definitions should return empty result."""
result = _group_references_by_definition([], [])
assert result == []
class TestTransformToReferenceResult:
"""Tests for _transform_to_reference_result helper."""
def test_normalizes_relationship_type(self, tmp_path):
"""Test that relationship type is normalized."""
test_file = tmp_path / "test.py"
test_file.write_text("def foo(): pass\n")
# Create a mock raw reference
raw_ref = MagicMock()
raw_ref.file_path = str(test_file)
raw_ref.line = 1
raw_ref.column = 0
raw_ref.relationship_type = "calls" # Plural form
result = _transform_to_reference_result(raw_ref)
assert result.relationship == "call" # Normalized form
assert result.context_line == "def foo(): pass"
class TestFindReferences:
"""Tests for find_references API function."""
def test_raises_for_invalid_project_root(self):
"""Test that ValueError is raised for invalid project root."""
with pytest.raises(ValueError, match="does not exist"):
find_references("/nonexistent/path", "some_symbol")
@patch("codexlens.search.chain_search.ChainSearchEngine")
@patch("codexlens.storage.registry.RegistryStore")
@patch("codexlens.storage.path_mapper.PathMapper")
@patch("codexlens.config.Config")
def test_returns_grouped_references(
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
):
"""Test that find_references returns GroupedReferences."""
# Setup mocks
mock_engine = MagicMock()
mock_engine_class.return_value = mock_engine
# Mock symbol search (for definitions)
mock_symbol = MagicMock()
mock_symbol.name = "test_func"
mock_symbol.kind = "function"
mock_symbol.file = str(tmp_path / "test.py")
mock_symbol.range = (10, 20)
mock_engine.search_symbols.return_value = [mock_symbol]
# Mock reference search
mock_ref = MagicMock()
mock_ref.file_path = str(tmp_path / "caller.py")
mock_ref.line = 5
mock_ref.column = 0
mock_ref.relationship_type = "call"
mock_engine.search_references.return_value = [mock_ref]
# Create test files
test_file = tmp_path / "test.py"
test_file.write_text("def test_func():\n pass\n")
caller_file = tmp_path / "caller.py"
caller_file.write_text("test_func()\n")
# Call find_references
result = find_references(str(tmp_path), "test_func")
# Verify result structure
assert isinstance(result, list)
assert len(result) == 1
assert isinstance(result[0], GroupedReferences)
assert result[0].definition.name == "test_func"
assert len(result[0].references) == 1
@patch("codexlens.search.chain_search.ChainSearchEngine")
@patch("codexlens.storage.registry.RegistryStore")
@patch("codexlens.storage.path_mapper.PathMapper")
@patch("codexlens.config.Config")
def test_respects_include_definition_false(
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
):
"""Test include_definition=False behavior."""
mock_engine = MagicMock()
mock_engine_class.return_value = mock_engine
mock_engine.search_symbols.return_value = []
mock_engine.search_references.return_value = []
result = find_references(
str(tmp_path), "test_func", include_definition=False
)
# Should still return a result with placeholder definition
assert len(result) == 1
assert result[0].definition.name == "test_func"
class TestImports:
"""Tests for module imports and exports."""
def test_find_references_exported_from_api(self):
"""Test that find_references is exported from codexlens.api."""
from codexlens.api import find_references as api_find_references
assert callable(api_find_references)
def test_models_exported_from_api(self):
"""Test that result models are exported from codexlens.api."""
from codexlens.api import (
GroupedReferences,
ReferenceResult,
DefinitionResult,
)
assert GroupedReferences is not None
assert ReferenceResult is not None
assert DefinitionResult is not None

View File

@@ -0,0 +1,528 @@
"""Tests for semantic_search API."""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codexlens.api import SemanticResult
from codexlens.api.semantic import (
semantic_search,
_build_search_options,
_generate_match_reason,
_split_camel_case,
_transform_results,
)
class TestSemanticSearchFunctionSignature:
"""Test that semantic_search has the correct function signature."""
def test_function_accepts_all_parameters(self):
"""Verify function signature matches spec."""
import inspect
sig = inspect.signature(semantic_search)
params = list(sig.parameters.keys())
expected_params = [
"project_root",
"query",
"mode",
"vector_weight",
"structural_weight",
"keyword_weight",
"fusion_strategy",
"kind_filter",
"limit",
"include_match_reason",
]
assert params == expected_params
def test_default_parameter_values(self):
"""Verify default parameter values match spec."""
import inspect
sig = inspect.signature(semantic_search)
assert sig.parameters["mode"].default == "fusion"
assert sig.parameters["vector_weight"].default == 0.5
assert sig.parameters["structural_weight"].default == 0.3
assert sig.parameters["keyword_weight"].default == 0.2
assert sig.parameters["fusion_strategy"].default == "rrf"
assert sig.parameters["kind_filter"].default is None
assert sig.parameters["limit"].default == 20
assert sig.parameters["include_match_reason"].default is False
class TestBuildSearchOptions:
"""Test _build_search_options helper function."""
def test_vector_mode_options(self):
"""Test options for pure vector mode."""
options = _build_search_options(
mode="vector",
vector_weight=1.0,
structural_weight=0.0,
keyword_weight=0.0,
limit=20,
)
assert options.hybrid_mode is True
assert options.enable_vector is True
assert options.pure_vector is True
assert options.enable_fuzzy is False
def test_structural_mode_options(self):
"""Test options for structural mode."""
options = _build_search_options(
mode="structural",
vector_weight=0.0,
structural_weight=1.0,
keyword_weight=0.0,
limit=20,
)
assert options.hybrid_mode is True
assert options.enable_vector is False
assert options.enable_fuzzy is True
assert options.include_symbols is True
def test_fusion_mode_options(self):
"""Test options for fusion mode (default)."""
options = _build_search_options(
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
limit=20,
)
assert options.hybrid_mode is True
assert options.enable_vector is True # vector_weight > 0
assert options.enable_fuzzy is True # keyword_weight > 0
assert options.include_symbols is True # structural_weight > 0
class TestTransformResults:
"""Test _transform_results helper function."""
def test_transforms_basic_result(self):
"""Test basic result transformation."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "def authenticate():"
mock_result.symbol_name = "authenticate"
mock_result.symbol_kind = "function"
mock_result.start_line = 10
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=False,
query="auth",
)
assert len(results) == 1
assert results[0].symbol_name == "authenticate"
assert results[0].kind == "function"
assert results[0].file_path == "/project/src/auth.py"
assert results[0].line == 10
assert results[0].fusion_score == 0.85
def test_kind_filter_excludes_non_matching(self):
"""Test that kind_filter excludes non-matching results."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "AUTH_TOKEN = 'secret'"
mock_result.symbol_name = "AUTH_TOKEN"
mock_result.symbol_kind = "variable"
mock_result.start_line = 5
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=["function", "class"], # Exclude variable
include_match_reason=False,
query="auth",
)
assert len(results) == 0
def test_kind_filter_includes_matching(self):
"""Test that kind_filter includes matching results."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "class AuthManager:"
mock_result.symbol_name = "AuthManager"
mock_result.symbol_kind = "class"
mock_result.start_line = 1
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=["function", "class"], # Include class
include_match_reason=False,
query="auth",
)
assert len(results) == 1
assert results[0].symbol_name == "AuthManager"
def test_include_match_reason_generates_reason(self):
"""Test that include_match_reason generates match reasons."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.85
mock_result.excerpt = "def authenticate(user, password):"
mock_result.symbol_name = "authenticate"
mock_result.symbol_kind = "function"
mock_result.start_line = 10
mock_result.symbol = None
mock_result.metadata = {}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=True,
query="authenticate",
)
assert len(results) == 1
assert results[0].match_reason is not None
assert "authenticate" in results[0].match_reason.lower()
class TestGenerateMatchReason:
"""Test _generate_match_reason helper function."""
def test_direct_name_match(self):
"""Test match reason for direct name match."""
reason = _generate_match_reason(
query="authenticate",
symbol_name="authenticate",
symbol_kind="function",
snippet="def authenticate(user): pass",
vector_score=0.8,
structural_score=None,
)
assert "authenticate" in reason.lower()
def test_keyword_match(self):
"""Test match reason for keyword match in snippet."""
reason = _generate_match_reason(
query="password validation",
symbol_name="verify_user",
symbol_kind="function",
snippet="def verify_user(password): validate(password)",
vector_score=0.6,
structural_score=None,
)
assert "password" in reason.lower() or "validation" in reason.lower()
def test_high_semantic_similarity(self):
"""Test match reason mentions semantic similarity for high vector score."""
reason = _generate_match_reason(
query="authentication",
symbol_name="login_handler",
symbol_kind="function",
snippet="def login_handler(): pass",
vector_score=0.85,
structural_score=None,
)
assert "semantic" in reason.lower()
def test_returns_string_even_with_no_matches(self):
"""Test that a reason string is always returned."""
reason = _generate_match_reason(
query="xyz123",
symbol_name="abc456",
symbol_kind="function",
snippet="completely unrelated code",
vector_score=0.3,
structural_score=None,
)
assert isinstance(reason, str)
assert len(reason) > 0
class TestSplitCamelCase:
"""Test _split_camel_case helper function."""
def test_camel_case(self):
"""Test splitting camelCase."""
result = _split_camel_case("authenticateUser")
assert "authenticate" in result.lower()
assert "user" in result.lower()
def test_pascal_case(self):
"""Test splitting PascalCase."""
result = _split_camel_case("AuthManager")
assert "auth" in result.lower()
assert "manager" in result.lower()
def test_snake_case(self):
"""Test splitting snake_case."""
result = _split_camel_case("auth_manager")
assert "auth" in result.lower()
assert "manager" in result.lower()
def test_mixed_case(self):
"""Test splitting mixed case."""
result = _split_camel_case("HTTPRequestHandler")
# Should handle acronyms
assert "http" in result.lower() or "request" in result.lower()
class TestSemanticResultDataclass:
"""Test SemanticResult dataclass structure."""
def test_semantic_result_fields(self):
"""Test SemanticResult has all required fields."""
result = SemanticResult(
symbol_name="test",
kind="function",
file_path="/test.py",
line=1,
vector_score=0.8,
structural_score=0.6,
fusion_score=0.7,
snippet="def test(): pass",
match_reason="Test match",
)
assert result.symbol_name == "test"
assert result.kind == "function"
assert result.file_path == "/test.py"
assert result.line == 1
assert result.vector_score == 0.8
assert result.structural_score == 0.6
assert result.fusion_score == 0.7
assert result.snippet == "def test(): pass"
assert result.match_reason == "Test match"
def test_semantic_result_optional_fields(self):
"""Test SemanticResult with optional None fields."""
result = SemanticResult(
symbol_name="test",
kind="function",
file_path="/test.py",
line=1,
vector_score=None, # Degraded - no vector index
structural_score=None, # Degraded - no relationships
fusion_score=0.5,
snippet="def test(): pass",
match_reason=None, # Not requested
)
assert result.vector_score is None
assert result.structural_score is None
assert result.match_reason is None
def test_semantic_result_to_dict(self):
"""Test SemanticResult.to_dict() filters None values."""
result = SemanticResult(
symbol_name="test",
kind="function",
file_path="/test.py",
line=1,
vector_score=None,
structural_score=0.6,
fusion_score=0.7,
snippet="def test(): pass",
match_reason=None,
)
d = result.to_dict()
assert "symbol_name" in d
assert "vector_score" not in d # None values filtered
assert "structural_score" in d
assert "match_reason" not in d # None values filtered
class TestFusionStrategyMapping:
"""Test fusion_strategy parameter mapping via _execute_search."""
def test_rrf_strategy_calls_search(self):
"""Test that rrf strategy maps to standard search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="rrf",
options=mock_options,
limit=20,
)
mock_engine.search.assert_called_once()
def test_staged_strategy_calls_staged_cascade_search(self):
"""Test that staged strategy maps to staged_cascade_search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.staged_cascade_search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="staged",
options=mock_options,
limit=20,
)
mock_engine.staged_cascade_search.assert_called_once()
def test_binary_strategy_calls_binary_cascade_search(self):
"""Test that binary strategy maps to binary_cascade_search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.binary_cascade_search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="binary",
options=mock_options,
limit=20,
)
mock_engine.binary_cascade_search.assert_called_once()
def test_hybrid_strategy_calls_hybrid_cascade_search(self):
"""Test that hybrid strategy maps to hybrid_cascade_search."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.hybrid_cascade_search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="hybrid",
options=mock_options,
limit=20,
)
mock_engine.hybrid_cascade_search.assert_called_once()
def test_unknown_strategy_defaults_to_rrf(self):
"""Test that unknown strategy defaults to standard search (rrf)."""
from codexlens.api.semantic import _execute_search
mock_engine = MagicMock()
mock_engine.search.return_value = MagicMock(results=[])
mock_options = MagicMock()
_execute_search(
engine=mock_engine,
query="test query",
source_path=Path("/test"),
fusion_strategy="unknown_strategy",
options=mock_options,
limit=20,
)
mock_engine.search.assert_called_once()
class TestGracefulDegradation:
"""Test graceful degradation behavior."""
def test_vector_score_none_when_no_vector_index(self):
"""Test vector_score=None when vector index unavailable."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.5
mock_result.excerpt = "def auth(): pass"
mock_result.symbol_name = "auth"
mock_result.symbol_kind = "function"
mock_result.start_line = 1
mock_result.symbol = None
mock_result.metadata = {} # No vector score in metadata
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=False,
query="auth",
)
assert len(results) == 1
# When no source_scores in metadata, vector_score should be None
assert results[0].vector_score is None
def test_structural_score_extracted_from_fts(self):
"""Test structural_score extracted from FTS scores."""
mock_result = MagicMock()
mock_result.path = "/project/src/auth.py"
mock_result.score = 0.8
mock_result.excerpt = "def auth(): pass"
mock_result.symbol_name = "auth"
mock_result.symbol_kind = "function"
mock_result.start_line = 1
mock_result.symbol = None
mock_result.metadata = {
"source_scores": {
"exact": 0.9,
"fuzzy": 0.7,
}
}
results = _transform_results(
results=[mock_result],
mode="fusion",
vector_weight=0.5,
structural_weight=0.3,
keyword_weight=0.2,
kind_filter=None,
include_match_reason=False,
query="auth",
)
assert len(results) == 1
assert results[0].structural_score == 0.9 # max of exact/fuzzy

View File

@@ -0,0 +1 @@
"""Tests package for LSP module."""

View File

@@ -0,0 +1,477 @@
"""Tests for hover provider."""
from __future__ import annotations
import pytest
from pathlib import Path
from unittest.mock import Mock, MagicMock
import tempfile
from codexlens.entities import Symbol
class TestHoverInfo:
"""Test HoverInfo dataclass."""
def test_hover_info_import(self):
"""HoverInfo can be imported."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from codexlens.lsp.providers import HoverInfo
assert HoverInfo is not None
def test_hover_info_fields(self):
"""HoverInfo has all required fields."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo
info = HoverInfo(
name="my_function",
kind="function",
signature="def my_function(x: int) -> str:",
documentation="A test function.",
file_path="/test/file.py",
line_range=(10, 15),
)
assert info.name == "my_function"
assert info.kind == "function"
assert info.signature == "def my_function(x: int) -> str:"
assert info.documentation == "A test function."
assert info.file_path == "/test/file.py"
assert info.line_range == (10, 15)
def test_hover_info_optional_documentation(self):
"""Documentation can be None."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation=None,
file_path="/test.py",
line_range=(1, 2),
)
assert info.documentation is None
class TestHoverProvider:
"""Test HoverProvider class."""
def test_provider_import(self):
"""HoverProvider can be imported."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
assert HoverProvider is not None
def test_returns_none_for_unknown_symbol(self):
"""Returns None when symbol not found."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_index = Mock()
mock_index.search.return_value = []
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
result = provider.get_hover_info("unknown_symbol")
assert result is None
mock_index.search.assert_called_once_with(
name="unknown_symbol", limit=1, prefix_mode=False
)
def test_returns_none_for_non_exact_match(self):
"""Returns None when search returns non-exact matches."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
# Return a symbol with different name (prefix match but not exact)
mock_symbol = Mock()
mock_symbol.name = "my_function_extended"
mock_symbol.kind = "function"
mock_symbol.file = "/test/file.py"
mock_symbol.range = (10, 15)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
result = provider.get_hover_info("my_function")
assert result is None
def test_returns_hover_info_for_known_symbol(self):
"""Returns HoverInfo for found symbol."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = None # No file, will use fallback signature
mock_symbol.range = (10, 15)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
result = provider.get_hover_info("my_func")
assert result is not None
assert result.name == "my_func"
assert result.kind == "function"
assert result.line_range == (10, 15)
assert result.signature == "function my_func"
def test_extracts_signature_from_file(self):
"""Extracts signature from actual file content."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
# Create a temporary file with Python content
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, encoding="utf-8"
) as f:
f.write("# comment\n")
f.write("def test_function(x: int, y: str) -> bool:\n")
f.write(" return True\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "test_function"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (2, 3) # Line 2 (1-based)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("test_function")
assert result is not None
assert "def test_function(x: int, y: str) -> bool:" in result.signature
finally:
Path(temp_path).unlink(missing_ok=True)
def test_extracts_multiline_signature(self):
"""Extracts multiline function signature."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
# Create a temporary file with multiline signature
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, encoding="utf-8"
) as f:
f.write("def complex_function(\n")
f.write(" arg1: int,\n")
f.write(" arg2: str,\n")
f.write(") -> bool:\n")
f.write(" return True\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "complex_function"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (1, 5) # Line 1 (1-based)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("complex_function")
assert result is not None
assert "def complex_function(" in result.signature
# Should capture multiline signature
assert "arg1: int" in result.signature
finally:
Path(temp_path).unlink(missing_ok=True)
def test_handles_nonexistent_file_gracefully(self):
"""Returns fallback signature when file doesn't exist."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/nonexistent/path/file.py"
mock_symbol.range = (10, 15)
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("my_func")
assert result is not None
assert result.signature == "function my_func"
def test_handles_invalid_line_range(self):
"""Returns fallback signature when line range is invalid."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False, encoding="utf-8"
) as f:
f.write("def test():\n")
f.write(" pass\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "test"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (100, 105) # Line beyond file length
mock_index = Mock()
mock_index.search.return_value = [mock_symbol]
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("test")
assert result is not None
assert result.signature == "function test"
finally:
Path(temp_path).unlink(missing_ok=True)
class TestFormatHoverMarkdown:
"""Test markdown formatting."""
def test_format_python_signature(self):
"""Formats Python signature with python code fence."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func(x: int) -> str:",
documentation=None,
file_path="/test/file.py",
line_range=(10, 15),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "```python" in result
assert "def func(x: int) -> str:" in result
assert "function" in result
assert "file.py" in result
assert "line 10" in result
def test_format_javascript_signature(self):
"""Formats JavaScript signature with javascript code fence."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="myFunc",
kind="function",
signature="function myFunc(x) {",
documentation=None,
file_path="/test/file.js",
line_range=(5, 10),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "```javascript" in result
assert "function myFunc(x) {" in result
def test_format_typescript_signature(self):
"""Formats TypeScript signature with typescript code fence."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="myFunc",
kind="function",
signature="function myFunc(x: number): string {",
documentation=None,
file_path="/test/file.ts",
line_range=(5, 10),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "```typescript" in result
def test_format_with_documentation(self):
"""Includes documentation when available."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation="This is a test function.",
file_path="/test/file.py",
line_range=(10, 15),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "This is a test function." in result
assert "---" in result # Separator before docs
def test_format_without_documentation(self):
"""Does not include documentation section when None."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation=None,
file_path="/test/file.py",
line_range=(10, 15),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
# Should have one separator for location, not two
# The result should not have duplicate doc separator
lines = result.split("\n")
separator_count = sum(1 for line in lines if line.strip() == "---")
assert separator_count == 1 # Only location separator
def test_format_unknown_extension(self):
"""Uses empty code fence for unknown file extensions."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="func code here",
documentation=None,
file_path="/test/file.xyz",
line_range=(1, 2),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
# Should have code fence without language specifier
assert "```\n" in result or "```xyz" not in result
def test_format_class_symbol(self):
"""Formats class symbol correctly."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="MyClass",
kind="class",
signature="class MyClass:",
documentation=None,
file_path="/test/file.py",
line_range=(1, 20),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "class MyClass:" in result
assert "*class*" in result
assert "line 1" in result
def test_format_empty_file_path(self):
"""Handles empty file path gracefully."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverInfo, HoverProvider
info = HoverInfo(
name="func",
kind="function",
signature="def func():",
documentation=None,
file_path="",
line_range=(1, 2),
)
mock_index = Mock()
provider = HoverProvider(mock_index, None)
result = provider.format_hover_markdown(info)
assert "unknown" in result or "```" in result
class TestHoverProviderRegistry:
"""Test HoverProvider with registry integration."""
def test_provider_accepts_none_registry(self):
"""HoverProvider works without registry."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_index = Mock()
mock_index.search.return_value = []
provider = HoverProvider(mock_index, None)
result = provider.get_hover_info("test")
assert result is None
assert provider.registry is None
def test_provider_stores_registry(self):
"""HoverProvider stores registry reference."""
pytest.importorskip("pygls")
from codexlens.lsp.providers import HoverProvider
mock_index = Mock()
mock_registry = Mock()
provider = HoverProvider(mock_index, mock_registry)
assert provider.global_index is mock_index
assert provider.registry is mock_registry

View File

@@ -0,0 +1,497 @@
"""Tests for reference search functionality.
This module tests the ReferenceResult dataclass and search_references method
in ChainSearchEngine, as well as the updated lsp_references handler.
"""
from __future__ import annotations
import pytest
from pathlib import Path
from unittest.mock import Mock, MagicMock, patch
import sqlite3
import tempfile
import os
class TestReferenceResult:
"""Test ReferenceResult dataclass."""
def test_reference_result_fields(self):
"""ReferenceResult has all required fields."""
from codexlens.search.chain_search import ReferenceResult
ref = ReferenceResult(
file_path="/test/file.py",
line=10,
column=5,
context="def foo():",
relationship_type="call",
)
assert ref.file_path == "/test/file.py"
assert ref.line == 10
assert ref.column == 5
assert ref.context == "def foo():"
assert ref.relationship_type == "call"
def test_reference_result_with_empty_context(self):
"""ReferenceResult can have empty context."""
from codexlens.search.chain_search import ReferenceResult
ref = ReferenceResult(
file_path="/test/file.py",
line=1,
column=0,
context="",
relationship_type="import",
)
assert ref.context == ""
def test_reference_result_different_relationship_types(self):
"""ReferenceResult supports different relationship types."""
from codexlens.search.chain_search import ReferenceResult
types = ["call", "import", "inheritance", "implementation", "usage"]
for rel_type in types:
ref = ReferenceResult(
file_path="/test/file.py",
line=1,
column=0,
context="test",
relationship_type=rel_type,
)
assert ref.relationship_type == rel_type
class TestExtractContext:
"""Test the _extract_context helper method."""
def test_extract_context_middle_of_file(self):
"""Extract context from middle of file."""
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
content = "\n".join([
"line 1",
"line 2",
"line 3",
"line 4", # target line
"line 5",
"line 6",
"line 7",
])
# Create minimal mock engine to test _extract_context
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context(content, line=4, context_lines=2)
assert "line 2" in context
assert "line 3" in context
assert "line 4" in context
assert "line 5" in context
assert "line 6" in context
def test_extract_context_start_of_file(self):
"""Extract context at start of file."""
from codexlens.search.chain_search import ChainSearchEngine
content = "\n".join([
"line 1", # target
"line 2",
"line 3",
"line 4",
])
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context(content, line=1, context_lines=2)
assert "line 1" in context
assert "line 2" in context
assert "line 3" in context
def test_extract_context_end_of_file(self):
"""Extract context at end of file."""
from codexlens.search.chain_search import ChainSearchEngine
content = "\n".join([
"line 1",
"line 2",
"line 3",
"line 4", # target
])
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context(content, line=4, context_lines=2)
assert "line 2" in context
assert "line 3" in context
assert "line 4" in context
def test_extract_context_empty_content(self):
"""Extract context from empty content."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
context = engine._extract_context("", line=1, context_lines=3)
assert context == ""
def test_extract_context_invalid_line(self):
"""Extract context with invalid line number."""
from codexlens.search.chain_search import ChainSearchEngine
content = "line 1\nline 2\nline 3"
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
# Line 0 (invalid)
assert engine._extract_context(content, line=0, context_lines=1) == ""
# Line beyond end
assert engine._extract_context(content, line=100, context_lines=1) == ""
class TestSearchReferences:
"""Test search_references method."""
def test_returns_empty_for_no_source_path_and_no_registry(self):
"""Returns empty list when no source path and registry has no mappings."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_registry.list_mappings.return_value = []
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
results = engine.search_references("test_symbol")
assert results == []
def test_returns_empty_for_no_indexes(self):
"""Returns empty list when no indexes found."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
mock_mapper.source_to_index_db.return_value = Path("/nonexistent/_index.db")
engine = ChainSearchEngine(mock_registry, mock_mapper)
with patch.object(engine, "_find_start_index", return_value=None):
results = engine.search_references("test_symbol", Path("/some/path"))
assert results == []
def test_deduplicates_results(self):
"""Removes duplicate file:line references."""
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
# Create a temporary database with duplicate relationships
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'def test(): pass');
INSERT INTO symbols VALUES (1, 1, 'test_func', 'function', 1, 1);
INSERT INTO code_relationships VALUES (1, 1, 'target_func', 'call', 10, NULL);
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 10, NULL);
""")
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target_func", Path(tmpdir))
# Should only have 1 result due to deduplication
assert len(results) == 1
assert results[0].line == 10
def test_sorts_by_file_and_line(self):
"""Results sorted by file path then line number."""
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/b_file.py', 'python', 'content');
INSERT INTO files VALUES (2, '/test/a_file.py', 'python', 'content');
INSERT INTO symbols VALUES (1, 1, 'func1', 'function', 1, 1);
INSERT INTO symbols VALUES (2, 2, 'func2', 'function', 1, 1);
INSERT INTO code_relationships VALUES (1, 1, 'target', 'call', 20, NULL);
INSERT INTO code_relationships VALUES (2, 1, 'target', 'call', 10, NULL);
INSERT INTO code_relationships VALUES (3, 2, 'target', 'call', 5, NULL);
""")
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target", Path(tmpdir))
# Should be sorted: a_file.py:5, b_file.py:10, b_file.py:20
assert len(results) == 3
assert results[0].file_path == "/test/a_file.py"
assert results[0].line == 5
assert results[1].file_path == "/test/b_file.py"
assert results[1].line == 10
assert results[2].file_path == "/test/b_file.py"
assert results[2].line == 20
def test_respects_limit(self):
"""Returns at most limit results."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
INSERT INTO symbols VALUES (1, 1, 'func', 'function', 1, 1);
""")
# Insert many relationships
for i in range(50):
conn.execute(
"INSERT INTO code_relationships VALUES (?, 1, 'target', 'call', ?, NULL)",
(i + 1, i + 1)
)
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target", Path(tmpdir), limit=10)
assert len(results) == 10
def test_matches_qualified_name(self):
"""Matches symbols by qualified name suffix."""
from codexlens.search.chain_search import ChainSearchEngine
mock_registry = Mock()
mock_mapper = Mock()
engine = ChainSearchEngine(mock_registry, mock_mapper)
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "_index.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE files (
id INTEGER PRIMARY KEY,
path TEXT NOT NULL,
language TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE symbols (
id INTEGER PRIMARY KEY,
file_id INTEGER NOT NULL,
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL
);
CREATE TABLE code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
);
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
INSERT INTO symbols VALUES (1, 1, 'caller', 'function', 1, 1);
-- Fully qualified name
INSERT INTO code_relationships VALUES (1, 1, 'module.submodule.target_func', 'call', 10, NULL);
-- Simple name
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 20, NULL);
""")
conn.commit()
conn.close()
with patch.object(engine, "_find_start_index", return_value=db_path):
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
results = engine.search_references("target_func", Path(tmpdir))
# Should find both references
assert len(results) == 2
class TestLspReferencesHandler:
"""Test the LSP references handler."""
def test_handler_uses_search_engine(self):
"""Handler uses search_engine.search_references when available."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import _path_to_uri
from codexlens.search.chain_search import ReferenceResult
# Create mock references
mock_references = [
ReferenceResult(
file_path="/test/file1.py",
line=10,
column=5,
context="def foo():",
relationship_type="call",
),
ReferenceResult(
file_path="/test/file2.py",
line=20,
column=0,
context="import foo",
relationship_type="import",
),
]
# Verify conversion to LSP Location
locations = []
for ref in mock_references:
locations.append(
lsp.Location(
uri=_path_to_uri(ref.file_path),
range=lsp.Range(
start=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column,
),
end=lsp.Position(
line=max(0, ref.line - 1),
character=ref.column + len("foo"),
),
),
)
)
assert len(locations) == 2
# First reference at line 10 (0-indexed = 9)
assert locations[0].range.start.line == 9
assert locations[0].range.start.character == 5
# Second reference at line 20 (0-indexed = 19)
assert locations[1].range.start.line == 19
assert locations[1].range.start.character == 0
def test_handler_falls_back_to_global_index(self):
"""Handler falls back to global_index when search_engine unavailable."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from codexlens.lsp.handlers import symbol_to_location
from codexlens.entities import Symbol
# Test fallback path converts Symbol to Location
symbol = Symbol(
name="test_func",
kind="function",
range=(10, 15),
file="/test/file.py",
)
location = symbol_to_location(symbol)
assert location is not None
# LSP uses 0-based lines
assert location.range.start.line == 9
assert location.range.end.line == 14

View File

@@ -0,0 +1,210 @@
"""Tests for codex-lens LSP server."""
from __future__ import annotations
import pytest
from pathlib import Path
from unittest.mock import MagicMock, patch
from codexlens.entities import Symbol
class TestCodexLensLanguageServer:
"""Tests for CodexLensLanguageServer."""
def test_server_import(self):
"""Test that server module can be imported."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from codexlens.lsp.server import CodexLensLanguageServer, server
assert CodexLensLanguageServer is not None
assert server is not None
assert server.name == "codexlens-lsp"
def test_server_initialization(self):
"""Test server instance creation."""
pytest.importorskip("pygls")
from codexlens.lsp.server import CodexLensLanguageServer
ls = CodexLensLanguageServer()
assert ls.registry is None
assert ls.mapper is None
assert ls.global_index is None
assert ls.search_engine is None
assert ls.workspace_root is None
class TestDefinitionHandler:
"""Tests for definition handler."""
def test_definition_lookup(self):
"""Test definition lookup returns location for known symbol."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import symbol_to_location
symbol = Symbol(
name="test_function",
kind="function",
range=(10, 15),
file="/path/to/file.py",
)
location = symbol_to_location(symbol)
assert location is not None
assert isinstance(location, lsp.Location)
# LSP uses 0-based lines
assert location.range.start.line == 9
assert location.range.end.line == 14
def test_definition_no_file(self):
"""Test definition lookup returns None for symbol without file."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import symbol_to_location
symbol = Symbol(
name="test_function",
kind="function",
range=(10, 15),
file=None,
)
location = symbol_to_location(symbol)
assert location is None
class TestCompletionHandler:
"""Tests for completion handler."""
def test_get_prefix_at_position(self):
"""Test extracting prefix at cursor position."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _get_prefix_at_position
document_text = "def hello_world():\n print(hel"
# Cursor at end of "hel"
prefix = _get_prefix_at_position(document_text, 1, 14)
assert prefix == "hel"
# Cursor at beginning of line (after whitespace)
prefix = _get_prefix_at_position(document_text, 1, 4)
assert prefix == ""
# Cursor after "he" in "hello_world" - returns text before cursor
prefix = _get_prefix_at_position(document_text, 0, 6)
assert prefix == "he"
# Cursor at end of "hello_world"
prefix = _get_prefix_at_position(document_text, 0, 15)
assert prefix == "hello_world"
def test_get_word_at_position(self):
"""Test extracting word at cursor position."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _get_word_at_position
document_text = "def hello_world():\n print(msg)"
# Cursor on "hello_world"
word = _get_word_at_position(document_text, 0, 6)
assert word == "hello_world"
# Cursor on "print"
word = _get_word_at_position(document_text, 1, 6)
assert word == "print"
# Cursor on "msg"
word = _get_word_at_position(document_text, 1, 11)
assert word == "msg"
def test_symbol_kind_mapping(self):
"""Test symbol kind to completion kind mapping."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import _symbol_kind_to_completion_kind
assert _symbol_kind_to_completion_kind("function") == lsp.CompletionItemKind.Function
assert _symbol_kind_to_completion_kind("class") == lsp.CompletionItemKind.Class
assert _symbol_kind_to_completion_kind("method") == lsp.CompletionItemKind.Method
assert _symbol_kind_to_completion_kind("variable") == lsp.CompletionItemKind.Variable
# Unknown kind should default to Text
assert _symbol_kind_to_completion_kind("unknown") == lsp.CompletionItemKind.Text
class TestWorkspaceSymbolHandler:
"""Tests for workspace symbol handler."""
def test_symbol_kind_to_lsp(self):
"""Test symbol kind to LSP SymbolKind mapping."""
pytest.importorskip("pygls")
pytest.importorskip("lsprotocol")
from lsprotocol import types as lsp
from codexlens.lsp.handlers import _symbol_kind_to_lsp
assert _symbol_kind_to_lsp("function") == lsp.SymbolKind.Function
assert _symbol_kind_to_lsp("class") == lsp.SymbolKind.Class
assert _symbol_kind_to_lsp("method") == lsp.SymbolKind.Method
assert _symbol_kind_to_lsp("interface") == lsp.SymbolKind.Interface
# Unknown kind should default to Variable
assert _symbol_kind_to_lsp("unknown") == lsp.SymbolKind.Variable
class TestUriConversion:
"""Tests for URI path conversion."""
def test_path_to_uri(self):
"""Test path to URI conversion."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _path_to_uri
# Unix path
uri = _path_to_uri("/home/user/file.py")
assert uri.startswith("file://")
assert "file.py" in uri
def test_uri_to_path(self):
"""Test URI to path conversion."""
pytest.importorskip("pygls")
from codexlens.lsp.handlers import _uri_to_path
# Basic URI
path = _uri_to_path("file:///home/user/file.py")
assert path.name == "file.py"
class TestMainEntryPoint:
"""Tests for main entry point."""
def test_main_help(self):
"""Test that main shows help without errors."""
pytest.importorskip("pygls")
import sys
from unittest.mock import patch
# Patch sys.argv to show help
with patch.object(sys, 'argv', ['codexlens-lsp', '--help']):
from codexlens.lsp.server import main
with pytest.raises(SystemExit) as exc_info:
main()
# Help exits with 0
assert exc_info.value.code == 0

View File

@@ -0,0 +1 @@
"""Tests for MCP (Model Context Protocol) module."""

View File

@@ -0,0 +1,208 @@
"""Tests for MCP hooks module."""
import pytest
from unittest.mock import Mock, patch
from pathlib import Path
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
from codexlens.mcp.schema import MCPContext, SymbolInfo
class TestHookManager:
"""Test HookManager class."""
@pytest.fixture
def mock_provider(self):
"""Create a mock MCP provider."""
provider = Mock()
provider.build_context.return_value = MCPContext(
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
context_type="symbol_explanation",
)
provider.build_context_for_file.return_value = MCPContext(
context_type="file_overview",
)
return provider
@pytest.fixture
def hook_manager(self, mock_provider):
"""Create a HookManager with mocked provider."""
return HookManager(mock_provider)
def test_default_hooks_registered(self, hook_manager):
"""Default hooks are registered on initialization."""
assert "explain" in hook_manager._pre_hooks
assert "refactor" in hook_manager._pre_hooks
assert "document" in hook_manager._pre_hooks
def test_execute_pre_hook_returns_context(self, hook_manager, mock_provider):
"""execute_pre_hook returns MCPContext for registered hook."""
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
assert result is not None
assert isinstance(result, MCPContext)
mock_provider.build_context.assert_called_once()
def test_execute_pre_hook_returns_none_for_unknown_action(self, hook_manager):
"""execute_pre_hook returns None for unregistered action."""
result = hook_manager.execute_pre_hook("unknown_action", {"symbol": "test"})
assert result is None
def test_execute_pre_hook_handles_exception(self, hook_manager, mock_provider):
"""execute_pre_hook handles provider exceptions gracefully."""
mock_provider.build_context.side_effect = Exception("Provider failed")
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
assert result is None
def test_execute_post_hook_no_error_for_unregistered(self, hook_manager):
"""execute_post_hook doesn't error for unregistered action."""
# Should not raise
hook_manager.execute_post_hook("unknown", {"result": "data"})
def test_pre_explain_hook_calls_build_context(self, hook_manager, mock_provider):
"""_pre_explain_hook calls build_context correctly."""
hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
mock_provider.build_context.assert_called_with(
symbol_name="my_func",
context_type="symbol_explanation",
include_references=True,
include_related=True,
)
def test_pre_explain_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
"""_pre_explain_hook returns None when symbol param missing."""
result = hook_manager.execute_pre_hook("explain", {})
assert result is None
mock_provider.build_context.assert_not_called()
def test_pre_refactor_hook_calls_build_context(self, hook_manager, mock_provider):
"""_pre_refactor_hook calls build_context with refactor settings."""
hook_manager.execute_pre_hook("refactor", {"symbol": "my_class"})
mock_provider.build_context.assert_called_with(
symbol_name="my_class",
context_type="refactor_context",
include_references=True,
include_related=True,
max_references=20,
)
def test_pre_refactor_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
"""_pre_refactor_hook returns None when symbol param missing."""
result = hook_manager.execute_pre_hook("refactor", {})
assert result is None
mock_provider.build_context.assert_not_called()
def test_pre_document_hook_with_symbol(self, hook_manager, mock_provider):
"""_pre_document_hook uses build_context when symbol provided."""
hook_manager.execute_pre_hook("document", {"symbol": "my_func"})
mock_provider.build_context.assert_called_with(
symbol_name="my_func",
context_type="documentation_context",
include_references=False,
include_related=True,
)
def test_pre_document_hook_with_file_path(self, hook_manager, mock_provider):
"""_pre_document_hook uses build_context_for_file when file_path provided."""
hook_manager.execute_pre_hook("document", {"file_path": "/src/module.py"})
mock_provider.build_context_for_file.assert_called_once()
call_args = mock_provider.build_context_for_file.call_args
assert call_args[0][0] == Path("/src/module.py")
assert call_args[1].get("context_type") == "file_documentation"
def test_pre_document_hook_prefers_symbol_over_file(self, hook_manager, mock_provider):
"""_pre_document_hook prefers symbol when both provided."""
hook_manager.execute_pre_hook(
"document", {"symbol": "my_func", "file_path": "/src/module.py"}
)
mock_provider.build_context.assert_called_once()
mock_provider.build_context_for_file.assert_not_called()
def test_pre_document_hook_returns_none_without_params(self, hook_manager, mock_provider):
"""_pre_document_hook returns None when neither symbol nor file_path provided."""
result = hook_manager.execute_pre_hook("document", {})
assert result is None
mock_provider.build_context.assert_not_called()
mock_provider.build_context_for_file.assert_not_called()
def test_register_pre_hook(self, hook_manager):
"""register_pre_hook adds custom hook."""
custom_hook = Mock(return_value=MCPContext())
hook_manager.register_pre_hook("custom_action", custom_hook)
assert "custom_action" in hook_manager._pre_hooks
hook_manager.execute_pre_hook("custom_action", {"data": "value"})
custom_hook.assert_called_once_with({"data": "value"})
def test_register_post_hook(self, hook_manager):
"""register_post_hook adds custom hook."""
custom_hook = Mock()
hook_manager.register_post_hook("custom_action", custom_hook)
assert "custom_action" in hook_manager._post_hooks
hook_manager.execute_post_hook("custom_action", {"result": "data"})
custom_hook.assert_called_once_with({"result": "data"})
def test_execute_post_hook_handles_exception(self, hook_manager):
"""execute_post_hook handles hook exceptions gracefully."""
failing_hook = Mock(side_effect=Exception("Hook failed"))
hook_manager.register_post_hook("failing", failing_hook)
# Should not raise
hook_manager.execute_post_hook("failing", {"data": "value"})
class TestCreateContextForPrompt:
"""Test create_context_for_prompt function."""
def test_returns_prompt_injection_string(self):
"""create_context_for_prompt returns formatted string."""
mock_provider = Mock()
mock_provider.build_context.return_value = MCPContext(
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
definition="def test_func(): pass",
)
result = create_context_for_prompt(
mock_provider, "explain", {"symbol": "test_func"}
)
assert isinstance(result, str)
assert "<code_context>" in result
assert "test_func" in result
assert "</code_context>" in result
def test_returns_empty_string_when_no_context(self):
"""create_context_for_prompt returns empty string when no context built."""
mock_provider = Mock()
mock_provider.build_context.return_value = None
result = create_context_for_prompt(
mock_provider, "explain", {"symbol": "nonexistent"}
)
assert result == ""
def test_returns_empty_string_for_unknown_action(self):
"""create_context_for_prompt returns empty string for unregistered action."""
mock_provider = Mock()
result = create_context_for_prompt(
mock_provider, "unknown_action", {"data": "value"}
)
assert result == ""
mock_provider.build_context.assert_not_called()

View File

@@ -0,0 +1,383 @@
"""Tests for MCP provider."""
import pytest
from unittest.mock import Mock, MagicMock, patch
from pathlib import Path
import tempfile
import os
from codexlens.mcp.provider import MCPProvider
from codexlens.mcp.schema import MCPContext, SymbolInfo, ReferenceInfo
class TestMCPProvider:
"""Test MCPProvider class."""
@pytest.fixture
def mock_global_index(self):
"""Create a mock global index."""
return Mock()
@pytest.fixture
def mock_search_engine(self):
"""Create a mock search engine."""
return Mock()
@pytest.fixture
def mock_registry(self):
"""Create a mock registry."""
return Mock()
@pytest.fixture
def provider(self, mock_global_index, mock_search_engine, mock_registry):
"""Create an MCPProvider with mocked dependencies."""
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
def test_build_context_returns_none_for_unknown_symbol(self, provider, mock_global_index):
"""build_context returns None when symbol is not found."""
mock_global_index.search.return_value = []
result = provider.build_context("unknown_symbol")
assert result is None
mock_global_index.search.assert_called_once_with(
"unknown_symbol", prefix_mode=False, limit=1
)
def test_build_context_returns_mcp_context(
self, provider, mock_global_index, mock_search_engine
):
"""build_context returns MCPContext for known symbol."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
result = provider.build_context("my_func")
assert result is not None
assert isinstance(result, MCPContext)
assert result.symbol is not None
assert result.symbol.name == "my_func"
assert result.symbol.kind == "function"
assert result.context_type == "symbol_explanation"
def test_build_context_with_custom_context_type(
self, provider, mock_global_index, mock_search_engine
):
"""build_context respects custom context_type."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
result = provider.build_context("my_func", context_type="refactor_context")
assert result is not None
assert result.context_type == "refactor_context"
def test_build_context_includes_references(
self, provider, mock_global_index, mock_search_engine
):
"""build_context includes references when include_references=True."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_ref = Mock()
mock_ref.file_path = "/caller.py"
mock_ref.line = 25
mock_ref.column = 4
mock_ref.context = "result = my_func()"
mock_ref.relationship_type = "call"
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = [mock_ref]
result = provider.build_context("my_func", include_references=True)
assert result is not None
assert len(result.references) == 1
assert result.references[0].file_path == "/caller.py"
assert result.references[0].line == 25
assert result.references[0].relationship_type == "call"
def test_build_context_excludes_references_when_disabled(
self, provider, mock_global_index, mock_search_engine
):
"""build_context excludes references when include_references=False."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
# Disable both references and related to avoid any search_references calls
result = provider.build_context(
"my_func", include_references=False, include_related=False
)
assert result is not None
assert len(result.references) == 0
mock_search_engine.search_references.assert_not_called()
def test_build_context_respects_max_references(
self, provider, mock_global_index, mock_search_engine
):
"""build_context passes max_references to search engine."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
# Disable include_related to test only the references call
provider.build_context("my_func", max_references=5, include_related=False)
mock_search_engine.search_references.assert_called_once_with(
"my_func", limit=5
)
def test_build_context_includes_metadata(
self, provider, mock_global_index, mock_search_engine
):
"""build_context includes source metadata."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.kind = "function"
mock_symbol.file = "/test.py"
mock_symbol.range = (10, 20)
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = []
result = provider.build_context("my_func")
assert result is not None
assert result.metadata.get("source") == "codex-lens"
def test_extract_definition_with_valid_file(self, provider):
"""_extract_definition reads file content correctly."""
# Create a temporary file with some content
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write("# Line 1\n")
f.write("# Line 2\n")
f.write("def my_func():\n") # Line 3
f.write(" pass\n") # Line 4
f.write("# Line 5\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.file = temp_path
mock_symbol.range = (3, 4) # 1-based line numbers
definition = provider._extract_definition(mock_symbol)
assert definition is not None
assert "def my_func():" in definition
assert "pass" in definition
finally:
os.unlink(temp_path)
def test_extract_definition_returns_none_for_missing_file(self, provider):
"""_extract_definition returns None for non-existent file."""
mock_symbol = Mock()
mock_symbol.file = "/nonexistent/path/file.py"
mock_symbol.range = (1, 5)
definition = provider._extract_definition(mock_symbol)
assert definition is None
def test_extract_definition_returns_none_for_none_file(self, provider):
"""_extract_definition returns None when symbol.file is None."""
mock_symbol = Mock()
mock_symbol.file = None
mock_symbol.range = (1, 5)
definition = provider._extract_definition(mock_symbol)
assert definition is None
def test_build_context_for_file_returns_context(
self, provider, mock_global_index
):
"""build_context_for_file returns MCPContext."""
mock_global_index.search.return_value = []
result = provider.build_context_for_file(
Path("/test/file.py"),
context_type="file_overview",
)
assert result is not None
assert isinstance(result, MCPContext)
assert result.context_type == "file_overview"
assert result.metadata.get("file_path") == str(Path("/test/file.py"))
def test_build_context_for_file_includes_symbols(
self, provider, mock_global_index
):
"""build_context_for_file includes symbols from the file."""
# Create temp file to get resolved path
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write("def func(): pass\n")
temp_path = f.name
try:
mock_symbol = Mock()
mock_symbol.name = "func"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (1, 1)
mock_global_index.search.return_value = [mock_symbol]
result = provider.build_context_for_file(Path(temp_path))
assert result is not None
# Symbols from this file should be in related_symbols
assert len(result.related_symbols) >= 0 # May be 0 if filtering doesn't match
finally:
os.unlink(temp_path)
class TestMCPProviderRelatedSymbols:
"""Test related symbols functionality."""
@pytest.fixture
def provider(self):
"""Create provider with mocks."""
mock_global_index = Mock()
mock_search_engine = Mock()
mock_registry = Mock()
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
def test_get_related_symbols_from_references(self, provider):
"""_get_related_symbols extracts symbols from references."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.file = "/test.py"
mock_ref1 = Mock()
mock_ref1.file_path = "/caller1.py"
mock_ref1.relationship_type = "call"
mock_ref2 = Mock()
mock_ref2.file_path = "/caller2.py"
mock_ref2.relationship_type = "import"
provider.search_engine.search_references.return_value = [mock_ref1, mock_ref2]
related = provider._get_related_symbols(mock_symbol)
assert len(related) == 2
assert related[0].relationship == "call"
assert related[1].relationship == "import"
def test_get_related_symbols_limits_results(self, provider):
"""_get_related_symbols limits to 10 unique relationship types."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.file = "/test.py"
# Create 15 references with unique relationship types
refs = []
for i in range(15):
ref = Mock()
ref.file_path = f"/file{i}.py"
ref.relationship_type = f"type{i}"
refs.append(ref)
provider.search_engine.search_references.return_value = refs
related = provider._get_related_symbols(mock_symbol)
assert len(related) <= 10
def test_get_related_symbols_handles_exception(self, provider):
"""_get_related_symbols handles exceptions gracefully."""
mock_symbol = Mock()
mock_symbol.name = "my_func"
mock_symbol.file = "/test.py"
provider.search_engine.search_references.side_effect = Exception("Search failed")
related = provider._get_related_symbols(mock_symbol)
assert related == []
class TestMCPProviderIntegration:
"""Integration-style tests for MCPProvider."""
def test_full_context_workflow(self):
"""Test complete context building workflow."""
# Create temp file
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write("def my_function(arg1, arg2):\n")
f.write(" '''This is my function.'''\n")
f.write(" return arg1 + arg2\n")
temp_path = f.name
try:
# Setup mocks
mock_global_index = Mock()
mock_search_engine = Mock()
mock_registry = Mock()
mock_symbol = Mock()
mock_symbol.name = "my_function"
mock_symbol.kind = "function"
mock_symbol.file = temp_path
mock_symbol.range = (1, 3)
mock_ref = Mock()
mock_ref.file_path = "/user.py"
mock_ref.line = 10
mock_ref.column = 4
mock_ref.context = "result = my_function(1, 2)"
mock_ref.relationship_type = "call"
mock_global_index.search.return_value = [mock_symbol]
mock_search_engine.search_references.return_value = [mock_ref]
provider = MCPProvider(mock_global_index, mock_search_engine, mock_registry)
context = provider.build_context("my_function")
assert context is not None
assert context.symbol.name == "my_function"
assert context.definition is not None
assert "def my_function" in context.definition
assert len(context.references) == 1
assert context.references[0].relationship_type == "call"
# Test serialization
json_str = context.to_json()
assert "my_function" in json_str
# Test prompt injection
prompt = context.to_prompt_injection()
assert "<code_context>" in prompt
assert "my_function" in prompt
assert "</code_context>" in prompt
finally:
os.unlink(temp_path)

View File

@@ -0,0 +1,288 @@
"""Tests for MCP schema."""
import pytest
import json
from codexlens.mcp.schema import (
MCPContext,
SymbolInfo,
ReferenceInfo,
RelatedSymbol,
)
class TestSymbolInfo:
"""Test SymbolInfo dataclass."""
def test_to_dict_includes_all_fields(self):
"""SymbolInfo.to_dict() includes all non-None fields."""
info = SymbolInfo(
name="func",
kind="function",
file_path="/test.py",
line_start=10,
line_end=20,
signature="def func():",
documentation="Test doc",
)
d = info.to_dict()
assert d["name"] == "func"
assert d["kind"] == "function"
assert d["file_path"] == "/test.py"
assert d["line_start"] == 10
assert d["line_end"] == 20
assert d["signature"] == "def func():"
assert d["documentation"] == "Test doc"
def test_to_dict_excludes_none(self):
"""SymbolInfo.to_dict() excludes None fields."""
info = SymbolInfo(
name="func",
kind="function",
file_path="/test.py",
line_start=10,
line_end=20,
)
d = info.to_dict()
assert "signature" not in d
assert "documentation" not in d
assert "name" in d
assert "kind" in d
def test_basic_creation(self):
"""SymbolInfo can be created with required fields only."""
info = SymbolInfo(
name="MyClass",
kind="class",
file_path="/src/module.py",
line_start=1,
line_end=50,
)
assert info.name == "MyClass"
assert info.kind == "class"
assert info.signature is None
assert info.documentation is None
class TestReferenceInfo:
"""Test ReferenceInfo dataclass."""
def test_to_dict(self):
"""ReferenceInfo.to_dict() returns all fields."""
ref = ReferenceInfo(
file_path="/src/main.py",
line=25,
column=4,
context="result = func()",
relationship_type="call",
)
d = ref.to_dict()
assert d["file_path"] == "/src/main.py"
assert d["line"] == 25
assert d["column"] == 4
assert d["context"] == "result = func()"
assert d["relationship_type"] == "call"
def test_all_fields_required(self):
"""ReferenceInfo requires all fields."""
ref = ReferenceInfo(
file_path="/test.py",
line=10,
column=0,
context="import module",
relationship_type="import",
)
assert ref.file_path == "/test.py"
assert ref.relationship_type == "import"
class TestRelatedSymbol:
"""Test RelatedSymbol dataclass."""
def test_to_dict_includes_all_fields(self):
"""RelatedSymbol.to_dict() includes all non-None fields."""
sym = RelatedSymbol(
name="BaseClass",
kind="class",
relationship="inherits",
file_path="/src/base.py",
)
d = sym.to_dict()
assert d["name"] == "BaseClass"
assert d["kind"] == "class"
assert d["relationship"] == "inherits"
assert d["file_path"] == "/src/base.py"
def test_to_dict_excludes_none(self):
"""RelatedSymbol.to_dict() excludes None file_path."""
sym = RelatedSymbol(
name="helper",
kind="function",
relationship="calls",
)
d = sym.to_dict()
assert "file_path" not in d
assert d["name"] == "helper"
assert d["relationship"] == "calls"
class TestMCPContext:
"""Test MCPContext dataclass."""
def test_to_dict_basic(self):
"""MCPContext.to_dict() returns basic structure."""
ctx = MCPContext(context_type="test")
d = ctx.to_dict()
assert d["version"] == "1.0"
assert d["context_type"] == "test"
assert d["metadata"] == {}
def test_to_dict_with_symbol(self):
"""MCPContext.to_dict() includes symbol when present."""
ctx = MCPContext(
context_type="test",
symbol=SymbolInfo("f", "function", "/t.py", 1, 2),
)
d = ctx.to_dict()
assert "symbol" in d
assert d["symbol"]["name"] == "f"
assert d["symbol"]["kind"] == "function"
def test_to_dict_with_references(self):
"""MCPContext.to_dict() includes references when present."""
ctx = MCPContext(
context_type="test",
references=[
ReferenceInfo("/a.py", 10, 0, "call()", "call"),
ReferenceInfo("/b.py", 20, 5, "import x", "import"),
],
)
d = ctx.to_dict()
assert "references" in d
assert len(d["references"]) == 2
assert d["references"][0]["line"] == 10
def test_to_dict_with_related_symbols(self):
"""MCPContext.to_dict() includes related_symbols when present."""
ctx = MCPContext(
context_type="test",
related_symbols=[
RelatedSymbol("Base", "class", "inherits"),
RelatedSymbol("helper", "function", "calls"),
],
)
d = ctx.to_dict()
assert "related_symbols" in d
assert len(d["related_symbols"]) == 2
def test_to_json(self):
"""MCPContext.to_json() returns valid JSON."""
ctx = MCPContext(context_type="test")
j = ctx.to_json()
parsed = json.loads(j)
assert parsed["version"] == "1.0"
assert parsed["context_type"] == "test"
def test_to_json_with_indent(self):
"""MCPContext.to_json() respects indent parameter."""
ctx = MCPContext(context_type="test")
j = ctx.to_json(indent=4)
# Check it's properly indented
assert " " in j
def test_to_prompt_injection_basic(self):
"""MCPContext.to_prompt_injection() returns formatted string."""
ctx = MCPContext(
symbol=SymbolInfo("my_func", "function", "/test.py", 10, 20),
definition="def my_func(): pass",
)
prompt = ctx.to_prompt_injection()
assert "<code_context>" in prompt
assert "my_func" in prompt
assert "def my_func()" in prompt
assert "</code_context>" in prompt
def test_to_prompt_injection_with_references(self):
"""MCPContext.to_prompt_injection() includes references."""
ctx = MCPContext(
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
references=[
ReferenceInfo("/a.py", 10, 0, "func()", "call"),
ReferenceInfo("/b.py", 20, 0, "from x import func", "import"),
],
)
prompt = ctx.to_prompt_injection()
assert "References (2 found)" in prompt
assert "/a.py:10" in prompt
assert "call" in prompt
def test_to_prompt_injection_limits_references(self):
"""MCPContext.to_prompt_injection() limits references to 5."""
refs = [
ReferenceInfo(f"/file{i}.py", i, 0, f"ref{i}", "call")
for i in range(10)
]
ctx = MCPContext(
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
references=refs,
)
prompt = ctx.to_prompt_injection()
# Should show "10 found" but only include 5
assert "References (10 found)" in prompt
assert "/file0.py" in prompt
assert "/file4.py" in prompt
assert "/file5.py" not in prompt
def test_to_prompt_injection_with_related_symbols(self):
"""MCPContext.to_prompt_injection() includes related symbols."""
ctx = MCPContext(
symbol=SymbolInfo("MyClass", "class", "/test.py", 1, 50),
related_symbols=[
RelatedSymbol("BaseClass", "class", "inherits"),
RelatedSymbol("helper", "function", "calls"),
],
)
prompt = ctx.to_prompt_injection()
assert "Related Symbols" in prompt
assert "BaseClass (inherits)" in prompt
assert "helper (calls)" in prompt
def test_to_prompt_injection_limits_related_symbols(self):
"""MCPContext.to_prompt_injection() limits related symbols to 10."""
related = [
RelatedSymbol(f"sym{i}", "function", "calls")
for i in range(15)
]
ctx = MCPContext(
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
related_symbols=related,
)
prompt = ctx.to_prompt_injection()
assert "sym0 (calls)" in prompt
assert "sym9 (calls)" in prompt
assert "sym10 (calls)" not in prompt
def test_empty_context(self):
"""MCPContext works with minimal data."""
ctx = MCPContext()
d = ctx.to_dict()
assert d["version"] == "1.0"
assert d["context_type"] == "code_context"
prompt = ctx.to_prompt_injection()
assert "<code_context>" in prompt
assert "</code_context>" in prompt
def test_metadata_preserved(self):
"""MCPContext preserves custom metadata."""
ctx = MCPContext(
context_type="custom",
metadata={
"source": "codex-lens",
"indexed_at": "2024-01-01",
"custom_key": "custom_value",
},
)
d = ctx.to_dict()
assert d["metadata"]["source"] == "codex-lens"
assert d["metadata"]["custom_key"] == "custom_value"

View File

@@ -79,3 +79,87 @@ def test_symbol_filtering_handles_path_failures(monkeypatch: pytest.MonkeyPatch,
if os.name == "nt":
assert "CrossDrive" in caplog.text
def test_cascade_search_strategy_routing(temp_paths: Path) -> None:
"""Test cascade_search() routes to correct strategy implementation."""
from unittest.mock import patch
from codexlens.search.chain_search import ChainSearchResult, SearchStats
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
mapper = PathMapper(index_root=temp_paths / "indexes")
config = Config(data_dir=temp_paths / "data")
engine = ChainSearchEngine(registry, mapper, config=config)
source_path = temp_paths / "src"
# Test strategy='staged' routing
with patch.object(engine, "staged_cascade_search") as mock_staged:
mock_staged.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="staged")
mock_staged.assert_called_once()
# Test strategy='binary' routing
with patch.object(engine, "binary_cascade_search") as mock_binary:
mock_binary.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="binary")
mock_binary.assert_called_once()
# Test strategy='hybrid' routing
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
mock_hybrid.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="hybrid")
mock_hybrid.assert_called_once()
# Test strategy='binary_rerank' routing
with patch.object(engine, "binary_rerank_cascade_search") as mock_br:
mock_br.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="binary_rerank")
mock_br.assert_called_once()
# Test strategy='dense_rerank' routing
with patch.object(engine, "dense_rerank_cascade_search") as mock_dr:
mock_dr.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="dense_rerank")
mock_dr.assert_called_once()
# Test default routing (no strategy specified) - defaults to binary
with patch.object(engine, "binary_cascade_search") as mock_default:
mock_default.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path)
mock_default.assert_called_once()
def test_cascade_search_invalid_strategy(temp_paths: Path) -> None:
"""Test cascade_search() defaults to 'binary' for invalid strategy."""
from unittest.mock import patch
from codexlens.search.chain_search import ChainSearchResult, SearchStats
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
mapper = PathMapper(index_root=temp_paths / "indexes")
config = Config(data_dir=temp_paths / "data")
engine = ChainSearchEngine(registry, mapper, config=config)
source_path = temp_paths / "src"
# Invalid strategy should default to binary
with patch.object(engine, "binary_cascade_search") as mock_binary:
mock_binary.return_value = ChainSearchResult(
query="query", results=[], symbols=[], stats=SearchStats()
)
engine.cascade_search("query", source_path, strategy="invalid_strategy")
mock_binary.assert_called_once()

View File

@@ -0,0 +1,766 @@
"""Unit tests for clustering strategies in the hybrid search pipeline.
Tests cover:
1. HDBSCANStrategy - Primary HDBSCAN clustering
2. DBSCANStrategy - Fallback DBSCAN clustering
3. NoOpStrategy - No-op fallback when clustering unavailable
4. ClusteringStrategyFactory - Factory with fallback chain
"""
from __future__ import annotations
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from codexlens.entities import SearchResult
from codexlens.search.clustering import (
BaseClusteringStrategy,
ClusteringConfig,
ClusteringStrategyFactory,
NoOpStrategy,
check_clustering_strategy_available,
get_strategy,
)
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def sample_results() -> List[SearchResult]:
"""Create sample search results for testing."""
return [
SearchResult(path="a.py", score=0.9, excerpt="def foo(): pass"),
SearchResult(path="b.py", score=0.8, excerpt="def foo(): pass"),
SearchResult(path="c.py", score=0.7, excerpt="def bar(): pass"),
SearchResult(path="d.py", score=0.6, excerpt="def bar(): pass"),
SearchResult(path="e.py", score=0.5, excerpt="def baz(): pass"),
]
@pytest.fixture
def mock_embeddings():
"""Create mock embeddings for 5 results.
Creates embeddings that should form 2 clusters:
- Results 0, 1 (similar to each other)
- Results 2, 3 (similar to each other)
- Result 4 (noise/singleton)
"""
import numpy as np
# Create embeddings in 3D for simplicity
return np.array(
[
[1.0, 0.0, 0.0], # Result 0 - cluster A
[0.9, 0.1, 0.0], # Result 1 - cluster A
[0.0, 1.0, 0.0], # Result 2 - cluster B
[0.1, 0.9, 0.0], # Result 3 - cluster B
[0.0, 0.0, 1.0], # Result 4 - noise/singleton
],
dtype=np.float32,
)
@pytest.fixture
def default_config() -> ClusteringConfig:
"""Create default clustering configuration."""
return ClusteringConfig(
min_cluster_size=2,
min_samples=1,
metric="euclidean",
)
# =============================================================================
# Test ClusteringConfig
# =============================================================================
class TestClusteringConfig:
"""Tests for ClusteringConfig validation."""
def test_default_values(self):
"""Test default configuration values."""
config = ClusteringConfig()
assert config.min_cluster_size == 3
assert config.min_samples == 2
assert config.metric == "cosine"
assert config.cluster_selection_epsilon == 0.0
assert config.allow_single_cluster is True
assert config.prediction_data is False
def test_custom_values(self):
"""Test custom configuration values."""
config = ClusteringConfig(
min_cluster_size=5,
min_samples=3,
metric="euclidean",
cluster_selection_epsilon=0.1,
allow_single_cluster=False,
prediction_data=True,
)
assert config.min_cluster_size == 5
assert config.min_samples == 3
assert config.metric == "euclidean"
def test_invalid_min_cluster_size(self):
"""Test validation rejects min_cluster_size < 2."""
with pytest.raises(ValueError, match="min_cluster_size must be >= 2"):
ClusteringConfig(min_cluster_size=1)
def test_invalid_min_samples(self):
"""Test validation rejects min_samples < 1."""
with pytest.raises(ValueError, match="min_samples must be >= 1"):
ClusteringConfig(min_samples=0)
def test_invalid_metric(self):
"""Test validation rejects invalid metric."""
with pytest.raises(ValueError, match="metric must be one of"):
ClusteringConfig(metric="invalid")
def test_invalid_epsilon(self):
"""Test validation rejects negative epsilon."""
with pytest.raises(ValueError, match="cluster_selection_epsilon must be >= 0"):
ClusteringConfig(cluster_selection_epsilon=-0.1)
# =============================================================================
# Test NoOpStrategy
# =============================================================================
class TestNoOpStrategy:
"""Tests for NoOpStrategy - always available."""
def test_cluster_returns_singleton_clusters(
self, sample_results: List[SearchResult], mock_embeddings
):
"""Test cluster() returns each result as singleton cluster."""
strategy = NoOpStrategy()
clusters = strategy.cluster(mock_embeddings, sample_results)
assert len(clusters) == 5
for i, cluster in enumerate(clusters):
assert cluster == [i]
def test_cluster_empty_results(self):
"""Test cluster() with empty results."""
import numpy as np
strategy = NoOpStrategy()
clusters = strategy.cluster(np.array([]), [])
assert clusters == []
def test_select_representatives_returns_all_sorted(
self, sample_results: List[SearchResult]
):
"""Test select_representatives() returns all results sorted by score."""
strategy = NoOpStrategy()
clusters = [[i] for i in range(len(sample_results))]
representatives = strategy.select_representatives(clusters, sample_results)
assert len(representatives) == 5
# Check sorted by score descending
scores = [r.score for r in representatives]
assert scores == sorted(scores, reverse=True)
def test_select_representatives_empty(self):
"""Test select_representatives() with empty input."""
strategy = NoOpStrategy()
representatives = strategy.select_representatives([], [])
assert representatives == []
def test_fit_predict_convenience_method(
self, sample_results: List[SearchResult], mock_embeddings
):
"""Test fit_predict() convenience method."""
strategy = NoOpStrategy()
representatives = strategy.fit_predict(mock_embeddings, sample_results)
assert len(representatives) == 5
# All results returned, sorted by score
assert representatives[0].score >= representatives[-1].score
# =============================================================================
# Test HDBSCANStrategy
# =============================================================================
class TestHDBSCANStrategy:
"""Tests for HDBSCANStrategy - requires hdbscan package."""
@pytest.fixture
def hdbscan_strategy(self, default_config):
"""Create HDBSCANStrategy if available."""
try:
from codexlens.search.clustering import HDBSCANStrategy
return HDBSCANStrategy(default_config)
except ImportError:
pytest.skip("hdbscan not installed")
def test_cluster_returns_list_of_lists(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test cluster() returns List[List[int]]."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
assert isinstance(clusters, list)
for cluster in clusters:
assert isinstance(cluster, list)
for idx in cluster:
assert isinstance(idx, int)
assert 0 <= idx < len(sample_results)
def test_cluster_covers_all_results(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test all result indices appear in clusters."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
all_indices = set()
for cluster in clusters:
all_indices.update(cluster)
assert all_indices == set(range(len(sample_results)))
def test_cluster_empty_results(self, hdbscan_strategy):
"""Test cluster() with empty results."""
import numpy as np
clusters = hdbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
assert clusters == []
def test_cluster_single_result(self, hdbscan_strategy):
"""Test cluster() with single result."""
import numpy as np
result = SearchResult(path="a.py", score=0.9, excerpt="test")
embeddings = np.array([[1.0, 0.0, 0.0]])
clusters = hdbscan_strategy.cluster(embeddings, [result])
assert len(clusters) == 1
assert clusters[0] == [0]
def test_cluster_fewer_than_min_cluster_size(self, hdbscan_strategy):
"""Test cluster() with fewer results than min_cluster_size."""
import numpy as np
# Strategy has min_cluster_size=2, so 1 result returns singleton
result = SearchResult(path="a.py", score=0.9, excerpt="test")
embeddings = np.array([[1.0, 0.0, 0.0]])
clusters = hdbscan_strategy.cluster(embeddings, [result])
assert len(clusters) == 1
assert clusters[0] == [0]
def test_select_representatives_picks_highest_score(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test select_representatives() picks highest score per cluster."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = hdbscan_strategy.select_representatives(
clusters, sample_results
)
# Each representative should be the highest-scored in its cluster
for rep in representatives:
# Find the cluster containing this representative
rep_idx = next(
i for i, r in enumerate(sample_results) if r.path == rep.path
)
for cluster in clusters:
if rep_idx in cluster:
cluster_scores = [sample_results[i].score for i in cluster]
assert rep.score == max(cluster_scores)
break
def test_select_representatives_sorted_by_score(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test representatives are sorted by score descending."""
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = hdbscan_strategy.select_representatives(
clusters, sample_results
)
scores = [r.score for r in representatives]
assert scores == sorted(scores, reverse=True)
def test_fit_predict_end_to_end(
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test fit_predict() end-to-end clustering."""
representatives = hdbscan_strategy.fit_predict(mock_embeddings, sample_results)
# Should have fewer or equal representatives than input
assert len(representatives) <= len(sample_results)
# All representatives should be from original results
rep_paths = {r.path for r in representatives}
original_paths = {r.path for r in sample_results}
assert rep_paths.issubset(original_paths)
# =============================================================================
# Test DBSCANStrategy
# =============================================================================
class TestDBSCANStrategy:
"""Tests for DBSCANStrategy - requires sklearn."""
@pytest.fixture
def dbscan_strategy(self, default_config):
"""Create DBSCANStrategy if available."""
try:
from codexlens.search.clustering import DBSCANStrategy
return DBSCANStrategy(default_config)
except ImportError:
pytest.skip("sklearn not installed")
def test_cluster_returns_list_of_lists(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test cluster() returns List[List[int]]."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
assert isinstance(clusters, list)
for cluster in clusters:
assert isinstance(cluster, list)
for idx in cluster:
assert isinstance(idx, int)
assert 0 <= idx < len(sample_results)
def test_cluster_covers_all_results(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test all result indices appear in clusters."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
all_indices = set()
for cluster in clusters:
all_indices.update(cluster)
assert all_indices == set(range(len(sample_results)))
def test_cluster_empty_results(self, dbscan_strategy):
"""Test cluster() with empty results."""
import numpy as np
clusters = dbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
assert clusters == []
def test_cluster_single_result(self, dbscan_strategy):
"""Test cluster() with single result."""
import numpy as np
result = SearchResult(path="a.py", score=0.9, excerpt="test")
embeddings = np.array([[1.0, 0.0, 0.0]])
clusters = dbscan_strategy.cluster(embeddings, [result])
assert len(clusters) == 1
assert clusters[0] == [0]
def test_cluster_with_explicit_eps(self, default_config):
"""Test cluster() with explicit eps parameter."""
try:
from codexlens.search.clustering import DBSCANStrategy
except ImportError:
pytest.skip("sklearn not installed")
import numpy as np
strategy = DBSCANStrategy(default_config, eps=0.5)
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(3)]
embeddings = np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0]])
clusters = strategy.cluster(embeddings, results)
# With eps=0.5, first two should cluster, third should be separate
assert len(clusters) >= 2
def test_auto_compute_eps(self, dbscan_strategy, mock_embeddings):
"""Test eps auto-computation from distance distribution."""
# Should not raise - eps is computed automatically
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(5)]
clusters = dbscan_strategy.cluster(mock_embeddings, results)
assert len(clusters) > 0
def test_select_representatives_picks_highest_score(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test select_representatives() picks highest score per cluster."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = dbscan_strategy.select_representatives(
clusters, sample_results
)
# Each representative should be the highest-scored in its cluster
for rep in representatives:
rep_idx = next(
i for i, r in enumerate(sample_results) if r.path == rep.path
)
for cluster in clusters:
if rep_idx in cluster:
cluster_scores = [sample_results[i].score for i in cluster]
assert rep.score == max(cluster_scores)
break
def test_select_representatives_sorted_by_score(
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
):
"""Test representatives are sorted by score descending."""
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
representatives = dbscan_strategy.select_representatives(
clusters, sample_results
)
scores = [r.score for r in representatives]
assert scores == sorted(scores, reverse=True)
# =============================================================================
# Test ClusteringStrategyFactory
# =============================================================================
class TestClusteringStrategyFactory:
"""Tests for ClusteringStrategyFactory."""
def test_check_noop_always_available(self):
"""Test noop strategy is always available."""
ok, err = check_clustering_strategy_available("noop")
assert ok is True
assert err is None
def test_check_invalid_strategy(self):
"""Test invalid strategy name returns error."""
ok, err = check_clustering_strategy_available("invalid")
assert ok is False
assert "Invalid clustering strategy" in err
def test_get_strategy_noop(self, default_config):
"""Test get_strategy('noop') returns NoOpStrategy."""
strategy = get_strategy("noop", default_config)
assert isinstance(strategy, NoOpStrategy)
def test_get_strategy_auto_returns_something(self, default_config):
"""Test get_strategy('auto') returns a strategy."""
strategy = get_strategy("auto", default_config)
assert isinstance(strategy, BaseClusteringStrategy)
def test_get_strategy_with_fallback_enabled(self, default_config):
"""Test fallback when primary strategy unavailable."""
# Mock hdbscan unavailable
with patch.dict("sys.modules", {"hdbscan": None}):
# Should fall back to dbscan or noop
strategy = get_strategy("hdbscan", default_config, fallback=True)
assert isinstance(strategy, BaseClusteringStrategy)
def test_get_strategy_fallback_disabled_raises(self, default_config):
"""Test ImportError when fallback disabled and strategy unavailable."""
with patch(
"codexlens.search.clustering.factory.check_clustering_strategy_available"
) as mock_check:
mock_check.return_value = (False, "Test error")
with pytest.raises(ImportError, match="Test error"):
get_strategy("hdbscan", default_config, fallback=False)
def test_get_strategy_invalid_raises(self, default_config):
"""Test ValueError for invalid strategy name."""
with pytest.raises(ValueError, match="Unknown clustering strategy"):
get_strategy("invalid", default_config)
def test_factory_class_interface(self, default_config):
"""Test ClusteringStrategyFactory class interface."""
strategy = ClusteringStrategyFactory.get_strategy("noop", default_config)
assert isinstance(strategy, NoOpStrategy)
ok, err = ClusteringStrategyFactory.check_available("noop")
assert ok is True
@pytest.mark.skipif(
not check_clustering_strategy_available("hdbscan")[0],
reason="hdbscan not installed",
)
def test_get_strategy_hdbscan(self, default_config):
"""Test get_strategy('hdbscan') returns HDBSCANStrategy."""
from codexlens.search.clustering import HDBSCANStrategy
strategy = get_strategy("hdbscan", default_config)
assert isinstance(strategy, HDBSCANStrategy)
@pytest.mark.skipif(
not check_clustering_strategy_available("dbscan")[0],
reason="sklearn not installed",
)
def test_get_strategy_dbscan(self, default_config):
"""Test get_strategy('dbscan') returns DBSCANStrategy."""
from codexlens.search.clustering import DBSCANStrategy
strategy = get_strategy("dbscan", default_config)
assert isinstance(strategy, DBSCANStrategy)
@pytest.mark.skipif(
not check_clustering_strategy_available("dbscan")[0],
reason="sklearn not installed",
)
def test_get_strategy_dbscan_with_kwargs(self, default_config):
"""Test DBSCANStrategy kwargs passed through factory."""
strategy = get_strategy("dbscan", default_config, eps=0.3, eps_percentile=20.0)
assert strategy.eps == 0.3
assert strategy.eps_percentile == 20.0
# =============================================================================
# Integration Tests
# =============================================================================
class TestClusteringIntegration:
"""Integration tests for clustering strategies."""
def test_all_strategies_same_interface(
self, sample_results: List[SearchResult], mock_embeddings, default_config
):
"""Test all strategies have consistent interface."""
strategies = [NoOpStrategy(default_config)]
# Add available strategies
try:
from codexlens.search.clustering import HDBSCANStrategy
strategies.append(HDBSCANStrategy(default_config))
except ImportError:
pass
try:
from codexlens.search.clustering import DBSCANStrategy
strategies.append(DBSCANStrategy(default_config))
except ImportError:
pass
for strategy in strategies:
# All should implement cluster()
clusters = strategy.cluster(mock_embeddings, sample_results)
assert isinstance(clusters, list)
# All should implement select_representatives()
reps = strategy.select_representatives(clusters, sample_results)
assert isinstance(reps, list)
assert all(isinstance(r, SearchResult) for r in reps)
# All should implement fit_predict()
reps = strategy.fit_predict(mock_embeddings, sample_results)
assert isinstance(reps, list)
def test_clustering_reduces_redundancy(
self, default_config
):
"""Test clustering reduces redundant similar results."""
import numpy as np
# Create results with very similar embeddings
results = [
SearchResult(path=f"{i}.py", score=0.9 - i * 0.01, excerpt="def foo(): pass")
for i in range(10)
]
# Very similar embeddings - should cluster together
embeddings = np.array(
[[1.0 + i * 0.01, 0.0, 0.0] for i in range(10)], dtype=np.float32
)
strategy = get_strategy("auto", default_config)
representatives = strategy.fit_predict(embeddings, results)
# Should have fewer representatives than input (clustering reduced redundancy)
# NoOp returns all, but HDBSCAN/DBSCAN should reduce
assert len(representatives) <= len(results)
# =============================================================================
# Test FrequencyStrategy
# =============================================================================
class TestFrequencyStrategy:
"""Tests for FrequencyStrategy - frequency-based clustering."""
@pytest.fixture
def frequency_config(self):
"""Create FrequencyConfig for testing."""
from codexlens.search.clustering import FrequencyConfig
return FrequencyConfig(min_frequency=1, max_representatives_per_group=3)
@pytest.fixture
def frequency_strategy(self, frequency_config):
"""Create FrequencyStrategy instance."""
from codexlens.search.clustering import FrequencyStrategy
return FrequencyStrategy(frequency_config)
@pytest.fixture
def symbol_results(self) -> List[SearchResult]:
"""Create sample results with symbol names for frequency testing."""
return [
SearchResult(path="auth.py", score=0.9, excerpt="authenticate user", symbol_name="authenticate"),
SearchResult(path="login.py", score=0.85, excerpt="authenticate login", symbol_name="authenticate"),
SearchResult(path="session.py", score=0.8, excerpt="authenticate session", symbol_name="authenticate"),
SearchResult(path="utils.py", score=0.7, excerpt="helper function", symbol_name="helper_func"),
SearchResult(path="validate.py", score=0.6, excerpt="validate input", symbol_name="validate"),
SearchResult(path="check.py", score=0.55, excerpt="validate data", symbol_name="validate"),
]
def test_frequency_strategy_available(self):
"""Test FrequencyStrategy is always available (no deps)."""
ok, err = check_clustering_strategy_available("frequency")
assert ok is True
assert err is None
def test_get_strategy_frequency(self):
"""Test get_strategy('frequency') returns FrequencyStrategy."""
from codexlens.search.clustering import FrequencyStrategy
strategy = get_strategy("frequency")
assert isinstance(strategy, FrequencyStrategy)
def test_cluster_groups_by_symbol(self, frequency_strategy, symbol_results):
"""Test cluster() groups results by symbol name."""
import numpy as np
embeddings = np.random.rand(len(symbol_results), 128)
clusters = frequency_strategy.cluster(embeddings, symbol_results)
# Should have 3 groups: authenticate(3), validate(2), helper_func(1)
assert len(clusters) == 3
# First cluster should be authenticate (highest frequency)
first_cluster_symbols = [symbol_results[i].symbol_name for i in clusters[0]]
assert all(s == "authenticate" for s in first_cluster_symbols)
assert len(clusters[0]) == 3
def test_cluster_orders_by_frequency(self, frequency_strategy, symbol_results):
"""Test clusters are ordered by frequency (descending)."""
import numpy as np
embeddings = np.random.rand(len(symbol_results), 128)
clusters = frequency_strategy.cluster(embeddings, symbol_results)
# Verify frequency ordering
frequencies = [len(c) for c in clusters]
assert frequencies == sorted(frequencies, reverse=True)
def test_select_representatives_adds_frequency_metadata(self, frequency_strategy, symbol_results):
"""Test representatives have frequency metadata."""
import numpy as np
embeddings = np.random.rand(len(symbol_results), 128)
clusters = frequency_strategy.cluster(embeddings, symbol_results)
reps = frequency_strategy.select_representatives(clusters, symbol_results, embeddings)
# Check frequency metadata
for rep in reps:
assert "frequency" in rep.metadata
assert rep.metadata["frequency"] >= 1
def test_min_frequency_filter_mode(self, symbol_results):
"""Test min_frequency with filter mode removes low-frequency results."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(min_frequency=2, keep_mode="filter")
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# helper_func (freq=1) should be filtered out
rep_symbols = [r.symbol_name for r in reps]
assert "helper_func" not in rep_symbols
assert "authenticate" in rep_symbols
assert "validate" in rep_symbols
def test_min_frequency_demote_mode(self, symbol_results):
"""Test min_frequency with demote mode keeps but deprioritizes low-frequency."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(min_frequency=2, keep_mode="demote")
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# helper_func should still be present but at the end
rep_symbols = [r.symbol_name for r in reps]
assert "helper_func" in rep_symbols
# Should be demoted to end
helper_idx = rep_symbols.index("helper_func")
assert helper_idx == len(rep_symbols) - 1
def test_group_by_file(self, symbol_results):
"""Test grouping by file path instead of symbol."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(group_by="file")
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
clusters = strategy.cluster(embeddings, symbol_results)
# Each file should be its own group (all unique paths)
assert len(clusters) == 6
def test_max_representatives_per_group(self, symbol_results):
"""Test max_representatives_per_group limits output per symbol."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(max_representatives_per_group=1)
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# Should have at most 1 per group = 3 groups = 3 reps
assert len(reps) == 3
def test_frequency_boost_score(self, symbol_results):
"""Test frequency_weight boosts high-frequency results."""
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
import numpy as np
config = FrequencyConfig(frequency_weight=0.5) # Strong boost
strategy = FrequencyStrategy(config)
embeddings = np.random.rand(len(symbol_results), 128)
reps = strategy.fit_predict(embeddings, symbol_results)
# High-frequency results should have boosted scores in metadata
for rep in reps:
if rep.metadata.get("frequency", 1) > 1:
assert rep.metadata.get("frequency_boosted_score", 0) > rep.score
def test_empty_results(self, frequency_strategy):
"""Test handling of empty results."""
import numpy as np
clusters = frequency_strategy.cluster(np.array([]).reshape(0, 128), [])
assert clusters == []
reps = frequency_strategy.select_representatives([], [], None)
assert reps == []
def test_factory_with_kwargs(self):
"""Test factory passes kwargs to FrequencyConfig."""
strategy = get_strategy("frequency", min_frequency=3, group_by="file")
assert strategy.config.min_frequency == 3
assert strategy.config.group_by == "file"

View File

@@ -0,0 +1,698 @@
"""Integration tests for staged cascade search pipeline.
Tests the 4-stage pipeline:
1. Stage 1: Binary coarse search
2. Stage 2: LSP graph expansion
3. Stage 3: Clustering and representative selection
4. Stage 4: Optional cross-encoder reranking
"""
from __future__ import annotations
import json
import tempfile
from pathlib import Path
from typing import List
from unittest.mock import MagicMock, Mock, patch
import pytest
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def temp_paths():
"""Create temporary directory structure."""
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
root = Path(tmpdir.name)
yield root
try:
tmpdir.cleanup()
except (PermissionError, OSError):
pass
@pytest.fixture
def mock_registry(temp_paths: Path):
"""Create mock registry store."""
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
return registry
@pytest.fixture
def mock_mapper(temp_paths: Path):
"""Create path mapper."""
return PathMapper(index_root=temp_paths / "indexes")
@pytest.fixture
def mock_config():
"""Create mock config with staged cascade settings."""
config = MagicMock(spec=Config)
config.cascade_coarse_k = 100
config.cascade_fine_k = 10
config.enable_staged_rerank = False
config.staged_clustering_strategy = "auto"
config.staged_clustering_min_size = 3
config.graph_expansion_depth = 2
return config
@pytest.fixture
def sample_binary_results() -> List[SearchResult]:
"""Create sample binary search results for testing."""
return [
SearchResult(
path="a.py",
score=0.95,
excerpt="def authenticate_user(username, password):",
symbol_name="authenticate_user",
symbol_kind="function",
start_line=10,
end_line=15,
),
SearchResult(
path="b.py",
score=0.85,
excerpt="class AuthManager:",
symbol_name="AuthManager",
symbol_kind="class",
start_line=5,
end_line=20,
),
SearchResult(
path="c.py",
score=0.75,
excerpt="def check_credentials(user, pwd):",
symbol_name="check_credentials",
symbol_kind="function",
start_line=30,
end_line=35,
),
]
@pytest.fixture
def sample_expanded_results() -> List[SearchResult]:
"""Create sample expanded results (after LSP expansion)."""
return [
SearchResult(
path="a.py",
score=0.95,
excerpt="def authenticate_user(username, password):",
symbol_name="authenticate_user",
symbol_kind="function",
),
SearchResult(
path="a.py",
score=0.90,
excerpt="def verify_password(pwd):",
symbol_name="verify_password",
symbol_kind="function",
),
SearchResult(
path="b.py",
score=0.85,
excerpt="class AuthManager:",
symbol_name="AuthManager",
symbol_kind="class",
),
SearchResult(
path="b.py",
score=0.80,
excerpt="def login(self, user):",
symbol_name="login",
symbol_kind="function",
),
SearchResult(
path="c.py",
score=0.75,
excerpt="def check_credentials(user, pwd):",
symbol_name="check_credentials",
symbol_kind="function",
),
SearchResult(
path="d.py",
score=0.70,
excerpt="class UserModel:",
symbol_name="UserModel",
symbol_kind="class",
),
]
# =============================================================================
# Test Stage Methods
# =============================================================================
class TestStage1BinarySearch:
"""Tests for Stage 1: Binary coarse search."""
def test_stage1_returns_results_with_index_root(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage1_binary_search returns results and index_root."""
from codexlens.search.chain_search import SearchStats
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Mock the binary embedding backend (import is inside the method)
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
mock_index = MagicMock()
mock_index.count.return_value = 10
mock_index.search.return_value = ([1, 2, 3], [10, 20, 30])
mock_binary_idx.return_value = mock_index
index_paths = [Path("/fake/index1/_index.db")]
stats = SearchStats()
results, index_root = engine._stage1_binary_search(
"query", index_paths, coarse_k=10, stats=stats
)
assert isinstance(results, list)
assert isinstance(index_root, (Path, type(None)))
def test_stage1_handles_empty_index_paths(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage1_binary_search handles empty index paths."""
from codexlens.search.chain_search import SearchStats
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
index_paths = []
stats = SearchStats()
results, index_root = engine._stage1_binary_search(
"query", index_paths, coarse_k=10, stats=stats
)
assert results == []
assert index_root is None
def test_stage1_aggregates_results_from_multiple_indexes(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage1_binary_search aggregates results from multiple indexes."""
from codexlens.search.chain_search import SearchStats
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
mock_index = MagicMock()
mock_index.count.return_value = 10
# Return different results for different calls
mock_index.search.side_effect = [
([1, 2], [10, 20]),
([3, 4], [15, 25]),
]
mock_binary_idx.return_value = mock_index
index_paths = [
Path("/fake/index1/_index.db"),
Path("/fake/index2/_index.db"),
]
stats = SearchStats()
results, _ = engine._stage1_binary_search(
"query", index_paths, coarse_k=10, stats=stats
)
# Should aggregate candidates from both indexes
assert isinstance(results, list)
class TestStage2LSPExpand:
"""Tests for Stage 2: LSP graph expansion."""
def test_stage2_returns_expanded_results(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test _stage2_lsp_expand returns expanded results."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Import is inside the method, so we need to patch it there
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
mock_expander = MagicMock()
mock_expander.expand.return_value = [
SearchResult(path="related.py", score=0.7, excerpt="related")
]
mock_expander_cls.return_value = mock_expander
expanded = engine._stage2_lsp_expand(
sample_binary_results, index_root=Path("/fake/index")
)
assert isinstance(expanded, list)
# Should include original results
assert len(expanded) >= len(sample_binary_results)
def test_stage2_handles_no_index_root(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test _stage2_lsp_expand handles missing index_root."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
expanded = engine._stage2_lsp_expand(sample_binary_results, index_root=None)
# Should return original results unchanged
assert expanded == sample_binary_results
def test_stage2_handles_empty_results(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage2_lsp_expand handles empty input."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
expanded = engine._stage2_lsp_expand([], index_root=Path("/fake"))
assert expanded == []
def test_stage2_deduplicates_results(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test _stage2_lsp_expand deduplicates by (path, symbol_name, start_line)."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Mock expander to return duplicate of first result
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
mock_expander = MagicMock()
duplicate = SearchResult(
path=sample_binary_results[0].path,
score=0.5,
excerpt="duplicate",
symbol_name=sample_binary_results[0].symbol_name,
start_line=sample_binary_results[0].start_line,
)
mock_expander.expand.return_value = [duplicate]
mock_expander_cls.return_value = mock_expander
expanded = engine._stage2_lsp_expand(
sample_binary_results, index_root=Path("/fake")
)
# Should not include duplicate
assert len(expanded) == len(sample_binary_results)
class TestStage3ClusterPrune:
"""Tests for Stage 3: Clustering and representative selection."""
def test_stage3_returns_representatives(
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
):
"""Test _stage3_cluster_prune returns representative results."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
import numpy as np
# Mock embeddings
mock_embed.return_value = np.random.rand(
len(sample_expanded_results), 128
).astype(np.float32)
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
assert isinstance(clustered, list)
assert len(clustered) <= len(sample_expanded_results)
assert all(isinstance(r, SearchResult) for r in clustered)
def test_stage3_handles_few_results(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage3_cluster_prune skips clustering for few results."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
few_results = [
SearchResult(path="a.py", score=0.9, excerpt="a"),
SearchResult(path="b.py", score=0.8, excerpt="b"),
]
clustered = engine._stage3_cluster_prune(few_results, target_count=5)
# Should return all results unchanged
assert clustered == few_results
def test_stage3_handles_no_embeddings(
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
):
"""Test _stage3_cluster_prune falls back to score-based selection without embeddings."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
mock_embed.return_value = None
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
# Should return top-scored results
assert len(clustered) <= 3
# Should be sorted by score descending
scores = [r.score for r in clustered]
assert scores == sorted(scores, reverse=True)
def test_stage3_uses_config_clustering_strategy(
self, mock_registry, mock_mapper, sample_expanded_results
):
"""Test _stage3_cluster_prune uses config clustering strategy."""
config = MagicMock(spec=Config)
config.staged_clustering_strategy = "auto"
config.staged_clustering_min_size = 2
engine = ChainSearchEngine(mock_registry, PathMapper(), config=config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
import numpy as np
mock_embed.return_value = np.random.rand(
len(sample_expanded_results), 128
).astype(np.float32)
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
# Should use clustering (auto will pick best available)
# Result should be a list of SearchResult objects
assert isinstance(clustered, list)
assert all(isinstance(r, SearchResult) for r in clustered)
class TestStage4OptionalRerank:
"""Tests for Stage 4: Optional cross-encoder reranking."""
def test_stage4_reranks_with_reranker(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage4_optional_rerank uses _cross_encoder_rerank."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
results = [
SearchResult(path="a.py", score=0.9, excerpt="a"),
SearchResult(path="b.py", score=0.8, excerpt="b"),
SearchResult(path="c.py", score=0.7, excerpt="c"),
]
# Mock the _cross_encoder_rerank method that _stage4 calls
with patch.object(engine, "_cross_encoder_rerank") as mock_rerank:
mock_rerank.return_value = [
SearchResult(path="c.py", score=0.95, excerpt="c"),
SearchResult(path="a.py", score=0.85, excerpt="a"),
]
reranked = engine._stage4_optional_rerank("query", results, k=2)
mock_rerank.assert_called_once_with("query", results, 2)
assert len(reranked) <= 2
# First result should be reranked winner
assert reranked[0].path == "c.py"
def test_stage4_handles_empty_results(
self, mock_registry, mock_mapper, mock_config
):
"""Test _stage4_optional_rerank handles empty input."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
reranked = engine._stage4_optional_rerank("query", [], k=2)
# Should return empty list
assert reranked == []
# =============================================================================
# Integration Tests
# =============================================================================
class TestStagedCascadeIntegration:
"""Integration tests for staged_cascade_search() end-to-end."""
def test_staged_cascade_returns_chain_result(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search returns ChainSearchResult."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Mock all stages
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
result = engine.staged_cascade_search(
"query", temp_paths / "src", k=10, coarse_k=100
)
from codexlens.search.chain_search import ChainSearchResult
assert isinstance(result, ChainSearchResult)
assert result.query == "query"
assert len(result.results) <= 10
def test_staged_cascade_includes_stage_stats(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search includes per-stage timing stats."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
result = engine.staged_cascade_search(
"query", temp_paths / "src"
)
# Check for stage stats in errors field
stage_stats = None
for err in result.stats.errors:
if err.startswith("STAGE_STATS:"):
stage_stats = json.loads(err.replace("STAGE_STATS:", ""))
break
assert stage_stats is not None
assert "stage_times" in stage_stats
assert "stage_counts" in stage_stats
assert "stage1_binary_ms" in stage_stats["stage_times"]
assert "stage1_candidates" in stage_stats["stage_counts"]
def test_staged_cascade_with_rerank_enabled(
self, mock_registry, mock_mapper, temp_paths
):
"""Test staged_cascade_search with reranking enabled."""
config = MagicMock(spec=Config)
config.cascade_coarse_k = 100
config.cascade_fine_k = 10
config.enable_staged_rerank = True
config.staged_clustering_strategy = "auto"
config.graph_expansion_depth = 2
engine = ChainSearchEngine(mock_registry, mock_mapper, config=config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage4_optional_rerank") as mock_stage4:
mock_stage4.return_value = [
SearchResult(path="a.py", score=0.95, excerpt="a")
]
result = engine.staged_cascade_search(
"query", temp_paths / "src"
)
# Verify stage 4 was called
mock_stage4.assert_called_once()
def test_staged_cascade_fallback_to_hybrid(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search falls back to hybrid when numpy unavailable."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch("codexlens.search.chain_search.NUMPY_AVAILABLE", False):
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
mock_hybrid.return_value = MagicMock()
engine.staged_cascade_search("query", temp_paths / "src")
# Should fall back to hybrid cascade
mock_hybrid.assert_called_once()
def test_staged_cascade_deduplicates_final_results(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged_cascade_search deduplicates results by path."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
mock_stage1.return_value = (
[SearchResult(path="a.py", score=0.9, excerpt="a")],
temp_paths / "index",
)
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
mock_stage2.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a")
]
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
# Return duplicates with different scores
mock_stage3.return_value = [
SearchResult(path="a.py", score=0.9, excerpt="a"),
SearchResult(path="a.py", score=0.8, excerpt="a duplicate"),
SearchResult(path="b.py", score=0.7, excerpt="b"),
]
result = engine.staged_cascade_search(
"query", temp_paths / "src", k=10
)
# Should deduplicate a.py (keep higher score)
paths = [r.path for r in result.results]
assert len(paths) == len(set(paths))
# a.py should have score 0.9
a_result = next(r for r in result.results if r.path == "a.py")
assert a_result.score == 0.9
# =============================================================================
# Graceful Degradation Tests
# =============================================================================
class TestStagedCascadeGracefulDegradation:
"""Tests for graceful degradation when dependencies unavailable."""
def test_falls_back_when_clustering_unavailable(
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
):
"""Test clustering stage falls back gracefully when clustering unavailable."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
mock_embed.return_value = None
clustered = engine._stage3_cluster_prune(
sample_expanded_results, target_count=3
)
# Should fall back to score-based selection
assert len(clustered) <= 3
def test_falls_back_when_graph_expander_unavailable(
self, mock_registry, mock_mapper, mock_config, sample_binary_results
):
"""Test LSP expansion falls back when GraphExpander unavailable."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
# Patch the import inside the method
with patch("codexlens.search.graph_expander.GraphExpander", side_effect=ImportError):
expanded = engine._stage2_lsp_expand(
sample_binary_results, index_root=Path("/fake")
)
# Should return original results
assert expanded == sample_binary_results
def test_handles_stage_failures_gracefully(
self, mock_registry, mock_mapper, mock_config, temp_paths
):
"""Test staged pipeline handles stage failures gracefully."""
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
with patch.object(engine, "_find_start_index") as mock_find:
mock_find.return_value = temp_paths / "index" / "_index.db"
with patch.object(engine, "_collect_index_paths") as mock_collect:
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
# Stage 1 returns no results
mock_stage1.return_value = ([], None)
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
mock_hybrid.return_value = MagicMock()
engine.staged_cascade_search("query", temp_paths / "src")
# Should fall back to hybrid when stage 1 fails
mock_hybrid.assert_called_once()