feat(codex-lens): add unified reranker architecture and file watcher

Unified Reranker Architecture:
- Add BaseReranker ABC with factory pattern
- Implement 4 backends: ONNX (default), API, LiteLLM, Legacy
- Add .env configuration parsing for API credentials
- Migrate from sentence-transformers to optimum+onnxruntime

File Watcher Module:
- Add real-time file system monitoring with watchdog
- Implement IncrementalIndexer for single-file updates
- Add WatcherManager with signal handling and graceful shutdown
- Add 'codexlens watch' CLI command
- Event filtering, debouncing, and deduplication
- Thread-safe design with proper resource cleanup

Tests: 16 watcher tests + 5 reranker test files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
catlog22
2026-01-01 13:23:52 +08:00
parent 8ac27548ad
commit 520f2d26f2
27 changed files with 3571 additions and 14 deletions

View File

@@ -0,0 +1,22 @@
"""Reranker backends for second-stage search ranking.
This subpackage provides a unified interface and factory for different reranking
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
"""
from __future__ import annotations
from .base import BaseReranker
from .factory import check_reranker_available, get_reranker
from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
__all__ = [
"BaseReranker",
"check_reranker_available",
"get_reranker",
"CrossEncoderReranker",
"check_cross_encoder_available",
"ONNXReranker",
"check_onnx_reranker_available",
]

View File

