Implement search and reranking functionality with FTS and embedding support

- Added BaseReranker abstract class for defining reranking interfaces.
- Implemented FastEmbedReranker using fastembed's TextCrossEncoder for scoring document-query pairs.
- Introduced FTSEngine for full-text search capabilities using SQLite FTS5.
- Developed SearchPipeline to integrate embedding, binary search, ANN indexing, FTS, and reranking.
- Added fusion methods for combining results from different search strategies using Reciprocal Rank Fusion.
- Created unit and integration tests for the new search and reranking components.
- Established configuration management for search parameters and models.
This commit is contained in:
catlog22
2026-03-16 23:03:17 +08:00
parent 5a4b18d9b1
commit de4158597b
41 changed files with 2655 additions and 1848 deletions

View File

View File

@@ -0,0 +1,31 @@
from codexlens.config import Config
def test_config_instantiates_no_args():
cfg = Config()
assert cfg is not None
def test_defaults_hnsw_ef():
cfg = Config.defaults()
assert cfg.hnsw_ef == 150
def test_defaults_hnsw_M():
cfg = Config.defaults()
assert cfg.hnsw_M == 32
def test_small_hnsw_ef():
cfg = Config.small()
assert cfg.hnsw_ef == 50
def test_custom_instantiation():
cfg = Config(hnsw_ef=100)
assert cfg.hnsw_ef == 100
def test_fusion_weights_keys():
cfg = Config()
assert set(cfg.fusion_weights.keys()) == {"exact", "fuzzy", "vector", "graph"}

View File

