Implement search and reranking functionality with FTS and embedding support

- Added BaseReranker abstract class for defining reranking interfaces.
- Implemented FastEmbedReranker using fastembed's TextCrossEncoder for scoring document-query pairs.
- Introduced FTSEngine for full-text search capabilities using SQLite FTS5.
- Developed SearchPipeline to integrate embedding, binary search, ANN indexing, FTS, and reranking.
- Added fusion methods for combining results from different search strategies using Reciprocal Rank Fusion.
- Created unit and integration tests for the new search and reranking components.
- Established configuration management for search parameters and models.
This commit is contained in:
catlog22
2026-03-16 23:03:17 +08:00
parent 5a4b18d9b1
commit de4158597b
41 changed files with 2655 additions and 1848 deletions

View File

View File

@@ -0,0 +1,108 @@
import pytest
import numpy as np
import tempfile
from pathlib import Path
from codexlens.config import Config
from codexlens.core import ANNIndex, BinaryStore
from codexlens.embed.base import BaseEmbedder
from codexlens.rerank.base import BaseReranker
from codexlens.search.fts import FTSEngine
from codexlens.search.pipeline import SearchPipeline
# Test documents: 20 code snippets with id, path, content
TEST_DOCS = [
(0, "auth.py", "def authenticate(user, password): return check_hash(password, user.hash)"),
(1, "auth.py", "def authorize(user, permission): return permission in user.roles"),
(2, "models.py", "class User: def __init__(self, name, email): self.name = name; self.email = email"),
(3, "models.py", "class Session: token = None; expires_at = None"),
(4, "middleware.py", "def auth_middleware(request): token = request.headers.get('Authorization')"),
(5, "utils.py", "def hash_password(password): import bcrypt; return bcrypt.hashpw(password)"),
(6, "config.py", "DATABASE_URL = os.environ.get('DATABASE_URL', 'sqlite:///db.sqlite3')"),
(7, "search.py", "def search_users(query): return User.objects.filter(name__icontains=query)"),
(8, "api.py", "def get_user(request, user_id): user = User.objects.get(id=user_id)"),
(9, "api.py", "def create_user(request): data = request.json(); user = User(**data)"),
(10, "tests.py", "def test_authenticate(): assert authenticate('admin', 'pass') is not None"),
(11, "tests.py", "def test_search(): results = search_users('alice'); assert len(results) > 0"),
(12, "router.py", "app.route('/users', methods=['GET'])(list_users)"),
(13, "router.py", "app.route('/login', methods=['POST'])(login_handler)"),
(14, "db.py", "def get_connection(): return sqlite3.connect(DATABASE_URL)"),
(15, "cache.py", "def cache_get(key): return redis_client.get(key)"),
(16, "cache.py", "def cache_set(key, value, ttl=3600): redis_client.setex(key, ttl, value)"),
(17, "errors.py", "class AuthError(Exception): status_code = 401"),
(18, "errors.py", "class NotFoundError(Exception): status_code = 404"),
(19, "validators.py", "def validate_email(email): return '@' in email and '.' in email.split('@')[1]"),
]
DIM = 32 # Use small dim for fast tests
def make_stable_vec(doc_id: int, dim: int = DIM) -> np.ndarray:
"""Generate a deterministic float32 vector for a given doc_id."""
rng = np.random.default_rng(seed=doc_id)
vec = rng.standard_normal(dim).astype(np.float32)
vec /= np.linalg.norm(vec)
return vec
class MockEmbedder(BaseEmbedder):
"""Returns stable deterministic vectors based on content hash."""
def embed_single(self, text: str) -> np.ndarray:
seed = hash(text) % (2**31)
rng = np.random.default_rng(seed=seed)
vec = rng.standard_normal(DIM).astype(np.float32)
vec /= np.linalg.norm(vec)
return vec
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
return [self.embed_single(t) for t in texts]
def embed(self, texts: list[str]) -> list[np.ndarray]:
"""Called by SearchPipeline as self._embedder.embed([query])[0]."""
return self.embed_batch(texts)
class MockReranker(BaseReranker):
"""Returns score based on simple keyword overlap."""
def score_pairs(self, query: str, documents: list[str]) -> list[float]:
query_words = set(query.lower().split())
scores = []
for doc in documents:
doc_words = set(doc.lower().split())
overlap = len(query_words & doc_words)
scores.append(float(overlap) / max(len(query_words), 1))
return scores
@pytest.fixture
def config():
return Config.small() # hnsw_ef=50, hnsw_M=16, binary_top_k=50, ann_top_k=20, rerank_top_k=10
@pytest.fixture
def search_pipeline(tmp_path, config):
"""Build a full SearchPipeline with 20 test docs indexed."""
embedder = MockEmbedder()
binary_store = BinaryStore(tmp_path / "binary", dim=DIM, config=config)
ann_index = ANNIndex(tmp_path / "ann.hnsw", dim=DIM, config=config)
fts = FTSEngine(tmp_path / "fts.db")
reranker = MockReranker()
# Index all test docs
ids = np.array([d[0] for d in TEST_DOCS], dtype=np.int64)
vectors = np.array([embedder.embed_single(d[2]) for d in TEST_DOCS], dtype=np.float32)
binary_store.add(ids, vectors)
ann_index.add(ids, vectors)
fts.add_documents(TEST_DOCS)
return SearchPipeline(
embedder=embedder,
binary_store=binary_store,
ann_index=ann_index,
reranker=reranker,
fts=fts,
config=config,
)

