Implement search and reranking functionality with FTS and embedding support

- Added BaseReranker abstract class for defining reranking interfaces.
- Implemented FastEmbedReranker using fastembed's TextCrossEncoder for scoring document-query pairs.
- Introduced FTSEngine for full-text search capabilities using SQLite FTS5.
- Developed SearchPipeline to integrate embedding, binary search, ANN indexing, FTS, and reranking.
- Added fusion methods for combining results from different search strategies using Reciprocal Rank Fusion.
- Created unit and integration tests for the new search and reranking components.
- Established configuration management for search parameters and models.
This commit is contained in:
catlog22
2026-03-16 23:03:17 +08:00
parent 5a4b18d9b1
commit de4158597b
41 changed files with 2655 additions and 1848 deletions

View File

@@ -0,0 +1,108 @@
import pytest
import numpy as np
import tempfile
from pathlib import Path
from codexlens.config import Config
from codexlens.core import ANNIndex, BinaryStore
from codexlens.embed.base import BaseEmbedder
from codexlens.rerank.base import BaseReranker
from codexlens.search.fts import FTSEngine
from codexlens.search.pipeline import SearchPipeline
# Test documents: 20 code snippets with id, path, content
TEST_DOCS = [
(0, "auth.py", "def authenticate(user, password): return check_hash(password, user.hash)"),
(1, "auth.py", "def authorize(user, permission): return permission in user.roles"),
(2, "models.py", "class User: def __init__(self, name, email): self.name = name; self.email = email"),
(3, "models.py", "class Session: token = None; expires_at = None"),
(4, "middleware.py", "def auth_middleware(request): token = request.headers.get('Authorization')"),
(5, "utils.py", "def hash_password(password): import bcrypt; return bcrypt.hashpw(password)"),
(6, "config.py", "DATABASE_URL = os.environ.get('DATABASE_URL', 'sqlite:///db.sqlite3')"),
(7, "search.py", "def search_users(query): return User.objects.filter(name__icontains=query)"),
(8, "api.py", "def get_user(request, user_id): user = User.objects.get(id=user_id)"),
(9, "api.py", "def create_user(request): data = request.json(); user = User(**data)"),
(10, "tests.py", "def test_authenticate(): assert authenticate('admin', 'pass') is not None"),
(11, "tests.py", "def test_search(): results = search_users('alice'); assert len(results) > 0"),
(12, "router.py", "app.route('/users', methods=['GET'])(list_users)"),
(13, "router.py", "app.route('/login', methods=['POST'])(login_handler)"),
(14, "db.py", "def get_connection(): return sqlite3.connect(DATABASE_URL)"),
(15, "cache.py", "def cache_get(key): return redis_client.get(key)"),
(16, "cache.py", "def cache_set(key, value, ttl=3600): redis_client.setex(key, ttl, value)"),
(17, "errors.py", "class AuthError(Exception): status_code = 401"),
(18, "errors.py", "class NotFoundError(Exception): status_code = 404"),
(19, "validators.py", "def validate_email(email): return '@' in email and '.' in email.split('@')[1]"),
]
DIM = 32 # Use small dim for fast tests
def make_stable_vec(doc_id: int, dim: int = DIM) -> np.ndarray:
"""Generate a deterministic float32 vector for a given doc_id."""
rng = np.random.default_rng(seed=doc_id)
vec = rng.standard_normal(dim).astype(np.float32)
vec /= np.linalg.norm(vec)
return vec
class MockEmbedder(BaseEmbedder):
"""Returns stable deterministic vectors based on content hash."""
def embed_single(self, text: str) -> np.ndarray:
seed = hash(text) % (2**31)
rng = np.random.default_rng(seed=seed)
vec = rng.standard_normal(DIM).astype(np.float32)
vec /= np.linalg.norm(vec)
return vec
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
return [self.embed_single(t) for t in texts]
def embed(self, texts: list[str]) -> list[np.ndarray]:
"""Called by SearchPipeline as self._embedder.embed([query])[0]."""
return self.embed_batch(texts)
class MockReranker(BaseReranker):
"""Returns score based on simple keyword overlap."""
def score_pairs(self, query: str, documents: list[str]) -> list[float]:
query_words = set(query.lower().split())
scores = []
for doc in documents:
doc_words = set(doc.lower().split())
overlap = len(query_words & doc_words)
scores.append(float(overlap) / max(len(query_words), 1))
return scores
@pytest.fixture
def config():
return Config.small() # hnsw_ef=50, hnsw_M=16, binary_top_k=50, ann_top_k=20, rerank_top_k=10
@pytest.fixture
def search_pipeline(tmp_path, config):
"""Build a full SearchPipeline with 20 test docs indexed."""
embedder = MockEmbedder()
binary_store = BinaryStore(tmp_path / "binary", dim=DIM, config=config)
ann_index = ANNIndex(tmp_path / "ann.hnsw", dim=DIM, config=config)
fts = FTSEngine(tmp_path / "fts.db")
reranker = MockReranker()
# Index all test docs
ids = np.array([d[0] for d in TEST_DOCS], dtype=np.int64)
vectors = np.array([embedder.embed_single(d[2]) for d in TEST_DOCS], dtype=np.float32)
binary_store.add(ids, vectors)
ann_index.add(ids, vectors)
fts.add_documents(TEST_DOCS)
return SearchPipeline(
embedder=embedder,
binary_store=binary_store,
ann_index=ann_index,
reranker=reranker,
fts=fts,
config=config,
)

View File

@@ -0,0 +1,44 @@
"""Integration tests for SearchPipeline using real components and mock embedder/reranker."""
from __future__ import annotations
def test_vector_search_returns_results(search_pipeline):
results = search_pipeline.search("authentication middleware")
assert len(results) > 0
assert all(isinstance(r.score, float) for r in results)
def test_exact_keyword_search(search_pipeline):
results = search_pipeline.search("authenticate")
assert len(results) > 0
result_ids = {r.id for r in results}
# Doc 0 and 10 both contain "authenticate"
assert result_ids & {0, 10}, f"Expected doc 0 or 10 in results, got {result_ids}"
def test_pipeline_top_k_limit(search_pipeline):
results = search_pipeline.search("user", top_k=5)
assert len(results) <= 5
def test_search_result_fields_populated(search_pipeline):
results = search_pipeline.search("password")
assert len(results) > 0
for r in results:
assert r.id >= 0
assert r.score >= 0
assert isinstance(r.path, str)
def test_empty_query_handled(search_pipeline):
results = search_pipeline.search("")
assert isinstance(results, list) # no exception
def test_different_queries_give_different_results(search_pipeline):
r1 = search_pipeline.search("authenticate user")
r2 = search_pipeline.search("cache redis")
# Results should differ (different top IDs or scores), unless both are empty
ids1 = [r.id for r in r1]
ids2 = [r.id for r in r2]
assert ids1 != ids2 or len(r1) == 0