@@ -0,0 +1,136 @@
"""Unit tests for BinaryStore and ANNIndex (no fastembed required)."""
from __future__ import annotations
import concurrent.futures
import tempfile
from pathlib import Path
import numpy as np
import pytest
from codexlens.config import Config
from codexlens.core import ANNIndex, BinaryStore
DIM = 32
RNG = np.random.default_rng(42)
def make_vectors(n: int, dim: int = DIM) -> np.ndarray:
return RNG.standard_normal((n, dim)).astype(np.float32)
def make_ids(n: int, start: int = 0) -> np.ndarray:
return np.arange(start, start + n, dtype=np.int64)
# ---------------------------------------------------------------------------
# BinaryStore tests
# ---------------------------------------------------------------------------
class TestBinaryStore:
def test_binary_store_add_and_search(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(10)
ids = make_ids(10)
store.add(ids, vecs)
assert len(store) == 10
top_k = 5
ret_ids, ret_dists = store.coarse_search(vecs[0], top_k=top_k)
assert ret_ids.shape == (top_k,)
assert ret_dists.shape == (top_k,)
# distances are non-negative integers
assert (ret_dists >= 0).all()
def test_binary_hamming_correctness(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(20)
ids = make_ids(20)
store.add(ids, vecs)
# Query with the exact stored vector; it must be the top-1 result
query = vecs[7]
ret_ids, ret_dists = store.coarse_search(query, top_k=1)
assert ret_ids[0] == 7
assert ret_dists[0] == 0 # Hamming distance to itself is 0
def test_binary_store_persist(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(15)
ids = make_ids(15)
store.add(ids, vecs)
store.save()
# Load into a fresh instance
store2 = BinaryStore(tmp_path, DIM, cfg)
assert len(store2) == 15
query = vecs[3]
ret_ids, ret_dists = store2.coarse_search(query, top_k=1)
assert ret_ids[0] == 3
assert ret_dists[0] == 0
# ---------------------------------------------------------------------------
# ANNIndex tests
# ---------------------------------------------------------------------------
class TestANNIndex:
def test_ann_index_add_and_search(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(50)
ids = make_ids(50)
idx.add(ids, vecs)
assert len(idx) == 50
ret_ids, ret_dists = idx.fine_search(vecs[0], top_k=5)
assert len(ret_ids) == 5
assert len(ret_dists) == 5
def test_ann_index_thread_safety(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(50)
ids = make_ids(50)
idx.add(ids, vecs)
query = vecs[0]
errors: list[Exception] = []
def search() -> None:
try:
idx.fine_search(query, top_k=3)
except Exception as exc:
errors.append(exc)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool:
futures = [pool.submit(search) for _ in range(5)]
concurrent.futures.wait(futures)
assert errors == [], f"Thread safety errors: {errors}"
def test_ann_index_save_load(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(30)
ids = make_ids(30)
idx.add(ids, vecs)
idx.save()
# Load into a fresh instance
idx2 = ANNIndex(tmp_path, DIM, cfg)
idx2.load()
assert len(idx2) == 30
ret_ids, ret_dists = idx2.fine_search(vecs[10], top_k=1)
assert len(ret_ids) == 1
assert ret_ids[0] == 10

View File

@@ -0,0 +1,80 @@
from __future__ import annotations
import sys
import types
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
def _make_fastembed_mock():
"""Build a minimal fastembed stub so imports succeed without the real package."""
fastembed_mod = types.ModuleType("fastembed")
fastembed_mod.TextEmbedding = MagicMock()
sys.modules.setdefault("fastembed", fastembed_mod)
return fastembed_mod
_make_fastembed_mock()
from codexlens.config import Config # noqa: E402
from codexlens.embed.base import BaseEmbedder # noqa: E402
from codexlens.embed.local import EMBED_PROFILES, FastEmbedEmbedder # noqa: E402
class TestEmbedSingle(unittest.TestCase):
def test_embed_single_returns_float32_ndarray(self):
config = Config()
embedder = FastEmbedEmbedder(config)
mock_model = MagicMock()
mock_model.embed.return_value = iter([np.ones(384, dtype=np.float64)])
# Inject mock model directly to bypass lazy load (no real fastembed needed)
embedder._model = mock_model
result = embedder.embed_single("hello world")
self.assertIsInstance(result, np.ndarray)
self.assertEqual(result.dtype, np.float32)
self.assertEqual(result.shape, (384,))
class TestEmbedBatch(unittest.TestCase):
def test_embed_batch_returns_list(self):
config = Config()
embedder = FastEmbedEmbedder(config)
vecs = [np.ones(384, dtype=np.float64) * i for i in range(3)]
mock_model = MagicMock()
mock_model.embed.return_value = iter(vecs)
embedder._model = mock_model
result = embedder.embed_batch(["a", "b", "c"])
self.assertIsInstance(result, list)
self.assertEqual(len(result), 3)
for arr in result:
self.assertIsInstance(arr, np.ndarray)
self.assertEqual(arr.dtype, np.float32)
class TestEmbedProfiles(unittest.TestCase):
def test_embed_profiles_all_have_valid_keys(self):
expected_keys = {"small", "base", "large", "code"}
self.assertEqual(set(EMBED_PROFILES.keys()), expected_keys)
def test_embed_profiles_model_ids_non_empty(self):
for key, model_id in EMBED_PROFILES.items():
self.assertIsInstance(model_id, str, msg=f"{key} model id should be str")
self.assertTrue(len(model_id) > 0, msg=f"{key} model id should be non-empty")
class TestBaseEmbedderAbstract(unittest.TestCase):
def test_base_embedder_is_abstract(self):
with self.assertRaises(TypeError):
BaseEmbedder() # type: ignore[abstract]
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,179 @@
from __future__ import annotations
import types
from unittest.mock import MagicMock, patch
import pytest
from codexlens.config import Config
from codexlens.rerank.base import BaseReranker
from codexlens.rerank.local import FastEmbedReranker
from codexlens.rerank.api import APIReranker
# ---------------------------------------------------------------------------
# BaseReranker
# ---------------------------------------------------------------------------
def test_base_reranker_is_abstract():
with pytest.raises(TypeError):
BaseReranker() # type: ignore[abstract]
# ---------------------------------------------------------------------------
# FastEmbedReranker
# ---------------------------------------------------------------------------
def _make_rerank_result(index: int, score: float) -> object:
obj = types.SimpleNamespace(index=index, score=score)
return obj
def test_local_reranker_score_pairs_length():
config = Config()
reranker = FastEmbedReranker(config)
mock_results = [
_make_rerank_result(0, 0.9),
_make_rerank_result(1, 0.5),
_make_rerank_result(2, 0.1),
]
mock_model = MagicMock()
mock_model.rerank.return_value = iter(mock_results)
reranker._model = mock_model
docs = ["doc0", "doc1", "doc2"]
scores = reranker.score_pairs("query", docs)
assert len(scores) == 3
def test_local_reranker_preserves_order():
config = Config()
reranker = FastEmbedReranker(config)
# rerank returns results in reverse order (index 2, 1, 0)
mock_results = [
_make_rerank_result(2, 0.1),
_make_rerank_result(1, 0.5),
_make_rerank_result(0, 0.9),
]
mock_model = MagicMock()
mock_model.rerank.return_value = iter(mock_results)
reranker._model = mock_model
docs = ["doc0", "doc1", "doc2"]
scores = reranker.score_pairs("query", docs)
assert scores[0] == pytest.approx(0.9)
assert scores[1] == pytest.approx(0.5)
assert scores[2] == pytest.approx(0.1)
# ---------------------------------------------------------------------------
# APIReranker
# ---------------------------------------------------------------------------
def _make_config(max_tokens_per_batch: int = 512) -> Config:
return Config(
reranker_api_url="https://api.example.com",
reranker_api_key="test-key",
reranker_api_model="test-model",
reranker_api_max_tokens_per_batch=max_tokens_per_batch,
)
def test_api_reranker_batch_splitting():
config = _make_config(max_tokens_per_batch=512)
with patch("httpx.Client"):
reranker = APIReranker(config)
# 10 docs, each ~200 tokens (800 chars)
docs = ["x" * 800] * 10
batches = reranker._split_batches(docs, max_tokens=512)
# Each doc is 200 tokens; batches should have at most 2 docs (200+200=400 <= 512, 400+200=600 > 512)
assert len(batches) > 1
for batch in batches:
total = sum(len(text) // 4 for _, text in batch)
assert total <= 512 or len(batch) == 1
def test_api_reranker_retry_on_429():
config = _make_config()
mock_429 = MagicMock()
mock_429.status_code = 429
mock_200 = MagicMock()
mock_200.status_code = 200
mock_200.json.return_value = {
"results": [
{"index": 0, "relevance_score": 0.8},
{"index": 1, "relevance_score": 0.3},
]
}
mock_200.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [mock_429, mock_429, mock_200]
reranker = APIReranker(config)
with patch("time.sleep"):
result = reranker._call_api_with_retry(
"query",
[(0, "doc0"), (1, "doc1")],
max_retries=3,
)
assert mock_client.post.call_count == 3
assert 0 in result
assert 1 in result
def test_api_reranker_merge_batches():
config = _make_config(max_tokens_per_batch=100)
# 4 docs of 25 tokens each (100 chars); each batch holds at most 4 docs
# Use smaller docs to force 2 batches: 2 docs per batch (50 tokens each = 200 chars)
docs = ["x" * 200] * 4 # 50 tokens each; 50+50=100 <= 100, 100+50=150 > 100 -> 2 per batch
batch0_response = MagicMock()
batch0_response.status_code = 200
batch0_response.json.return_value = {
"results": [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.8},
]
}
batch0_response.raise_for_status = MagicMock()
batch1_response = MagicMock()
batch1_response.status_code = 200
batch1_response.json.return_value = {
"results": [
{"index": 0, "relevance_score": 0.7},
{"index": 1, "relevance_score": 0.6},
]
}
batch1_response.raise_for_status = MagicMock()
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [batch0_response, batch1_response]
reranker = APIReranker(config)
with patch("time.sleep"):
scores = reranker.score_pairs("query", docs)
assert len(scores) == 4
# All original indices should have scores
assert all(s > 0 for s in scores)

View File

@@ -0,0 +1,156 @@
"""Unit tests for search layer: FTSEngine, fusion, and SearchPipeline."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from codexlens.search.fts import FTSEngine
from codexlens.search.fusion import (
DEFAULT_WEIGHTS,
QueryIntent,
detect_query_intent,
get_adaptive_weights,
reciprocal_rank_fusion,
)
from codexlens.search.pipeline import SearchPipeline, SearchResult
from codexlens.config import Config
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_fts(docs: list[tuple[int, str, str]] | None = None) -> FTSEngine:
"""Create an in-memory FTSEngine and optionally add documents."""
engine = FTSEngine(":memory:")
if docs:
engine.add_documents(docs)
return engine
# ---------------------------------------------------------------------------
# FTSEngine tests
# ---------------------------------------------------------------------------
def test_fts_add_and_exact_search():
docs = [
(1, "a.py", "def authenticate user password login"),
(2, "b.py", "connect to database with credentials"),
(3, "c.py", "render template html response"),
]
engine = make_fts(docs)
results = engine.exact_search("authenticate", top_k=10)
ids = [r[0] for r in results]
assert 1 in ids, "doc 1 should match 'authenticate'"
assert 2 not in ids or results[0][0] == 1 # doc 1 must rank higher
def test_fts_fuzzy_search_prefix():
docs = [
(10, "auth.py", "authentication token refresh"),
(11, "db.py", "database connection pool"),
(12, "ui.py", "render button click handler"),
]
engine = make_fts(docs)
# Prefix 'auth' should match 'authentication' in doc 10
results = engine.fuzzy_search("auth", top_k=10)
ids = [r[0] for r in results]
assert 10 in ids, "prefix 'auth' should match doc 10 with 'authentication'"
# ---------------------------------------------------------------------------
# RRF fusion tests
# ---------------------------------------------------------------------------
def test_rrf_fusion_ordering():
"""When two sources agree on top-1, it should rank first in fused result."""
source_a = [(1, 0.9), (2, 0.5), (3, 0.2)]
source_b = [(1, 0.8), (3, 0.6), (2, 0.1)]
fused = reciprocal_rank_fusion({"a": source_a, "b": source_b})
assert fused[0][0] == 1, "doc 1 agreed top by both sources must rank first"
def test_rrf_equal_weight_default():
"""Calling with None weights should use DEFAULT_WEIGHTS shape (not crash)."""
source_exact = [(5, 1.0), (6, 0.8)]
source_vector = [(6, 0.9), (5, 0.7)]
# Should not raise and should return results
fused = reciprocal_rank_fusion(
{"exact": source_exact, "vector": source_vector},
weights=None,
)
assert len(fused) == 2
ids = [r[0] for r in fused]
assert 5 in ids and 6 in ids
# ---------------------------------------------------------------------------
# detect_query_intent tests
# ---------------------------------------------------------------------------
def test_detect_intent_code_symbol():
assert detect_query_intent("def authenticate()") == QueryIntent.CODE_SYMBOL
def test_detect_intent_natural():
assert detect_query_intent("how do I authenticate users") == QueryIntent.NATURAL_LANGUAGE
# ---------------------------------------------------------------------------
# SearchPipeline tests
# ---------------------------------------------------------------------------
def _make_pipeline(fts: FTSEngine, top_k: int = 5) -> SearchPipeline:
"""Build a SearchPipeline with mocked heavy components."""
cfg = Config.small()
cfg.reranker_top_k = top_k
embedder = MagicMock()
embedder.embed.return_value = [[0.1] * cfg.embed_dim]
binary_store = MagicMock()
binary_store.coarse_search.return_value = ([1, 2, 3], None)
ann_index = MagicMock()
ann_index.fine_search.return_value = ([1, 2, 3], [0.9, 0.8, 0.7])
reranker = MagicMock()
# Return a score for each content string passed
reranker.score_pairs.side_effect = lambda q, contents: [0.9 - i * 0.1 for i in range(len(contents))]
return SearchPipeline(
embedder=embedder,
binary_store=binary_store,
ann_index=ann_index,
reranker=reranker,
fts=fts,
config=cfg,
)
def test_pipeline_search_returns_results():
docs = [
(1, "a.py", "test content alpha"),
(2, "b.py", "test content beta"),
(3, "c.py", "test content gamma"),
]
fts = make_fts(docs)
pipeline = _make_pipeline(fts)
results = pipeline.search("test")
assert len(results) > 0
assert all(isinstance(r, SearchResult) for r in results)
def test_pipeline_top_k_limit():
docs = [
(1, "a.py", "hello world one"),
(2, "b.py", "hello world two"),
(3, "c.py", "hello world three"),
(4, "d.py", "hello world four"),
(5, "e.py", "hello world five"),
]
fts = make_fts(docs)
pipeline = _make_pipeline(fts, top_k=2)
results = pipeline.search("hello", top_k=2)
assert len(results) <= 2, "pipeline must respect top_k limit"