mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-03-19 18:58:47 +08:00
Implement search and reranking functionality with FTS and embedding support
- Added BaseReranker abstract class for defining reranking interfaces. - Implemented FastEmbedReranker using fastembed's TextCrossEncoder for scoring document-query pairs. - Introduced FTSEngine for full-text search capabilities using SQLite FTS5. - Developed SearchPipeline to integrate embedding, binary search, ANN indexing, FTS, and reranking. - Added fusion methods for combining results from different search strategies using Reciprocal Rank Fusion. - Created unit and integration tests for the new search and reranking components. - Established configuration management for search parameters and models.
This commit is contained in:
0
codex-lens-v2/tests/__init__.py
Normal file
0
codex-lens-v2/tests/__init__.py
Normal file
0
codex-lens-v2/tests/integration/__init__.py
Normal file
0
codex-lens-v2/tests/integration/__init__.py
Normal file
108
codex-lens-v2/tests/integration/conftest.py
Normal file
108
codex-lens-v2/tests/integration/conftest.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.core import ANNIndex, BinaryStore
|
||||
from codexlens.embed.base import BaseEmbedder
|
||||
from codexlens.rerank.base import BaseReranker
|
||||
from codexlens.search.fts import FTSEngine
|
||||
from codexlens.search.pipeline import SearchPipeline
|
||||
|
||||
# Test documents: 20 code snippets with id, path, content
|
||||
TEST_DOCS = [
|
||||
(0, "auth.py", "def authenticate(user, password): return check_hash(password, user.hash)"),
|
||||
(1, "auth.py", "def authorize(user, permission): return permission in user.roles"),
|
||||
(2, "models.py", "class User: def __init__(self, name, email): self.name = name; self.email = email"),
|
||||
(3, "models.py", "class Session: token = None; expires_at = None"),
|
||||
(4, "middleware.py", "def auth_middleware(request): token = request.headers.get('Authorization')"),
|
||||
(5, "utils.py", "def hash_password(password): import bcrypt; return bcrypt.hashpw(password)"),
|
||||
(6, "config.py", "DATABASE_URL = os.environ.get('DATABASE_URL', 'sqlite:///db.sqlite3')"),
|
||||
(7, "search.py", "def search_users(query): return User.objects.filter(name__icontains=query)"),
|
||||
(8, "api.py", "def get_user(request, user_id): user = User.objects.get(id=user_id)"),
|
||||
(9, "api.py", "def create_user(request): data = request.json(); user = User(**data)"),
|
||||
(10, "tests.py", "def test_authenticate(): assert authenticate('admin', 'pass') is not None"),
|
||||
(11, "tests.py", "def test_search(): results = search_users('alice'); assert len(results) > 0"),
|
||||
(12, "router.py", "app.route('/users', methods=['GET'])(list_users)"),
|
||||
(13, "router.py", "app.route('/login', methods=['POST'])(login_handler)"),
|
||||
(14, "db.py", "def get_connection(): return sqlite3.connect(DATABASE_URL)"),
|
||||
(15, "cache.py", "def cache_get(key): return redis_client.get(key)"),
|
||||
(16, "cache.py", "def cache_set(key, value, ttl=3600): redis_client.setex(key, ttl, value)"),
|
||||
(17, "errors.py", "class AuthError(Exception): status_code = 401"),
|
||||
(18, "errors.py", "class NotFoundError(Exception): status_code = 404"),
|
||||
(19, "validators.py", "def validate_email(email): return '@' in email and '.' in email.split('@')[1]"),
|
||||
]
|
||||
|
||||
DIM = 32 # Use small dim for fast tests
|
||||
|
||||
|
||||
def make_stable_vec(doc_id: int, dim: int = DIM) -> np.ndarray:
|
||||
"""Generate a deterministic float32 vector for a given doc_id."""
|
||||
rng = np.random.default_rng(seed=doc_id)
|
||||
vec = rng.standard_normal(dim).astype(np.float32)
|
||||
vec /= np.linalg.norm(vec)
|
||||
return vec
|
||||
|
||||
|
||||
class MockEmbedder(BaseEmbedder):
|
||||
"""Returns stable deterministic vectors based on content hash."""
|
||||
|
||||
def embed_single(self, text: str) -> np.ndarray:
|
||||
seed = hash(text) % (2**31)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
vec = rng.standard_normal(DIM).astype(np.float32)
|
||||
vec /= np.linalg.norm(vec)
|
||||
return vec
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
return [self.embed_single(t) for t in texts]
|
||||
|
||||
def embed(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""Called by SearchPipeline as self._embedder.embed([query])[0]."""
|
||||
return self.embed_batch(texts)
|
||||
|
||||
|
||||
class MockReranker(BaseReranker):
|
||||
"""Returns score based on simple keyword overlap."""
|
||||
|
||||
def score_pairs(self, query: str, documents: list[str]) -> list[float]:
|
||||
query_words = set(query.lower().split())
|
||||
scores = []
|
||||
for doc in documents:
|
||||
doc_words = set(doc.lower().split())
|
||||
overlap = len(query_words & doc_words)
|
||||
scores.append(float(overlap) / max(len(query_words), 1))
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return Config.small() # hnsw_ef=50, hnsw_M=16, binary_top_k=50, ann_top_k=20, rerank_top_k=10
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_pipeline(tmp_path, config):
|
||||
"""Build a full SearchPipeline with 20 test docs indexed."""
|
||||
embedder = MockEmbedder()
|
||||
binary_store = BinaryStore(tmp_path / "binary", dim=DIM, config=config)
|
||||
ann_index = ANNIndex(tmp_path / "ann.hnsw", dim=DIM, config=config)
|
||||
fts = FTSEngine(tmp_path / "fts.db")
|
||||
reranker = MockReranker()
|
||||
|
||||
# Index all test docs
|
||||
ids = np.array([d[0] for d in TEST_DOCS], dtype=np.int64)
|
||||
vectors = np.array([embedder.embed_single(d[2]) for d in TEST_DOCS], dtype=np.float32)
|
||||
|
||||
binary_store.add(ids, vectors)
|
||||
ann_index.add(ids, vectors)
|
||||
fts.add_documents(TEST_DOCS)
|
||||
|
||||
return SearchPipeline(
|
||||
embedder=embedder,
|
||||
binary_store=binary_store,
|
||||
ann_index=ann_index,
|
||||
reranker=reranker,
|
||||
fts=fts,
|
||||
config=config,
|
||||
)
|
||||
44
codex-lens-v2/tests/integration/test_search_pipeline.py
Normal file
44
codex-lens-v2/tests/integration/test_search_pipeline.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Integration tests for SearchPipeline using real components and mock embedder/reranker."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def test_vector_search_returns_results(search_pipeline):
|
||||
results = search_pipeline.search("authentication middleware")
|
||||
assert len(results) > 0
|
||||
assert all(isinstance(r.score, float) for r in results)
|
||||
|
||||
|
||||
def test_exact_keyword_search(search_pipeline):
|
||||
results = search_pipeline.search("authenticate")
|
||||
assert len(results) > 0
|
||||
result_ids = {r.id for r in results}
|
||||
# Doc 0 and 10 both contain "authenticate"
|
||||
assert result_ids & {0, 10}, f"Expected doc 0 or 10 in results, got {result_ids}"
|
||||
|
||||
|
||||
def test_pipeline_top_k_limit(search_pipeline):
|
||||
results = search_pipeline.search("user", top_k=5)
|
||||
assert len(results) <= 5
|
||||
|
||||
|
||||
def test_search_result_fields_populated(search_pipeline):
|
||||
results = search_pipeline.search("password")
|
||||
assert len(results) > 0
|
||||
for r in results:
|
||||
assert r.id >= 0
|
||||
assert r.score >= 0
|
||||
assert isinstance(r.path, str)
|
||||
|
||||
|
||||
def test_empty_query_handled(search_pipeline):
|
||||
results = search_pipeline.search("")
|
||||
assert isinstance(results, list) # no exception
|
||||
|
||||
|
||||
def test_different_queries_give_different_results(search_pipeline):
|
||||
r1 = search_pipeline.search("authenticate user")
|
||||
r2 = search_pipeline.search("cache redis")
|
||||
# Results should differ (different top IDs or scores), unless both are empty
|
||||
ids1 = [r.id for r in r1]
|
||||
ids2 = [r.id for r in r2]
|
||||
assert ids1 != ids2 or len(r1) == 0
|
||||
0
codex-lens-v2/tests/unit/__init__.py
Normal file
0
codex-lens-v2/tests/unit/__init__.py
Normal file
31
codex-lens-v2/tests/unit/test_config.py
Normal file
31
codex-lens-v2/tests/unit/test_config.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from codexlens.config import Config
|
||||
|
||||
|
||||
def test_config_instantiates_no_args():
|
||||
cfg = Config()
|
||||
assert cfg is not None
|
||||
|
||||
|
||||
def test_defaults_hnsw_ef():
|
||||
cfg = Config.defaults()
|
||||
assert cfg.hnsw_ef == 150
|
||||
|
||||
|
||||
def test_defaults_hnsw_M():
|
||||
cfg = Config.defaults()
|
||||
assert cfg.hnsw_M == 32
|
||||
|
||||
|
||||
def test_small_hnsw_ef():
|
||||
cfg = Config.small()
|
||||
assert cfg.hnsw_ef == 50
|
||||
|
||||
|
||||
def test_custom_instantiation():
|
||||
cfg = Config(hnsw_ef=100)
|
||||
assert cfg.hnsw_ef == 100
|
||||
|
||||
|
||||
def test_fusion_weights_keys():
|
||||
cfg = Config()
|
||||
assert set(cfg.fusion_weights.keys()) == {"exact", "fuzzy", "vector", "graph"}
|
||||
136
codex-lens-v2/tests/unit/test_core.py
Normal file
136
codex-lens-v2/tests/unit/test_core.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Unit tests for BinaryStore and ANNIndex (no fastembed required)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.core import ANNIndex, BinaryStore
|
||||
|
||||
|
||||
DIM = 32
|
||||
RNG = np.random.default_rng(42)
|
||||
|
||||
|
||||
def make_vectors(n: int, dim: int = DIM) -> np.ndarray:
|
||||
return RNG.standard_normal((n, dim)).astype(np.float32)
|
||||
|
||||
|
||||
def make_ids(n: int, start: int = 0) -> np.ndarray:
|
||||
return np.arange(start, start + n, dtype=np.int64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BinaryStore tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBinaryStore:
|
||||
def test_binary_store_add_and_search(self, tmp_path: Path) -> None:
|
||||
cfg = Config.small()
|
||||
store = BinaryStore(tmp_path, DIM, cfg)
|
||||
vecs = make_vectors(10)
|
||||
ids = make_ids(10)
|
||||
store.add(ids, vecs)
|
||||
|
||||
assert len(store) == 10
|
||||
|
||||
top_k = 5
|
||||
ret_ids, ret_dists = store.coarse_search(vecs[0], top_k=top_k)
|
||||
assert ret_ids.shape == (top_k,)
|
||||
assert ret_dists.shape == (top_k,)
|
||||
# distances are non-negative integers
|
||||
assert (ret_dists >= 0).all()
|
||||
|
||||
def test_binary_hamming_correctness(self, tmp_path: Path) -> None:
|
||||
cfg = Config.small()
|
||||
store = BinaryStore(tmp_path, DIM, cfg)
|
||||
vecs = make_vectors(20)
|
||||
ids = make_ids(20)
|
||||
store.add(ids, vecs)
|
||||
|
||||
# Query with the exact stored vector; it must be the top-1 result
|
||||
query = vecs[7]
|
||||
ret_ids, ret_dists = store.coarse_search(query, top_k=1)
|
||||
assert ret_ids[0] == 7
|
||||
assert ret_dists[0] == 0 # Hamming distance to itself is 0
|
||||
|
||||
def test_binary_store_persist(self, tmp_path: Path) -> None:
|
||||
cfg = Config.small()
|
||||
store = BinaryStore(tmp_path, DIM, cfg)
|
||||
vecs = make_vectors(15)
|
||||
ids = make_ids(15)
|
||||
store.add(ids, vecs)
|
||||
store.save()
|
||||
|
||||
# Load into a fresh instance
|
||||
store2 = BinaryStore(tmp_path, DIM, cfg)
|
||||
assert len(store2) == 15
|
||||
|
||||
query = vecs[3]
|
||||
ret_ids, ret_dists = store2.coarse_search(query, top_k=1)
|
||||
assert ret_ids[0] == 3
|
||||
assert ret_dists[0] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ANNIndex tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestANNIndex:
|
||||
def test_ann_index_add_and_search(self, tmp_path: Path) -> None:
|
||||
cfg = Config.small()
|
||||
idx = ANNIndex(tmp_path, DIM, cfg)
|
||||
vecs = make_vectors(50)
|
||||
ids = make_ids(50)
|
||||
idx.add(ids, vecs)
|
||||
|
||||
assert len(idx) == 50
|
||||
|
||||
ret_ids, ret_dists = idx.fine_search(vecs[0], top_k=5)
|
||||
assert len(ret_ids) == 5
|
||||
assert len(ret_dists) == 5
|
||||
|
||||
def test_ann_index_thread_safety(self, tmp_path: Path) -> None:
|
||||
cfg = Config.small()
|
||||
idx = ANNIndex(tmp_path, DIM, cfg)
|
||||
vecs = make_vectors(50)
|
||||
ids = make_ids(50)
|
||||
idx.add(ids, vecs)
|
||||
|
||||
query = vecs[0]
|
||||
errors: list[Exception] = []
|
||||
|
||||
def search() -> None:
|
||||
try:
|
||||
idx.fine_search(query, top_k=3)
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool:
|
||||
futures = [pool.submit(search) for _ in range(5)]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
assert errors == [], f"Thread safety errors: {errors}"
|
||||
|
||||
def test_ann_index_save_load(self, tmp_path: Path) -> None:
|
||||
cfg = Config.small()
|
||||
idx = ANNIndex(tmp_path, DIM, cfg)
|
||||
vecs = make_vectors(30)
|
||||
ids = make_ids(30)
|
||||
idx.add(ids, vecs)
|
||||
idx.save()
|
||||
|
||||
# Load into a fresh instance
|
||||
idx2 = ANNIndex(tmp_path, DIM, cfg)
|
||||
idx2.load()
|
||||
assert len(idx2) == 30
|
||||
|
||||
ret_ids, ret_dists = idx2.fine_search(vecs[10], top_k=1)
|
||||
assert len(ret_ids) == 1
|
||||
assert ret_ids[0] == 10
|
||||
80
codex-lens-v2/tests/unit/test_embed.py
Normal file
80
codex-lens-v2/tests/unit/test_embed.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _make_fastembed_mock():
|
||||
"""Build a minimal fastembed stub so imports succeed without the real package."""
|
||||
fastembed_mod = types.ModuleType("fastembed")
|
||||
fastembed_mod.TextEmbedding = MagicMock()
|
||||
sys.modules.setdefault("fastembed", fastembed_mod)
|
||||
return fastembed_mod
|
||||
|
||||
|
||||
_make_fastembed_mock()
|
||||
|
||||
from codexlens.config import Config # noqa: E402
|
||||
from codexlens.embed.base import BaseEmbedder # noqa: E402
|
||||
from codexlens.embed.local import EMBED_PROFILES, FastEmbedEmbedder # noqa: E402
|
||||
|
||||
|
||||
class TestEmbedSingle(unittest.TestCase):
|
||||
def test_embed_single_returns_float32_ndarray(self):
|
||||
config = Config()
|
||||
embedder = FastEmbedEmbedder(config)
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.embed.return_value = iter([np.ones(384, dtype=np.float64)])
|
||||
|
||||
# Inject mock model directly to bypass lazy load (no real fastembed needed)
|
||||
embedder._model = mock_model
|
||||
result = embedder.embed_single("hello world")
|
||||
|
||||
self.assertIsInstance(result, np.ndarray)
|
||||
self.assertEqual(result.dtype, np.float32)
|
||||
self.assertEqual(result.shape, (384,))
|
||||
|
||||
|
||||
class TestEmbedBatch(unittest.TestCase):
|
||||
def test_embed_batch_returns_list(self):
|
||||
config = Config()
|
||||
embedder = FastEmbedEmbedder(config)
|
||||
|
||||
vecs = [np.ones(384, dtype=np.float64) * i for i in range(3)]
|
||||
mock_model = MagicMock()
|
||||
mock_model.embed.return_value = iter(vecs)
|
||||
|
||||
embedder._model = mock_model
|
||||
result = embedder.embed_batch(["a", "b", "c"])
|
||||
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertEqual(len(result), 3)
|
||||
for arr in result:
|
||||
self.assertIsInstance(arr, np.ndarray)
|
||||
self.assertEqual(arr.dtype, np.float32)
|
||||
|
||||
|
||||
class TestEmbedProfiles(unittest.TestCase):
|
||||
def test_embed_profiles_all_have_valid_keys(self):
|
||||
expected_keys = {"small", "base", "large", "code"}
|
||||
self.assertEqual(set(EMBED_PROFILES.keys()), expected_keys)
|
||||
|
||||
def test_embed_profiles_model_ids_non_empty(self):
|
||||
for key, model_id in EMBED_PROFILES.items():
|
||||
self.assertIsInstance(model_id, str, msg=f"{key} model id should be str")
|
||||
self.assertTrue(len(model_id) > 0, msg=f"{key} model id should be non-empty")
|
||||
|
||||
|
||||
class TestBaseEmbedderAbstract(unittest.TestCase):
|
||||
def test_base_embedder_is_abstract(self):
|
||||
with self.assertRaises(TypeError):
|
||||
BaseEmbedder() # type: ignore[abstract]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
179
codex-lens-v2/tests/unit/test_rerank.py
Normal file
179
codex-lens-v2/tests/unit/test_rerank.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.rerank.base import BaseReranker
|
||||
from codexlens.rerank.local import FastEmbedReranker
|
||||
from codexlens.rerank.api import APIReranker
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseReranker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_base_reranker_is_abstract():
|
||||
with pytest.raises(TypeError):
|
||||
BaseReranker() # type: ignore[abstract]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastEmbedReranker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_rerank_result(index: int, score: float) -> object:
|
||||
obj = types.SimpleNamespace(index=index, score=score)
|
||||
return obj
|
||||
|
||||
|
||||
def test_local_reranker_score_pairs_length():
|
||||
config = Config()
|
||||
reranker = FastEmbedReranker(config)
|
||||
|
||||
mock_results = [
|
||||
_make_rerank_result(0, 0.9),
|
||||
_make_rerank_result(1, 0.5),
|
||||
_make_rerank_result(2, 0.1),
|
||||
]
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.rerank.return_value = iter(mock_results)
|
||||
reranker._model = mock_model
|
||||
|
||||
docs = ["doc0", "doc1", "doc2"]
|
||||
scores = reranker.score_pairs("query", docs)
|
||||
|
||||
assert len(scores) == 3
|
||||
|
||||
|
||||
def test_local_reranker_preserves_order():
|
||||
config = Config()
|
||||
reranker = FastEmbedReranker(config)
|
||||
|
||||
# rerank returns results in reverse order (index 2, 1, 0)
|
||||
mock_results = [
|
||||
_make_rerank_result(2, 0.1),
|
||||
_make_rerank_result(1, 0.5),
|
||||
_make_rerank_result(0, 0.9),
|
||||
]
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.rerank.return_value = iter(mock_results)
|
||||
reranker._model = mock_model
|
||||
|
||||
docs = ["doc0", "doc1", "doc2"]
|
||||
scores = reranker.score_pairs("query", docs)
|
||||
|
||||
assert scores[0] == pytest.approx(0.9)
|
||||
assert scores[1] == pytest.approx(0.5)
|
||||
assert scores[2] == pytest.approx(0.1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# APIReranker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_config(max_tokens_per_batch: int = 512) -> Config:
|
||||
return Config(
|
||||
reranker_api_url="https://api.example.com",
|
||||
reranker_api_key="test-key",
|
||||
reranker_api_model="test-model",
|
||||
reranker_api_max_tokens_per_batch=max_tokens_per_batch,
|
||||
)
|
||||
|
||||
|
||||
def test_api_reranker_batch_splitting():
|
||||
config = _make_config(max_tokens_per_batch=512)
|
||||
|
||||
with patch("httpx.Client"):
|
||||
reranker = APIReranker(config)
|
||||
|
||||
# 10 docs, each ~200 tokens (800 chars)
|
||||
docs = ["x" * 800] * 10
|
||||
batches = reranker._split_batches(docs, max_tokens=512)
|
||||
|
||||
# Each doc is 200 tokens; batches should have at most 2 docs (200+200=400 <= 512, 400+200=600 > 512)
|
||||
assert len(batches) > 1
|
||||
for batch in batches:
|
||||
total = sum(len(text) // 4 for _, text in batch)
|
||||
assert total <= 512 or len(batch) == 1
|
||||
|
||||
|
||||
def test_api_reranker_retry_on_429():
|
||||
config = _make_config()
|
||||
|
||||
mock_429 = MagicMock()
|
||||
mock_429.status_code = 429
|
||||
|
||||
mock_200 = MagicMock()
|
||||
mock_200.status_code = 200
|
||||
mock_200.json.return_value = {
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.8},
|
||||
{"index": 1, "relevance_score": 0.3},
|
||||
]
|
||||
}
|
||||
mock_200.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.post.side_effect = [mock_429, mock_429, mock_200]
|
||||
|
||||
reranker = APIReranker(config)
|
||||
|
||||
with patch("time.sleep"):
|
||||
result = reranker._call_api_with_retry(
|
||||
"query",
|
||||
[(0, "doc0"), (1, "doc1")],
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
assert mock_client.post.call_count == 3
|
||||
assert 0 in result
|
||||
assert 1 in result
|
||||
|
||||
|
||||
def test_api_reranker_merge_batches():
|
||||
config = _make_config(max_tokens_per_batch=100)
|
||||
|
||||
# 4 docs of 25 tokens each (100 chars); each batch holds at most 4 docs
|
||||
# Use smaller docs to force 2 batches: 2 docs per batch (50 tokens each = 200 chars)
|
||||
docs = ["x" * 200] * 4 # 50 tokens each; 50+50=100 <= 100, 100+50=150 > 100 -> 2 per batch
|
||||
|
||||
batch0_response = MagicMock()
|
||||
batch0_response.status_code = 200
|
||||
batch0_response.json.return_value = {
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
]
|
||||
}
|
||||
batch0_response.raise_for_status = MagicMock()
|
||||
|
||||
batch1_response = MagicMock()
|
||||
batch1_response.status_code = 200
|
||||
batch1_response.json.return_value = {
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.7},
|
||||
{"index": 1, "relevance_score": 0.6},
|
||||
]
|
||||
}
|
||||
batch1_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.post.side_effect = [batch0_response, batch1_response]
|
||||
|
||||
reranker = APIReranker(config)
|
||||
|
||||
with patch("time.sleep"):
|
||||
scores = reranker.score_pairs("query", docs)
|
||||
|
||||
assert len(scores) == 4
|
||||
# All original indices should have scores
|
||||
assert all(s > 0 for s in scores)
|
||||
156
codex-lens-v2/tests/unit/test_search.py
Normal file
156
codex-lens-v2/tests/unit/test_search.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Unit tests for search layer: FTSEngine, fusion, and SearchPipeline."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.search.fts import FTSEngine
|
||||
from codexlens.search.fusion import (
|
||||
DEFAULT_WEIGHTS,
|
||||
QueryIntent,
|
||||
detect_query_intent,
|
||||
get_adaptive_weights,
|
||||
reciprocal_rank_fusion,
|
||||
)
|
||||
from codexlens.search.pipeline import SearchPipeline, SearchResult
|
||||
from codexlens.config import Config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_fts(docs: list[tuple[int, str, str]] | None = None) -> FTSEngine:
|
||||
"""Create an in-memory FTSEngine and optionally add documents."""
|
||||
engine = FTSEngine(":memory:")
|
||||
if docs:
|
||||
engine.add_documents(docs)
|
||||
return engine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FTSEngine tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_fts_add_and_exact_search():
|
||||
docs = [
|
||||
(1, "a.py", "def authenticate user password login"),
|
||||
(2, "b.py", "connect to database with credentials"),
|
||||
(3, "c.py", "render template html response"),
|
||||
]
|
||||
engine = make_fts(docs)
|
||||
results = engine.exact_search("authenticate", top_k=10)
|
||||
ids = [r[0] for r in results]
|
||||
assert 1 in ids, "doc 1 should match 'authenticate'"
|
||||
assert 2 not in ids or results[0][0] == 1 # doc 1 must rank higher
|
||||
|
||||
|
||||
def test_fts_fuzzy_search_prefix():
|
||||
docs = [
|
||||
(10, "auth.py", "authentication token refresh"),
|
||||
(11, "db.py", "database connection pool"),
|
||||
(12, "ui.py", "render button click handler"),
|
||||
]
|
||||
engine = make_fts(docs)
|
||||
# Prefix 'auth' should match 'authentication' in doc 10
|
||||
results = engine.fuzzy_search("auth", top_k=10)
|
||||
ids = [r[0] for r in results]
|
||||
assert 10 in ids, "prefix 'auth' should match doc 10 with 'authentication'"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RRF fusion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_rrf_fusion_ordering():
|
||||
"""When two sources agree on top-1, it should rank first in fused result."""
|
||||
source_a = [(1, 0.9), (2, 0.5), (3, 0.2)]
|
||||
source_b = [(1, 0.8), (3, 0.6), (2, 0.1)]
|
||||
fused = reciprocal_rank_fusion({"a": source_a, "b": source_b})
|
||||
assert fused[0][0] == 1, "doc 1 agreed top by both sources must rank first"
|
||||
|
||||
|
||||
def test_rrf_equal_weight_default():
|
||||
"""Calling with None weights should use DEFAULT_WEIGHTS shape (not crash)."""
|
||||
source_exact = [(5, 1.0), (6, 0.8)]
|
||||
source_vector = [(6, 0.9), (5, 0.7)]
|
||||
# Should not raise and should return results
|
||||
fused = reciprocal_rank_fusion(
|
||||
{"exact": source_exact, "vector": source_vector},
|
||||
weights=None,
|
||||
)
|
||||
assert len(fused) == 2
|
||||
ids = [r[0] for r in fused]
|
||||
assert 5 in ids and 6 in ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# detect_query_intent tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_detect_intent_code_symbol():
|
||||
assert detect_query_intent("def authenticate()") == QueryIntent.CODE_SYMBOL
|
||||
|
||||
|
||||
def test_detect_intent_natural():
|
||||
assert detect_query_intent("how do I authenticate users") == QueryIntent.NATURAL_LANGUAGE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchPipeline tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_pipeline(fts: FTSEngine, top_k: int = 5) -> SearchPipeline:
|
||||
"""Build a SearchPipeline with mocked heavy components."""
|
||||
cfg = Config.small()
|
||||
cfg.reranker_top_k = top_k
|
||||
|
||||
embedder = MagicMock()
|
||||
embedder.embed.return_value = [[0.1] * cfg.embed_dim]
|
||||
|
||||
binary_store = MagicMock()
|
||||
binary_store.coarse_search.return_value = ([1, 2, 3], None)
|
||||
|
||||
ann_index = MagicMock()
|
||||
ann_index.fine_search.return_value = ([1, 2, 3], [0.9, 0.8, 0.7])
|
||||
|
||||
reranker = MagicMock()
|
||||
# Return a score for each content string passed
|
||||
reranker.score_pairs.side_effect = lambda q, contents: [0.9 - i * 0.1 for i in range(len(contents))]
|
||||
|
||||
return SearchPipeline(
|
||||
embedder=embedder,
|
||||
binary_store=binary_store,
|
||||
ann_index=ann_index,
|
||||
reranker=reranker,
|
||||
fts=fts,
|
||||
config=cfg,
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_search_returns_results():
|
||||
docs = [
|
||||
(1, "a.py", "test content alpha"),
|
||||
(2, "b.py", "test content beta"),
|
||||
(3, "c.py", "test content gamma"),
|
||||
]
|
||||
fts = make_fts(docs)
|
||||
pipeline = _make_pipeline(fts)
|
||||
results = pipeline.search("test")
|
||||
assert len(results) > 0
|
||||
assert all(isinstance(r, SearchResult) for r in results)
|
||||
|
||||
|
||||
def test_pipeline_top_k_limit():
|
||||
docs = [
|
||||
(1, "a.py", "hello world one"),
|
||||
(2, "b.py", "hello world two"),
|
||||
(3, "c.py", "hello world three"),
|
||||
(4, "d.py", "hello world four"),
|
||||
(5, "e.py", "hello world five"),
|
||||
]
|
||||
fts = make_fts(docs)
|
||||
pipeline = _make_pipeline(fts, top_k=2)
|
||||
results = pipeline.search("hello", top_k=2)
|
||||
assert len(results) <= 2, "pipeline must respect top_k limit"
|
||||
Reference in New Issue
Block a user