mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-12 02:37:45 +08:00
feat: Implement adaptive RRF weights and query intent detection
- Added integration tests for adaptive RRF weights in hybrid search. - Enhanced query intent detection with new classifications: keyword, semantic, and mixed. - Introduced symbol boosting in search results based on explicit symbol matches. - Implemented embedding-based reranking with configurable options. - Added global symbol index for efficient symbol lookups across projects. - Improved file deletion handling on Windows to avoid permission errors. - Updated chunk configuration to increase overlap for better context. - Modified package.json test script to target specific test files. - Created comprehensive writing style guidelines for documentation. - Added TypeScript tests for query intent detection and adaptive weights. - Established performance benchmarks for global symbol indexing.
This commit is contained in:
293
codex-lens/tests/test_global_index.py
Normal file
293
codex-lens/tests/test_global_index.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import Symbol
|
||||
from codexlens.errors import StorageError
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def temp_paths():
|
||||
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
||||
root = Path(tmpdir.name)
|
||||
yield root
|
||||
try:
|
||||
tmpdir.cleanup()
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def test_add_symbol(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "a.py"
|
||||
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=1) as store:
|
||||
store.add_symbol(
|
||||
Symbol(name="AuthManager", kind="class", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
matches = store.search("AuthManager", kind="class", limit=10, prefix_mode=True)
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "AuthManager"
|
||||
assert matches[0].file == str(file_path.resolve())
|
||||
|
||||
# Schema version safety: newer schema versions should be rejected.
|
||||
bad_db = temp_paths / "indexes" / "_global_symbols_bad.db"
|
||||
bad_db.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(bad_db)
|
||||
conn.execute("PRAGMA user_version = 999")
|
||||
conn.close()
|
||||
|
||||
with pytest.raises(StorageError):
|
||||
GlobalSymbolIndex(bad_db, project_id=1).initialize()
|
||||
|
||||
|
||||
def test_search_symbols(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "mod.py"
|
||||
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("def authenticate():\n pass\n", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=7) as store:
|
||||
store.add_symbol(
|
||||
Symbol(name="authenticate", kind="function", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
locations = store.search_symbols("auth", kind="function", limit=10, prefix_mode=True)
|
||||
assert locations
|
||||
assert any(p.endswith("mod.py") for p, _ in locations)
|
||||
assert any(rng == (1, 2) for _, rng in locations)
|
||||
|
||||
|
||||
def test_update_file_symbols(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
file_path = temp_paths / "src" / "mod.py"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("def a():\n pass\n", encoding="utf-8")
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=7) as store:
|
||||
store.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[
|
||||
Symbol(name="old_func", kind="function", range=(1, 2)),
|
||||
Symbol(name="Other", kind="class", range=(10, 20)),
|
||||
],
|
||||
index_path=index_path,
|
||||
)
|
||||
assert any(s.name == "old_func" for s in store.search("old_", prefix_mode=True))
|
||||
|
||||
store.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[Symbol(name="new_func", kind="function", range=(3, 4))],
|
||||
index_path=index_path,
|
||||
)
|
||||
assert not any(s.name == "old_func" for s in store.search("old_", prefix_mode=True))
|
||||
assert any(s.name == "new_func" for s in store.search("new_", prefix_mode=True))
|
||||
|
||||
# Backward-compatible path: index_path can be omitted after it's been established.
|
||||
store.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[Symbol(name="new_func2", kind="function", range=(5, 6))],
|
||||
index_path=None,
|
||||
)
|
||||
assert any(s.name == "new_func2" for s in store.search("new_func2", prefix_mode=True))
|
||||
|
||||
# New file + symbols without index_path should raise.
|
||||
missing_index_file = temp_paths / "src" / "new_file.py"
|
||||
with pytest.raises(StorageError):
|
||||
store.update_file_symbols(
|
||||
file_path=missing_index_file,
|
||||
symbols=[Symbol(name="must_fail", kind="function", range=(1, 1))],
|
||||
index_path=None,
|
||||
)
|
||||
|
||||
deleted = store.delete_file_symbols(file_path)
|
||||
assert deleted > 0
|
||||
|
||||
|
||||
def test_incremental_updates(temp_paths: Path, monkeypatch):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
file_path = temp_paths / "src" / "same.py"
|
||||
idx_a = temp_paths / "indexes" / "a" / "_index.db"
|
||||
idx_b = temp_paths / "indexes" / "b" / "_index.db"
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
|
||||
idx_a.parent.mkdir(parents=True, exist_ok=True)
|
||||
idx_a.write_text("", encoding="utf-8")
|
||||
idx_b.parent.mkdir(parents=True, exist_ok=True)
|
||||
idx_b.write_text("", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=42) as store:
|
||||
sym = Symbol(name="AuthManager", kind="class", range=(1, 2))
|
||||
store.add_symbol(sym, file_path=file_path, index_path=idx_a)
|
||||
store.add_symbol(sym, file_path=file_path, index_path=idx_b)
|
||||
|
||||
# prefix_mode=False exercises substring matching.
|
||||
assert store.search("Manager", prefix_mode=False)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT index_path
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND symbol_name=? AND symbol_kind=? AND file_path=?
|
||||
""",
|
||||
(42, "AuthManager", "class", str(file_path.resolve())),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
assert row is not None
|
||||
assert str(Path(row[0]).resolve()) == str(idx_b.resolve())
|
||||
|
||||
# Migration path coverage: simulate a future schema version and an older DB version.
|
||||
migrating_db = temp_paths / "indexes" / "_global_symbols_migrate.db"
|
||||
migrating_db.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(migrating_db)
|
||||
conn.execute("PRAGMA user_version = 1")
|
||||
conn.close()
|
||||
|
||||
monkeypatch.setattr(GlobalSymbolIndex, "SCHEMA_VERSION", 2)
|
||||
GlobalSymbolIndex(migrating_db, project_id=1).initialize()
|
||||
|
||||
|
||||
def test_concurrent_access(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "a.py"
|
||||
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class A:\n pass\n", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=1) as store:
|
||||
def add_many(worker_id: int):
|
||||
for i in range(50):
|
||||
store.add_symbol(
|
||||
Symbol(name=f"Sym{worker_id}_{i}", kind="class", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=8) as ex:
|
||||
list(ex.map(add_many, range(8)))
|
||||
|
||||
matches = store.search("Sym", kind="class", limit=1000, prefix_mode=True)
|
||||
assert len(matches) >= 200
|
||||
|
||||
|
||||
def test_chain_search_integration(temp_paths: Path):
|
||||
project_root = temp_paths / "project"
|
||||
project_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index_root = temp_paths / "indexes"
|
||||
mapper = PathMapper(index_root=index_root)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
index_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_db_path.write_text("", encoding="utf-8")
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
project_info = registry.register_project(project_root, mapper.source_to_index_dir(project_root))
|
||||
|
||||
global_db_path = project_info.index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||
with GlobalSymbolIndex(global_db_path, project_id=project_info.id) as global_index:
|
||||
file_path = project_root / "auth.py"
|
||||
global_index.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[
|
||||
Symbol(name="AuthManager", kind="class", range=(1, 10)),
|
||||
Symbol(name="authenticate", kind="function", range=(12, 20)),
|
||||
],
|
||||
index_path=index_db_path,
|
||||
)
|
||||
|
||||
config = Config(data_dir=temp_paths / "data", global_symbol_index_enabled=True)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
engine._search_symbols_parallel = MagicMock(side_effect=AssertionError("should not traverse chain"))
|
||||
|
||||
symbols = engine.search_symbols("Auth", project_root)
|
||||
assert any(s.name == "AuthManager" for s in symbols)
|
||||
registry.close()
|
||||
|
||||
|
||||
def test_disabled_fallback(temp_paths: Path):
|
||||
project_root = temp_paths / "project"
|
||||
project_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index_root = temp_paths / "indexes"
|
||||
mapper = PathMapper(index_root=index_root)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
index_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_db_path.write_text("", encoding="utf-8")
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
registry.register_project(project_root, mapper.source_to_index_dir(project_root))
|
||||
|
||||
config = Config(data_dir=temp_paths / "data", global_symbol_index_enabled=False)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
engine._collect_index_paths = MagicMock(return_value=[index_db_path])
|
||||
engine._search_symbols_parallel = MagicMock(
|
||||
return_value=[Symbol(name="FallbackSymbol", kind="function", range=(1, 2))]
|
||||
)
|
||||
|
||||
symbols = engine.search_symbols("Fallback", project_root)
|
||||
assert any(s.name == "FallbackSymbol" for s in symbols)
|
||||
assert engine._search_symbols_parallel.called
|
||||
registry.close()
|
||||
|
||||
|
||||
def test_performance_benchmark(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "perf.py"
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=1) as store:
|
||||
for i in range(500):
|
||||
store.add_symbol(
|
||||
Symbol(name=f"AuthManager{i}", kind="class", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
results = store.search("AuthManager", kind="class", limit=50, prefix_mode=True)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
assert elapsed_ms < 100.0
|
||||
assert results
|
||||
@@ -551,3 +551,72 @@ class UserProfile:
|
||||
# Verify <15% overhead (reasonable threshold for performance tests with system variance)
|
||||
assert overhead < 15.0, f"Overhead {overhead:.2f}% exceeds 15% threshold (base={base_time:.4f}s, hybrid={hybrid_time:.4f}s)"
|
||||
|
||||
|
||||
class TestHybridChunkerV1Optimizations:
|
||||
"""Tests for v1.0 optimization behaviors (parent metadata + determinism)."""
|
||||
|
||||
def test_merged_docstring_metadata(self):
|
||||
"""Docstring chunks include parent_symbol metadata when applicable."""
|
||||
config = ChunkConfig(min_chunk_size=1)
|
||||
chunker = HybridChunker(config=config)
|
||||
|
||||
content = '''"""Module docstring."""
|
||||
|
||||
def hello():
|
||||
"""Function docstring."""
|
||||
return 1
|
||||
'''
|
||||
symbols = [Symbol(name="hello", kind="function", range=(3, 5))]
|
||||
|
||||
chunks = chunker.chunk_file(content, symbols, "m.py", "python")
|
||||
func_doc_chunks = [
|
||||
c for c in chunks
|
||||
if c.metadata.get("chunk_type") == "docstring" and c.metadata.get("start_line") == 4
|
||||
]
|
||||
assert len(func_doc_chunks) == 1
|
||||
assert func_doc_chunks[0].metadata.get("parent_symbol") == "hello"
|
||||
assert func_doc_chunks[0].metadata.get("parent_symbol_kind") == "function"
|
||||
|
||||
def test_deterministic_chunk_boundaries(self):
|
||||
"""Chunk boundaries are stable across repeated runs on identical input."""
|
||||
config = ChunkConfig(max_chunk_size=80, overlap=10, min_chunk_size=1)
|
||||
chunker = HybridChunker(config=config)
|
||||
|
||||
# No docstrings, no symbols -> sliding window path.
|
||||
content = "\n".join([f"line {i}: x = {i}" for i in range(1, 200)]) + "\n"
|
||||
|
||||
boundaries = []
|
||||
for _ in range(3):
|
||||
chunks = chunker.chunk_file(content, [], "deterministic.py", "python")
|
||||
boundaries.append([
|
||||
(c.metadata.get("start_line"), c.metadata.get("end_line"))
|
||||
for c in chunks
|
||||
if c.metadata.get("chunk_type") == "code"
|
||||
])
|
||||
|
||||
assert boundaries[0] == boundaries[1] == boundaries[2]
|
||||
|
||||
def test_orphan_docstrings(self):
|
||||
"""Module-level docstrings remain standalone (no parent_symbol assigned)."""
|
||||
config = ChunkConfig(min_chunk_size=1)
|
||||
chunker = HybridChunker(config=config)
|
||||
|
||||
content = '''"""Module-level docstring."""
|
||||
|
||||
def hello():
|
||||
"""Function docstring."""
|
||||
return 1
|
||||
'''
|
||||
symbols = [Symbol(name="hello", kind="function", range=(3, 5))]
|
||||
chunks = chunker.chunk_file(content, symbols, "orphan.py", "python")
|
||||
|
||||
module_doc = [
|
||||
c for c in chunks
|
||||
if c.metadata.get("chunk_type") == "docstring" and c.metadata.get("start_line") == 1
|
||||
]
|
||||
assert len(module_doc) == 1
|
||||
assert module_doc[0].metadata.get("parent_symbol") is None
|
||||
|
||||
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
|
||||
assert code_chunks, "Expected at least one code chunk"
|
||||
assert all("Module-level docstring" not in c.content for c in code_chunks)
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.hybrid_search import HybridSearchEngine
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
@@ -774,3 +775,97 @@ class TestHybridSearchWithVectorMock:
|
||||
assert hasattr(result, 'score')
|
||||
assert result.score > 0 # RRF fusion scores are positive
|
||||
|
||||
|
||||
class TestHybridSearchAdaptiveWeights:
|
||||
"""Integration tests for adaptive RRF weights + reranking gating."""
|
||||
|
||||
def test_adaptive_weights_code_query(self):
|
||||
"""Exact weight should dominate for code-like queries."""
|
||||
from unittest.mock import patch
|
||||
|
||||
engine = HybridSearchEngine()
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
captured = {}
|
||||
from codexlens.search import ranking as ranking_module
|
||||
|
||||
def capture_rrf(map_in, weights_in, k=60):
|
||||
captured["weights"] = dict(weights_in)
|
||||
return ranking_module.reciprocal_rank_fusion(map_in, weights_in, k=k)
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.reciprocal_rank_fusion",
|
||||
side_effect=capture_rrf,
|
||||
):
|
||||
engine.search(Path("dummy.db"), "def authenticate", enable_vector=True)
|
||||
|
||||
assert captured["weights"]["exact"] > 0.4
|
||||
|
||||
def test_adaptive_weights_nl_query(self):
|
||||
"""Vector weight should dominate for natural-language queries."""
|
||||
from unittest.mock import patch
|
||||
|
||||
engine = HybridSearchEngine()
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
captured = {}
|
||||
from codexlens.search import ranking as ranking_module
|
||||
|
||||
def capture_rrf(map_in, weights_in, k=60):
|
||||
captured["weights"] = dict(weights_in)
|
||||
return ranking_module.reciprocal_rank_fusion(map_in, weights_in, k=k)
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.reciprocal_rank_fusion",
|
||||
side_effect=capture_rrf,
|
||||
):
|
||||
engine.search(Path("dummy.db"), "how to handle user login", enable_vector=True)
|
||||
|
||||
assert captured["weights"]["vector"] > 0.6
|
||||
|
||||
def test_reranking_enabled(self, tmp_path):
|
||||
"""Reranking runs only when explicitly enabled via config."""
|
||||
from unittest.mock import patch
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
class DummyEmbedder:
|
||||
def embed(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
return [[1.0, 0.0] for _ in texts]
|
||||
|
||||
# Disabled: should not invoke rerank_results
|
||||
config_off = Config(data_dir=tmp_path / "off", enable_reranking=False)
|
||||
engine_off = HybridSearchEngine(config=config_off, embedder=DummyEmbedder())
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.rerank_results"
|
||||
) as rerank_mock:
|
||||
engine_off.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
rerank_mock.assert_not_called()
|
||||
|
||||
# Enabled: should invoke rerank_results once
|
||||
config_on = Config(data_dir=tmp_path / "on", enable_reranking=True, reranking_top_k=10)
|
||||
engine_on = HybridSearchEngine(config=config_on, embedder=DummyEmbedder())
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.rerank_results",
|
||||
side_effect=lambda q, r, e, top_k=50: r,
|
||||
) as rerank_mock:
|
||||
engine_on.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
assert rerank_mock.call_count == 1
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import pytest
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.search.hybrid_search import HybridSearchEngine
|
||||
@@ -16,6 +17,22 @@ except ImportError:
|
||||
SEMANTIC_DEPS_AVAILABLE = False
|
||||
|
||||
|
||||
def _safe_unlink(path: Path, retries: int = 5, delay_s: float = 0.05) -> None:
|
||||
"""Best-effort unlink for Windows where SQLite can keep files locked briefly."""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
path.unlink()
|
||||
return
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except PermissionError:
|
||||
time.sleep(delay_s * (attempt + 1))
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
class TestPureVectorSearch:
|
||||
"""Tests for pure vector search mode."""
|
||||
|
||||
@@ -48,7 +65,7 @@ class TestPureVectorSearch:
|
||||
store.close()
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
_safe_unlink(db_path)
|
||||
|
||||
def test_pure_vector_without_embeddings(self, sample_db):
|
||||
"""Test pure_vector mode returns empty when no embeddings exist."""
|
||||
@@ -200,12 +217,8 @@ def login_handler(credentials: dict) -> bool:
|
||||
yield db_path
|
||||
store.close()
|
||||
|
||||
# Ignore file deletion errors on Windows (SQLite file lock)
|
||||
try:
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
except PermissionError:
|
||||
pass # Ignore Windows file lock errors
|
||||
if db_path.exists():
|
||||
_safe_unlink(db_path)
|
||||
|
||||
def test_pure_vector_with_embeddings(self, db_with_embeddings):
|
||||
"""Test pure vector search returns results when embeddings exist."""
|
||||
@@ -289,7 +302,7 @@ class TestSearchModeComparison:
|
||||
store.close()
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
_safe_unlink(db_path)
|
||||
|
||||
def test_mode_comparison_without_embeddings(self, comparison_db):
|
||||
"""Compare all search modes without embeddings."""
|
||||
|
||||
@@ -7,8 +7,12 @@ import pytest
|
||||
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import (
|
||||
apply_symbol_boost,
|
||||
QueryIntent,
|
||||
detect_query_intent,
|
||||
normalize_bm25_score,
|
||||
reciprocal_rank_fusion,
|
||||
rerank_results,
|
||||
tag_search_source,
|
||||
)
|
||||
|
||||
@@ -342,6 +346,62 @@ class TestTagSearchSource:
|
||||
assert tagged[0].symbol_kind == "function"
|
||||
|
||||
|
||||
class TestSymbolBoost:
|
||||
"""Tests for apply_symbol_boost function."""
|
||||
|
||||
def test_symbol_boost(self):
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.2, excerpt="...", symbol_name="foo"),
|
||||
SearchResult(path="b.py", score=0.21, excerpt="..."),
|
||||
]
|
||||
|
||||
boosted = apply_symbol_boost(results, boost_factor=1.5)
|
||||
|
||||
assert boosted[0].path == "a.py"
|
||||
assert boosted[0].score == pytest.approx(0.2 * 1.5)
|
||||
assert boosted[0].metadata["boosted"] is True
|
||||
assert boosted[0].metadata["original_fusion_score"] == pytest.approx(0.2)
|
||||
|
||||
assert boosted[1].path == "b.py"
|
||||
assert boosted[1].score == pytest.approx(0.21)
|
||||
assert "boosted" not in boosted[1].metadata
|
||||
|
||||
|
||||
class TestEmbeddingReranking:
|
||||
"""Tests for rerank_results embedding-based similarity."""
|
||||
|
||||
def test_rerank_embedding_similarity(self):
|
||||
class DummyEmbedder:
|
||||
def embed(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
mapping = {
|
||||
"query": [1.0, 0.0],
|
||||
"doc1": [1.0, 0.0],
|
||||
"doc2": [0.0, 1.0],
|
||||
}
|
||||
return [mapping[t] for t in texts]
|
||||
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.2, excerpt="doc1"),
|
||||
SearchResult(path="b.py", score=0.9, excerpt="doc2"),
|
||||
]
|
||||
|
||||
reranked = rerank_results("query", results, DummyEmbedder(), top_k=2)
|
||||
|
||||
assert reranked[0].path == "a.py"
|
||||
assert reranked[0].metadata["reranked"] is True
|
||||
assert reranked[0].metadata["rrf_score"] == pytest.approx(0.2)
|
||||
assert reranked[0].metadata["cosine_similarity"] == pytest.approx(1.0)
|
||||
assert reranked[0].score == pytest.approx(0.5 * 0.2 + 0.5 * 1.0)
|
||||
|
||||
assert reranked[1].path == "b.py"
|
||||
assert reranked[1].metadata["reranked"] is True
|
||||
assert reranked[1].metadata["rrf_score"] == pytest.approx(0.9)
|
||||
assert reranked[1].metadata["cosine_similarity"] == pytest.approx(0.0)
|
||||
assert reranked[1].score == pytest.approx(0.5 * 0.9 + 0.5 * 0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k_value", [30, 60, 100])
|
||||
class TestRRFParameterized:
|
||||
"""Parameterized tests for RRF with different k values."""
|
||||
@@ -419,3 +479,41 @@ class TestRRFEdgeCases:
|
||||
# Should work with normalization
|
||||
assert len(fused) == 1 # Deduplicated
|
||||
assert fused[0].score > 0
|
||||
|
||||
|
||||
class TestSymbolBoostAndIntentV1:
|
||||
"""Tests for symbol boosting and query intent detection (v1.0)."""
|
||||
|
||||
def test_symbol_boost_application(self):
|
||||
"""Results with symbol_name receive a multiplicative boost (default 1.5x)."""
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.4, excerpt="...", symbol_name="AuthManager"),
|
||||
SearchResult(path="b.py", score=0.41, excerpt="..."),
|
||||
]
|
||||
|
||||
boosted = apply_symbol_boost(results, boost_factor=1.5)
|
||||
|
||||
assert boosted[0].score == pytest.approx(0.4 * 1.5)
|
||||
assert boosted[0].metadata["boosted"] is True
|
||||
assert boosted[0].metadata["original_fusion_score"] == pytest.approx(0.4)
|
||||
assert boosted[1].score == pytest.approx(0.41)
|
||||
assert "boosted" not in boosted[1].metadata
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("query", "expected"),
|
||||
[
|
||||
("def authenticate", QueryIntent.KEYWORD),
|
||||
("MyClass", QueryIntent.KEYWORD),
|
||||
("user_id", QueryIntent.KEYWORD),
|
||||
("UserService::authenticate", QueryIntent.KEYWORD),
|
||||
("ptr->next", QueryIntent.KEYWORD),
|
||||
("how to handle user login", QueryIntent.SEMANTIC),
|
||||
("what is authentication?", QueryIntent.SEMANTIC),
|
||||
("where is this used?", QueryIntent.SEMANTIC),
|
||||
("why does FooBar crash?", QueryIntent.MIXED),
|
||||
("how to use user_id in query", QueryIntent.MIXED),
|
||||
],
|
||||
)
|
||||
def test_query_intent_detection(self, query, expected):
|
||||
"""Detect intent for representative queries (Python/TypeScript parity)."""
|
||||
assert detect_query_intent(query) == expected
|
||||
|
||||
@@ -466,7 +466,18 @@ class TestDiagnostics:
|
||||
|
||||
yield db_path
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
for attempt in range(5):
|
||||
try:
|
||||
db_path.unlink()
|
||||
break
|
||||
except PermissionError:
|
||||
time.sleep(0.05 * (attempt + 1))
|
||||
else:
|
||||
# Best-effort cleanup (Windows SQLite locks can linger briefly).
|
||||
try:
|
||||
db_path.unlink(missing_ok=True)
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
def test_diagnose_empty_database(self, empty_db):
|
||||
"""Diagnose behavior with empty database."""
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestChunkConfig:
|
||||
"""Test default configuration values."""
|
||||
config = ChunkConfig()
|
||||
assert config.max_chunk_size == 1000
|
||||
assert config.overlap == 100
|
||||
assert config.overlap == 200
|
||||
assert config.min_chunk_size == 50
|
||||
|
||||
def test_custom_config(self):
|
||||
|
||||
Reference in New Issue
Block a user