mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
feat(cli): 添加 --rule 选项支持模板自动发现
重构 ccw cli 模板系统: - 新增 template-discovery.ts 模块,支持扁平化模板自动发现 - 添加 --rule <template> 选项,自动加载 protocol 和 template - 模板目录从嵌套结构 (prompts/category/file.txt) 迁移到扁平结构 (prompts/category-function.txt) - 更新所有 agent/command 文件,使用 $PROTO $TMPL 环境变量替代 $(cat ...) 模式 - 支持模糊匹配:--rule 02-review-architecture 可匹配 analysis-review-architecture.txt 其他更新: - Dashboard: 添加 Claude Manager 和 Issue Manager 页面 - Codex-lens: 增强 chain_search 和 clustering 模块 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
282
codex-lens/tests/api/test_references.py
Normal file
282
codex-lens/tests/api/test_references.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Tests for codexlens.api.references module."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.api.references import (
|
||||
find_references,
|
||||
_read_line_from_file,
|
||||
_proximity_score,
|
||||
_group_references_by_definition,
|
||||
_transform_to_reference_result,
|
||||
)
|
||||
from codexlens.api.models import (
|
||||
DefinitionResult,
|
||||
ReferenceResult,
|
||||
GroupedReferences,
|
||||
)
|
||||
|
||||
|
||||
class TestReadLineFromFile:
|
||||
"""Tests for _read_line_from_file helper."""
|
||||
|
||||
def test_read_existing_line(self, tmp_path):
|
||||
"""Test reading an existing line from a file."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("line 1\nline 2\nline 3\n")
|
||||
|
||||
assert _read_line_from_file(str(test_file), 1) == "line 1"
|
||||
assert _read_line_from_file(str(test_file), 2) == "line 2"
|
||||
assert _read_line_from_file(str(test_file), 3) == "line 3"
|
||||
|
||||
def test_read_nonexistent_line(self, tmp_path):
|
||||
"""Test reading a line that doesn't exist."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("line 1\nline 2\n")
|
||||
|
||||
assert _read_line_from_file(str(test_file), 10) == ""
|
||||
|
||||
def test_read_nonexistent_file(self):
|
||||
"""Test reading from a file that doesn't exist."""
|
||||
assert _read_line_from_file("/nonexistent/path/file.py", 1) == ""
|
||||
|
||||
def test_strips_trailing_whitespace(self, tmp_path):
|
||||
"""Test that trailing whitespace is stripped."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("line with spaces \n")
|
||||
|
||||
assert _read_line_from_file(str(test_file), 1) == "line with spaces"
|
||||
|
||||
|
||||
class TestProximityScore:
|
||||
"""Tests for _proximity_score helper."""
|
||||
|
||||
def test_same_file(self):
|
||||
"""Same file should return highest score."""
|
||||
score = _proximity_score("/a/b/c.py", "/a/b/c.py")
|
||||
assert score == 1000
|
||||
|
||||
def test_same_directory(self):
|
||||
"""Same directory should return 100."""
|
||||
score = _proximity_score("/a/b/x.py", "/a/b/y.py")
|
||||
assert score == 100
|
||||
|
||||
def test_different_directories(self):
|
||||
"""Different directories should return common prefix length."""
|
||||
score = _proximity_score("/a/b/c/x.py", "/a/b/d/y.py")
|
||||
# Common path is /a/b
|
||||
assert score > 0
|
||||
|
||||
def test_empty_paths(self):
|
||||
"""Empty paths should return 0."""
|
||||
assert _proximity_score("", "/a/b/c.py") == 0
|
||||
assert _proximity_score("/a/b/c.py", "") == 0
|
||||
assert _proximity_score("", "") == 0
|
||||
|
||||
|
||||
class TestGroupReferencesByDefinition:
|
||||
"""Tests for _group_references_by_definition helper."""
|
||||
|
||||
def test_single_definition(self):
|
||||
"""Single definition should have all references."""
|
||||
definition = DefinitionResult(
|
||||
name="foo",
|
||||
kind="function",
|
||||
file_path="/a/b/c.py",
|
||||
line=10,
|
||||
end_line=20,
|
||||
)
|
||||
references = [
|
||||
ReferenceResult(
|
||||
file_path="/a/b/d.py",
|
||||
line=5,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
),
|
||||
ReferenceResult(
|
||||
file_path="/a/x/y.py",
|
||||
line=10,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
),
|
||||
]
|
||||
|
||||
result = _group_references_by_definition([definition], references)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].definition == definition
|
||||
assert len(result[0].references) == 2
|
||||
|
||||
def test_multiple_definitions(self):
|
||||
"""Multiple definitions should group by proximity."""
|
||||
def1 = DefinitionResult(
|
||||
name="foo",
|
||||
kind="function",
|
||||
file_path="/a/b/c.py",
|
||||
line=10,
|
||||
end_line=20,
|
||||
)
|
||||
def2 = DefinitionResult(
|
||||
name="foo",
|
||||
kind="function",
|
||||
file_path="/x/y/z.py",
|
||||
line=10,
|
||||
end_line=20,
|
||||
)
|
||||
|
||||
# Reference closer to def1
|
||||
ref1 = ReferenceResult(
|
||||
file_path="/a/b/d.py",
|
||||
line=5,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
)
|
||||
# Reference closer to def2
|
||||
ref2 = ReferenceResult(
|
||||
file_path="/x/y/w.py",
|
||||
line=10,
|
||||
column=0,
|
||||
context_line="foo()",
|
||||
relationship="call",
|
||||
)
|
||||
|
||||
result = _group_references_by_definition(
|
||||
[def1, def2], [ref1, ref2], include_definition=True
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Each definition should have the closer reference
|
||||
def1_refs = [g for g in result if g.definition == def1][0].references
|
||||
def2_refs = [g for g in result if g.definition == def2][0].references
|
||||
|
||||
assert any(r.file_path == "/a/b/d.py" for r in def1_refs)
|
||||
assert any(r.file_path == "/x/y/w.py" for r in def2_refs)
|
||||
|
||||
def test_empty_definitions(self):
|
||||
"""Empty definitions should return empty result."""
|
||||
result = _group_references_by_definition([], [])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestTransformToReferenceResult:
|
||||
"""Tests for _transform_to_reference_result helper."""
|
||||
|
||||
def test_normalizes_relationship_type(self, tmp_path):
|
||||
"""Test that relationship type is normalized."""
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("def foo(): pass\n")
|
||||
|
||||
# Create a mock raw reference
|
||||
raw_ref = MagicMock()
|
||||
raw_ref.file_path = str(test_file)
|
||||
raw_ref.line = 1
|
||||
raw_ref.column = 0
|
||||
raw_ref.relationship_type = "calls" # Plural form
|
||||
|
||||
result = _transform_to_reference_result(raw_ref)
|
||||
|
||||
assert result.relationship == "call" # Normalized form
|
||||
assert result.context_line == "def foo(): pass"
|
||||
|
||||
|
||||
class TestFindReferences:
|
||||
"""Tests for find_references API function."""
|
||||
|
||||
def test_raises_for_invalid_project_root(self):
|
||||
"""Test that ValueError is raised for invalid project root."""
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
find_references("/nonexistent/path", "some_symbol")
|
||||
|
||||
@patch("codexlens.search.chain_search.ChainSearchEngine")
|
||||
@patch("codexlens.storage.registry.RegistryStore")
|
||||
@patch("codexlens.storage.path_mapper.PathMapper")
|
||||
@patch("codexlens.config.Config")
|
||||
def test_returns_grouped_references(
|
||||
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
|
||||
):
|
||||
"""Test that find_references returns GroupedReferences."""
|
||||
# Setup mocks
|
||||
mock_engine = MagicMock()
|
||||
mock_engine_class.return_value = mock_engine
|
||||
|
||||
# Mock symbol search (for definitions)
|
||||
mock_symbol = MagicMock()
|
||||
mock_symbol.name = "test_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = str(tmp_path / "test.py")
|
||||
mock_symbol.range = (10, 20)
|
||||
mock_engine.search_symbols.return_value = [mock_symbol]
|
||||
|
||||
# Mock reference search
|
||||
mock_ref = MagicMock()
|
||||
mock_ref.file_path = str(tmp_path / "caller.py")
|
||||
mock_ref.line = 5
|
||||
mock_ref.column = 0
|
||||
mock_ref.relationship_type = "call"
|
||||
mock_engine.search_references.return_value = [mock_ref]
|
||||
|
||||
# Create test files
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("def test_func():\n pass\n")
|
||||
caller_file = tmp_path / "caller.py"
|
||||
caller_file.write_text("test_func()\n")
|
||||
|
||||
# Call find_references
|
||||
result = find_references(str(tmp_path), "test_func")
|
||||
|
||||
# Verify result structure
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], GroupedReferences)
|
||||
assert result[0].definition.name == "test_func"
|
||||
assert len(result[0].references) == 1
|
||||
|
||||
@patch("codexlens.search.chain_search.ChainSearchEngine")
|
||||
@patch("codexlens.storage.registry.RegistryStore")
|
||||
@patch("codexlens.storage.path_mapper.PathMapper")
|
||||
@patch("codexlens.config.Config")
|
||||
def test_respects_include_definition_false(
|
||||
self, mock_config, mock_mapper, mock_registry, mock_engine_class, tmp_path
|
||||
):
|
||||
"""Test include_definition=False behavior."""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine_class.return_value = mock_engine
|
||||
mock_engine.search_symbols.return_value = []
|
||||
mock_engine.search_references.return_value = []
|
||||
|
||||
result = find_references(
|
||||
str(tmp_path), "test_func", include_definition=False
|
||||
)
|
||||
|
||||
# Should still return a result with placeholder definition
|
||||
assert len(result) == 1
|
||||
assert result[0].definition.name == "test_func"
|
||||
|
||||
|
||||
class TestImports:
|
||||
"""Tests for module imports and exports."""
|
||||
|
||||
def test_find_references_exported_from_api(self):
|
||||
"""Test that find_references is exported from codexlens.api."""
|
||||
from codexlens.api import find_references as api_find_references
|
||||
|
||||
assert callable(api_find_references)
|
||||
|
||||
def test_models_exported_from_api(self):
|
||||
"""Test that result models are exported from codexlens.api."""
|
||||
from codexlens.api import (
|
||||
GroupedReferences,
|
||||
ReferenceResult,
|
||||
DefinitionResult,
|
||||
)
|
||||
|
||||
assert GroupedReferences is not None
|
||||
assert ReferenceResult is not None
|
||||
assert DefinitionResult is not None
|
||||
528
codex-lens/tests/api/test_semantic_search.py
Normal file
528
codex-lens/tests/api/test_semantic_search.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Tests for semantic_search API."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.api import SemanticResult
|
||||
from codexlens.api.semantic import (
|
||||
semantic_search,
|
||||
_build_search_options,
|
||||
_generate_match_reason,
|
||||
_split_camel_case,
|
||||
_transform_results,
|
||||
)
|
||||
|
||||
|
||||
class TestSemanticSearchFunctionSignature:
|
||||
"""Test that semantic_search has the correct function signature."""
|
||||
|
||||
def test_function_accepts_all_parameters(self):
|
||||
"""Verify function signature matches spec."""
|
||||
import inspect
|
||||
sig = inspect.signature(semantic_search)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
expected_params = [
|
||||
"project_root",
|
||||
"query",
|
||||
"mode",
|
||||
"vector_weight",
|
||||
"structural_weight",
|
||||
"keyword_weight",
|
||||
"fusion_strategy",
|
||||
"kind_filter",
|
||||
"limit",
|
||||
"include_match_reason",
|
||||
]
|
||||
|
||||
assert params == expected_params
|
||||
|
||||
def test_default_parameter_values(self):
|
||||
"""Verify default parameter values match spec."""
|
||||
import inspect
|
||||
sig = inspect.signature(semantic_search)
|
||||
|
||||
assert sig.parameters["mode"].default == "fusion"
|
||||
assert sig.parameters["vector_weight"].default == 0.5
|
||||
assert sig.parameters["structural_weight"].default == 0.3
|
||||
assert sig.parameters["keyword_weight"].default == 0.2
|
||||
assert sig.parameters["fusion_strategy"].default == "rrf"
|
||||
assert sig.parameters["kind_filter"].default is None
|
||||
assert sig.parameters["limit"].default == 20
|
||||
assert sig.parameters["include_match_reason"].default is False
|
||||
|
||||
|
||||
class TestBuildSearchOptions:
|
||||
"""Test _build_search_options helper function."""
|
||||
|
||||
def test_vector_mode_options(self):
|
||||
"""Test options for pure vector mode."""
|
||||
options = _build_search_options(
|
||||
mode="vector",
|
||||
vector_weight=1.0,
|
||||
structural_weight=0.0,
|
||||
keyword_weight=0.0,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert options.hybrid_mode is True
|
||||
assert options.enable_vector is True
|
||||
assert options.pure_vector is True
|
||||
assert options.enable_fuzzy is False
|
||||
|
||||
def test_structural_mode_options(self):
|
||||
"""Test options for structural mode."""
|
||||
options = _build_search_options(
|
||||
mode="structural",
|
||||
vector_weight=0.0,
|
||||
structural_weight=1.0,
|
||||
keyword_weight=0.0,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert options.hybrid_mode is True
|
||||
assert options.enable_vector is False
|
||||
assert options.enable_fuzzy is True
|
||||
assert options.include_symbols is True
|
||||
|
||||
def test_fusion_mode_options(self):
|
||||
"""Test options for fusion mode (default)."""
|
||||
options = _build_search_options(
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert options.hybrid_mode is True
|
||||
assert options.enable_vector is True # vector_weight > 0
|
||||
assert options.enable_fuzzy is True # keyword_weight > 0
|
||||
assert options.include_symbols is True # structural_weight > 0
|
||||
|
||||
|
||||
class TestTransformResults:
|
||||
"""Test _transform_results helper function."""
|
||||
|
||||
def test_transforms_basic_result(self):
|
||||
"""Test basic result transformation."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "def authenticate():"
|
||||
mock_result.symbol_name = "authenticate"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 10
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].symbol_name == "authenticate"
|
||||
assert results[0].kind == "function"
|
||||
assert results[0].file_path == "/project/src/auth.py"
|
||||
assert results[0].line == 10
|
||||
assert results[0].fusion_score == 0.85
|
||||
|
||||
def test_kind_filter_excludes_non_matching(self):
|
||||
"""Test that kind_filter excludes non-matching results."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "AUTH_TOKEN = 'secret'"
|
||||
mock_result.symbol_name = "AUTH_TOKEN"
|
||||
mock_result.symbol_kind = "variable"
|
||||
mock_result.start_line = 5
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=["function", "class"], # Exclude variable
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
def test_kind_filter_includes_matching(self):
|
||||
"""Test that kind_filter includes matching results."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "class AuthManager:"
|
||||
mock_result.symbol_name = "AuthManager"
|
||||
mock_result.symbol_kind = "class"
|
||||
mock_result.start_line = 1
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=["function", "class"], # Include class
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].symbol_name == "AuthManager"
|
||||
|
||||
def test_include_match_reason_generates_reason(self):
|
||||
"""Test that include_match_reason generates match reasons."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.85
|
||||
mock_result.excerpt = "def authenticate(user, password):"
|
||||
mock_result.symbol_name = "authenticate"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 10
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=True,
|
||||
query="authenticate",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].match_reason is not None
|
||||
assert "authenticate" in results[0].match_reason.lower()
|
||||
|
||||
|
||||
class TestGenerateMatchReason:
|
||||
"""Test _generate_match_reason helper function."""
|
||||
|
||||
def test_direct_name_match(self):
|
||||
"""Test match reason for direct name match."""
|
||||
reason = _generate_match_reason(
|
||||
query="authenticate",
|
||||
symbol_name="authenticate",
|
||||
symbol_kind="function",
|
||||
snippet="def authenticate(user): pass",
|
||||
vector_score=0.8,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert "authenticate" in reason.lower()
|
||||
|
||||
def test_keyword_match(self):
|
||||
"""Test match reason for keyword match in snippet."""
|
||||
reason = _generate_match_reason(
|
||||
query="password validation",
|
||||
symbol_name="verify_user",
|
||||
symbol_kind="function",
|
||||
snippet="def verify_user(password): validate(password)",
|
||||
vector_score=0.6,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert "password" in reason.lower() or "validation" in reason.lower()
|
||||
|
||||
def test_high_semantic_similarity(self):
|
||||
"""Test match reason mentions semantic similarity for high vector score."""
|
||||
reason = _generate_match_reason(
|
||||
query="authentication",
|
||||
symbol_name="login_handler",
|
||||
symbol_kind="function",
|
||||
snippet="def login_handler(): pass",
|
||||
vector_score=0.85,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert "semantic" in reason.lower()
|
||||
|
||||
def test_returns_string_even_with_no_matches(self):
|
||||
"""Test that a reason string is always returned."""
|
||||
reason = _generate_match_reason(
|
||||
query="xyz123",
|
||||
symbol_name="abc456",
|
||||
symbol_kind="function",
|
||||
snippet="completely unrelated code",
|
||||
vector_score=0.3,
|
||||
structural_score=None,
|
||||
)
|
||||
|
||||
assert isinstance(reason, str)
|
||||
assert len(reason) > 0
|
||||
|
||||
|
||||
class TestSplitCamelCase:
|
||||
"""Test _split_camel_case helper function."""
|
||||
|
||||
def test_camel_case(self):
|
||||
"""Test splitting camelCase."""
|
||||
result = _split_camel_case("authenticateUser")
|
||||
assert "authenticate" in result.lower()
|
||||
assert "user" in result.lower()
|
||||
|
||||
def test_pascal_case(self):
|
||||
"""Test splitting PascalCase."""
|
||||
result = _split_camel_case("AuthManager")
|
||||
assert "auth" in result.lower()
|
||||
assert "manager" in result.lower()
|
||||
|
||||
def test_snake_case(self):
|
||||
"""Test splitting snake_case."""
|
||||
result = _split_camel_case("auth_manager")
|
||||
assert "auth" in result.lower()
|
||||
assert "manager" in result.lower()
|
||||
|
||||
def test_mixed_case(self):
|
||||
"""Test splitting mixed case."""
|
||||
result = _split_camel_case("HTTPRequestHandler")
|
||||
# Should handle acronyms
|
||||
assert "http" in result.lower() or "request" in result.lower()
|
||||
|
||||
|
||||
class TestSemanticResultDataclass:
|
||||
"""Test SemanticResult dataclass structure."""
|
||||
|
||||
def test_semantic_result_fields(self):
|
||||
"""Test SemanticResult has all required fields."""
|
||||
result = SemanticResult(
|
||||
symbol_name="test",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line=1,
|
||||
vector_score=0.8,
|
||||
structural_score=0.6,
|
||||
fusion_score=0.7,
|
||||
snippet="def test(): pass",
|
||||
match_reason="Test match",
|
||||
)
|
||||
|
||||
assert result.symbol_name == "test"
|
||||
assert result.kind == "function"
|
||||
assert result.file_path == "/test.py"
|
||||
assert result.line == 1
|
||||
assert result.vector_score == 0.8
|
||||
assert result.structural_score == 0.6
|
||||
assert result.fusion_score == 0.7
|
||||
assert result.snippet == "def test(): pass"
|
||||
assert result.match_reason == "Test match"
|
||||
|
||||
def test_semantic_result_optional_fields(self):
|
||||
"""Test SemanticResult with optional None fields."""
|
||||
result = SemanticResult(
|
||||
symbol_name="test",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line=1,
|
||||
vector_score=None, # Degraded - no vector index
|
||||
structural_score=None, # Degraded - no relationships
|
||||
fusion_score=0.5,
|
||||
snippet="def test(): pass",
|
||||
match_reason=None, # Not requested
|
||||
)
|
||||
|
||||
assert result.vector_score is None
|
||||
assert result.structural_score is None
|
||||
assert result.match_reason is None
|
||||
|
||||
def test_semantic_result_to_dict(self):
|
||||
"""Test SemanticResult.to_dict() filters None values."""
|
||||
result = SemanticResult(
|
||||
symbol_name="test",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line=1,
|
||||
vector_score=None,
|
||||
structural_score=0.6,
|
||||
fusion_score=0.7,
|
||||
snippet="def test(): pass",
|
||||
match_reason=None,
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert "symbol_name" in d
|
||||
assert "vector_score" not in d # None values filtered
|
||||
assert "structural_score" in d
|
||||
assert "match_reason" not in d # None values filtered
|
||||
|
||||
|
||||
class TestFusionStrategyMapping:
|
||||
"""Test fusion_strategy parameter mapping via _execute_search."""
|
||||
|
||||
def test_rrf_strategy_calls_search(self):
|
||||
"""Test that rrf strategy maps to standard search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="rrf",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.search.assert_called_once()
|
||||
|
||||
def test_staged_strategy_calls_staged_cascade_search(self):
|
||||
"""Test that staged strategy maps to staged_cascade_search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.staged_cascade_search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="staged",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.staged_cascade_search.assert_called_once()
|
||||
|
||||
def test_binary_strategy_calls_binary_cascade_search(self):
|
||||
"""Test that binary strategy maps to binary_cascade_search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.binary_cascade_search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="binary",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.binary_cascade_search.assert_called_once()
|
||||
|
||||
def test_hybrid_strategy_calls_hybrid_cascade_search(self):
|
||||
"""Test that hybrid strategy maps to hybrid_cascade_search."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.hybrid_cascade_search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="hybrid",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.hybrid_cascade_search.assert_called_once()
|
||||
|
||||
def test_unknown_strategy_defaults_to_rrf(self):
|
||||
"""Test that unknown strategy defaults to standard search (rrf)."""
|
||||
from codexlens.api.semantic import _execute_search
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.search.return_value = MagicMock(results=[])
|
||||
mock_options = MagicMock()
|
||||
|
||||
_execute_search(
|
||||
engine=mock_engine,
|
||||
query="test query",
|
||||
source_path=Path("/test"),
|
||||
fusion_strategy="unknown_strategy",
|
||||
options=mock_options,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
mock_engine.search.assert_called_once()
|
||||
|
||||
|
||||
class TestGracefulDegradation:
|
||||
"""Test graceful degradation behavior."""
|
||||
|
||||
def test_vector_score_none_when_no_vector_index(self):
|
||||
"""Test vector_score=None when vector index unavailable."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.5
|
||||
mock_result.excerpt = "def auth(): pass"
|
||||
mock_result.symbol_name = "auth"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 1
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {} # No vector score in metadata
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# When no source_scores in metadata, vector_score should be None
|
||||
assert results[0].vector_score is None
|
||||
|
||||
def test_structural_score_extracted_from_fts(self):
|
||||
"""Test structural_score extracted from FTS scores."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.path = "/project/src/auth.py"
|
||||
mock_result.score = 0.8
|
||||
mock_result.excerpt = "def auth(): pass"
|
||||
mock_result.symbol_name = "auth"
|
||||
mock_result.symbol_kind = "function"
|
||||
mock_result.start_line = 1
|
||||
mock_result.symbol = None
|
||||
mock_result.metadata = {
|
||||
"source_scores": {
|
||||
"exact": 0.9,
|
||||
"fuzzy": 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
results = _transform_results(
|
||||
results=[mock_result],
|
||||
mode="fusion",
|
||||
vector_weight=0.5,
|
||||
structural_weight=0.3,
|
||||
keyword_weight=0.2,
|
||||
kind_filter=None,
|
||||
include_match_reason=False,
|
||||
query="auth",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].structural_score == 0.9 # max of exact/fuzzy
|
||||
1
codex-lens/tests/lsp/__init__.py
Normal file
1
codex-lens/tests/lsp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests package for LSP module."""
|
||||
477
codex-lens/tests/lsp/test_hover.py
Normal file
477
codex-lens/tests/lsp/test_hover.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""Tests for hover provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock
|
||||
import tempfile
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
|
||||
class TestHoverInfo:
|
||||
"""Test HoverInfo dataclass."""
|
||||
|
||||
def test_hover_info_import(self):
|
||||
"""HoverInfo can be imported."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo
|
||||
|
||||
assert HoverInfo is not None
|
||||
|
||||
def test_hover_info_fields(self):
|
||||
"""HoverInfo has all required fields."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo
|
||||
|
||||
info = HoverInfo(
|
||||
name="my_function",
|
||||
kind="function",
|
||||
signature="def my_function(x: int) -> str:",
|
||||
documentation="A test function.",
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
assert info.name == "my_function"
|
||||
assert info.kind == "function"
|
||||
assert info.signature == "def my_function(x: int) -> str:"
|
||||
assert info.documentation == "A test function."
|
||||
assert info.file_path == "/test/file.py"
|
||||
assert info.line_range == (10, 15)
|
||||
|
||||
def test_hover_info_optional_documentation(self):
|
||||
"""Documentation can be None."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation=None,
|
||||
file_path="/test.py",
|
||||
line_range=(1, 2),
|
||||
)
|
||||
assert info.documentation is None
|
||||
|
||||
|
||||
class TestHoverProvider:
|
||||
"""Test HoverProvider class."""
|
||||
|
||||
def test_provider_import(self):
|
||||
"""HoverProvider can be imported."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
assert HoverProvider is not None
|
||||
|
||||
def test_returns_none_for_unknown_symbol(self):
|
||||
"""Returns None when symbol not found."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = []
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
result = provider.get_hover_info("unknown_symbol")
|
||||
|
||||
assert result is None
|
||||
mock_index.search.assert_called_once_with(
|
||||
name="unknown_symbol", limit=1, prefix_mode=False
|
||||
)
|
||||
|
||||
def test_returns_none_for_non_exact_match(self):
|
||||
"""Returns None when search returns non-exact matches."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
# Return a symbol with different name (prefix match but not exact)
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_function_extended"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test/file.py"
|
||||
mock_symbol.range = (10, 15)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
result = provider.get_hover_info("my_function")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_hover_info_for_known_symbol(self):
|
||||
"""Returns HoverInfo for found symbol."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = None # No file, will use fallback signature
|
||||
mock_symbol.range = (10, 15)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
result = provider.get_hover_info("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "my_func"
|
||||
assert result.kind == "function"
|
||||
assert result.line_range == (10, 15)
|
||||
assert result.signature == "function my_func"
|
||||
|
||||
def test_extracts_signature_from_file(self):
|
||||
"""Extracts signature from actual file content."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
# Create a temporary file with Python content
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write("# comment\n")
|
||||
f.write("def test_function(x: int, y: str) -> bool:\n")
|
||||
f.write(" return True\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "test_function"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (2, 3) # Line 2 (1-based)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("test_function")
|
||||
|
||||
assert result is not None
|
||||
assert "def test_function(x: int, y: str) -> bool:" in result.signature
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
def test_extracts_multiline_signature(self):
|
||||
"""Extracts multiline function signature."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
# Create a temporary file with multiline signature
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write("def complex_function(\n")
|
||||
f.write(" arg1: int,\n")
|
||||
f.write(" arg2: str,\n")
|
||||
f.write(") -> bool:\n")
|
||||
f.write(" return True\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "complex_function"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (1, 5) # Line 1 (1-based)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("complex_function")
|
||||
|
||||
assert result is not None
|
||||
assert "def complex_function(" in result.signature
|
||||
# Should capture multiline signature
|
||||
assert "arg1: int" in result.signature
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
def test_handles_nonexistent_file_gracefully(self):
|
||||
"""Returns fallback signature when file doesn't exist."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/nonexistent/path/file.py"
|
||||
mock_symbol.range = (10, 15)
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert result.signature == "function my_func"
|
||||
|
||||
def test_handles_invalid_line_range(self):
|
||||
"""Returns fallback signature when line range is invalid."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".py", delete=False, encoding="utf-8"
|
||||
) as f:
|
||||
f.write("def test():\n")
|
||||
f.write(" pass\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "test"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (100, 105) # Line beyond file length
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = [mock_symbol]
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("test")
|
||||
|
||||
assert result is not None
|
||||
assert result.signature == "function test"
|
||||
finally:
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestFormatHoverMarkdown:
|
||||
"""Test markdown formatting."""
|
||||
|
||||
def test_format_python_signature(self):
|
||||
"""Formats Python signature with python code fence."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func(x: int) -> str:",
|
||||
documentation=None,
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "```python" in result
|
||||
assert "def func(x: int) -> str:" in result
|
||||
assert "function" in result
|
||||
assert "file.py" in result
|
||||
assert "line 10" in result
|
||||
|
||||
def test_format_javascript_signature(self):
|
||||
"""Formats JavaScript signature with javascript code fence."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="myFunc",
|
||||
kind="function",
|
||||
signature="function myFunc(x) {",
|
||||
documentation=None,
|
||||
file_path="/test/file.js",
|
||||
line_range=(5, 10),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "```javascript" in result
|
||||
assert "function myFunc(x) {" in result
|
||||
|
||||
def test_format_typescript_signature(self):
|
||||
"""Formats TypeScript signature with typescript code fence."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="myFunc",
|
||||
kind="function",
|
||||
signature="function myFunc(x: number): string {",
|
||||
documentation=None,
|
||||
file_path="/test/file.ts",
|
||||
line_range=(5, 10),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "```typescript" in result
|
||||
|
||||
def test_format_with_documentation(self):
|
||||
"""Includes documentation when available."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation="This is a test function.",
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "This is a test function." in result
|
||||
assert "---" in result # Separator before docs
|
||||
|
||||
def test_format_without_documentation(self):
|
||||
"""Does not include documentation section when None."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation=None,
|
||||
file_path="/test/file.py",
|
||||
line_range=(10, 15),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
# Should have one separator for location, not two
|
||||
# The result should not have duplicate doc separator
|
||||
lines = result.split("\n")
|
||||
separator_count = sum(1 for line in lines if line.strip() == "---")
|
||||
assert separator_count == 1 # Only location separator
|
||||
|
||||
def test_format_unknown_extension(self):
|
||||
"""Uses empty code fence for unknown file extensions."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="func code here",
|
||||
documentation=None,
|
||||
file_path="/test/file.xyz",
|
||||
line_range=(1, 2),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
# Should have code fence without language specifier
|
||||
assert "```\n" in result or "```xyz" not in result
|
||||
|
||||
def test_format_class_symbol(self):
|
||||
"""Formats class symbol correctly."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="MyClass",
|
||||
kind="class",
|
||||
signature="class MyClass:",
|
||||
documentation=None,
|
||||
file_path="/test/file.py",
|
||||
line_range=(1, 20),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "class MyClass:" in result
|
||||
assert "*class*" in result
|
||||
assert "line 1" in result
|
||||
|
||||
def test_format_empty_file_path(self):
|
||||
"""Handles empty file path gracefully."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverInfo, HoverProvider
|
||||
|
||||
info = HoverInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
signature="def func():",
|
||||
documentation=None,
|
||||
file_path="",
|
||||
line_range=(1, 2),
|
||||
)
|
||||
mock_index = Mock()
|
||||
provider = HoverProvider(mock_index, None)
|
||||
|
||||
result = provider.format_hover_markdown(info)
|
||||
|
||||
assert "unknown" in result or "```" in result
|
||||
|
||||
|
||||
class TestHoverProviderRegistry:
|
||||
"""Test HoverProvider with registry integration."""
|
||||
|
||||
def test_provider_accepts_none_registry(self):
|
||||
"""HoverProvider works without registry."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_index = Mock()
|
||||
mock_index.search.return_value = []
|
||||
|
||||
provider = HoverProvider(mock_index, None)
|
||||
result = provider.get_hover_info("test")
|
||||
|
||||
assert result is None
|
||||
assert provider.registry is None
|
||||
|
||||
def test_provider_stores_registry(self):
|
||||
"""HoverProvider stores registry reference."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.providers import HoverProvider
|
||||
|
||||
mock_index = Mock()
|
||||
mock_registry = Mock()
|
||||
|
||||
provider = HoverProvider(mock_index, mock_registry)
|
||||
|
||||
assert provider.global_index is mock_index
|
||||
assert provider.registry is mock_registry
|
||||
497
codex-lens/tests/lsp/test_references.py
Normal file
497
codex-lens/tests/lsp/test_references.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""Tests for reference search functionality.
|
||||
|
||||
This module tests the ReferenceResult dataclass and search_references method
|
||||
in ChainSearchEngine, as well as the updated lsp_references handler.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class TestReferenceResult:
|
||||
"""Test ReferenceResult dataclass."""
|
||||
|
||||
def test_reference_result_fields(self):
|
||||
"""ReferenceResult has all required fields."""
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
ref = ReferenceResult(
|
||||
file_path="/test/file.py",
|
||||
line=10,
|
||||
column=5,
|
||||
context="def foo():",
|
||||
relationship_type="call",
|
||||
)
|
||||
assert ref.file_path == "/test/file.py"
|
||||
assert ref.line == 10
|
||||
assert ref.column == 5
|
||||
assert ref.context == "def foo():"
|
||||
assert ref.relationship_type == "call"
|
||||
|
||||
def test_reference_result_with_empty_context(self):
|
||||
"""ReferenceResult can have empty context."""
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
ref = ReferenceResult(
|
||||
file_path="/test/file.py",
|
||||
line=1,
|
||||
column=0,
|
||||
context="",
|
||||
relationship_type="import",
|
||||
)
|
||||
assert ref.context == ""
|
||||
|
||||
def test_reference_result_different_relationship_types(self):
|
||||
"""ReferenceResult supports different relationship types."""
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
types = ["call", "import", "inheritance", "implementation", "usage"]
|
||||
for rel_type in types:
|
||||
ref = ReferenceResult(
|
||||
file_path="/test/file.py",
|
||||
line=1,
|
||||
column=0,
|
||||
context="test",
|
||||
relationship_type=rel_type,
|
||||
)
|
||||
assert ref.relationship_type == rel_type
|
||||
|
||||
|
||||
class TestExtractContext:
|
||||
"""Test the _extract_context helper method."""
|
||||
|
||||
def test_extract_context_middle_of_file(self):
|
||||
"""Extract context from middle of file."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
|
||||
|
||||
content = "\n".join([
|
||||
"line 1",
|
||||
"line 2",
|
||||
"line 3",
|
||||
"line 4", # target line
|
||||
"line 5",
|
||||
"line 6",
|
||||
"line 7",
|
||||
])
|
||||
|
||||
# Create minimal mock engine to test _extract_context
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context(content, line=4, context_lines=2)
|
||||
|
||||
assert "line 2" in context
|
||||
assert "line 3" in context
|
||||
assert "line 4" in context
|
||||
assert "line 5" in context
|
||||
assert "line 6" in context
|
||||
|
||||
def test_extract_context_start_of_file(self):
|
||||
"""Extract context at start of file."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
content = "\n".join([
|
||||
"line 1", # target
|
||||
"line 2",
|
||||
"line 3",
|
||||
"line 4",
|
||||
])
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context(content, line=1, context_lines=2)
|
||||
|
||||
assert "line 1" in context
|
||||
assert "line 2" in context
|
||||
assert "line 3" in context
|
||||
|
||||
def test_extract_context_end_of_file(self):
|
||||
"""Extract context at end of file."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
content = "\n".join([
|
||||
"line 1",
|
||||
"line 2",
|
||||
"line 3",
|
||||
"line 4", # target
|
||||
])
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context(content, line=4, context_lines=2)
|
||||
|
||||
assert "line 2" in context
|
||||
assert "line 3" in context
|
||||
assert "line 4" in context
|
||||
|
||||
def test_extract_context_empty_content(self):
|
||||
"""Extract context from empty content."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
context = engine._extract_context("", line=1, context_lines=3)
|
||||
|
||||
assert context == ""
|
||||
|
||||
def test_extract_context_invalid_line(self):
|
||||
"""Extract context with invalid line number."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
content = "line 1\nline 2\nline 3"
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
# Line 0 (invalid)
|
||||
assert engine._extract_context(content, line=0, context_lines=1) == ""
|
||||
|
||||
# Line beyond end
|
||||
assert engine._extract_context(content, line=100, context_lines=1) == ""
|
||||
|
||||
|
||||
class TestSearchReferences:
|
||||
"""Test search_references method."""
|
||||
|
||||
def test_returns_empty_for_no_source_path_and_no_registry(self):
|
||||
"""Returns empty list when no source path and registry has no mappings."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_registry.list_mappings.return_value = []
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
results = engine.search_references("test_symbol")
|
||||
|
||||
assert results == []
|
||||
|
||||
def test_returns_empty_for_no_indexes(self):
|
||||
"""Returns empty list when no indexes found."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
mock_mapper.source_to_index_db.return_value = Path("/nonexistent/_index.db")
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=None):
|
||||
results = engine.search_references("test_symbol", Path("/some/path"))
|
||||
|
||||
assert results == []
|
||||
|
||||
def test_deduplicates_results(self):
|
||||
"""Removes duplicate file:line references."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
# Create a temporary database with duplicate relationships
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'def test(): pass');
|
||||
INSERT INTO symbols VALUES (1, 1, 'test_func', 'function', 1, 1);
|
||||
INSERT INTO code_relationships VALUES (1, 1, 'target_func', 'call', 10, NULL);
|
||||
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 10, NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target_func", Path(tmpdir))
|
||||
|
||||
# Should only have 1 result due to deduplication
|
||||
assert len(results) == 1
|
||||
assert results[0].line == 10
|
||||
|
||||
def test_sorts_by_file_and_line(self):
|
||||
"""Results sorted by file path then line number."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine, ReferenceResult
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/b_file.py', 'python', 'content');
|
||||
INSERT INTO files VALUES (2, '/test/a_file.py', 'python', 'content');
|
||||
INSERT INTO symbols VALUES (1, 1, 'func1', 'function', 1, 1);
|
||||
INSERT INTO symbols VALUES (2, 2, 'func2', 'function', 1, 1);
|
||||
INSERT INTO code_relationships VALUES (1, 1, 'target', 'call', 20, NULL);
|
||||
INSERT INTO code_relationships VALUES (2, 1, 'target', 'call', 10, NULL);
|
||||
INSERT INTO code_relationships VALUES (3, 2, 'target', 'call', 5, NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target", Path(tmpdir))
|
||||
|
||||
# Should be sorted: a_file.py:5, b_file.py:10, b_file.py:20
|
||||
assert len(results) == 3
|
||||
assert results[0].file_path == "/test/a_file.py"
|
||||
assert results[0].line == 5
|
||||
assert results[1].file_path == "/test/b_file.py"
|
||||
assert results[1].line == 10
|
||||
assert results[2].file_path == "/test/b_file.py"
|
||||
assert results[2].line == 20
|
||||
|
||||
def test_respects_limit(self):
|
||||
"""Returns at most limit results."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
|
||||
INSERT INTO symbols VALUES (1, 1, 'func', 'function', 1, 1);
|
||||
""")
|
||||
# Insert many relationships
|
||||
for i in range(50):
|
||||
conn.execute(
|
||||
"INSERT INTO code_relationships VALUES (?, 1, 'target', 'call', ?, NULL)",
|
||||
(i + 1, i + 1)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target", Path(tmpdir), limit=10)
|
||||
|
||||
assert len(results) == 10
|
||||
|
||||
def test_matches_qualified_name(self):
|
||||
"""Matches symbols by qualified name suffix."""
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
|
||||
mock_registry = Mock()
|
||||
mock_mapper = Mock()
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY,
|
||||
path TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE symbols (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
);
|
||||
|
||||
INSERT INTO files VALUES (1, '/test/file.py', 'python', 'content');
|
||||
INSERT INTO symbols VALUES (1, 1, 'caller', 'function', 1, 1);
|
||||
-- Fully qualified name
|
||||
INSERT INTO code_relationships VALUES (1, 1, 'module.submodule.target_func', 'call', 10, NULL);
|
||||
-- Simple name
|
||||
INSERT INTO code_relationships VALUES (2, 1, 'target_func', 'call', 20, NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
with patch.object(engine, "_find_start_index", return_value=db_path):
|
||||
with patch.object(engine, "_collect_index_paths", return_value=[db_path]):
|
||||
results = engine.search_references("target_func", Path(tmpdir))
|
||||
|
||||
# Should find both references
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
class TestLspReferencesHandler:
|
||||
"""Test the LSP references handler."""
|
||||
|
||||
def test_handler_uses_search_engine(self):
|
||||
"""Handler uses search_engine.search_references when available."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import _path_to_uri
|
||||
from codexlens.search.chain_search import ReferenceResult
|
||||
|
||||
# Create mock references
|
||||
mock_references = [
|
||||
ReferenceResult(
|
||||
file_path="/test/file1.py",
|
||||
line=10,
|
||||
column=5,
|
||||
context="def foo():",
|
||||
relationship_type="call",
|
||||
),
|
||||
ReferenceResult(
|
||||
file_path="/test/file2.py",
|
||||
line=20,
|
||||
column=0,
|
||||
context="import foo",
|
||||
relationship_type="import",
|
||||
),
|
||||
]
|
||||
|
||||
# Verify conversion to LSP Location
|
||||
locations = []
|
||||
for ref in mock_references:
|
||||
locations.append(
|
||||
lsp.Location(
|
||||
uri=_path_to_uri(ref.file_path),
|
||||
range=lsp.Range(
|
||||
start=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column,
|
||||
),
|
||||
end=lsp.Position(
|
||||
line=max(0, ref.line - 1),
|
||||
character=ref.column + len("foo"),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
assert len(locations) == 2
|
||||
# First reference at line 10 (0-indexed = 9)
|
||||
assert locations[0].range.start.line == 9
|
||||
assert locations[0].range.start.character == 5
|
||||
# Second reference at line 20 (0-indexed = 19)
|
||||
assert locations[1].range.start.line == 19
|
||||
assert locations[1].range.start.character == 0
|
||||
|
||||
def test_handler_falls_back_to_global_index(self):
|
||||
"""Handler falls back to global_index when search_engine unavailable."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from codexlens.lsp.handlers import symbol_to_location
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
# Test fallback path converts Symbol to Location
|
||||
symbol = Symbol(
|
||||
name="test_func",
|
||||
kind="function",
|
||||
range=(10, 15),
|
||||
file="/test/file.py",
|
||||
)
|
||||
|
||||
location = symbol_to_location(symbol)
|
||||
assert location is not None
|
||||
# LSP uses 0-based lines
|
||||
assert location.range.start.line == 9
|
||||
assert location.range.end.line == 14
|
||||
210
codex-lens/tests/lsp/test_server.py
Normal file
210
codex-lens/tests/lsp/test_server.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for codex-lens LSP server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from codexlens.entities import Symbol
|
||||
|
||||
|
||||
class TestCodexLensLanguageServer:
|
||||
"""Tests for CodexLensLanguageServer."""
|
||||
|
||||
def test_server_import(self):
|
||||
"""Test that server module can be imported."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from codexlens.lsp.server import CodexLensLanguageServer, server
|
||||
|
||||
assert CodexLensLanguageServer is not None
|
||||
assert server is not None
|
||||
assert server.name == "codexlens-lsp"
|
||||
|
||||
def test_server_initialization(self):
|
||||
"""Test server instance creation."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.server import CodexLensLanguageServer
|
||||
|
||||
ls = CodexLensLanguageServer()
|
||||
assert ls.registry is None
|
||||
assert ls.mapper is None
|
||||
assert ls.global_index is None
|
||||
assert ls.search_engine is None
|
||||
assert ls.workspace_root is None
|
||||
|
||||
|
||||
class TestDefinitionHandler:
|
||||
"""Tests for definition handler."""
|
||||
|
||||
def test_definition_lookup(self):
|
||||
"""Test definition lookup returns location for known symbol."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import symbol_to_location
|
||||
|
||||
symbol = Symbol(
|
||||
name="test_function",
|
||||
kind="function",
|
||||
range=(10, 15),
|
||||
file="/path/to/file.py",
|
||||
)
|
||||
|
||||
location = symbol_to_location(symbol)
|
||||
|
||||
assert location is not None
|
||||
assert isinstance(location, lsp.Location)
|
||||
# LSP uses 0-based lines
|
||||
assert location.range.start.line == 9
|
||||
assert location.range.end.line == 14
|
||||
|
||||
def test_definition_no_file(self):
|
||||
"""Test definition lookup returns None for symbol without file."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import symbol_to_location
|
||||
|
||||
symbol = Symbol(
|
||||
name="test_function",
|
||||
kind="function",
|
||||
range=(10, 15),
|
||||
file=None,
|
||||
)
|
||||
|
||||
location = symbol_to_location(symbol)
|
||||
assert location is None
|
||||
|
||||
|
||||
class TestCompletionHandler:
|
||||
"""Tests for completion handler."""
|
||||
|
||||
def test_get_prefix_at_position(self):
|
||||
"""Test extracting prefix at cursor position."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _get_prefix_at_position
|
||||
|
||||
document_text = "def hello_world():\n print(hel"
|
||||
|
||||
# Cursor at end of "hel"
|
||||
prefix = _get_prefix_at_position(document_text, 1, 14)
|
||||
assert prefix == "hel"
|
||||
|
||||
# Cursor at beginning of line (after whitespace)
|
||||
prefix = _get_prefix_at_position(document_text, 1, 4)
|
||||
assert prefix == ""
|
||||
|
||||
# Cursor after "he" in "hello_world" - returns text before cursor
|
||||
prefix = _get_prefix_at_position(document_text, 0, 6)
|
||||
assert prefix == "he"
|
||||
|
||||
# Cursor at end of "hello_world"
|
||||
prefix = _get_prefix_at_position(document_text, 0, 15)
|
||||
assert prefix == "hello_world"
|
||||
|
||||
def test_get_word_at_position(self):
|
||||
"""Test extracting word at cursor position."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _get_word_at_position
|
||||
|
||||
document_text = "def hello_world():\n print(msg)"
|
||||
|
||||
# Cursor on "hello_world"
|
||||
word = _get_word_at_position(document_text, 0, 6)
|
||||
assert word == "hello_world"
|
||||
|
||||
# Cursor on "print"
|
||||
word = _get_word_at_position(document_text, 1, 6)
|
||||
assert word == "print"
|
||||
|
||||
# Cursor on "msg"
|
||||
word = _get_word_at_position(document_text, 1, 11)
|
||||
assert word == "msg"
|
||||
|
||||
def test_symbol_kind_mapping(self):
|
||||
"""Test symbol kind to completion kind mapping."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import _symbol_kind_to_completion_kind
|
||||
|
||||
assert _symbol_kind_to_completion_kind("function") == lsp.CompletionItemKind.Function
|
||||
assert _symbol_kind_to_completion_kind("class") == lsp.CompletionItemKind.Class
|
||||
assert _symbol_kind_to_completion_kind("method") == lsp.CompletionItemKind.Method
|
||||
assert _symbol_kind_to_completion_kind("variable") == lsp.CompletionItemKind.Variable
|
||||
|
||||
# Unknown kind should default to Text
|
||||
assert _symbol_kind_to_completion_kind("unknown") == lsp.CompletionItemKind.Text
|
||||
|
||||
|
||||
class TestWorkspaceSymbolHandler:
|
||||
"""Tests for workspace symbol handler."""
|
||||
|
||||
def test_symbol_kind_to_lsp(self):
|
||||
"""Test symbol kind to LSP SymbolKind mapping."""
|
||||
pytest.importorskip("pygls")
|
||||
pytest.importorskip("lsprotocol")
|
||||
|
||||
from lsprotocol import types as lsp
|
||||
from codexlens.lsp.handlers import _symbol_kind_to_lsp
|
||||
|
||||
assert _symbol_kind_to_lsp("function") == lsp.SymbolKind.Function
|
||||
assert _symbol_kind_to_lsp("class") == lsp.SymbolKind.Class
|
||||
assert _symbol_kind_to_lsp("method") == lsp.SymbolKind.Method
|
||||
assert _symbol_kind_to_lsp("interface") == lsp.SymbolKind.Interface
|
||||
|
||||
# Unknown kind should default to Variable
|
||||
assert _symbol_kind_to_lsp("unknown") == lsp.SymbolKind.Variable
|
||||
|
||||
|
||||
class TestUriConversion:
|
||||
"""Tests for URI path conversion."""
|
||||
|
||||
def test_path_to_uri(self):
|
||||
"""Test path to URI conversion."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _path_to_uri
|
||||
|
||||
# Unix path
|
||||
uri = _path_to_uri("/home/user/file.py")
|
||||
assert uri.startswith("file://")
|
||||
assert "file.py" in uri
|
||||
|
||||
def test_uri_to_path(self):
|
||||
"""Test URI to path conversion."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
from codexlens.lsp.handlers import _uri_to_path
|
||||
|
||||
# Basic URI
|
||||
path = _uri_to_path("file:///home/user/file.py")
|
||||
assert path.name == "file.py"
|
||||
|
||||
|
||||
class TestMainEntryPoint:
|
||||
"""Tests for main entry point."""
|
||||
|
||||
def test_main_help(self):
|
||||
"""Test that main shows help without errors."""
|
||||
pytest.importorskip("pygls")
|
||||
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
# Patch sys.argv to show help
|
||||
with patch.object(sys, 'argv', ['codexlens-lsp', '--help']):
|
||||
from codexlens.lsp.server import main
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main()
|
||||
|
||||
# Help exits with 0
|
||||
assert exc_info.value.code == 0
|
||||
1
codex-lens/tests/mcp/__init__.py
Normal file
1
codex-lens/tests/mcp/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for MCP (Model Context Protocol) module."""
|
||||
208
codex-lens/tests/mcp/test_hooks.py
Normal file
208
codex-lens/tests/mcp/test_hooks.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Tests for MCP hooks module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.mcp.hooks import HookManager, create_context_for_prompt
|
||||
from codexlens.mcp.schema import MCPContext, SymbolInfo
|
||||
|
||||
|
||||
class TestHookManager:
|
||||
"""Test HookManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(self):
|
||||
"""Create a mock MCP provider."""
|
||||
provider = Mock()
|
||||
provider.build_context.return_value = MCPContext(
|
||||
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
|
||||
context_type="symbol_explanation",
|
||||
)
|
||||
provider.build_context_for_file.return_value = MCPContext(
|
||||
context_type="file_overview",
|
||||
)
|
||||
return provider
|
||||
|
||||
@pytest.fixture
|
||||
def hook_manager(self, mock_provider):
|
||||
"""Create a HookManager with mocked provider."""
|
||||
return HookManager(mock_provider)
|
||||
|
||||
def test_default_hooks_registered(self, hook_manager):
|
||||
"""Default hooks are registered on initialization."""
|
||||
assert "explain" in hook_manager._pre_hooks
|
||||
assert "refactor" in hook_manager._pre_hooks
|
||||
assert "document" in hook_manager._pre_hooks
|
||||
|
||||
def test_execute_pre_hook_returns_context(self, hook_manager, mock_provider):
|
||||
"""execute_pre_hook returns MCPContext for registered hook."""
|
||||
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MCPContext)
|
||||
mock_provider.build_context.assert_called_once()
|
||||
|
||||
def test_execute_pre_hook_returns_none_for_unknown_action(self, hook_manager):
|
||||
"""execute_pre_hook returns None for unregistered action."""
|
||||
result = hook_manager.execute_pre_hook("unknown_action", {"symbol": "test"})
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_execute_pre_hook_handles_exception(self, hook_manager, mock_provider):
|
||||
"""execute_pre_hook handles provider exceptions gracefully."""
|
||||
mock_provider.build_context.side_effect = Exception("Provider failed")
|
||||
|
||||
result = hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_execute_post_hook_no_error_for_unregistered(self, hook_manager):
|
||||
"""execute_post_hook doesn't error for unregistered action."""
|
||||
# Should not raise
|
||||
hook_manager.execute_post_hook("unknown", {"result": "data"})
|
||||
|
||||
def test_pre_explain_hook_calls_build_context(self, hook_manager, mock_provider):
|
||||
"""_pre_explain_hook calls build_context correctly."""
|
||||
hook_manager.execute_pre_hook("explain", {"symbol": "my_func"})
|
||||
|
||||
mock_provider.build_context.assert_called_with(
|
||||
symbol_name="my_func",
|
||||
context_type="symbol_explanation",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def test_pre_explain_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
|
||||
"""_pre_explain_hook returns None when symbol param missing."""
|
||||
result = hook_manager.execute_pre_hook("explain", {})
|
||||
|
||||
assert result is None
|
||||
mock_provider.build_context.assert_not_called()
|
||||
|
||||
def test_pre_refactor_hook_calls_build_context(self, hook_manager, mock_provider):
|
||||
"""_pre_refactor_hook calls build_context with refactor settings."""
|
||||
hook_manager.execute_pre_hook("refactor", {"symbol": "my_class"})
|
||||
|
||||
mock_provider.build_context.assert_called_with(
|
||||
symbol_name="my_class",
|
||||
context_type="refactor_context",
|
||||
include_references=True,
|
||||
include_related=True,
|
||||
max_references=20,
|
||||
)
|
||||
|
||||
def test_pre_refactor_hook_returns_none_without_symbol(self, hook_manager, mock_provider):
|
||||
"""_pre_refactor_hook returns None when symbol param missing."""
|
||||
result = hook_manager.execute_pre_hook("refactor", {})
|
||||
|
||||
assert result is None
|
||||
mock_provider.build_context.assert_not_called()
|
||||
|
||||
def test_pre_document_hook_with_symbol(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook uses build_context when symbol provided."""
|
||||
hook_manager.execute_pre_hook("document", {"symbol": "my_func"})
|
||||
|
||||
mock_provider.build_context.assert_called_with(
|
||||
symbol_name="my_func",
|
||||
context_type="documentation_context",
|
||||
include_references=False,
|
||||
include_related=True,
|
||||
)
|
||||
|
||||
def test_pre_document_hook_with_file_path(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook uses build_context_for_file when file_path provided."""
|
||||
hook_manager.execute_pre_hook("document", {"file_path": "/src/module.py"})
|
||||
|
||||
mock_provider.build_context_for_file.assert_called_once()
|
||||
call_args = mock_provider.build_context_for_file.call_args
|
||||
assert call_args[0][0] == Path("/src/module.py")
|
||||
assert call_args[1].get("context_type") == "file_documentation"
|
||||
|
||||
def test_pre_document_hook_prefers_symbol_over_file(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook prefers symbol when both provided."""
|
||||
hook_manager.execute_pre_hook(
|
||||
"document", {"symbol": "my_func", "file_path": "/src/module.py"}
|
||||
)
|
||||
|
||||
mock_provider.build_context.assert_called_once()
|
||||
mock_provider.build_context_for_file.assert_not_called()
|
||||
|
||||
def test_pre_document_hook_returns_none_without_params(self, hook_manager, mock_provider):
|
||||
"""_pre_document_hook returns None when neither symbol nor file_path provided."""
|
||||
result = hook_manager.execute_pre_hook("document", {})
|
||||
|
||||
assert result is None
|
||||
mock_provider.build_context.assert_not_called()
|
||||
mock_provider.build_context_for_file.assert_not_called()
|
||||
|
||||
def test_register_pre_hook(self, hook_manager):
|
||||
"""register_pre_hook adds custom hook."""
|
||||
custom_hook = Mock(return_value=MCPContext())
|
||||
|
||||
hook_manager.register_pre_hook("custom_action", custom_hook)
|
||||
|
||||
assert "custom_action" in hook_manager._pre_hooks
|
||||
hook_manager.execute_pre_hook("custom_action", {"data": "value"})
|
||||
custom_hook.assert_called_once_with({"data": "value"})
|
||||
|
||||
def test_register_post_hook(self, hook_manager):
|
||||
"""register_post_hook adds custom hook."""
|
||||
custom_hook = Mock()
|
||||
|
||||
hook_manager.register_post_hook("custom_action", custom_hook)
|
||||
|
||||
assert "custom_action" in hook_manager._post_hooks
|
||||
hook_manager.execute_post_hook("custom_action", {"result": "data"})
|
||||
custom_hook.assert_called_once_with({"result": "data"})
|
||||
|
||||
def test_execute_post_hook_handles_exception(self, hook_manager):
|
||||
"""execute_post_hook handles hook exceptions gracefully."""
|
||||
failing_hook = Mock(side_effect=Exception("Hook failed"))
|
||||
hook_manager.register_post_hook("failing", failing_hook)
|
||||
|
||||
# Should not raise
|
||||
hook_manager.execute_post_hook("failing", {"data": "value"})
|
||||
|
||||
|
||||
class TestCreateContextForPrompt:
|
||||
"""Test create_context_for_prompt function."""
|
||||
|
||||
def test_returns_prompt_injection_string(self):
|
||||
"""create_context_for_prompt returns formatted string."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.build_context.return_value = MCPContext(
|
||||
symbol=SymbolInfo("test_func", "function", "/test.py", 1, 10),
|
||||
definition="def test_func(): pass",
|
||||
)
|
||||
|
||||
result = create_context_for_prompt(
|
||||
mock_provider, "explain", {"symbol": "test_func"}
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "<code_context>" in result
|
||||
assert "test_func" in result
|
||||
assert "</code_context>" in result
|
||||
|
||||
def test_returns_empty_string_when_no_context(self):
|
||||
"""create_context_for_prompt returns empty string when no context built."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.build_context.return_value = None
|
||||
|
||||
result = create_context_for_prompt(
|
||||
mock_provider, "explain", {"symbol": "nonexistent"}
|
||||
)
|
||||
|
||||
assert result == ""
|
||||
|
||||
def test_returns_empty_string_for_unknown_action(self):
|
||||
"""create_context_for_prompt returns empty string for unregistered action."""
|
||||
mock_provider = Mock()
|
||||
|
||||
result = create_context_for_prompt(
|
||||
mock_provider, "unknown_action", {"data": "value"}
|
||||
)
|
||||
|
||||
assert result == ""
|
||||
mock_provider.build_context.assert_not_called()
|
||||
383
codex-lens/tests/mcp/test_provider.py
Normal file
383
codex-lens/tests/mcp/test_provider.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Tests for MCP provider."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from codexlens.mcp.provider import MCPProvider
|
||||
from codexlens.mcp.schema import MCPContext, SymbolInfo, ReferenceInfo
|
||||
|
||||
|
||||
class TestMCPProvider:
|
||||
"""Test MCPProvider class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_global_index(self):
|
||||
"""Create a mock global index."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_engine(self):
|
||||
"""Create a mock search engine."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(self):
|
||||
"""Create a mock registry."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_global_index, mock_search_engine, mock_registry):
|
||||
"""Create an MCPProvider with mocked dependencies."""
|
||||
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
|
||||
|
||||
def test_build_context_returns_none_for_unknown_symbol(self, provider, mock_global_index):
|
||||
"""build_context returns None when symbol is not found."""
|
||||
mock_global_index.search.return_value = []
|
||||
|
||||
result = provider.build_context("unknown_symbol")
|
||||
|
||||
assert result is None
|
||||
mock_global_index.search.assert_called_once_with(
|
||||
"unknown_symbol", prefix_mode=False, limit=1
|
||||
)
|
||||
|
||||
def test_build_context_returns_mcp_context(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context returns MCPContext for known symbol."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
result = provider.build_context("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MCPContext)
|
||||
assert result.symbol is not None
|
||||
assert result.symbol.name == "my_func"
|
||||
assert result.symbol.kind == "function"
|
||||
assert result.context_type == "symbol_explanation"
|
||||
|
||||
def test_build_context_with_custom_context_type(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context respects custom context_type."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
result = provider.build_context("my_func", context_type="refactor_context")
|
||||
|
||||
assert result is not None
|
||||
assert result.context_type == "refactor_context"
|
||||
|
||||
def test_build_context_includes_references(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context includes references when include_references=True."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_ref = Mock()
|
||||
mock_ref.file_path = "/caller.py"
|
||||
mock_ref.line = 25
|
||||
mock_ref.column = 4
|
||||
mock_ref.context = "result = my_func()"
|
||||
mock_ref.relationship_type = "call"
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = [mock_ref]
|
||||
|
||||
result = provider.build_context("my_func", include_references=True)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.references) == 1
|
||||
assert result.references[0].file_path == "/caller.py"
|
||||
assert result.references[0].line == 25
|
||||
assert result.references[0].relationship_type == "call"
|
||||
|
||||
def test_build_context_excludes_references_when_disabled(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context excludes references when include_references=False."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
# Disable both references and related to avoid any search_references calls
|
||||
result = provider.build_context(
|
||||
"my_func", include_references=False, include_related=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.references) == 0
|
||||
mock_search_engine.search_references.assert_not_called()
|
||||
|
||||
def test_build_context_respects_max_references(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context passes max_references to search engine."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
# Disable include_related to test only the references call
|
||||
provider.build_context("my_func", max_references=5, include_related=False)
|
||||
|
||||
mock_search_engine.search_references.assert_called_once_with(
|
||||
"my_func", limit=5
|
||||
)
|
||||
|
||||
def test_build_context_includes_metadata(
|
||||
self, provider, mock_global_index, mock_search_engine
|
||||
):
|
||||
"""build_context includes source metadata."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = "/test.py"
|
||||
mock_symbol.range = (10, 20)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = []
|
||||
|
||||
result = provider.build_context("my_func")
|
||||
|
||||
assert result is not None
|
||||
assert result.metadata.get("source") == "codex-lens"
|
||||
|
||||
def test_extract_definition_with_valid_file(self, provider):
|
||||
"""_extract_definition reads file content correctly."""
|
||||
# Create a temporary file with some content
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write("# Line 1\n")
|
||||
f.write("# Line 2\n")
|
||||
f.write("def my_func():\n") # Line 3
|
||||
f.write(" pass\n") # Line 4
|
||||
f.write("# Line 5\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (3, 4) # 1-based line numbers
|
||||
|
||||
definition = provider._extract_definition(mock_symbol)
|
||||
|
||||
assert definition is not None
|
||||
assert "def my_func():" in definition
|
||||
assert "pass" in definition
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_extract_definition_returns_none_for_missing_file(self, provider):
|
||||
"""_extract_definition returns None for non-existent file."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.file = "/nonexistent/path/file.py"
|
||||
mock_symbol.range = (1, 5)
|
||||
|
||||
definition = provider._extract_definition(mock_symbol)
|
||||
|
||||
assert definition is None
|
||||
|
||||
def test_extract_definition_returns_none_for_none_file(self, provider):
|
||||
"""_extract_definition returns None when symbol.file is None."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.file = None
|
||||
mock_symbol.range = (1, 5)
|
||||
|
||||
definition = provider._extract_definition(mock_symbol)
|
||||
|
||||
assert definition is None
|
||||
|
||||
def test_build_context_for_file_returns_context(
|
||||
self, provider, mock_global_index
|
||||
):
|
||||
"""build_context_for_file returns MCPContext."""
|
||||
mock_global_index.search.return_value = []
|
||||
|
||||
result = provider.build_context_for_file(
|
||||
Path("/test/file.py"),
|
||||
context_type="file_overview",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MCPContext)
|
||||
assert result.context_type == "file_overview"
|
||||
assert result.metadata.get("file_path") == str(Path("/test/file.py"))
|
||||
|
||||
def test_build_context_for_file_includes_symbols(
|
||||
self, provider, mock_global_index
|
||||
):
|
||||
"""build_context_for_file includes symbols from the file."""
|
||||
# Create temp file to get resolved path
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write("def func(): pass\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "func"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (1, 1)
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
|
||||
result = provider.build_context_for_file(Path(temp_path))
|
||||
|
||||
assert result is not None
|
||||
# Symbols from this file should be in related_symbols
|
||||
assert len(result.related_symbols) >= 0 # May be 0 if filtering doesn't match
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
class TestMCPProviderRelatedSymbols:
|
||||
"""Test related symbols functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
"""Create provider with mocks."""
|
||||
mock_global_index = Mock()
|
||||
mock_search_engine = Mock()
|
||||
mock_registry = Mock()
|
||||
return MCPProvider(mock_global_index, mock_search_engine, mock_registry)
|
||||
|
||||
def test_get_related_symbols_from_references(self, provider):
|
||||
"""_get_related_symbols extracts symbols from references."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.file = "/test.py"
|
||||
|
||||
mock_ref1 = Mock()
|
||||
mock_ref1.file_path = "/caller1.py"
|
||||
mock_ref1.relationship_type = "call"
|
||||
|
||||
mock_ref2 = Mock()
|
||||
mock_ref2.file_path = "/caller2.py"
|
||||
mock_ref2.relationship_type = "import"
|
||||
|
||||
provider.search_engine.search_references.return_value = [mock_ref1, mock_ref2]
|
||||
|
||||
related = provider._get_related_symbols(mock_symbol)
|
||||
|
||||
assert len(related) == 2
|
||||
assert related[0].relationship == "call"
|
||||
assert related[1].relationship == "import"
|
||||
|
||||
def test_get_related_symbols_limits_results(self, provider):
|
||||
"""_get_related_symbols limits to 10 unique relationship types."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.file = "/test.py"
|
||||
|
||||
# Create 15 references with unique relationship types
|
||||
refs = []
|
||||
for i in range(15):
|
||||
ref = Mock()
|
||||
ref.file_path = f"/file{i}.py"
|
||||
ref.relationship_type = f"type{i}"
|
||||
refs.append(ref)
|
||||
|
||||
provider.search_engine.search_references.return_value = refs
|
||||
|
||||
related = provider._get_related_symbols(mock_symbol)
|
||||
|
||||
assert len(related) <= 10
|
||||
|
||||
def test_get_related_symbols_handles_exception(self, provider):
|
||||
"""_get_related_symbols handles exceptions gracefully."""
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_func"
|
||||
mock_symbol.file = "/test.py"
|
||||
|
||||
provider.search_engine.search_references.side_effect = Exception("Search failed")
|
||||
|
||||
related = provider._get_related_symbols(mock_symbol)
|
||||
|
||||
assert related == []
|
||||
|
||||
|
||||
class TestMCPProviderIntegration:
|
||||
"""Integration-style tests for MCPProvider."""
|
||||
|
||||
def test_full_context_workflow(self):
|
||||
"""Test complete context building workflow."""
|
||||
# Create temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
||||
f.write("def my_function(arg1, arg2):\n")
|
||||
f.write(" '''This is my function.'''\n")
|
||||
f.write(" return arg1 + arg2\n")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Setup mocks
|
||||
mock_global_index = Mock()
|
||||
mock_search_engine = Mock()
|
||||
mock_registry = Mock()
|
||||
|
||||
mock_symbol = Mock()
|
||||
mock_symbol.name = "my_function"
|
||||
mock_symbol.kind = "function"
|
||||
mock_symbol.file = temp_path
|
||||
mock_symbol.range = (1, 3)
|
||||
|
||||
mock_ref = Mock()
|
||||
mock_ref.file_path = "/user.py"
|
||||
mock_ref.line = 10
|
||||
mock_ref.column = 4
|
||||
mock_ref.context = "result = my_function(1, 2)"
|
||||
mock_ref.relationship_type = "call"
|
||||
|
||||
mock_global_index.search.return_value = [mock_symbol]
|
||||
mock_search_engine.search_references.return_value = [mock_ref]
|
||||
|
||||
provider = MCPProvider(mock_global_index, mock_search_engine, mock_registry)
|
||||
context = provider.build_context("my_function")
|
||||
|
||||
assert context is not None
|
||||
assert context.symbol.name == "my_function"
|
||||
assert context.definition is not None
|
||||
assert "def my_function" in context.definition
|
||||
assert len(context.references) == 1
|
||||
assert context.references[0].relationship_type == "call"
|
||||
|
||||
# Test serialization
|
||||
json_str = context.to_json()
|
||||
assert "my_function" in json_str
|
||||
|
||||
# Test prompt injection
|
||||
prompt = context.to_prompt_injection()
|
||||
assert "<code_context>" in prompt
|
||||
assert "my_function" in prompt
|
||||
assert "</code_context>" in prompt
|
||||
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
288
codex-lens/tests/mcp/test_schema.py
Normal file
288
codex-lens/tests/mcp/test_schema.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Tests for MCP schema."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from codexlens.mcp.schema import (
|
||||
MCPContext,
|
||||
SymbolInfo,
|
||||
ReferenceInfo,
|
||||
RelatedSymbol,
|
||||
)
|
||||
|
||||
|
||||
class TestSymbolInfo:
|
||||
"""Test SymbolInfo dataclass."""
|
||||
|
||||
def test_to_dict_includes_all_fields(self):
|
||||
"""SymbolInfo.to_dict() includes all non-None fields."""
|
||||
info = SymbolInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line_start=10,
|
||||
line_end=20,
|
||||
signature="def func():",
|
||||
documentation="Test doc",
|
||||
)
|
||||
d = info.to_dict()
|
||||
assert d["name"] == "func"
|
||||
assert d["kind"] == "function"
|
||||
assert d["file_path"] == "/test.py"
|
||||
assert d["line_start"] == 10
|
||||
assert d["line_end"] == 20
|
||||
assert d["signature"] == "def func():"
|
||||
assert d["documentation"] == "Test doc"
|
||||
|
||||
def test_to_dict_excludes_none(self):
|
||||
"""SymbolInfo.to_dict() excludes None fields."""
|
||||
info = SymbolInfo(
|
||||
name="func",
|
||||
kind="function",
|
||||
file_path="/test.py",
|
||||
line_start=10,
|
||||
line_end=20,
|
||||
)
|
||||
d = info.to_dict()
|
||||
assert "signature" not in d
|
||||
assert "documentation" not in d
|
||||
assert "name" in d
|
||||
assert "kind" in d
|
||||
|
||||
def test_basic_creation(self):
|
||||
"""SymbolInfo can be created with required fields only."""
|
||||
info = SymbolInfo(
|
||||
name="MyClass",
|
||||
kind="class",
|
||||
file_path="/src/module.py",
|
||||
line_start=1,
|
||||
line_end=50,
|
||||
)
|
||||
assert info.name == "MyClass"
|
||||
assert info.kind == "class"
|
||||
assert info.signature is None
|
||||
assert info.documentation is None
|
||||
|
||||
|
||||
class TestReferenceInfo:
|
||||
"""Test ReferenceInfo dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""ReferenceInfo.to_dict() returns all fields."""
|
||||
ref = ReferenceInfo(
|
||||
file_path="/src/main.py",
|
||||
line=25,
|
||||
column=4,
|
||||
context="result = func()",
|
||||
relationship_type="call",
|
||||
)
|
||||
d = ref.to_dict()
|
||||
assert d["file_path"] == "/src/main.py"
|
||||
assert d["line"] == 25
|
||||
assert d["column"] == 4
|
||||
assert d["context"] == "result = func()"
|
||||
assert d["relationship_type"] == "call"
|
||||
|
||||
def test_all_fields_required(self):
|
||||
"""ReferenceInfo requires all fields."""
|
||||
ref = ReferenceInfo(
|
||||
file_path="/test.py",
|
||||
line=10,
|
||||
column=0,
|
||||
context="import module",
|
||||
relationship_type="import",
|
||||
)
|
||||
assert ref.file_path == "/test.py"
|
||||
assert ref.relationship_type == "import"
|
||||
|
||||
|
||||
class TestRelatedSymbol:
|
||||
"""Test RelatedSymbol dataclass."""
|
||||
|
||||
def test_to_dict_includes_all_fields(self):
|
||||
"""RelatedSymbol.to_dict() includes all non-None fields."""
|
||||
sym = RelatedSymbol(
|
||||
name="BaseClass",
|
||||
kind="class",
|
||||
relationship="inherits",
|
||||
file_path="/src/base.py",
|
||||
)
|
||||
d = sym.to_dict()
|
||||
assert d["name"] == "BaseClass"
|
||||
assert d["kind"] == "class"
|
||||
assert d["relationship"] == "inherits"
|
||||
assert d["file_path"] == "/src/base.py"
|
||||
|
||||
def test_to_dict_excludes_none(self):
|
||||
"""RelatedSymbol.to_dict() excludes None file_path."""
|
||||
sym = RelatedSymbol(
|
||||
name="helper",
|
||||
kind="function",
|
||||
relationship="calls",
|
||||
)
|
||||
d = sym.to_dict()
|
||||
assert "file_path" not in d
|
||||
assert d["name"] == "helper"
|
||||
assert d["relationship"] == "calls"
|
||||
|
||||
|
||||
class TestMCPContext:
|
||||
"""Test MCPContext dataclass."""
|
||||
|
||||
def test_to_dict_basic(self):
|
||||
"""MCPContext.to_dict() returns basic structure."""
|
||||
ctx = MCPContext(context_type="test")
|
||||
d = ctx.to_dict()
|
||||
assert d["version"] == "1.0"
|
||||
assert d["context_type"] == "test"
|
||||
assert d["metadata"] == {}
|
||||
|
||||
def test_to_dict_with_symbol(self):
|
||||
"""MCPContext.to_dict() includes symbol when present."""
|
||||
ctx = MCPContext(
|
||||
context_type="test",
|
||||
symbol=SymbolInfo("f", "function", "/t.py", 1, 2),
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert "symbol" in d
|
||||
assert d["symbol"]["name"] == "f"
|
||||
assert d["symbol"]["kind"] == "function"
|
||||
|
||||
def test_to_dict_with_references(self):
|
||||
"""MCPContext.to_dict() includes references when present."""
|
||||
ctx = MCPContext(
|
||||
context_type="test",
|
||||
references=[
|
||||
ReferenceInfo("/a.py", 10, 0, "call()", "call"),
|
||||
ReferenceInfo("/b.py", 20, 5, "import x", "import"),
|
||||
],
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert "references" in d
|
||||
assert len(d["references"]) == 2
|
||||
assert d["references"][0]["line"] == 10
|
||||
|
||||
def test_to_dict_with_related_symbols(self):
|
||||
"""MCPContext.to_dict() includes related_symbols when present."""
|
||||
ctx = MCPContext(
|
||||
context_type="test",
|
||||
related_symbols=[
|
||||
RelatedSymbol("Base", "class", "inherits"),
|
||||
RelatedSymbol("helper", "function", "calls"),
|
||||
],
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert "related_symbols" in d
|
||||
assert len(d["related_symbols"]) == 2
|
||||
|
||||
def test_to_json(self):
|
||||
"""MCPContext.to_json() returns valid JSON."""
|
||||
ctx = MCPContext(context_type="test")
|
||||
j = ctx.to_json()
|
||||
parsed = json.loads(j)
|
||||
assert parsed["version"] == "1.0"
|
||||
assert parsed["context_type"] == "test"
|
||||
|
||||
def test_to_json_with_indent(self):
|
||||
"""MCPContext.to_json() respects indent parameter."""
|
||||
ctx = MCPContext(context_type="test")
|
||||
j = ctx.to_json(indent=4)
|
||||
# Check it's properly indented
|
||||
assert " " in j
|
||||
|
||||
def test_to_prompt_injection_basic(self):
|
||||
"""MCPContext.to_prompt_injection() returns formatted string."""
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("my_func", "function", "/test.py", 10, 20),
|
||||
definition="def my_func(): pass",
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "<code_context>" in prompt
|
||||
assert "my_func" in prompt
|
||||
assert "def my_func()" in prompt
|
||||
assert "</code_context>" in prompt
|
||||
|
||||
def test_to_prompt_injection_with_references(self):
|
||||
"""MCPContext.to_prompt_injection() includes references."""
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
|
||||
references=[
|
||||
ReferenceInfo("/a.py", 10, 0, "func()", "call"),
|
||||
ReferenceInfo("/b.py", 20, 0, "from x import func", "import"),
|
||||
],
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "References (2 found)" in prompt
|
||||
assert "/a.py:10" in prompt
|
||||
assert "call" in prompt
|
||||
|
||||
def test_to_prompt_injection_limits_references(self):
|
||||
"""MCPContext.to_prompt_injection() limits references to 5."""
|
||||
refs = [
|
||||
ReferenceInfo(f"/file{i}.py", i, 0, f"ref{i}", "call")
|
||||
for i in range(10)
|
||||
]
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
|
||||
references=refs,
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
# Should show "10 found" but only include 5
|
||||
assert "References (10 found)" in prompt
|
||||
assert "/file0.py" in prompt
|
||||
assert "/file4.py" in prompt
|
||||
assert "/file5.py" not in prompt
|
||||
|
||||
def test_to_prompt_injection_with_related_symbols(self):
|
||||
"""MCPContext.to_prompt_injection() includes related symbols."""
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("MyClass", "class", "/test.py", 1, 50),
|
||||
related_symbols=[
|
||||
RelatedSymbol("BaseClass", "class", "inherits"),
|
||||
RelatedSymbol("helper", "function", "calls"),
|
||||
],
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "Related Symbols" in prompt
|
||||
assert "BaseClass (inherits)" in prompt
|
||||
assert "helper (calls)" in prompt
|
||||
|
||||
def test_to_prompt_injection_limits_related_symbols(self):
|
||||
"""MCPContext.to_prompt_injection() limits related symbols to 10."""
|
||||
related = [
|
||||
RelatedSymbol(f"sym{i}", "function", "calls")
|
||||
for i in range(15)
|
||||
]
|
||||
ctx = MCPContext(
|
||||
symbol=SymbolInfo("func", "function", "/test.py", 1, 5),
|
||||
related_symbols=related,
|
||||
)
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "sym0 (calls)" in prompt
|
||||
assert "sym9 (calls)" in prompt
|
||||
assert "sym10 (calls)" not in prompt
|
||||
|
||||
def test_empty_context(self):
|
||||
"""MCPContext works with minimal data."""
|
||||
ctx = MCPContext()
|
||||
d = ctx.to_dict()
|
||||
assert d["version"] == "1.0"
|
||||
assert d["context_type"] == "code_context"
|
||||
|
||||
prompt = ctx.to_prompt_injection()
|
||||
assert "<code_context>" in prompt
|
||||
assert "</code_context>" in prompt
|
||||
|
||||
def test_metadata_preserved(self):
|
||||
"""MCPContext preserves custom metadata."""
|
||||
ctx = MCPContext(
|
||||
context_type="custom",
|
||||
metadata={
|
||||
"source": "codex-lens",
|
||||
"indexed_at": "2024-01-01",
|
||||
"custom_key": "custom_value",
|
||||
},
|
||||
)
|
||||
d = ctx.to_dict()
|
||||
assert d["metadata"]["source"] == "codex-lens"
|
||||
assert d["metadata"]["custom_key"] == "custom_value"
|
||||
@@ -79,3 +79,87 @@ def test_symbol_filtering_handles_path_failures(monkeypatch: pytest.MonkeyPatch,
|
||||
if os.name == "nt":
|
||||
assert "CrossDrive" in caplog.text
|
||||
|
||||
|
||||
def test_cascade_search_strategy_routing(temp_paths: Path) -> None:
|
||||
"""Test cascade_search() routes to correct strategy implementation."""
|
||||
from unittest.mock import patch
|
||||
from codexlens.search.chain_search import ChainSearchResult, SearchStats
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
mapper = PathMapper(index_root=temp_paths / "indexes")
|
||||
config = Config(data_dir=temp_paths / "data")
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
source_path = temp_paths / "src"
|
||||
|
||||
# Test strategy='staged' routing
|
||||
with patch.object(engine, "staged_cascade_search") as mock_staged:
|
||||
mock_staged.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="staged")
|
||||
mock_staged.assert_called_once()
|
||||
|
||||
# Test strategy='binary' routing
|
||||
with patch.object(engine, "binary_cascade_search") as mock_binary:
|
||||
mock_binary.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="binary")
|
||||
mock_binary.assert_called_once()
|
||||
|
||||
# Test strategy='hybrid' routing
|
||||
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
|
||||
mock_hybrid.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="hybrid")
|
||||
mock_hybrid.assert_called_once()
|
||||
|
||||
# Test strategy='binary_rerank' routing
|
||||
with patch.object(engine, "binary_rerank_cascade_search") as mock_br:
|
||||
mock_br.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="binary_rerank")
|
||||
mock_br.assert_called_once()
|
||||
|
||||
# Test strategy='dense_rerank' routing
|
||||
with patch.object(engine, "dense_rerank_cascade_search") as mock_dr:
|
||||
mock_dr.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="dense_rerank")
|
||||
mock_dr.assert_called_once()
|
||||
|
||||
# Test default routing (no strategy specified) - defaults to binary
|
||||
with patch.object(engine, "binary_cascade_search") as mock_default:
|
||||
mock_default.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path)
|
||||
mock_default.assert_called_once()
|
||||
|
||||
|
||||
def test_cascade_search_invalid_strategy(temp_paths: Path) -> None:
|
||||
"""Test cascade_search() defaults to 'binary' for invalid strategy."""
|
||||
from unittest.mock import patch
|
||||
from codexlens.search.chain_search import ChainSearchResult, SearchStats
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
mapper = PathMapper(index_root=temp_paths / "indexes")
|
||||
config = Config(data_dir=temp_paths / "data")
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
source_path = temp_paths / "src"
|
||||
|
||||
# Invalid strategy should default to binary
|
||||
with patch.object(engine, "binary_cascade_search") as mock_binary:
|
||||
mock_binary.return_value = ChainSearchResult(
|
||||
query="query", results=[], symbols=[], stats=SearchStats()
|
||||
)
|
||||
engine.cascade_search("query", source_path, strategy="invalid_strategy")
|
||||
mock_binary.assert_called_once()
|
||||
|
||||
|
||||
766
codex-lens/tests/test_clustering_strategies.py
Normal file
766
codex-lens/tests/test_clustering_strategies.py
Normal file
@@ -0,0 +1,766 @@
|
||||
"""Unit tests for clustering strategies in the hybrid search pipeline.
|
||||
|
||||
Tests cover:
|
||||
1. HDBSCANStrategy - Primary HDBSCAN clustering
|
||||
2. DBSCANStrategy - Fallback DBSCAN clustering
|
||||
3. NoOpStrategy - No-op fallback when clustering unavailable
|
||||
4. ClusteringStrategyFactory - Factory with fallback chain
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.clustering import (
|
||||
BaseClusteringStrategy,
|
||||
ClusteringConfig,
|
||||
ClusteringStrategyFactory,
|
||||
NoOpStrategy,
|
||||
check_clustering_strategy_available,
|
||||
get_strategy,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_results() -> List[SearchResult]:
|
||||
"""Create sample search results for testing."""
|
||||
return [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="def foo(): pass"),
|
||||
SearchResult(path="b.py", score=0.8, excerpt="def foo(): pass"),
|
||||
SearchResult(path="c.py", score=0.7, excerpt="def bar(): pass"),
|
||||
SearchResult(path="d.py", score=0.6, excerpt="def bar(): pass"),
|
||||
SearchResult(path="e.py", score=0.5, excerpt="def baz(): pass"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings():
|
||||
"""Create mock embeddings for 5 results.
|
||||
|
||||
Creates embeddings that should form 2 clusters:
|
||||
- Results 0, 1 (similar to each other)
|
||||
- Results 2, 3 (similar to each other)
|
||||
- Result 4 (noise/singleton)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Create embeddings in 3D for simplicity
|
||||
return np.array(
|
||||
[
|
||||
[1.0, 0.0, 0.0], # Result 0 - cluster A
|
||||
[0.9, 0.1, 0.0], # Result 1 - cluster A
|
||||
[0.0, 1.0, 0.0], # Result 2 - cluster B
|
||||
[0.1, 0.9, 0.0], # Result 3 - cluster B
|
||||
[0.0, 0.0, 1.0], # Result 4 - noise/singleton
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_config() -> ClusteringConfig:
|
||||
"""Create default clustering configuration."""
|
||||
return ClusteringConfig(
|
||||
min_cluster_size=2,
|
||||
min_samples=1,
|
||||
metric="euclidean",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test ClusteringConfig
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestClusteringConfig:
|
||||
"""Tests for ClusteringConfig validation."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default configuration values."""
|
||||
config = ClusteringConfig()
|
||||
assert config.min_cluster_size == 3
|
||||
assert config.min_samples == 2
|
||||
assert config.metric == "cosine"
|
||||
assert config.cluster_selection_epsilon == 0.0
|
||||
assert config.allow_single_cluster is True
|
||||
assert config.prediction_data is False
|
||||
|
||||
def test_custom_values(self):
|
||||
"""Test custom configuration values."""
|
||||
config = ClusteringConfig(
|
||||
min_cluster_size=5,
|
||||
min_samples=3,
|
||||
metric="euclidean",
|
||||
cluster_selection_epsilon=0.1,
|
||||
allow_single_cluster=False,
|
||||
prediction_data=True,
|
||||
)
|
||||
assert config.min_cluster_size == 5
|
||||
assert config.min_samples == 3
|
||||
assert config.metric == "euclidean"
|
||||
|
||||
def test_invalid_min_cluster_size(self):
|
||||
"""Test validation rejects min_cluster_size < 2."""
|
||||
with pytest.raises(ValueError, match="min_cluster_size must be >= 2"):
|
||||
ClusteringConfig(min_cluster_size=1)
|
||||
|
||||
def test_invalid_min_samples(self):
|
||||
"""Test validation rejects min_samples < 1."""
|
||||
with pytest.raises(ValueError, match="min_samples must be >= 1"):
|
||||
ClusteringConfig(min_samples=0)
|
||||
|
||||
def test_invalid_metric(self):
|
||||
"""Test validation rejects invalid metric."""
|
||||
with pytest.raises(ValueError, match="metric must be one of"):
|
||||
ClusteringConfig(metric="invalid")
|
||||
|
||||
def test_invalid_epsilon(self):
|
||||
"""Test validation rejects negative epsilon."""
|
||||
with pytest.raises(ValueError, match="cluster_selection_epsilon must be >= 0"):
|
||||
ClusteringConfig(cluster_selection_epsilon=-0.1)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test NoOpStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNoOpStrategy:
|
||||
"""Tests for NoOpStrategy - always available."""
|
||||
|
||||
def test_cluster_returns_singleton_clusters(
|
||||
self, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test cluster() returns each result as singleton cluster."""
|
||||
strategy = NoOpStrategy()
|
||||
clusters = strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
assert len(clusters) == 5
|
||||
for i, cluster in enumerate(clusters):
|
||||
assert cluster == [i]
|
||||
|
||||
def test_cluster_empty_results(self):
|
||||
"""Test cluster() with empty results."""
|
||||
import numpy as np
|
||||
|
||||
strategy = NoOpStrategy()
|
||||
clusters = strategy.cluster(np.array([]), [])
|
||||
|
||||
assert clusters == []
|
||||
|
||||
def test_select_representatives_returns_all_sorted(
|
||||
self, sample_results: List[SearchResult]
|
||||
):
|
||||
"""Test select_representatives() returns all results sorted by score."""
|
||||
strategy = NoOpStrategy()
|
||||
clusters = [[i] for i in range(len(sample_results))]
|
||||
representatives = strategy.select_representatives(clusters, sample_results)
|
||||
|
||||
assert len(representatives) == 5
|
||||
# Check sorted by score descending
|
||||
scores = [r.score for r in representatives]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_select_representatives_empty(self):
|
||||
"""Test select_representatives() with empty input."""
|
||||
strategy = NoOpStrategy()
|
||||
representatives = strategy.select_representatives([], [])
|
||||
assert representatives == []
|
||||
|
||||
def test_fit_predict_convenience_method(
|
||||
self, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test fit_predict() convenience method."""
|
||||
strategy = NoOpStrategy()
|
||||
representatives = strategy.fit_predict(mock_embeddings, sample_results)
|
||||
|
||||
assert len(representatives) == 5
|
||||
# All results returned, sorted by score
|
||||
assert representatives[0].score >= representatives[-1].score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test HDBSCANStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestHDBSCANStrategy:
|
||||
"""Tests for HDBSCANStrategy - requires hdbscan package."""
|
||||
|
||||
@pytest.fixture
|
||||
def hdbscan_strategy(self, default_config):
|
||||
"""Create HDBSCANStrategy if available."""
|
||||
try:
|
||||
from codexlens.search.clustering import HDBSCANStrategy
|
||||
|
||||
return HDBSCANStrategy(default_config)
|
||||
except ImportError:
|
||||
pytest.skip("hdbscan not installed")
|
||||
|
||||
def test_cluster_returns_list_of_lists(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test cluster() returns List[List[int]]."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
assert isinstance(clusters, list)
|
||||
for cluster in clusters:
|
||||
assert isinstance(cluster, list)
|
||||
for idx in cluster:
|
||||
assert isinstance(idx, int)
|
||||
assert 0 <= idx < len(sample_results)
|
||||
|
||||
def test_cluster_covers_all_results(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test all result indices appear in clusters."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
all_indices = set()
|
||||
for cluster in clusters:
|
||||
all_indices.update(cluster)
|
||||
|
||||
assert all_indices == set(range(len(sample_results)))
|
||||
|
||||
def test_cluster_empty_results(self, hdbscan_strategy):
|
||||
"""Test cluster() with empty results."""
|
||||
import numpy as np
|
||||
|
||||
clusters = hdbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
|
||||
assert clusters == []
|
||||
|
||||
def test_cluster_single_result(self, hdbscan_strategy):
|
||||
"""Test cluster() with single result."""
|
||||
import numpy as np
|
||||
|
||||
result = SearchResult(path="a.py", score=0.9, excerpt="test")
|
||||
embeddings = np.array([[1.0, 0.0, 0.0]])
|
||||
clusters = hdbscan_strategy.cluster(embeddings, [result])
|
||||
|
||||
assert len(clusters) == 1
|
||||
assert clusters[0] == [0]
|
||||
|
||||
def test_cluster_fewer_than_min_cluster_size(self, hdbscan_strategy):
|
||||
"""Test cluster() with fewer results than min_cluster_size."""
|
||||
import numpy as np
|
||||
|
||||
# Strategy has min_cluster_size=2, so 1 result returns singleton
|
||||
result = SearchResult(path="a.py", score=0.9, excerpt="test")
|
||||
embeddings = np.array([[1.0, 0.0, 0.0]])
|
||||
clusters = hdbscan_strategy.cluster(embeddings, [result])
|
||||
|
||||
assert len(clusters) == 1
|
||||
assert clusters[0] == [0]
|
||||
|
||||
def test_select_representatives_picks_highest_score(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test select_representatives() picks highest score per cluster."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = hdbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
# Each representative should be the highest-scored in its cluster
|
||||
for rep in representatives:
|
||||
# Find the cluster containing this representative
|
||||
rep_idx = next(
|
||||
i for i, r in enumerate(sample_results) if r.path == rep.path
|
||||
)
|
||||
for cluster in clusters:
|
||||
if rep_idx in cluster:
|
||||
cluster_scores = [sample_results[i].score for i in cluster]
|
||||
assert rep.score == max(cluster_scores)
|
||||
break
|
||||
|
||||
def test_select_representatives_sorted_by_score(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test representatives are sorted by score descending."""
|
||||
clusters = hdbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = hdbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
scores = [r.score for r in representatives]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_fit_predict_end_to_end(
|
||||
self, hdbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test fit_predict() end-to-end clustering."""
|
||||
representatives = hdbscan_strategy.fit_predict(mock_embeddings, sample_results)
|
||||
|
||||
# Should have fewer or equal representatives than input
|
||||
assert len(representatives) <= len(sample_results)
|
||||
# All representatives should be from original results
|
||||
rep_paths = {r.path for r in representatives}
|
||||
original_paths = {r.path for r in sample_results}
|
||||
assert rep_paths.issubset(original_paths)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test DBSCANStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDBSCANStrategy:
|
||||
"""Tests for DBSCANStrategy - requires sklearn."""
|
||||
|
||||
@pytest.fixture
|
||||
def dbscan_strategy(self, default_config):
|
||||
"""Create DBSCANStrategy if available."""
|
||||
try:
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
|
||||
return DBSCANStrategy(default_config)
|
||||
except ImportError:
|
||||
pytest.skip("sklearn not installed")
|
||||
|
||||
def test_cluster_returns_list_of_lists(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test cluster() returns List[List[int]]."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
assert isinstance(clusters, list)
|
||||
for cluster in clusters:
|
||||
assert isinstance(cluster, list)
|
||||
for idx in cluster:
|
||||
assert isinstance(idx, int)
|
||||
assert 0 <= idx < len(sample_results)
|
||||
|
||||
def test_cluster_covers_all_results(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test all result indices appear in clusters."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
|
||||
all_indices = set()
|
||||
for cluster in clusters:
|
||||
all_indices.update(cluster)
|
||||
|
||||
assert all_indices == set(range(len(sample_results)))
|
||||
|
||||
def test_cluster_empty_results(self, dbscan_strategy):
|
||||
"""Test cluster() with empty results."""
|
||||
import numpy as np
|
||||
|
||||
clusters = dbscan_strategy.cluster(np.array([]).reshape(0, 3), [])
|
||||
assert clusters == []
|
||||
|
||||
def test_cluster_single_result(self, dbscan_strategy):
|
||||
"""Test cluster() with single result."""
|
||||
import numpy as np
|
||||
|
||||
result = SearchResult(path="a.py", score=0.9, excerpt="test")
|
||||
embeddings = np.array([[1.0, 0.0, 0.0]])
|
||||
clusters = dbscan_strategy.cluster(embeddings, [result])
|
||||
|
||||
assert len(clusters) == 1
|
||||
assert clusters[0] == [0]
|
||||
|
||||
def test_cluster_with_explicit_eps(self, default_config):
|
||||
"""Test cluster() with explicit eps parameter."""
|
||||
try:
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
except ImportError:
|
||||
pytest.skip("sklearn not installed")
|
||||
|
||||
import numpy as np
|
||||
|
||||
strategy = DBSCANStrategy(default_config, eps=0.5)
|
||||
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(3)]
|
||||
embeddings = np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0]])
|
||||
|
||||
clusters = strategy.cluster(embeddings, results)
|
||||
# With eps=0.5, first two should cluster, third should be separate
|
||||
assert len(clusters) >= 2
|
||||
|
||||
def test_auto_compute_eps(self, dbscan_strategy, mock_embeddings):
|
||||
"""Test eps auto-computation from distance distribution."""
|
||||
# Should not raise - eps is computed automatically
|
||||
results = [SearchResult(path=f"{i}.py", score=0.5, excerpt="test") for i in range(5)]
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, results)
|
||||
assert len(clusters) > 0
|
||||
|
||||
def test_select_representatives_picks_highest_score(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test select_representatives() picks highest score per cluster."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = dbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
# Each representative should be the highest-scored in its cluster
|
||||
for rep in representatives:
|
||||
rep_idx = next(
|
||||
i for i, r in enumerate(sample_results) if r.path == rep.path
|
||||
)
|
||||
for cluster in clusters:
|
||||
if rep_idx in cluster:
|
||||
cluster_scores = [sample_results[i].score for i in cluster]
|
||||
assert rep.score == max(cluster_scores)
|
||||
break
|
||||
|
||||
def test_select_representatives_sorted_by_score(
|
||||
self, dbscan_strategy, sample_results: List[SearchResult], mock_embeddings
|
||||
):
|
||||
"""Test representatives are sorted by score descending."""
|
||||
clusters = dbscan_strategy.cluster(mock_embeddings, sample_results)
|
||||
representatives = dbscan_strategy.select_representatives(
|
||||
clusters, sample_results
|
||||
)
|
||||
|
||||
scores = [r.score for r in representatives]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test ClusteringStrategyFactory
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestClusteringStrategyFactory:
|
||||
"""Tests for ClusteringStrategyFactory."""
|
||||
|
||||
def test_check_noop_always_available(self):
|
||||
"""Test noop strategy is always available."""
|
||||
ok, err = check_clustering_strategy_available("noop")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
def test_check_invalid_strategy(self):
|
||||
"""Test invalid strategy name returns error."""
|
||||
ok, err = check_clustering_strategy_available("invalid")
|
||||
assert ok is False
|
||||
assert "Invalid clustering strategy" in err
|
||||
|
||||
def test_get_strategy_noop(self, default_config):
|
||||
"""Test get_strategy('noop') returns NoOpStrategy."""
|
||||
strategy = get_strategy("noop", default_config)
|
||||
assert isinstance(strategy, NoOpStrategy)
|
||||
|
||||
def test_get_strategy_auto_returns_something(self, default_config):
|
||||
"""Test get_strategy('auto') returns a strategy."""
|
||||
strategy = get_strategy("auto", default_config)
|
||||
assert isinstance(strategy, BaseClusteringStrategy)
|
||||
|
||||
def test_get_strategy_with_fallback_enabled(self, default_config):
|
||||
"""Test fallback when primary strategy unavailable."""
|
||||
# Mock hdbscan unavailable
|
||||
with patch.dict("sys.modules", {"hdbscan": None}):
|
||||
# Should fall back to dbscan or noop
|
||||
strategy = get_strategy("hdbscan", default_config, fallback=True)
|
||||
assert isinstance(strategy, BaseClusteringStrategy)
|
||||
|
||||
def test_get_strategy_fallback_disabled_raises(self, default_config):
|
||||
"""Test ImportError when fallback disabled and strategy unavailable."""
|
||||
with patch(
|
||||
"codexlens.search.clustering.factory.check_clustering_strategy_available"
|
||||
) as mock_check:
|
||||
mock_check.return_value = (False, "Test error")
|
||||
|
||||
with pytest.raises(ImportError, match="Test error"):
|
||||
get_strategy("hdbscan", default_config, fallback=False)
|
||||
|
||||
def test_get_strategy_invalid_raises(self, default_config):
|
||||
"""Test ValueError for invalid strategy name."""
|
||||
with pytest.raises(ValueError, match="Unknown clustering strategy"):
|
||||
get_strategy("invalid", default_config)
|
||||
|
||||
def test_factory_class_interface(self, default_config):
|
||||
"""Test ClusteringStrategyFactory class interface."""
|
||||
strategy = ClusteringStrategyFactory.get_strategy("noop", default_config)
|
||||
assert isinstance(strategy, NoOpStrategy)
|
||||
|
||||
ok, err = ClusteringStrategyFactory.check_available("noop")
|
||||
assert ok is True
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_clustering_strategy_available("hdbscan")[0],
|
||||
reason="hdbscan not installed",
|
||||
)
|
||||
def test_get_strategy_hdbscan(self, default_config):
|
||||
"""Test get_strategy('hdbscan') returns HDBSCANStrategy."""
|
||||
from codexlens.search.clustering import HDBSCANStrategy
|
||||
|
||||
strategy = get_strategy("hdbscan", default_config)
|
||||
assert isinstance(strategy, HDBSCANStrategy)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_clustering_strategy_available("dbscan")[0],
|
||||
reason="sklearn not installed",
|
||||
)
|
||||
def test_get_strategy_dbscan(self, default_config):
|
||||
"""Test get_strategy('dbscan') returns DBSCANStrategy."""
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
|
||||
strategy = get_strategy("dbscan", default_config)
|
||||
assert isinstance(strategy, DBSCANStrategy)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not check_clustering_strategy_available("dbscan")[0],
|
||||
reason="sklearn not installed",
|
||||
)
|
||||
def test_get_strategy_dbscan_with_kwargs(self, default_config):
|
||||
"""Test DBSCANStrategy kwargs passed through factory."""
|
||||
strategy = get_strategy("dbscan", default_config, eps=0.3, eps_percentile=20.0)
|
||||
assert strategy.eps == 0.3
|
||||
assert strategy.eps_percentile == 20.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestClusteringIntegration:
|
||||
"""Integration tests for clustering strategies."""
|
||||
|
||||
def test_all_strategies_same_interface(
|
||||
self, sample_results: List[SearchResult], mock_embeddings, default_config
|
||||
):
|
||||
"""Test all strategies have consistent interface."""
|
||||
strategies = [NoOpStrategy(default_config)]
|
||||
|
||||
# Add available strategies
|
||||
try:
|
||||
from codexlens.search.clustering import HDBSCANStrategy
|
||||
|
||||
strategies.append(HDBSCANStrategy(default_config))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from codexlens.search.clustering import DBSCANStrategy
|
||||
|
||||
strategies.append(DBSCANStrategy(default_config))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
for strategy in strategies:
|
||||
# All should implement cluster()
|
||||
clusters = strategy.cluster(mock_embeddings, sample_results)
|
||||
assert isinstance(clusters, list)
|
||||
|
||||
# All should implement select_representatives()
|
||||
reps = strategy.select_representatives(clusters, sample_results)
|
||||
assert isinstance(reps, list)
|
||||
assert all(isinstance(r, SearchResult) for r in reps)
|
||||
|
||||
# All should implement fit_predict()
|
||||
reps = strategy.fit_predict(mock_embeddings, sample_results)
|
||||
assert isinstance(reps, list)
|
||||
|
||||
def test_clustering_reduces_redundancy(
|
||||
self, default_config
|
||||
):
|
||||
"""Test clustering reduces redundant similar results."""
|
||||
import numpy as np
|
||||
|
||||
# Create results with very similar embeddings
|
||||
results = [
|
||||
SearchResult(path=f"{i}.py", score=0.9 - i * 0.01, excerpt="def foo(): pass")
|
||||
for i in range(10)
|
||||
]
|
||||
# Very similar embeddings - should cluster together
|
||||
embeddings = np.array(
|
||||
[[1.0 + i * 0.01, 0.0, 0.0] for i in range(10)], dtype=np.float32
|
||||
)
|
||||
|
||||
strategy = get_strategy("auto", default_config)
|
||||
representatives = strategy.fit_predict(embeddings, results)
|
||||
|
||||
# Should have fewer representatives than input (clustering reduced redundancy)
|
||||
# NoOp returns all, but HDBSCAN/DBSCAN should reduce
|
||||
assert len(representatives) <= len(results)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test FrequencyStrategy
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFrequencyStrategy:
|
||||
"""Tests for FrequencyStrategy - frequency-based clustering."""
|
||||
|
||||
@pytest.fixture
|
||||
def frequency_config(self):
|
||||
"""Create FrequencyConfig for testing."""
|
||||
from codexlens.search.clustering import FrequencyConfig
|
||||
return FrequencyConfig(min_frequency=1, max_representatives_per_group=3)
|
||||
|
||||
@pytest.fixture
|
||||
def frequency_strategy(self, frequency_config):
|
||||
"""Create FrequencyStrategy instance."""
|
||||
from codexlens.search.clustering import FrequencyStrategy
|
||||
return FrequencyStrategy(frequency_config)
|
||||
|
||||
@pytest.fixture
|
||||
def symbol_results(self) -> List[SearchResult]:
|
||||
"""Create sample results with symbol names for frequency testing."""
|
||||
return [
|
||||
SearchResult(path="auth.py", score=0.9, excerpt="authenticate user", symbol_name="authenticate"),
|
||||
SearchResult(path="login.py", score=0.85, excerpt="authenticate login", symbol_name="authenticate"),
|
||||
SearchResult(path="session.py", score=0.8, excerpt="authenticate session", symbol_name="authenticate"),
|
||||
SearchResult(path="utils.py", score=0.7, excerpt="helper function", symbol_name="helper_func"),
|
||||
SearchResult(path="validate.py", score=0.6, excerpt="validate input", symbol_name="validate"),
|
||||
SearchResult(path="check.py", score=0.55, excerpt="validate data", symbol_name="validate"),
|
||||
]
|
||||
|
||||
def test_frequency_strategy_available(self):
|
||||
"""Test FrequencyStrategy is always available (no deps)."""
|
||||
ok, err = check_clustering_strategy_available("frequency")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
def test_get_strategy_frequency(self):
|
||||
"""Test get_strategy('frequency') returns FrequencyStrategy."""
|
||||
from codexlens.search.clustering import FrequencyStrategy
|
||||
strategy = get_strategy("frequency")
|
||||
assert isinstance(strategy, FrequencyStrategy)
|
||||
|
||||
def test_cluster_groups_by_symbol(self, frequency_strategy, symbol_results):
|
||||
"""Test cluster() groups results by symbol name."""
|
||||
import numpy as np
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = frequency_strategy.cluster(embeddings, symbol_results)
|
||||
|
||||
# Should have 3 groups: authenticate(3), validate(2), helper_func(1)
|
||||
assert len(clusters) == 3
|
||||
|
||||
# First cluster should be authenticate (highest frequency)
|
||||
first_cluster_symbols = [symbol_results[i].symbol_name for i in clusters[0]]
|
||||
assert all(s == "authenticate" for s in first_cluster_symbols)
|
||||
assert len(clusters[0]) == 3
|
||||
|
||||
def test_cluster_orders_by_frequency(self, frequency_strategy, symbol_results):
|
||||
"""Test clusters are ordered by frequency (descending)."""
|
||||
import numpy as np
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = frequency_strategy.cluster(embeddings, symbol_results)
|
||||
|
||||
# Verify frequency ordering
|
||||
frequencies = [len(c) for c in clusters]
|
||||
assert frequencies == sorted(frequencies, reverse=True)
|
||||
|
||||
def test_select_representatives_adds_frequency_metadata(self, frequency_strategy, symbol_results):
|
||||
"""Test representatives have frequency metadata."""
|
||||
import numpy as np
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = frequency_strategy.cluster(embeddings, symbol_results)
|
||||
reps = frequency_strategy.select_representatives(clusters, symbol_results, embeddings)
|
||||
|
||||
# Check frequency metadata
|
||||
for rep in reps:
|
||||
assert "frequency" in rep.metadata
|
||||
assert rep.metadata["frequency"] >= 1
|
||||
|
||||
def test_min_frequency_filter_mode(self, symbol_results):
|
||||
"""Test min_frequency with filter mode removes low-frequency results."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(min_frequency=2, keep_mode="filter")
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# helper_func (freq=1) should be filtered out
|
||||
rep_symbols = [r.symbol_name for r in reps]
|
||||
assert "helper_func" not in rep_symbols
|
||||
assert "authenticate" in rep_symbols
|
||||
assert "validate" in rep_symbols
|
||||
|
||||
def test_min_frequency_demote_mode(self, symbol_results):
|
||||
"""Test min_frequency with demote mode keeps but deprioritizes low-frequency."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(min_frequency=2, keep_mode="demote")
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# helper_func should still be present but at the end
|
||||
rep_symbols = [r.symbol_name for r in reps]
|
||||
assert "helper_func" in rep_symbols
|
||||
# Should be demoted to end
|
||||
helper_idx = rep_symbols.index("helper_func")
|
||||
assert helper_idx == len(rep_symbols) - 1
|
||||
|
||||
def test_group_by_file(self, symbol_results):
|
||||
"""Test grouping by file path instead of symbol."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(group_by="file")
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
clusters = strategy.cluster(embeddings, symbol_results)
|
||||
|
||||
# Each file should be its own group (all unique paths)
|
||||
assert len(clusters) == 6
|
||||
|
||||
def test_max_representatives_per_group(self, symbol_results):
|
||||
"""Test max_representatives_per_group limits output per symbol."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(max_representatives_per_group=1)
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# Should have at most 1 per group = 3 groups = 3 reps
|
||||
assert len(reps) == 3
|
||||
|
||||
def test_frequency_boost_score(self, symbol_results):
|
||||
"""Test frequency_weight boosts high-frequency results."""
|
||||
from codexlens.search.clustering import FrequencyStrategy, FrequencyConfig
|
||||
import numpy as np
|
||||
|
||||
config = FrequencyConfig(frequency_weight=0.5) # Strong boost
|
||||
strategy = FrequencyStrategy(config)
|
||||
embeddings = np.random.rand(len(symbol_results), 128)
|
||||
|
||||
reps = strategy.fit_predict(embeddings, symbol_results)
|
||||
|
||||
# High-frequency results should have boosted scores in metadata
|
||||
for rep in reps:
|
||||
if rep.metadata.get("frequency", 1) > 1:
|
||||
assert rep.metadata.get("frequency_boosted_score", 0) > rep.score
|
||||
|
||||
def test_empty_results(self, frequency_strategy):
|
||||
"""Test handling of empty results."""
|
||||
import numpy as np
|
||||
|
||||
clusters = frequency_strategy.cluster(np.array([]).reshape(0, 128), [])
|
||||
assert clusters == []
|
||||
|
||||
reps = frequency_strategy.select_representatives([], [], None)
|
||||
assert reps == []
|
||||
|
||||
def test_factory_with_kwargs(self):
|
||||
"""Test factory passes kwargs to FrequencyConfig."""
|
||||
strategy = get_strategy("frequency", min_frequency=3, group_by="file")
|
||||
assert strategy.config.min_frequency == 3
|
||||
assert strategy.config.group_by == "file"
|
||||
698
codex-lens/tests/test_staged_cascade.py
Normal file
698
codex-lens/tests/test_staged_cascade.py
Normal file
@@ -0,0 +1,698 @@
|
||||
"""Integration tests for staged cascade search pipeline.
|
||||
|
||||
Tests the 4-stage pipeline:
|
||||
1. Stage 1: Binary coarse search
|
||||
2. Stage 2: LSP graph expansion
|
||||
3. Stage 3: Clustering and representative selection
|
||||
4. Stage 4: Optional cross-encoder reranking
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_paths():
|
||||
"""Create temporary directory structure."""
|
||||
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
||||
root = Path(tmpdir.name)
|
||||
yield root
|
||||
try:
|
||||
tmpdir.cleanup()
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry(temp_paths: Path):
|
||||
"""Create mock registry store."""
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mapper(temp_paths: Path):
|
||||
"""Create path mapper."""
|
||||
return PathMapper(index_root=temp_paths / "indexes")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create mock config with staged cascade settings."""
|
||||
config = MagicMock(spec=Config)
|
||||
config.cascade_coarse_k = 100
|
||||
config.cascade_fine_k = 10
|
||||
config.enable_staged_rerank = False
|
||||
config.staged_clustering_strategy = "auto"
|
||||
config.staged_clustering_min_size = 3
|
||||
config.graph_expansion_depth = 2
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_binary_results() -> List[SearchResult]:
|
||||
"""Create sample binary search results for testing."""
|
||||
return [
|
||||
SearchResult(
|
||||
path="a.py",
|
||||
score=0.95,
|
||||
excerpt="def authenticate_user(username, password):",
|
||||
symbol_name="authenticate_user",
|
||||
symbol_kind="function",
|
||||
start_line=10,
|
||||
end_line=15,
|
||||
),
|
||||
SearchResult(
|
||||
path="b.py",
|
||||
score=0.85,
|
||||
excerpt="class AuthManager:",
|
||||
symbol_name="AuthManager",
|
||||
symbol_kind="class",
|
||||
start_line=5,
|
||||
end_line=20,
|
||||
),
|
||||
SearchResult(
|
||||
path="c.py",
|
||||
score=0.75,
|
||||
excerpt="def check_credentials(user, pwd):",
|
||||
symbol_name="check_credentials",
|
||||
symbol_kind="function",
|
||||
start_line=30,
|
||||
end_line=35,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_expanded_results() -> List[SearchResult]:
|
||||
"""Create sample expanded results (after LSP expansion)."""
|
||||
return [
|
||||
SearchResult(
|
||||
path="a.py",
|
||||
score=0.95,
|
||||
excerpt="def authenticate_user(username, password):",
|
||||
symbol_name="authenticate_user",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="a.py",
|
||||
score=0.90,
|
||||
excerpt="def verify_password(pwd):",
|
||||
symbol_name="verify_password",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="b.py",
|
||||
score=0.85,
|
||||
excerpt="class AuthManager:",
|
||||
symbol_name="AuthManager",
|
||||
symbol_kind="class",
|
||||
),
|
||||
SearchResult(
|
||||
path="b.py",
|
||||
score=0.80,
|
||||
excerpt="def login(self, user):",
|
||||
symbol_name="login",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="c.py",
|
||||
score=0.75,
|
||||
excerpt="def check_credentials(user, pwd):",
|
||||
symbol_name="check_credentials",
|
||||
symbol_kind="function",
|
||||
),
|
||||
SearchResult(
|
||||
path="d.py",
|
||||
score=0.70,
|
||||
excerpt="class UserModel:",
|
||||
symbol_name="UserModel",
|
||||
symbol_kind="class",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Stage Methods
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStage1BinarySearch:
|
||||
"""Tests for Stage 1: Binary coarse search."""
|
||||
|
||||
def test_stage1_returns_results_with_index_root(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage1_binary_search returns results and index_root."""
|
||||
from codexlens.search.chain_search import SearchStats
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Mock the binary embedding backend (import is inside the method)
|
||||
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
|
||||
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
|
||||
mock_index = MagicMock()
|
||||
mock_index.count.return_value = 10
|
||||
mock_index.search.return_value = ([1, 2, 3], [10, 20, 30])
|
||||
mock_binary_idx.return_value = mock_index
|
||||
|
||||
index_paths = [Path("/fake/index1/_index.db")]
|
||||
stats = SearchStats()
|
||||
|
||||
results, index_root = engine._stage1_binary_search(
|
||||
"query", index_paths, coarse_k=10, stats=stats
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(index_root, (Path, type(None)))
|
||||
|
||||
def test_stage1_handles_empty_index_paths(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage1_binary_search handles empty index paths."""
|
||||
from codexlens.search.chain_search import SearchStats
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
index_paths = []
|
||||
stats = SearchStats()
|
||||
|
||||
results, index_root = engine._stage1_binary_search(
|
||||
"query", index_paths, coarse_k=10, stats=stats
|
||||
)
|
||||
|
||||
assert results == []
|
||||
assert index_root is None
|
||||
|
||||
def test_stage1_aggregates_results_from_multiple_indexes(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage1_binary_search aggregates results from multiple indexes."""
|
||||
from codexlens.search.chain_search import SearchStats
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch("codexlens.indexing.embedding.BinaryEmbeddingBackend"):
|
||||
with patch.object(engine, "_get_or_create_binary_index") as mock_binary_idx:
|
||||
mock_index = MagicMock()
|
||||
mock_index.count.return_value = 10
|
||||
# Return different results for different calls
|
||||
mock_index.search.side_effect = [
|
||||
([1, 2], [10, 20]),
|
||||
([3, 4], [15, 25]),
|
||||
]
|
||||
mock_binary_idx.return_value = mock_index
|
||||
|
||||
index_paths = [
|
||||
Path("/fake/index1/_index.db"),
|
||||
Path("/fake/index2/_index.db"),
|
||||
]
|
||||
stats = SearchStats()
|
||||
|
||||
results, _ = engine._stage1_binary_search(
|
||||
"query", index_paths, coarse_k=10, stats=stats
|
||||
)
|
||||
|
||||
# Should aggregate candidates from both indexes
|
||||
assert isinstance(results, list)
|
||||
|
||||
|
||||
class TestStage2LSPExpand:
|
||||
"""Tests for Stage 2: LSP graph expansion."""
|
||||
|
||||
def test_stage2_returns_expanded_results(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test _stage2_lsp_expand returns expanded results."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Import is inside the method, so we need to patch it there
|
||||
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
|
||||
mock_expander = MagicMock()
|
||||
mock_expander.expand.return_value = [
|
||||
SearchResult(path="related.py", score=0.7, excerpt="related")
|
||||
]
|
||||
mock_expander_cls.return_value = mock_expander
|
||||
|
||||
expanded = engine._stage2_lsp_expand(
|
||||
sample_binary_results, index_root=Path("/fake/index")
|
||||
)
|
||||
|
||||
assert isinstance(expanded, list)
|
||||
# Should include original results
|
||||
assert len(expanded) >= len(sample_binary_results)
|
||||
|
||||
def test_stage2_handles_no_index_root(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test _stage2_lsp_expand handles missing index_root."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
expanded = engine._stage2_lsp_expand(sample_binary_results, index_root=None)
|
||||
|
||||
# Should return original results unchanged
|
||||
assert expanded == sample_binary_results
|
||||
|
||||
def test_stage2_handles_empty_results(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage2_lsp_expand handles empty input."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
expanded = engine._stage2_lsp_expand([], index_root=Path("/fake"))
|
||||
|
||||
assert expanded == []
|
||||
|
||||
def test_stage2_deduplicates_results(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test _stage2_lsp_expand deduplicates by (path, symbol_name, start_line)."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Mock expander to return duplicate of first result
|
||||
with patch("codexlens.search.graph_expander.GraphExpander") as mock_expander_cls:
|
||||
mock_expander = MagicMock()
|
||||
duplicate = SearchResult(
|
||||
path=sample_binary_results[0].path,
|
||||
score=0.5,
|
||||
excerpt="duplicate",
|
||||
symbol_name=sample_binary_results[0].symbol_name,
|
||||
start_line=sample_binary_results[0].start_line,
|
||||
)
|
||||
mock_expander.expand.return_value = [duplicate]
|
||||
mock_expander_cls.return_value = mock_expander
|
||||
|
||||
expanded = engine._stage2_lsp_expand(
|
||||
sample_binary_results, index_root=Path("/fake")
|
||||
)
|
||||
|
||||
# Should not include duplicate
|
||||
assert len(expanded) == len(sample_binary_results)
|
||||
|
||||
|
||||
class TestStage3ClusterPrune:
|
||||
"""Tests for Stage 3: Clustering and representative selection."""
|
||||
|
||||
def test_stage3_returns_representatives(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
|
||||
):
|
||||
"""Test _stage3_cluster_prune returns representative results."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
import numpy as np
|
||||
|
||||
# Mock embeddings
|
||||
mock_embed.return_value = np.random.rand(
|
||||
len(sample_expanded_results), 128
|
||||
).astype(np.float32)
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
assert isinstance(clustered, list)
|
||||
assert len(clustered) <= len(sample_expanded_results)
|
||||
assert all(isinstance(r, SearchResult) for r in clustered)
|
||||
|
||||
def test_stage3_handles_few_results(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage3_cluster_prune skips clustering for few results."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
few_results = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a"),
|
||||
SearchResult(path="b.py", score=0.8, excerpt="b"),
|
||||
]
|
||||
|
||||
clustered = engine._stage3_cluster_prune(few_results, target_count=5)
|
||||
|
||||
# Should return all results unchanged
|
||||
assert clustered == few_results
|
||||
|
||||
def test_stage3_handles_no_embeddings(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
|
||||
):
|
||||
"""Test _stage3_cluster_prune falls back to score-based selection without embeddings."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
mock_embed.return_value = None
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
# Should return top-scored results
|
||||
assert len(clustered) <= 3
|
||||
# Should be sorted by score descending
|
||||
scores = [r.score for r in clustered]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_stage3_uses_config_clustering_strategy(
|
||||
self, mock_registry, mock_mapper, sample_expanded_results
|
||||
):
|
||||
"""Test _stage3_cluster_prune uses config clustering strategy."""
|
||||
config = MagicMock(spec=Config)
|
||||
config.staged_clustering_strategy = "auto"
|
||||
config.staged_clustering_min_size = 2
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, PathMapper(), config=config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
import numpy as np
|
||||
|
||||
mock_embed.return_value = np.random.rand(
|
||||
len(sample_expanded_results), 128
|
||||
).astype(np.float32)
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
# Should use clustering (auto will pick best available)
|
||||
# Result should be a list of SearchResult objects
|
||||
assert isinstance(clustered, list)
|
||||
assert all(isinstance(r, SearchResult) for r in clustered)
|
||||
|
||||
|
||||
class TestStage4OptionalRerank:
|
||||
"""Tests for Stage 4: Optional cross-encoder reranking."""
|
||||
|
||||
def test_stage4_reranks_with_reranker(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage4_optional_rerank uses _cross_encoder_rerank."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a"),
|
||||
SearchResult(path="b.py", score=0.8, excerpt="b"),
|
||||
SearchResult(path="c.py", score=0.7, excerpt="c"),
|
||||
]
|
||||
|
||||
# Mock the _cross_encoder_rerank method that _stage4 calls
|
||||
with patch.object(engine, "_cross_encoder_rerank") as mock_rerank:
|
||||
mock_rerank.return_value = [
|
||||
SearchResult(path="c.py", score=0.95, excerpt="c"),
|
||||
SearchResult(path="a.py", score=0.85, excerpt="a"),
|
||||
]
|
||||
|
||||
reranked = engine._stage4_optional_rerank("query", results, k=2)
|
||||
|
||||
mock_rerank.assert_called_once_with("query", results, 2)
|
||||
assert len(reranked) <= 2
|
||||
# First result should be reranked winner
|
||||
assert reranked[0].path == "c.py"
|
||||
|
||||
def test_stage4_handles_empty_results(
|
||||
self, mock_registry, mock_mapper, mock_config
|
||||
):
|
||||
"""Test _stage4_optional_rerank handles empty input."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
reranked = engine._stage4_optional_rerank("query", [], k=2)
|
||||
|
||||
# Should return empty list
|
||||
assert reranked == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStagedCascadeIntegration:
|
||||
"""Integration tests for staged_cascade_search() end-to-end."""
|
||||
|
||||
def test_staged_cascade_returns_chain_result(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search returns ChainSearchResult."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Mock all stages
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src", k=10, coarse_k=100
|
||||
)
|
||||
|
||||
from codexlens.search.chain_search import ChainSearchResult
|
||||
|
||||
assert isinstance(result, ChainSearchResult)
|
||||
assert result.query == "query"
|
||||
assert len(result.results) <= 10
|
||||
|
||||
def test_staged_cascade_includes_stage_stats(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search includes per-stage timing stats."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src"
|
||||
)
|
||||
|
||||
# Check for stage stats in errors field
|
||||
stage_stats = None
|
||||
for err in result.stats.errors:
|
||||
if err.startswith("STAGE_STATS:"):
|
||||
stage_stats = json.loads(err.replace("STAGE_STATS:", ""))
|
||||
break
|
||||
|
||||
assert stage_stats is not None
|
||||
assert "stage_times" in stage_stats
|
||||
assert "stage_counts" in stage_stats
|
||||
assert "stage1_binary_ms" in stage_stats["stage_times"]
|
||||
assert "stage1_candidates" in stage_stats["stage_counts"]
|
||||
|
||||
def test_staged_cascade_with_rerank_enabled(
|
||||
self, mock_registry, mock_mapper, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search with reranking enabled."""
|
||||
config = MagicMock(spec=Config)
|
||||
config.cascade_coarse_k = 100
|
||||
config.cascade_fine_k = 10
|
||||
config.enable_staged_rerank = True
|
||||
config.staged_clustering_strategy = "auto"
|
||||
config.graph_expansion_depth = 2
|
||||
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage4_optional_rerank") as mock_stage4:
|
||||
mock_stage4.return_value = [
|
||||
SearchResult(path="a.py", score=0.95, excerpt="a")
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src"
|
||||
)
|
||||
|
||||
# Verify stage 4 was called
|
||||
mock_stage4.assert_called_once()
|
||||
|
||||
def test_staged_cascade_fallback_to_hybrid(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search falls back to hybrid when numpy unavailable."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch("codexlens.search.chain_search.NUMPY_AVAILABLE", False):
|
||||
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
|
||||
mock_hybrid.return_value = MagicMock()
|
||||
|
||||
engine.staged_cascade_search("query", temp_paths / "src")
|
||||
|
||||
# Should fall back to hybrid cascade
|
||||
mock_hybrid.assert_called_once()
|
||||
|
||||
def test_staged_cascade_deduplicates_final_results(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged_cascade_search deduplicates results by path."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
mock_stage1.return_value = (
|
||||
[SearchResult(path="a.py", score=0.9, excerpt="a")],
|
||||
temp_paths / "index",
|
||||
)
|
||||
|
||||
with patch.object(engine, "_stage2_lsp_expand") as mock_stage2:
|
||||
mock_stage2.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a")
|
||||
]
|
||||
|
||||
with patch.object(engine, "_stage3_cluster_prune") as mock_stage3:
|
||||
# Return duplicates with different scores
|
||||
mock_stage3.return_value = [
|
||||
SearchResult(path="a.py", score=0.9, excerpt="a"),
|
||||
SearchResult(path="a.py", score=0.8, excerpt="a duplicate"),
|
||||
SearchResult(path="b.py", score=0.7, excerpt="b"),
|
||||
]
|
||||
|
||||
result = engine.staged_cascade_search(
|
||||
"query", temp_paths / "src", k=10
|
||||
)
|
||||
|
||||
# Should deduplicate a.py (keep higher score)
|
||||
paths = [r.path for r in result.results]
|
||||
assert len(paths) == len(set(paths))
|
||||
# a.py should have score 0.9
|
||||
a_result = next(r for r in result.results if r.path == "a.py")
|
||||
assert a_result.score == 0.9
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Graceful Degradation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestStagedCascadeGracefulDegradation:
|
||||
"""Tests for graceful degradation when dependencies unavailable."""
|
||||
|
||||
def test_falls_back_when_clustering_unavailable(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_expanded_results
|
||||
):
|
||||
"""Test clustering stage falls back gracefully when clustering unavailable."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_get_embeddings_for_clustering") as mock_embed:
|
||||
mock_embed.return_value = None
|
||||
|
||||
clustered = engine._stage3_cluster_prune(
|
||||
sample_expanded_results, target_count=3
|
||||
)
|
||||
|
||||
# Should fall back to score-based selection
|
||||
assert len(clustered) <= 3
|
||||
|
||||
def test_falls_back_when_graph_expander_unavailable(
|
||||
self, mock_registry, mock_mapper, mock_config, sample_binary_results
|
||||
):
|
||||
"""Test LSP expansion falls back when GraphExpander unavailable."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
# Patch the import inside the method
|
||||
with patch("codexlens.search.graph_expander.GraphExpander", side_effect=ImportError):
|
||||
expanded = engine._stage2_lsp_expand(
|
||||
sample_binary_results, index_root=Path("/fake")
|
||||
)
|
||||
|
||||
# Should return original results
|
||||
assert expanded == sample_binary_results
|
||||
|
||||
def test_handles_stage_failures_gracefully(
|
||||
self, mock_registry, mock_mapper, mock_config, temp_paths
|
||||
):
|
||||
"""Test staged pipeline handles stage failures gracefully."""
|
||||
engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config)
|
||||
|
||||
with patch.object(engine, "_find_start_index") as mock_find:
|
||||
mock_find.return_value = temp_paths / "index" / "_index.db"
|
||||
|
||||
with patch.object(engine, "_collect_index_paths") as mock_collect:
|
||||
mock_collect.return_value = [temp_paths / "index" / "_index.db"]
|
||||
|
||||
with patch.object(engine, "_stage1_binary_search") as mock_stage1:
|
||||
# Stage 1 returns no results
|
||||
mock_stage1.return_value = ([], None)
|
||||
|
||||
with patch.object(engine, "hybrid_cascade_search") as mock_hybrid:
|
||||
mock_hybrid.return_value = MagicMock()
|
||||
|
||||
engine.staged_cascade_search("query", temp_paths / "src")
|
||||
|
||||
# Should fall back to hybrid when stage 1 fails
|
||||
mock_hybrid.assert_called_once()
|
||||
Reference in New Issue
Block a user