@@ -0,0 +1,310 @@
"""API-based reranker using a remote HTTP provider.
Supported providers:
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
- Cohere: https://api.cohere.ai/v1/rerank
- Jina: https://api.jina.ai/v1/rerank
"""
from __future__ import annotations
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Mapping, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback."""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Try loading from .env files
try:
from codexlens.env_config import get_env
return get_env(key, workspace_root=workspace_root)
except ImportError:
return None
def check_httpx_available() -> tuple[bool, str | None]:
try:
import httpx # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"httpx not available: {exc}. Install with: pip install httpx"
return True, None
class APIReranker(BaseReranker):
"""Reranker backed by a remote reranking HTTP API."""
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
"siliconflow": {
"api_base": "https://api.siliconflow.cn",
"endpoint": "/v1/rerank",
"default_model": "BAAI/bge-reranker-v2-m3",
},
"cohere": {
"api_base": "https://api.cohere.ai",
"endpoint": "/v1/rerank",
"default_model": "rerank-english-v3.0",
},
"jina": {
"api_base": "https://api.jina.ai",
"endpoint": "/v1/rerank",
"default_model": "jina-reranker-v2-base-multilingual",
},
}
def __init__(
self,
*,
provider: str = "siliconflow",
model_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
timeout: float = 30.0,
max_retries: int = 3,
backoff_base_s: float = 0.5,
backoff_max_s: float = 8.0,
env_api_key: str = _DEFAULT_ENV_API_KEY,
workspace_root: Path | str | None = None,
) -> None:
ok, err = check_httpx_available()
if not ok: # pragma: no cover - exercised via factory availability tests
raise ImportError(err)
import httpx
self._workspace_root = Path(workspace_root) if workspace_root else None
self.provider = (provider or "").strip().lower()
if self.provider not in self._PROVIDER_DEFAULTS:
raise ValueError(
f"Unknown reranker provider: {provider}. "
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
)
defaults = self._PROVIDER_DEFAULTS[self.provider]
# Load api_base from env with .env fallback
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
self.endpoint = defaults["endpoint"]
# Load model from env with .env fallback
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
# Load API key from env with .env fallback
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
resolved_key = resolved_key.strip()
if not resolved_key:
raise ValueError(
f"Missing API key for reranker provider '{self.provider}'. "
f"Pass api_key=... or set ${env_api_key}."
)
self._api_key = resolved_key
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
if self.provider == "cohere":
headers.setdefault("Cohere-Version", "2022-12-06")
self._client = httpx.Client(
base_url=self.api_base,
headers=headers,
timeout=self.timeout_s,
)
def close(self) -> None:
try:
self._client.close()
except Exception: # pragma: no cover - defensive
return
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
if retry_after_s is not None and retry_after_s > 0:
time.sleep(min(float(retry_after_s), self.backoff_max_s))
return
exp = self.backoff_base_s * (2**attempt)
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
time.sleep(min(self.backoff_max_s, exp + jitter))
@staticmethod
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
value = (headers.get("Retry-After") or "").strip()
if not value:
return None
try:
return float(value)
except ValueError:
return None
@staticmethod
def _should_retry_status(status_code: int) -> bool:
return status_code == 429 or 500 <= status_code <= 599
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
response = self._client.post(self.endpoint, json=dict(payload))
except Exception as exc: # httpx is optional at import-time
last_exc = exc
if attempt < self.max_retries:
self._sleep_backoff(attempt)
continue
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' after "
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
) from exc
status = int(getattr(response, "status_code", 0) or 0)
if status >= 400:
body_preview = ""
try:
body_preview = (response.text or "").strip()
except Exception:
body_preview = ""
if len(body_preview) > 300:
body_preview = body_preview[:300] + ""
if self._should_retry_status(status) and attempt < self.max_retries:
retry_after = self._parse_retry_after_seconds(response.headers)
logger.warning(
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
self.api_base,
self.endpoint,
status,
attempt + 1,
self.max_retries + 1,
)
self._sleep_backoff(attempt, retry_after_s=retry_after)
continue
if status in {401, 403}:
raise RuntimeError(
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
"Check your API key."
)
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
f"Response: {body_preview or '<empty>'}"
)
try:
data = response.json()
except Exception as exc:
raise RuntimeError(
f"Rerank response from provider '{self.provider}' is not valid JSON: "
f"{type(exc).__name__}: {exc}"
) from exc
if not isinstance(data, dict):
raise RuntimeError(
f"Rerank response from provider '{self.provider}' must be a JSON object; "
f"got {type(data).__name__}"
)
return data
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
)
@staticmethod
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
if not isinstance(results, list):
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
scores: list[float] = [0.0 for _ in range(expected)]
filled = 0
for item in results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score"))
if idx is None or score is None:
continue
try:
idx_int = int(idx)
score_f = float(score)
except (TypeError, ValueError):
continue
if 0 <= idx_int < expected:
scores[idx_int] = score_f
filled += 1
if filled != expected:
raise RuntimeError(
f"Rerank response contained {filled}/{expected} scored documents; "
"ensure top_n matches the number of documents."
)
return scores
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
payload: dict[str, Any] = {
"model": self.model_name,
"query": query,
"documents": list(documents),
"top_n": len(documents),
"return_documents": False,
}
return payload
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
if not documents:
return []
payload = self._build_payload(query=query, documents=documents)
data = self._request_json(payload)
results = data.get("results")
return self._extract_scores_from_results(results, expected=len(documents))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
) -> list[float]:
if not pairs:
return []
grouped: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
grouped.setdefault(str(query), []).append((idx, str(doc)))
scores: list[float] = [0.0 for _ in range(len(pairs))]
for query, items in grouped.items():
documents = [doc for _, doc in items]
query_scores = self._rerank_one_query(query=query, documents=documents)
for (orig_idx, _), score in zip(items, query_scores):
scores[orig_idx] = float(score)
return scores

View File

@@ -0,0 +1,36 @@
"""Base class for rerankers.
Defines the interface that all rerankers must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence
class BaseReranker(ABC):
"""Base class for all rerankers.
All reranker implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@abstractmethod
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair).
"""
...

View File

