mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
Unified Reranker Architecture: - Add BaseReranker ABC with factory pattern - Implement 4 backends: ONNX (default), API, LiteLLM, Legacy - Add .env configuration parsing for API credentials - Migrate from sentence-transformers to optimum+onnxruntime File Watcher Module: - Add real-time file system monitoring with watchdog - Implement IncrementalIndexer for single-file updates - Add WatcherManager with signal handling and graceful shutdown - Add 'codexlens watch' CLI command - Event filtering, debouncing, and deduplication - Thread-safe design with proper resource cleanup Tests: 16 watcher tests + 5 reranker test files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""Tests for reranker factory and availability checks."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import builtins
|
|
import math
|
|
import sys
|
|
import types
|
|
|
|
import pytest
|
|
|
|
from codexlens.semantic.reranker import (
|
|
BaseReranker,
|
|
ONNXReranker,
|
|
check_reranker_available,
|
|
get_reranker,
|
|
)
|
|
from codexlens.semantic.reranker import legacy as legacy_module
|
|
|
|
|
|
def test_public_imports_work() -> None:
|
|
from codexlens.semantic.reranker import BaseReranker as ImportedBaseReranker
|
|
from codexlens.semantic.reranker import get_reranker as imported_get_reranker
|
|
|
|
assert ImportedBaseReranker is BaseReranker
|
|
assert imported_get_reranker is get_reranker
|
|
|
|
|
|
def test_base_reranker_is_abstract() -> None:
|
|
with pytest.raises(TypeError):
|
|
BaseReranker() # type: ignore[abstract]
|
|
|
|
|
|
def test_check_reranker_available_invalid_backend() -> None:
|
|
ok, err = check_reranker_available("nope")
|
|
assert ok is False
|
|
assert "Invalid reranker backend" in (err or "")
|
|
|
|
|
|
def test_get_reranker_invalid_backend_raises_value_error() -> None:
|
|
with pytest.raises(ValueError, match="Unknown backend"):
|
|
get_reranker("nope")
|
|
|
|
|
|
def test_get_reranker_legacy_missing_dependency_raises_import_error(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", False)
|
|
monkeypatch.setattr(legacy_module, "_import_error", "missing sentence-transformers")
|
|
|
|
with pytest.raises(ImportError, match="missing sentence-transformers"):
|
|
get_reranker(backend="legacy", model_name="dummy-model")
|
|
|
|
|
|
def test_get_reranker_legacy_returns_cross_encoder_reranker(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
class DummyCrossEncoder:
|
|
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self.last_batch_size: int | None = None
|
|
|
|
def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]:
|
|
self.last_batch_size = int(batch_size)
|
|
return [0.5 for _ in pairs]
|
|
|
|
monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder)
|
|
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True)
|
|
monkeypatch.setattr(legacy_module, "_import_error", None)
|
|
|
|
reranker = get_reranker(backend=" LEGACY ", model_name="dummy-model", device="cpu")
|
|
assert isinstance(reranker, legacy_module.CrossEncoderReranker)
|
|
|
|
assert reranker.score_pairs([]) == []
|
|
|
|
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0)
|
|
assert scores == pytest.approx([0.5, 0.5])
|
|
assert reranker._model is not None
|
|
assert reranker._model.last_batch_size == 32
|
|
|
|
|
|
def test_check_reranker_available_onnx_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
real_import = builtins.__import__
|
|
|
|
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
|
if name == "onnxruntime":
|
|
raise ImportError("no onnxruntime")
|
|
return real_import(name, globals, locals, fromlist, level)
|
|
|
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
|
|
|
ok, err = check_reranker_available("onnx")
|
|
assert ok is False
|
|
assert "onnxruntime not available" in (err or "")
|
|
|
|
|
|
def test_check_reranker_available_onnx_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
dummy_onnxruntime = types.ModuleType("onnxruntime")
|
|
dummy_optimum = types.ModuleType("optimum")
|
|
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
|
|
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
|
|
dummy_optimum_ort.ORTModelForSequenceClassification = object()
|
|
|
|
dummy_transformers = types.ModuleType("transformers")
|
|
dummy_transformers.AutoTokenizer = object()
|
|
|
|
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
|
|
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
|
|
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
|
|
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
|
|
|
|
ok, err = check_reranker_available("onnx")
|
|
assert ok is True
|
|
assert err is None
|
|
|
|
|
|
def test_check_reranker_available_litellm_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
real_import = builtins.__import__
|
|
|
|
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
|
if name == "ccw_litellm":
|
|
raise ImportError("no ccw-litellm")
|
|
return real_import(name, globals, locals, fromlist, level)
|
|
|
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
|
|
|
ok, err = check_reranker_available("litellm")
|
|
assert ok is False
|
|
assert "ccw-litellm not available" in (err or "")
|
|
|
|
|
|
def test_check_reranker_available_litellm_deps_present(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
dummy_litellm = types.ModuleType("ccw_litellm")
|
|
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
|
|
|
|
ok, err = check_reranker_available("litellm")
|
|
assert ok is True
|
|
assert err is None
|
|
|
|
|
|
def test_check_reranker_available_api_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
real_import = builtins.__import__
|
|
|
|
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
|
if name == "httpx":
|
|
raise ImportError("no httpx")
|
|
return real_import(name, globals, locals, fromlist, level)
|
|
|
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
|
|
|
ok, err = check_reranker_available("api")
|
|
assert ok is False
|
|
assert "httpx not available" in (err or "")
|
|
|
|
|
|
def test_check_reranker_available_api_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
dummy_httpx = types.ModuleType("httpx")
|
|
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
|
|
|
|
ok, err = check_reranker_available("api")
|
|
assert ok is True
|
|
assert err is None
|
|
|
|
|
|
def test_get_reranker_litellm_returns_litellm_reranker(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ChatMessage:
|
|
role: str
|
|
content: str
|
|
|
|
class DummyLiteLLMClient:
|
|
def __init__(self, model: str = "default", **kwargs) -> None:
|
|
self.model = model
|
|
self.kwargs = kwargs
|
|
|
|
def chat(self, messages, **kwargs):
|
|
return types.SimpleNamespace(content="0.5")
|
|
|
|
dummy_litellm = types.ModuleType("ccw_litellm")
|
|
dummy_litellm.ChatMessage = ChatMessage
|
|
dummy_litellm.LiteLLMClient = DummyLiteLLMClient
|
|
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
|
|
|
|
reranker = get_reranker(backend="litellm", model_name="dummy-model")
|
|
|
|
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
|
|
|
|
assert isinstance(reranker, LiteLLMReranker)
|
|
assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5])
|
|
|
|
|
|
def test_get_reranker_onnx_raises_import_error_with_dependency_hint(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
real_import = builtins.__import__
|
|
|
|
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
|
if name == "onnxruntime":
|
|
raise ImportError("no onnxruntime")
|
|
return real_import(name, globals, locals, fromlist, level)
|
|
|
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
|
|
|
with pytest.raises(ImportError) as exc:
|
|
get_reranker(backend="onnx", model_name="any")
|
|
|
|
assert "onnxruntime" in str(exc.value)
|
|
|
|
|
|
def test_get_reranker_default_backend_is_onnx(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
dummy_onnxruntime = types.ModuleType("onnxruntime")
|
|
dummy_optimum = types.ModuleType("optimum")
|
|
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
|
|
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
|
|
dummy_optimum_ort.ORTModelForSequenceClassification = object()
|
|
|
|
dummy_transformers = types.ModuleType("transformers")
|
|
dummy_transformers.AutoTokenizer = object()
|
|
|
|
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
|
|
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
|
|
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
|
|
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
|
|
|
|
reranker = get_reranker()
|
|
assert isinstance(reranker, ONNXReranker)
|
|
|
|
|
|
def test_onnx_reranker_scores_pairs_with_sigmoid_normalization(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
import numpy as np
|
|
|
|
dummy_onnxruntime = types.ModuleType("onnxruntime")
|
|
|
|
dummy_optimum = types.ModuleType("optimum")
|
|
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
|
|
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
|
|
|
|
class DummyModelOutput:
|
|
def __init__(self, logits: np.ndarray) -> None:
|
|
self.logits = logits
|
|
|
|
class DummyModel:
|
|
input_names = ["input_ids", "attention_mask"]
|
|
|
|
def __init__(self) -> None:
|
|
self.calls: list[int] = []
|
|
self._next_logit = 0
|
|
|
|
def __call__(self, **inputs):
|
|
batch = int(inputs["input_ids"].shape[0])
|
|
start = self._next_logit
|
|
self._next_logit += batch
|
|
self.calls.append(batch)
|
|
logits = np.arange(start, start + batch, dtype=np.float32).reshape(batch, 1)
|
|
return DummyModelOutput(logits=logits)
|
|
|
|
class DummyORTModelForSequenceClassification:
|
|
@classmethod
|
|
def from_pretrained(cls, model_name: str, providers=None, **kwargs):
|
|
_ = model_name, providers, kwargs
|
|
return DummyModel()
|
|
|
|
dummy_optimum_ort.ORTModelForSequenceClassification = DummyORTModelForSequenceClassification
|
|
|
|
dummy_transformers = types.ModuleType("transformers")
|
|
|
|
class DummyAutoTokenizer:
|
|
model_max_length = 512
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_name: str, **kwargs):
|
|
_ = model_name, kwargs
|
|
return cls()
|
|
|
|
def __call__(self, *, text, text_pair, return_tensors, **kwargs):
|
|
_ = text_pair, kwargs
|
|
assert return_tensors == "np"
|
|
batch = len(text)
|
|
# Include token_type_ids to ensure input filtering is exercised.
|
|
return {
|
|
"input_ids": np.zeros((batch, 4), dtype=np.int64),
|
|
"attention_mask": np.ones((batch, 4), dtype=np.int64),
|
|
"token_type_ids": np.zeros((batch, 4), dtype=np.int64),
|
|
}
|
|
|
|
dummy_transformers.AutoTokenizer = DummyAutoTokenizer
|
|
|
|
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
|
|
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
|
|
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
|
|
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
|
|
|
|
reranker = get_reranker(backend="onnx", model_name="dummy-model", use_gpu=False)
|
|
assert isinstance(reranker, ONNXReranker)
|
|
assert reranker._model is None
|
|
|
|
pairs = [("q", f"d{idx}") for idx in range(5)]
|
|
scores = reranker.score_pairs(pairs, batch_size=2)
|
|
|
|
assert reranker._model is not None
|
|
assert reranker._model.calls == [2, 2, 1]
|
|
assert len(scores) == len(pairs)
|
|
assert all(0.0 <= s <= 1.0 for s in scores)
|
|
|
|
expected = [1.0 / (1.0 + math.exp(-float(i))) for i in range(len(pairs))]
|
|
assert scores == pytest.approx(expected, rel=1e-6, abs=1e-6)
|