Files
Claude-Code-Workflow/codex-lens/build/lib/codexlens/semantic/reranker/onnx_reranker.py

269 lines
9.5 KiB
Python

"""Optimum + ONNX Runtime reranker backend.
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
classification models. It is designed to run without requiring PyTorch at
runtime by using numpy tensors and ONNX Runtime execution providers.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Iterable, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_onnx_reranker_available() -> tuple[bool, str | None]:
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
try:
import numpy # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
class ONNXReranker(BaseReranker):
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
providers: list[Any] | None = None,
max_length: int | None = None,
) -> None:
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.providers = providers
self.max_length = int(max_length) if max_length is not None else None
self._tokenizer: Any | None = None
self._model: Any | None = None
self._model_input_names: set[str] | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_onnx_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
if self.providers is None:
from ..gpu_support import get_optimal_providers
# Include device_id options for DirectML/CUDA selection when available.
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=True
)
# Some Optimum versions accept `providers`, others accept a single `provider`.
# Prefer passing the full providers list, with a conservative fallback.
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception:
model_kwargs = {}
try:
self._model = ORTModelForSequenceClassification.from_pretrained(
self.model_name,
**model_kwargs,
)
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments.
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache model input names to filter tokenizer outputs defensively.
input_names: set[str] | None = None
for attr in ("input_names", "model_input_names"):
names = getattr(self._model, attr, None)
if isinstance(names, (list, tuple)) and names:
input_names = {str(n) for n in names}
break
if input_names is None:
try:
session = getattr(self._model, "model", None)
if session is not None and hasattr(session, "get_inputs"):
input_names = {i.name for i in session.get_inputs()}
except Exception:
input_names = None
self._model_input_names = input_names
@staticmethod
def _sigmoid(x: "Any") -> "Any":
import numpy as np
x = np.clip(x, -50.0, 50.0)
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _select_relevance_logit(logits: "Any") -> "Any":
import numpy as np
arr = np.asarray(logits)
if arr.ndim == 0:
return arr.reshape(1)
if arr.ndim == 1:
return arr
if arr.ndim >= 2:
# Common cases:
# - Regression: (batch, 1)
# - Binary classification: (batch, 2)
if arr.shape[-1] == 1:
return arr[..., 0]
if arr.shape[-1] == 2:
# Convert 2-logit softmax into a single logit via difference.
return arr[..., 1] - arr[..., 0]
return arr.max(axis=-1)
return arr.reshape(-1)
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
queries = [q for q, _ in batch]
docs = [d for _, d in batch]
tokenizer_kwargs: dict[str, Any] = {
"text": queries,
"text_pair": docs,
"padding": True,
"truncation": True,
"return_tensors": "np",
}
max_len = self.max_length
if max_len is None:
try:
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
if 0 < model_max < 10_000:
max_len = model_max
else:
max_len = 512
except Exception:
max_len = 512
if max_len is not None and max_len > 0:
tokenizer_kwargs["max_length"] = int(max_len)
encoded = self._tokenizer(**tokenizer_kwargs)
inputs = dict(encoded)
# Some models do not accept token_type_ids; filter to known input names if available.
if self._model_input_names:
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
return inputs
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
if self._model is None:
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
outputs = self._model(**inputs)
if hasattr(outputs, "logits"):
return outputs.logits
if isinstance(outputs, dict) and "logits" in outputs:
return outputs["logits"]
if isinstance(outputs, (list, tuple)) and outputs:
return outputs[0]
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
if not pairs:
return []
self._load_model()
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
return []
import numpy as np
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for batch in _iter_batches(list(pairs), bs):
inputs = self._tokenize_batch(batch)
logits = self._forward_logits(inputs)
rel_logits = self._select_relevance_logit(logits)
probs = self._sigmoid(rel_logits)
probs = np.clip(probs, 0.0, 1.0)
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
if len(scores) != len(pairs):
logger.debug(
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
)
return scores[: len(pairs)]
return scores