mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-03-19 18:58:47 +08:00
Implement search and reranking functionality with FTS and embedding support
- Added BaseReranker abstract class for defining reranking interfaces. - Implemented FastEmbedReranker using fastembed's TextCrossEncoder for scoring document-query pairs. - Introduced FTSEngine for full-text search capabilities using SQLite FTS5. - Developed SearchPipeline to integrate embedding, binary search, ANN indexing, FTS, and reranking. - Added fusion methods for combining results from different search strategies using Reciprocal Rank Fusion. - Created unit and integration tests for the new search and reranking components. - Established configuration management for search parameters and models.
This commit is contained in:
108
codex-lens-v2/tests/integration/conftest.py
Normal file
108
codex-lens-v2/tests/integration/conftest.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.core import ANNIndex, BinaryStore
|
||||
from codexlens.embed.base import BaseEmbedder
|
||||
from codexlens.rerank.base import BaseReranker
|
||||
from codexlens.search.fts import FTSEngine
|
||||
from codexlens.search.pipeline import SearchPipeline
|
||||
|
||||
# Test documents: 20 code snippets with id, path, content
|
||||
TEST_DOCS = [
|
||||
(0, "auth.py", "def authenticate(user, password): return check_hash(password, user.hash)"),
|
||||
(1, "auth.py", "def authorize(user, permission): return permission in user.roles"),
|
||||
(2, "models.py", "class User: def __init__(self, name, email): self.name = name; self.email = email"),
|
||||
(3, "models.py", "class Session: token = None; expires_at = None"),
|
||||
(4, "middleware.py", "def auth_middleware(request): token = request.headers.get('Authorization')"),
|
||||
(5, "utils.py", "def hash_password(password): import bcrypt; return bcrypt.hashpw(password)"),
|
||||
(6, "config.py", "DATABASE_URL = os.environ.get('DATABASE_URL', 'sqlite:///db.sqlite3')"),
|
||||
(7, "search.py", "def search_users(query): return User.objects.filter(name__icontains=query)"),
|
||||
(8, "api.py", "def get_user(request, user_id): user = User.objects.get(id=user_id)"),
|
||||
(9, "api.py", "def create_user(request): data = request.json(); user = User(**data)"),
|
||||
(10, "tests.py", "def test_authenticate(): assert authenticate('admin', 'pass') is not None"),
|
||||
(11, "tests.py", "def test_search(): results = search_users('alice'); assert len(results) > 0"),
|
||||
(12, "router.py", "app.route('/users', methods=['GET'])(list_users)"),
|
||||
(13, "router.py", "app.route('/login', methods=['POST'])(login_handler)"),
|
||||
(14, "db.py", "def get_connection(): return sqlite3.connect(DATABASE_URL)"),
|
||||
(15, "cache.py", "def cache_get(key): return redis_client.get(key)"),
|
||||
(16, "cache.py", "def cache_set(key, value, ttl=3600): redis_client.setex(key, ttl, value)"),
|
||||
(17, "errors.py", "class AuthError(Exception): status_code = 401"),
|
||||
(18, "errors.py", "class NotFoundError(Exception): status_code = 404"),
|
||||
(19, "validators.py", "def validate_email(email): return '@' in email and '.' in email.split('@')[1]"),
|
||||
]
|
||||
|
||||
DIM = 32 # Use small dim for fast tests
|
||||
|
||||
|
||||
def make_stable_vec(doc_id: int, dim: int = DIM) -> np.ndarray:
|
||||
"""Generate a deterministic float32 vector for a given doc_id."""
|
||||
rng = np.random.default_rng(seed=doc_id)
|
||||
vec = rng.standard_normal(dim).astype(np.float32)
|
||||
vec /= np.linalg.norm(vec)
|
||||
return vec
|
||||
|
||||
|
||||
class MockEmbedder(BaseEmbedder):
|
||||
"""Returns stable deterministic vectors based on content hash."""
|
||||
|
||||
def embed_single(self, text: str) -> np.ndarray:
|
||||
seed = hash(text) % (2**31)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
vec = rng.standard_normal(DIM).astype(np.float32)
|
||||
vec /= np.linalg.norm(vec)
|
||||
return vec
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
return [self.embed_single(t) for t in texts]
|
||||
|
||||
def embed(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""Called by SearchPipeline as self._embedder.embed([query])[0]."""
|
||||
return self.embed_batch(texts)
|
||||
|
||||
|
||||
class MockReranker(BaseReranker):
|
||||
"""Returns score based on simple keyword overlap."""
|
||||
|
||||
def score_pairs(self, query: str, documents: list[str]) -> list[float]:
|
||||
query_words = set(query.lower().split())
|
||||
scores = []
|
||||
for doc in documents:
|
||||
doc_words = set(doc.lower().split())
|
||||
overlap = len(query_words & doc_words)
|
||||
scores.append(float(overlap) / max(len(query_words), 1))
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return Config.small() # hnsw_ef=50, hnsw_M=16, binary_top_k=50, ann_top_k=20, rerank_top_k=10
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_pipeline(tmp_path, config):
|
||||
"""Build a full SearchPipeline with 20 test docs indexed."""
|
||||
embedder = MockEmbedder()
|
||||
binary_store = BinaryStore(tmp_path / "binary", dim=DIM, config=config)
|
||||
ann_index = ANNIndex(tmp_path / "ann.hnsw", dim=DIM, config=config)
|
||||
fts = FTSEngine(tmp_path / "fts.db")
|
||||
reranker = MockReranker()
|
||||
|
||||
# Index all test docs
|
||||
ids = np.array([d[0] for d in TEST_DOCS], dtype=np.int64)
|
||||
vectors = np.array([embedder.embed_single(d[2]) for d in TEST_DOCS], dtype=np.float32)
|
||||
|
||||
binary_store.add(ids, vectors)
|
||||
ann_index.add(ids, vectors)
|
||||
fts.add_documents(TEST_DOCS)
|
||||
|
||||
return SearchPipeline(
|
||||
embedder=embedder,
|
||||
binary_store=binary_store,
|
||||
ann_index=ann_index,
|
||||
reranker=reranker,
|
||||
fts=fts,
|
||||
config=config,
|
||||
)
|
||||
Reference in New Issue
Block a user