@@ -0,0 +1,138 @@
"""Factory for creating rerankers.
Provides a unified interface for instantiating different reranker backends.
"""
from __future__ import annotations
from typing import Any
from .base import BaseReranker
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific reranker backend can be used.
Notes:
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
- "onnx" uses Optimum + ONNX Runtime (pip install codexlens[reranker] or codexlens[reranker-onnx]).
- "api" uses a remote reranking HTTP API (requires httpx).
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
"""
backend = (backend or "").strip().lower()
if backend == "legacy":
from .legacy import check_cross_encoder_available
return check_cross_encoder_available()
if backend == "onnx":
from .onnx_reranker import check_onnx_reranker_available
return check_onnx_reranker_available()
if backend == "litellm":
try:
import ccw_litellm # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
)
try:
from .litellm_reranker import LiteLLMReranker # noqa: F401
except Exception as exc: # pragma: no cover - defensive
return False, f"LiteLLM reranker backend not available: {exc}"
return True, None
if backend == "api":
from .api_reranker import check_httpx_available
return check_httpx_available()
return False, (
f"Invalid reranker backend: {backend}. "
"Must be 'onnx', 'api', 'litellm', or 'legacy'."
)
def get_reranker(
backend: str = "onnx",
model_name: str | None = None,
*,
device: str | None = None,
**kwargs: Any,
) -> BaseReranker:
"""Factory function to create reranker based on backend.
Args:
backend: Reranker backend to use. Options:
- "onnx": Optimum + onnxruntime backend (default)
- "api": HTTP API backend (remote providers)
- "litellm": LiteLLM backend (LLM-based, experimental)
- "legacy": sentence-transformers CrossEncoder backend (optional)
model_name: Model identifier for model-based backends. Defaults depend on backend:
- onnx: Xenova/ms-marco-MiniLM-L-6-v2
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
- litellm: default
device: Optional device string for backends that support it (legacy only).
**kwargs: Additional backend-specific arguments.
Returns:
BaseReranker: Configured reranker instance.
Raises:
ValueError: If backend is not recognized.
ImportError: If required backend dependencies are not installed or backend is unavailable.
"""
backend = (backend or "").strip().lower()
if backend == "onnx":
ok, err = check_reranker_available("onnx")
if not ok:
raise ImportError(err)
from .onnx_reranker import ONNXReranker
resolved_model_name = (model_name or "").strip() or ONNXReranker.DEFAULT_MODEL
_ = device # Device selection is managed via ONNX Runtime providers.
return ONNXReranker(model_name=resolved_model_name, **kwargs)
if backend == "legacy":
ok, err = check_reranker_available("legacy")
if not ok:
raise ImportError(err)
from .legacy import CrossEncoderReranker
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
if backend == "litellm":
ok, err = check_reranker_available("litellm")
if not ok:
raise ImportError(err)
from .litellm_reranker import LiteLLMReranker
_ = device # Device selection is not applicable to remote LLM backends.
resolved_model_name = (model_name or "").strip() or "default"
return LiteLLMReranker(model=resolved_model_name, **kwargs)
if backend == "api":
ok, err = check_reranker_available("api")
if not ok:
raise ImportError(err)
from .api_reranker import APIReranker
_ = device # Device selection is not applicable to remote HTTP backends.
resolved_model_name = (model_name or "").strip() or None
return APIReranker(model_name=resolved_model_name, **kwargs)
raise ValueError(
f"Unknown backend: {backend}. Supported backends: 'onnx', 'api', 'litellm', 'legacy'"
)

View File

@@ -1,6 +1,6 @@
"""Optional cross-encoder reranker for second-stage search ranking.
"""Legacy sentence-transformers cross-encoder reranker.
Install with: pip install codexlens[reranker]
Install with: pip install codexlens[reranker-legacy]
"""
from __future__ import annotations
@@ -9,6 +9,8 @@ import logging
import threading
from typing import List, Sequence, Tuple
from .base import BaseReranker
logger = logging.getLogger(__name__)
try:
@@ -25,10 +27,14 @@ except ImportError as exc: # pragma: no cover - optional dependency
def check_cross_encoder_available() -> tuple[bool, str | None]:
if CROSS_ENCODER_AVAILABLE:
return True, None
return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]"
return (
False,
_import_error
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
)
class CrossEncoderReranker:
class CrossEncoderReranker(BaseReranker):
"""Cross-encoder reranker with lazy model loading."""
def __init__(self, model_name: str, *, device: str | None = None) -> None:
@@ -83,4 +89,3 @@ class CrossEncoderReranker:
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
return [float(s) for s in scores]

View File