View File

@@ -0,0 +1,44 @@
"""Integration tests for SearchPipeline using real components and mock embedder/reranker."""
from __future__ import annotations
def test_vector_search_returns_results(search_pipeline):
results = search_pipeline.search("authentication middleware")
assert len(results) > 0
assert all(isinstance(r.score, float) for r in results)
def test_exact_keyword_search(search_pipeline):
results = search_pipeline.search("authenticate")
assert len(results) > 0
result_ids = {r.id for r in results}
# Doc 0 and 10 both contain "authenticate"
assert result_ids & {0, 10}, f"Expected doc 0 or 10 in results, got {result_ids}"
def test_pipeline_top_k_limit(search_pipeline):
results = search_pipeline.search("user", top_k=5)
assert len(results) <= 5
def test_search_result_fields_populated(search_pipeline):
results = search_pipeline.search("password")
assert len(results) > 0
for r in results:
assert r.id >= 0
assert r.score >= 0
assert isinstance(r.path, str)
def test_empty_query_handled(search_pipeline):
results = search_pipeline.search("")
assert isinstance(results, list) # no exception
def test_different_queries_give_different_results(search_pipeline):
r1 = search_pipeline.search("authenticate user")
r2 = search_pipeline.search("cache redis")
# Results should differ (different top IDs or scores), unless both are empty
ids1 = [r.id for r in r1]
ids2 = [r.id for r in r2]
assert ids1 != ids2 or len(r1) == 0

View File

View File

@@ -0,0 +1,31 @@
from codexlens.config import Config
def test_config_instantiates_no_args():
cfg = Config()
assert cfg is not None
def test_defaults_hnsw_ef():
cfg = Config.defaults()
assert cfg.hnsw_ef == 150
def test_defaults_hnsw_M():
cfg = Config.defaults()
assert cfg.hnsw_M == 32
def test_small_hnsw_ef():
cfg = Config.small()
assert cfg.hnsw_ef == 50
def test_custom_instantiation():
cfg = Config(hnsw_ef=100)
assert cfg.hnsw_ef == 100
def test_fusion_weights_keys():
cfg = Config()
assert set(cfg.fusion_weights.keys()) == {"exact", "fuzzy", "vector", "graph"}

View File

