Files
Claude-Code-Workflow/codex-lens/tests/test_reranker_factory.py
catlog22 520f2d26f2 feat(codex-lens): add unified reranker architecture and file watcher
Unified Reranker Architecture:
- Add BaseReranker ABC with factory pattern
- Implement 4 backends: ONNX (default), API, LiteLLM, Legacy
- Add .env configuration parsing for API credentials
- Migrate from sentence-transformers to optimum+onnxruntime

File Watcher Module:
- Add real-time file system monitoring with watchdog
- Implement IncrementalIndexer for single-file updates
- Add WatcherManager with signal handling and graceful shutdown
- Add 'codexlens watch' CLI command
- Event filtering, debouncing, and deduplication
- Thread-safe design with proper resource cleanup

Tests: 16 watcher tests + 5 reranker test files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-01 13:23:52 +08:00

316 lines
11 KiB
Python

"""Tests for reranker factory and availability checks."""
from __future__ import annotations
import builtins
import math
import sys
import types
import pytest
from codexlens.semantic.reranker import (
BaseReranker,
ONNXReranker,
check_reranker_available,
get_reranker,
)
from codexlens.semantic.reranker import legacy as legacy_module
def test_public_imports_work() -> None:
from codexlens.semantic.reranker import BaseReranker as ImportedBaseReranker
from codexlens.semantic.reranker import get_reranker as imported_get_reranker
assert ImportedBaseReranker is BaseReranker
assert imported_get_reranker is get_reranker
def test_base_reranker_is_abstract() -> None:
with pytest.raises(TypeError):
BaseReranker() # type: ignore[abstract]
def test_check_reranker_available_invalid_backend() -> None:
ok, err = check_reranker_available("nope")
assert ok is False
assert "Invalid reranker backend" in (err or "")
def test_get_reranker_invalid_backend_raises_value_error() -> None:
with pytest.raises(ValueError, match="Unknown backend"):
get_reranker("nope")
def test_get_reranker_legacy_missing_dependency_raises_import_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", False)
monkeypatch.setattr(legacy_module, "_import_error", "missing sentence-transformers")
with pytest.raises(ImportError, match="missing sentence-transformers"):
get_reranker(backend="legacy", model_name="dummy-model")
def test_get_reranker_legacy_returns_cross_encoder_reranker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class DummyCrossEncoder:
def __init__(self, model_name: str, *, device: str | None = None) -> None:
self.model_name = model_name
self.device = device
self.last_batch_size: int | None = None
def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]:
self.last_batch_size = int(batch_size)
return [0.5 for _ in pairs]
monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder)
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True)
monkeypatch.setattr(legacy_module, "_import_error", None)
reranker = get_reranker(backend=" LEGACY ", model_name="dummy-model", device="cpu")
assert isinstance(reranker, legacy_module.CrossEncoderReranker)
assert reranker.score_pairs([]) == []
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0)
assert scores == pytest.approx([0.5, 0.5])
assert reranker._model is not None
assert reranker._model.last_batch_size == 32
def test_check_reranker_available_onnx_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "onnxruntime":
raise ImportError("no onnxruntime")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
ok, err = check_reranker_available("onnx")
assert ok is False
assert "onnxruntime not available" in (err or "")
def test_check_reranker_available_onnx_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_onnxruntime = types.ModuleType("onnxruntime")
dummy_optimum = types.ModuleType("optimum")
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
dummy_optimum_ort.ORTModelForSequenceClassification = object()
dummy_transformers = types.ModuleType("transformers")
dummy_transformers.AutoTokenizer = object()
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
ok, err = check_reranker_available("onnx")
assert ok is True
assert err is None
def test_check_reranker_available_litellm_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "ccw_litellm":
raise ImportError("no ccw-litellm")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
ok, err = check_reranker_available("litellm")
assert ok is False
assert "ccw-litellm not available" in (err or "")
def test_check_reranker_available_litellm_deps_present(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dummy_litellm = types.ModuleType("ccw_litellm")
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
ok, err = check_reranker_available("litellm")
assert ok is True
assert err is None
def test_check_reranker_available_api_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "httpx":
raise ImportError("no httpx")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
ok, err = check_reranker_available("api")
assert ok is False
assert "httpx not available" in (err or "")
def test_check_reranker_available_api_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_httpx = types.ModuleType("httpx")
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
ok, err = check_reranker_available("api")
assert ok is True
assert err is None
def test_get_reranker_litellm_returns_litellm_reranker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class ChatMessage:
role: str
content: str
class DummyLiteLLMClient:
def __init__(self, model: str = "default", **kwargs) -> None:
self.model = model
self.kwargs = kwargs
def chat(self, messages, **kwargs):
return types.SimpleNamespace(content="0.5")
dummy_litellm = types.ModuleType("ccw_litellm")
dummy_litellm.ChatMessage = ChatMessage
dummy_litellm.LiteLLMClient = DummyLiteLLMClient
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
reranker = get_reranker(backend="litellm", model_name="dummy-model")
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
assert isinstance(reranker, LiteLLMReranker)
assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5])
def test_get_reranker_onnx_raises_import_error_with_dependency_hint(
monkeypatch: pytest.MonkeyPatch,
) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "onnxruntime":
raise ImportError("no onnxruntime")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
with pytest.raises(ImportError) as exc:
get_reranker(backend="onnx", model_name="any")
assert "onnxruntime" in str(exc.value)
def test_get_reranker_default_backend_is_onnx(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_onnxruntime = types.ModuleType("onnxruntime")
dummy_optimum = types.ModuleType("optimum")
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
dummy_optimum_ort.ORTModelForSequenceClassification = object()
dummy_transformers = types.ModuleType("transformers")
dummy_transformers.AutoTokenizer = object()
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
reranker = get_reranker()
assert isinstance(reranker, ONNXReranker)
def test_onnx_reranker_scores_pairs_with_sigmoid_normalization(
monkeypatch: pytest.MonkeyPatch,
) -> None:
import numpy as np
dummy_onnxruntime = types.ModuleType("onnxruntime")
dummy_optimum = types.ModuleType("optimum")
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
class DummyModelOutput:
def __init__(self, logits: np.ndarray) -> None:
self.logits = logits
class DummyModel:
input_names = ["input_ids", "attention_mask"]
def __init__(self) -> None:
self.calls: list[int] = []
self._next_logit = 0
def __call__(self, **inputs):
batch = int(inputs["input_ids"].shape[0])
start = self._next_logit
self._next_logit += batch
self.calls.append(batch)
logits = np.arange(start, start + batch, dtype=np.float32).reshape(batch, 1)
return DummyModelOutput(logits=logits)
class DummyORTModelForSequenceClassification:
@classmethod
def from_pretrained(cls, model_name: str, providers=None, **kwargs):
_ = model_name, providers, kwargs
return DummyModel()
dummy_optimum_ort.ORTModelForSequenceClassification = DummyORTModelForSequenceClassification
dummy_transformers = types.ModuleType("transformers")
class DummyAutoTokenizer:
model_max_length = 512
@classmethod
def from_pretrained(cls, model_name: str, **kwargs):
_ = model_name, kwargs
return cls()
def __call__(self, *, text, text_pair, return_tensors, **kwargs):
_ = text_pair, kwargs
assert return_tensors == "np"
batch = len(text)
# Include token_type_ids to ensure input filtering is exercised.
return {
"input_ids": np.zeros((batch, 4), dtype=np.int64),
"attention_mask": np.ones((batch, 4), dtype=np.int64),
"token_type_ids": np.zeros((batch, 4), dtype=np.int64),
}
dummy_transformers.AutoTokenizer = DummyAutoTokenizer
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
reranker = get_reranker(backend="onnx", model_name="dummy-model", use_gpu=False)
assert isinstance(reranker, ONNXReranker)
assert reranker._model is None
pairs = [("q", f"d{idx}") for idx in range(5)]
scores = reranker.score_pairs(pairs, batch_size=2)
assert reranker._model is not None
assert reranker._model.calls == [2, 2, 1]
assert len(scores) == len(pairs)
assert all(0.0 <= s <= 1.0 for s in scores)
expected = [1.0 / (1.0 + math.exp(-float(i))) for i in range(len(pairs))]
assert scores == pytest.approx(expected, rel=1e-6, abs=1e-6)