@@ -0,0 +1,214 @@
"""Experimental LiteLLM reranker backend.
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
relevance of a single (query, document) pair per request.
Notes:
- This backend is experimental and may be slow/expensive compared to local
rerankers.
- It relies on `ccw-litellm` for a unified LLM API across providers.
"""
from __future__ import annotations
import json
import logging
import re
import threading
import time
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
def _coerce_score_to_unit_interval(score: float) -> float:
"""Coerce a numeric score into [0, 1].
The prompt asks for a float in [0, 1], but some models may respond with 0-10
or 0-100 scales. This function attempts a conservative normalization.
"""
if 0.0 <= score <= 1.0:
return score
if 0.0 <= score <= 10.0:
return score / 10.0
if 0.0 <= score <= 100.0:
return score / 100.0
return max(0.0, min(1.0, score))
def _extract_score(text: str) -> float | None:
"""Extract a numeric relevance score from an LLM response."""
content = (text or "").strip()
if not content:
return None
# Prefer JSON if present.
if "{" in content and "}" in content:
try:
start = content.index("{")
end = content.rindex("}") + 1
payload = json.loads(content[start:end])
if isinstance(payload, dict) and "score" in payload:
return float(payload["score"])
except Exception:
pass
match = _NUMBER_RE.search(content)
if not match:
return None
try:
return float(match.group(0))
except ValueError:
return None
class LiteLLMReranker(BaseReranker):
"""Experimental reranker that uses a LiteLLM-compatible model.
This reranker scores each (query, doc) pair in isolation (single-pair mode)
to improve prompt reliability across providers.
"""
_SYSTEM_PROMPT = (
"You are a relevance scoring assistant.\n"
"Given a search query and a document snippet, output a single numeric "
"relevance score between 0 and 1.\n\n"
"Scoring guidance:\n"
"- 1.0: The document directly answers the query.\n"
"- 0.5: The document is partially relevant.\n"
"- 0.0: The document is unrelated.\n\n"
"Output requirements:\n"
"- Output ONLY the number (e.g., 0.73).\n"
"- Do not include any other text."
)
def __init__(
self,
model: str = "default",
*,
requests_per_minute: float | None = None,
min_interval_seconds: float | None = None,
default_score: float = 0.0,
max_doc_chars: int = 8000,
**litellm_kwargs: Any,
) -> None:
"""Initialize the reranker.
Args:
model: Model name from ccw-litellm configuration (default: "default").
requests_per_minute: Optional rate limit in requests per minute.
min_interval_seconds: Optional minimum interval between requests. If set,
it takes precedence over requests_per_minute.
default_score: Score to use when an API call fails or parsing fails.
max_doc_chars: Maximum number of document characters to include in the prompt.
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
Raises:
ImportError: If ccw-litellm is not installed.
ValueError: If model is blank.
"""
self.model_name = (model or "").strip()
if not self.model_name:
raise ValueError("model cannot be blank")
self.default_score = float(default_score)
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
if min_interval_seconds is not None:
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
elif requests_per_minute is not None and float(requests_per_minute) > 0:
self._min_interval_seconds = 60.0 / float(requests_per_minute)
else:
self._min_interval_seconds = 0.0
# Prefer deterministic output by default; allow overrides via kwargs.
litellm_kwargs = dict(litellm_kwargs)
litellm_kwargs.setdefault("temperature", 0.0)
litellm_kwargs.setdefault("max_tokens", 16)
try:
from ccw_litellm import ChatMessage, LiteLLMClient
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from exc
self._ChatMessage = ChatMessage
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
self._lock = threading.RLock()
self._last_request_at = 0.0
def _sanitize_text(self, text: str) -> str:
# Keep consistent with LiteLLMEmbedderWrapper workaround.
if text.startswith("import"):
return " " + text
return text
def _rate_limit(self) -> None:
if self._min_interval_seconds <= 0:
return
with self._lock:
now = time.monotonic()
elapsed = now - self._last_request_at
if elapsed < self._min_interval_seconds:
time.sleep(self._min_interval_seconds - elapsed)
self._last_request_at = time.monotonic()
def _build_user_prompt(self, query: str, doc: str) -> str:
sanitized_query = self._sanitize_text(query or "")
sanitized_doc = self._sanitize_text(doc or "")
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
sanitized_doc = sanitized_doc[: self.max_doc_chars]
return (
"Query:\n"
f"{sanitized_query}\n\n"
"Document:\n"
f"{sanitized_doc}\n\n"
"Return the relevance score (0 to 1) as a single number:"
)
def _score_single_pair(self, query: str, doc: str) -> float:
messages = [
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
]
try:
self._rate_limit()
response = self._client.chat(messages)
except Exception as exc:
logger.debug("LiteLLM reranker request failed: %s", exc)
return self.default_score
raw = getattr(response, "content", "") or ""
score = _extract_score(raw)
if score is None:
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
return self.default_score
return _coerce_score_to_unit_interval(float(score))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with per-pair LLM calls."""
if not pairs:
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for i in range(0, len(pairs), bs):
batch = pairs[i : i + bs]
for query, doc in batch:
scores.append(self._score_single_pair(query, doc))
return scores

View File

@@ -0,0 +1,268 @@
"""Optimum + ONNX Runtime reranker backend.
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
classification models. It is designed to run without requiring PyTorch at
runtime by using numpy tensors and ONNX Runtime execution providers.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Iterable, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_onnx_reranker_available() -> tuple[bool, str | None]:
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
try:
import numpy # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
class ONNXReranker(BaseReranker):
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
providers: list[Any] | None = None,
max_length: int | None = None,
) -> None:
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.providers = providers
self.max_length = int(max_length) if max_length is not None else None
self._tokenizer: Any | None = None
self._model: Any | None = None
self._model_input_names: set[str] | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_onnx_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
if self.providers is None:
from ..gpu_support import get_optimal_providers
# Include device_id options for DirectML/CUDA selection when available.
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=True
)
# Some Optimum versions accept `providers`, others accept a single `provider`.
# Prefer passing the full providers list, with a conservative fallback.
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception:
model_kwargs = {}
try:
self._model = ORTModelForSequenceClassification.from_pretrained(
self.model_name,
**model_kwargs,
)
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments.
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache model input names to filter tokenizer outputs defensively.
input_names: set[str] | None = None
for attr in ("input_names", "model_input_names"):
names = getattr(self._model, attr, None)
if isinstance(names, (list, tuple)) and names:
input_names = {str(n) for n in names}
break
if input_names is None:
try:
session = getattr(self._model, "model", None)
if session is not None and hasattr(session, "get_inputs"):
input_names = {i.name for i in session.get_inputs()}
except Exception:
input_names = None
self._model_input_names = input_names
@staticmethod
def _sigmoid(x: "Any") -> "Any":
import numpy as np
x = np.clip(x, -50.0, 50.0)
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _select_relevance_logit(logits: "Any") -> "Any":
import numpy as np
arr = np.asarray(logits)
if arr.ndim == 0:
return arr.reshape(1)
if arr.ndim == 1:
return arr
if arr.ndim >= 2:
# Common cases:
# - Regression: (batch, 1)
# - Binary classification: (batch, 2)
if arr.shape[-1] == 1:
return arr[..., 0]
if arr.shape[-1] == 2:
# Convert 2-logit softmax into a single logit via difference.
return arr[..., 1] - arr[..., 0]
return arr.max(axis=-1)
return arr.reshape(-1)
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
queries = [q for q, _ in batch]
docs = [d for _, d in batch]
tokenizer_kwargs: dict[str, Any] = {
"text": queries,
"text_pair": docs,
"padding": True,
"truncation": True,
"return_tensors": "np",
}
max_len = self.max_length
if max_len is None:
try:
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
if 0 < model_max < 10_000:
max_len = model_max
else:
max_len = 512
except Exception:
max_len = 512
if max_len is not None and max_len > 0:
tokenizer_kwargs["max_length"] = int(max_len)
encoded = self._tokenizer(**tokenizer_kwargs)
inputs = dict(encoded)
# Some models do not accept token_type_ids; filter to known input names if available.
if self._model_input_names:
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
return inputs
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
if self._model is None:
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
outputs = self._model(**inputs)
if hasattr(outputs, "logits"):
return outputs.logits
if isinstance(outputs, dict) and "logits" in outputs:
return outputs["logits"]
if isinstance(outputs, (list, tuple)) and outputs:
return outputs[0]
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
if not pairs:
return []
self._load_model()
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
return []
import numpy as np
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for batch in _iter_batches(list(pairs), bs):
inputs = self._tokenize_batch(batch)
logits = self._forward_logits(inputs)
rel_logits = self._select_relevance_logit(logits)
probs = self._sigmoid(rel_logits)
probs = np.clip(probs, 0.0, 1.0)
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
if len(scores) != len(pairs):
logger.debug(
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
)
return scores[: len(pairs)]
return scores