Files
Claude-Code-Workflow/codex-lens-v2/tests/unit/test_core.py
catlog22 de4158597b Implement search and reranking functionality with FTS and embedding support
- Added BaseReranker abstract class for defining reranking interfaces.
- Implemented FastEmbedReranker using fastembed's TextCrossEncoder for scoring document-query pairs.
- Introduced FTSEngine for full-text search capabilities using SQLite FTS5.
- Developed SearchPipeline to integrate embedding, binary search, ANN indexing, FTS, and reranking.
- Added fusion methods for combining results from different search strategies using Reciprocal Rank Fusion.
- Created unit and integration tests for the new search and reranking components.
- Established configuration management for search parameters and models.
2026-03-16 23:03:17 +08:00

137 lines
4.0 KiB
Python

"""Unit tests for BinaryStore and ANNIndex (no fastembed required)."""
from __future__ import annotations
import concurrent.futures
import tempfile
from pathlib import Path
import numpy as np
import pytest
from codexlens.config import Config
from codexlens.core import ANNIndex, BinaryStore
DIM = 32
RNG = np.random.default_rng(42)
def make_vectors(n: int, dim: int = DIM) -> np.ndarray:
return RNG.standard_normal((n, dim)).astype(np.float32)
def make_ids(n: int, start: int = 0) -> np.ndarray:
return np.arange(start, start + n, dtype=np.int64)
# ---------------------------------------------------------------------------
# BinaryStore tests
# ---------------------------------------------------------------------------
class TestBinaryStore:
def test_binary_store_add_and_search(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(10)
ids = make_ids(10)
store.add(ids, vecs)
assert len(store) == 10
top_k = 5
ret_ids, ret_dists = store.coarse_search(vecs[0], top_k=top_k)
assert ret_ids.shape == (top_k,)
assert ret_dists.shape == (top_k,)
# distances are non-negative integers
assert (ret_dists >= 0).all()
def test_binary_hamming_correctness(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(20)
ids = make_ids(20)
store.add(ids, vecs)
# Query with the exact stored vector; it must be the top-1 result
query = vecs[7]
ret_ids, ret_dists = store.coarse_search(query, top_k=1)
assert ret_ids[0] == 7
assert ret_dists[0] == 0 # Hamming distance to itself is 0
def test_binary_store_persist(self, tmp_path: Path) -> None:
cfg = Config.small()
store = BinaryStore(tmp_path, DIM, cfg)
vecs = make_vectors(15)
ids = make_ids(15)
store.add(ids, vecs)
store.save()
# Load into a fresh instance
store2 = BinaryStore(tmp_path, DIM, cfg)
assert len(store2) == 15
query = vecs[3]
ret_ids, ret_dists = store2.coarse_search(query, top_k=1)
assert ret_ids[0] == 3
assert ret_dists[0] == 0
# ---------------------------------------------------------------------------
# ANNIndex tests
# ---------------------------------------------------------------------------
class TestANNIndex:
def test_ann_index_add_and_search(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(50)
ids = make_ids(50)
idx.add(ids, vecs)
assert len(idx) == 50
ret_ids, ret_dists = idx.fine_search(vecs[0], top_k=5)
assert len(ret_ids) == 5
assert len(ret_dists) == 5
def test_ann_index_thread_safety(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(50)
ids = make_ids(50)
idx.add(ids, vecs)
query = vecs[0]
errors: list[Exception] = []
def search() -> None:
try:
idx.fine_search(query, top_k=3)
except Exception as exc:
errors.append(exc)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool:
futures = [pool.submit(search) for _ in range(5)]
concurrent.futures.wait(futures)
assert errors == [], f"Thread safety errors: {errors}"
def test_ann_index_save_load(self, tmp_path: Path) -> None:
cfg = Config.small()
idx = ANNIndex(tmp_path, DIM, cfg)
vecs = make_vectors(30)
ids = make_ids(30)
idx.add(ids, vecs)
idx.save()
# Load into a fresh instance
idx2 = ANNIndex(tmp_path, DIM, cfg)
idx2.load()
assert len(idx2) == 30
ret_ids, ret_dists = idx2.fine_search(vecs[10], top_k=1)
assert len(ret_ids) == 1
assert ret_ids[0] == 10