@@ -0,0 +1,136 @@
"""Unit tests for BinaryStore and ANNIndex (no fastembed required)."""
from __future__ import annotations
import concurrent.futures
import tempfile
from pathlib import Path
import numpy as np
import pytest
from codexlens.config import Config
from codexlens.core import ANNIndex, BinaryStore
DIM = 32
RNG = np.random.default_rng(42)
def make_vectors(n: int, dim: int = DIM) -> np.ndarray:
return RNG.standard_normal((n, dim)).astype(np.float32)
def make_ids(n: int, start: int = 0) -> np.ndarray:
return np.arange(start, start + n, dtype=np.int64)
# ---------------------------------------------------------------------------
# BinaryStore tests
# ---------------------------------------------------------------------------
class TestBinaryStore:
def test_binary_store_add_and_search(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(10)
ids = make_ids(10)
store.add(ids, vecs)
assert len(store) == 10
top_k = 5
ret_ids, ret_dists = store.coarse_search(vecs[0], top_k=top_k)
assert ret_ids.shape == (top_k,)
assert ret_dists.shape == (top_k,)
# distances are non-negative integers
assert (ret_dists >= 0).all()
def test_binary_hamming_correctness(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(20)
ids = make_ids(20)
store.add(ids, vecs)
# Query with the exact stored vector; it must be the top-1 result
query = vecs[7]
ret_ids, ret_dists = store.coarse_search(query, top_k=1)
assert ret_ids[0] == 7
assert ret_dists[0] == 0 # Hamming distance to itself is 0
def test_binary_store_persist(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(15)
ids = make_ids(15)
store.add(ids, vecs)
store.save()
# Load into a fresh instance
store2 = BinaryStore(tmp_path, DIM, cfg)
assert len(store2) == 15
query = vecs[3]
ret_ids, ret_dists = store2.coarse_search(query, top_k=1)
assert ret_ids[0] == 3
assert ret_dists[0] == 0
# ---------------------------------------------------------------------------
# ANNIndex tests
# ---------------------------------------------------------------------------
class TestANNIndex:
def test_ann_index_add_and_search(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(50)
ids = make_ids(50)
idx.add(ids, vecs)
assert len(idx) == 50
ret_ids, ret_dists = idx.fine_search(vecs[0], top_k=5)
assert len(ret_ids) == 5
assert len(ret_dists) == 5
def test_ann_index_thread_safety(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(50)
ids = make_ids(50)
idx.add(ids, vecs)
query = vecs[0]
errors: list[Exception] = []
def search() -> None:
try:
idx.fine_search(query, top_k=3)
except Exception as exc:
errors.append(exc)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool:
futures = [pool.submit(search) for _ in range(5)]
concurrent.futures.wait(futures)
assert errors == [], f"Thread safety errors: {errors}"
def test_ann_index_save_load(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(30)
ids = make_ids(30)
idx.add(ids, vecs)
idx.save()
# Load into a fresh instance
idx2 = ANNIndex(tmp_path, DIM, cfg)
idx2.load()
assert len(idx2) == 30
ret_ids, ret_dists = idx2.fine_search(vecs[10], top_k=1)
assert len(ret_ids) == 1
assert ret_ids[0] == 10

View File

@@ -0,0 +1,80 @@
from __future__ import annotations
import sys
import types
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
def _make_fastembed_mock():
"""Build a minimal fastembed stub so imports succeed without the real package."""
fastembed_mod = types.ModuleType("fastembed")
fastembed_mod.TextEmbedding = MagicMock()
sys.modules.setdefault("fastembed", fastembed_mod)
return fastembed_mod
_make_fastembed_mock()
from codexlens.config import Config # noqa: E402
from codexlens.embed.base import BaseEmbedder # noqa: E402
from codexlens.embed.local import EMBED_PROFILES, FastEmbedEmbedder # noqa: E402
class TestEmbedSingle(unittest.TestCase):
def test_embed_single_returns_float32_ndarray(self):
config = Config()
embedder = FastEmbedEmbedder(config)
mock_model = MagicMock()
mock_model.embed.return_value = iter([np.ones(384, dtype=np.float64)])
# Inject mock model directly to bypass lazy load (no real fastembed needed)
embedder._model = mock_model
result = embedder.embed_single("hello world")
self.assertIsInstance(result, np.ndarray)
self.assertEqual(result.dtype, np.float32)
self.assertEqual(result.shape, (384,))
class TestEmbedBatch(unittest.TestCase):
def test_embed_batch_returns_list(self):
config = Config()
embedder = FastEmbedEmbedder(config)
vecs = [np.ones(384, dtype=np.float64) * i for i in range(3)]
mock_model = MagicMock()
mock_model.embed.return_value = iter(vecs)
embedder._model = mock_model
result = embedder.embed_batch(["a", "b", "c"])
self.assertIsInstance(result, list)
self.assertEqual(len(result), 3)
for arr in result:
self.assertIsInstance(arr, np.ndarray)
self.assertEqual(arr.dtype, np.float32)
class TestEmbedProfiles(unittest.TestCase):
def test_embed_profiles_all_have_valid_keys(self):
expected_keys = {"small", "base", "large", "code"}
self.assertEqual(set(EMBED_PROFILES.keys()), expected_keys)
def test_embed_profiles_model_ids_non_empty(self):
for key, model_id in EMBED_PROFILES.items():
self.assertIsInstance(model_id, str, msg=f"{key} model id should be str")
self.assertTrue(len(model_id) > 0, msg=f"{key} model id should be non-empty")
class TestBaseEmbedderAbstract(unittest.TestCase):
def test_base_embedder_is_abstract(self):
with self.assertRaises(TypeError):
BaseEmbedder() # type: ignore[abstract]
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,179 @@
from __future__ import annotations
import types
from unittest.mock import MagicMock, patch
import pytest
from codexlens.config import Config
from codexlens.rerank.base import BaseReranker
from codexlens.rerank.local import FastEmbedReranker
from codexlens.rerank.api import APIReranker
# ---------------------------------------------------------------------------
# BaseReranker
# ---------------------------------------------------------------------------
def test_base_reranker_is_abstract():
with pytest.raises(TypeError):
BaseReranker() # type: ignore[abstract]
# ---------------------------------------------------------------------------
# FastEmbedReranker
# ---------------------------------------------------------------------------
def _make_rerank_result(index: int, score: float) -> object:
obj = types.SimpleNamespace(index=index, score=score)
return obj
def test_local_reranker_score_pairs_length():
config = Config()
reranker = FastEmbedReranker(config)
mock_results = [
_make_rerank_result(0, 0.9),
_make_rerank_result(1, 0.5),
_make_rerank_result(2, 0.1),
]
mock_model = MagicMock()
mock_model.rerank.return_value = iter(mock_results)
reranker._model = mock_model
docs = ["doc0", "doc1", "doc2"]
scores = reranker.score_pairs("query", docs)
assert len(scores) == 3
def test_local_reranker_preserves_order():
config = Config()
reranker = FastEmbedReranker(config)
# rerank returns results in reverse order (index 2, 1, 0)
mock_results = [
_make_rerank_result(2, 0.1),
_make_rerank_result(1, 0.5),
_make_rerank_result(0, 0.9),
]
mock_model = MagicMock()
mock_model.rerank.return_value = iter(mock_results)
reranker._model = mock_model
docs = ["doc0", "doc1", "doc2"]
scores = reranker.score_pairs("query", docs)
assert scores[0] == pytest.approx(0.9)
assert scores[1] == pytest.approx(0.5)
assert scores[2] == pytest.approx(0.1)
# ---------------------------------------------------------------------------
# APIReranker
# ---------------------------------------------------------------------------
def _make_config(max_tokens_per_batch: int = 512) -> Config:
return Config(
reranker_api_url="https://api.example.com",
reranker_api_key="test-key",
reranker_api_model="test-model",
reranker_api_max_tokens_per_batch=max_tokens_per_batch,
)
def test_api_reranker_batch_splitting():
config = _make_config(max_tokens_per_batch=512)
with patch("httpx.Client"):
reranker = APIReranker(config)
# 10 docs, each ~200 tokens (800 chars)
docs = ["x" * 800] * 10
batches = reranker._split_batches(docs, max_tokens=512)
# Each doc is 200 tokens; batches should have at most 2 docs (200+200=400 <= 512, 400+200=600 > 512)
assert len(batches) > 1
for batch in batches:
total = sum(len(text) // 4 for _, text in batch)
assert total <= 512 or len(batch) == 1
def test_api_reranker_retry_on_429():
config = _make_config()
mock_429 = MagicMock()
mock_429.status_code = 429
mock_200 = MagicMock()
mock_200.status_code = 200
mock_200.json.return_value = {
"results": [
{"index": 0, "relevance_score": 0.8},
{"index": 1, "relevance_score": 0.3},
]
}
mock_200.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [mock_429, mock_429, mock_200]
reranker = APIReranker(config)
with patch("time.sleep"):
result = reranker._call_api_with_retry(
"query",
[(0, "doc0"), (1, "doc1")],
max_retries=3,
)
assert mock_client.post.call_count == 3
assert 0 in result
assert 1 in result
def test_api_reranker_merge_batches():
config = _make_config(max_tokens_per_batch=100)
# 4 docs of 25 tokens each (100 chars); each batch holds at most 4 docs
# Use smaller docs to force 2 batches: 2 docs per batch (50 tokens each = 200 chars)
docs = ["x" * 200] * 4 # 50 tokens each; 50+50=100 <= 100, 100+50=150 > 100 -> 2 per batch
batch0_response = MagicMock()
batch0_response.status_code = 200
batch0_response.json.return_value = {
"results": [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.8},
]
}
batch0_response.raise_for_status = MagicMock()
batch1_response = MagicMock()
batch1_response.status_code = 200
batch1_response.json.return_value = {
"results": [
{"index": 0, "relevance_score": 0.7},
{"index": 1, "relevance_score": 0.6},
]
}
batch1_response.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [batch0_response, batch1_response]
reranker = APIReranker(config)
with patch("time.sleep"):
scores = reranker.score_pairs("query", docs)
assert len(scores) == 4
# All original indices should have scores
assert all(s > 0 for s in scores)

View File

@@ -0,0 +1,156 @@
"""Unit tests for search layer: FTSEngine, fusion, and SearchPipeline."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from codexlens.search.fts import FTSEngine
from codexlens.search.fusion import (
DEFAULT_WEIGHTS,
QueryIntent,
detect_query_intent,
get_adaptive_weights,
reciprocal_rank_fusion,
)
from codexlens.search.pipeline import SearchPipeline, SearchResult
from codexlens.config import Config
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_fts(docs: list[tuple[int, str, str]] | None = None) -> FTSEngine:
"""Create an in-memory FTSEngine and optionally add documents."""
engine = FTSEngine(":memory:")
if docs:
engine.add_documents(docs)
return engine
# ---------------------------------------------------------------------------
# FTSEngine tests
# ---------------------------------------------------------------------------
def test_fts_add_and_exact_search():
docs = [
(1, "a.py", "def authenticate user password login"),
(2, "b.py", "connect to database with credentials"),
(3, "c.py", "render template html response"),
]
engine = make_fts(docs)
results = engine.exact_search("authenticate", top_k=10)
ids = [r[0] for r in results]
assert 1 in ids, "doc 1 should match 'authenticate'"
assert 2 not in ids or results[0][0] == 1 # doc 1 must rank higher
def test_fts_fuzzy_search_prefix():
docs = [
(10, "auth.py", "authentication token refresh"),
(11, "db.py", "database connection pool"),
(12, "ui.py", "render button click handler"),
]
engine = make_fts(docs)
# Prefix 'auth' should match 'authentication' in doc 10
results = engine.fuzzy_search("auth", top_k=10)
ids = [r[0] for r in results]
assert 10 in ids, "prefix 'auth' should match doc 10 with 'authentication'"
# ---------------------------------------------------------------------------
# RRF fusion tests
# ---------------------------------------------------------------------------
def test_rrf_fusion_ordering():
"""When two sources agree on top-1, it should rank first in fused result."""
source_a = [(1, 0.9), (2, 0.5), (3, 0.2)]
source_b = [(1, 0.8), (3, 0.6), (2, 0.1)]
fused = reciprocal_rank_fusion({"a": source_a, "b": source_b})
assert fused[0][0] == 1, "doc 1 agreed top by both sources must rank first"
def test_rrf_equal_weight_default():
"""Calling with None weights should use DEFAULT_WEIGHTS shape (not crash)."""
source_exact = [(5, 1.0), (6, 0.8)]
source_vector = [(6, 0.9), (5, 0.7)]
# Should not raise and should return results
fused = reciprocal_rank_fusion(
{"exact": source_exact, "vector": source_vector},
weights=None,
)
assert len(fused) == 2
ids = [r[0] for r in fused]
assert 5 in ids and 6 in ids
# ---------------------------------------------------------------------------
# detect_query_intent tests
# ---------------------------------------------------------------------------
def test_detect_intent_code_symbol():
assert detect_query_intent("def authenticate()") == QueryIntent.CODE_SYMBOL
def test_detect_intent_natural():
assert detect_query_intent("how do I authenticate users") == QueryIntent.NATURAL_LANGUAGE
# ---------------------------------------------------------------------------
# SearchPipeline tests
# ---------------------------------------------------------------------------
def _make_pipeline(fts: FTSEngine, top_k: int = 5) -> SearchPipeline:
"""Build a SearchPipeline with mocked heavy components."""
cfg = Config.small()
cfg.reranker_top_k = top_k
embedder = MagicMock()
embedder.embed.return_value = [[0.1] * cfg.embed_dim]
binary_store = MagicMock()
binary_store.coarse_search.return_value = ([1, 2, 3], None)
ann_index = MagicMock()
ann_index.fine_search.return_value = ([1, 2, 3], [0.9, 0.8, 0.7])
reranker = MagicMock()
# Return a score for each content string passed
reranker.score_pairs.side_effect = lambda q, contents: [0.9 - i * 0.1 for i in range(len(contents))]
return SearchPipeline(
embedder=embedder,
binary_store=binary_store,
ann_index=ann_index,
reranker=reranker,
fts=fts,
config=cfg,
)
def test_pipeline_search_returns_results():
docs = [
(1, "a.py", "test content alpha"),
(2, "b.py", "test content beta"),
(3, "c.py", "test content gamma"),
]
fts = make_fts(docs)
pipeline = _make_pipeline(fts)
results = pipeline.search("test")
assert len(results) > 0
assert all(isinstance(r, SearchResult) for r in results)
def test_pipeline_top_k_limit():
docs = [
(1, "a.py", "hello world one"),
(2, "b.py", "hello world two"),
(3, "c.py", "hello world three"),
(4, "d.py", "hello world four"),
(5, "e.py", "hello world five"),
]
fts = make_fts(docs)
pipeline = _make_pipeline(fts, top_k=2)
results = pipeline.search("hello", top_k=2)
assert len(results) <= 2, "pipeline must respect top_k limit"