mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
269 lines
9.5 KiB
Python
269 lines
9.5 KiB
Python
"""Optimum + ONNX Runtime reranker backend.
|
|
|
|
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
|
|
classification models. It is designed to run without requiring PyTorch at
|
|
runtime by using numpy tensors and ONNX Runtime execution providers.
|
|
|
|
Install (CPU):
|
|
pip install onnxruntime optimum[onnxruntime] transformers
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
from typing import Any, Iterable, Sequence
|
|
|
|
from .base import BaseReranker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def check_onnx_reranker_available() -> tuple[bool, str | None]:
|
|
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
|
|
try:
|
|
import numpy # noqa: F401
|
|
except ImportError as exc: # pragma: no cover - optional dependency
|
|
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
|
|
|
try:
|
|
import onnxruntime # noqa: F401
|
|
except ImportError as exc: # pragma: no cover - optional dependency
|
|
return (
|
|
False,
|
|
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
|
)
|
|
|
|
try:
|
|
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
|
|
except ImportError as exc: # pragma: no cover - optional dependency
|
|
return (
|
|
False,
|
|
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
|
)
|
|
|
|
try:
|
|
from transformers import AutoTokenizer # noqa: F401
|
|
except ImportError as exc: # pragma: no cover - optional dependency
|
|
return (
|
|
False,
|
|
f"transformers not available: {exc}. Install with: pip install transformers",
|
|
)
|
|
|
|
return True, None
|
|
|
|
|
|
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
|
|
for i in range(0, len(items), batch_size):
|
|
yield items[i : i + batch_size]
|
|
|
|
|
|
class ONNXReranker(BaseReranker):
|
|
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
|
|
|
|
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str | None = None,
|
|
*,
|
|
use_gpu: bool = True,
|
|
providers: list[Any] | None = None,
|
|
max_length: int | None = None,
|
|
) -> None:
|
|
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
|
if not self.model_name:
|
|
raise ValueError("model_name cannot be blank")
|
|
|
|
self.use_gpu = bool(use_gpu)
|
|
self.providers = providers
|
|
|
|
self.max_length = int(max_length) if max_length is not None else None
|
|
|
|
self._tokenizer: Any | None = None
|
|
self._model: Any | None = None
|
|
self._model_input_names: set[str] | None = None
|
|
self._lock = threading.RLock()
|
|
|
|
def _load_model(self) -> None:
|
|
if self._model is not None and self._tokenizer is not None:
|
|
return
|
|
|
|
ok, err = check_onnx_reranker_available()
|
|
if not ok:
|
|
raise ImportError(err)
|
|
|
|
with self._lock:
|
|
if self._model is not None and self._tokenizer is not None:
|
|
return
|
|
|
|
from inspect import signature
|
|
|
|
from optimum.onnxruntime import ORTModelForSequenceClassification
|
|
from transformers import AutoTokenizer
|
|
|
|
if self.providers is None:
|
|
from ..gpu_support import get_optimal_providers
|
|
|
|
# Include device_id options for DirectML/CUDA selection when available.
|
|
self.providers = get_optimal_providers(
|
|
use_gpu=self.use_gpu, with_device_options=True
|
|
)
|
|
|
|
# Some Optimum versions accept `providers`, others accept a single `provider`.
|
|
# Prefer passing the full providers list, with a conservative fallback.
|
|
model_kwargs: dict[str, Any] = {}
|
|
try:
|
|
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
|
|
if "providers" in params:
|
|
model_kwargs["providers"] = self.providers
|
|
elif "provider" in params:
|
|
provider_name = "CPUExecutionProvider"
|
|
if self.providers:
|
|
first = self.providers[0]
|
|
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
|
model_kwargs["provider"] = provider_name
|
|
except Exception:
|
|
model_kwargs = {}
|
|
|
|
try:
|
|
self._model = ORTModelForSequenceClassification.from_pretrained(
|
|
self.model_name,
|
|
**model_kwargs,
|
|
)
|
|
except TypeError:
|
|
# Fallback for older Optimum versions: retry without provider arguments.
|
|
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
|
|
|
|
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
|
|
|
# Cache model input names to filter tokenizer outputs defensively.
|
|
input_names: set[str] | None = None
|
|
for attr in ("input_names", "model_input_names"):
|
|
names = getattr(self._model, attr, None)
|
|
if isinstance(names, (list, tuple)) and names:
|
|
input_names = {str(n) for n in names}
|
|
break
|
|
if input_names is None:
|
|
try:
|
|
session = getattr(self._model, "model", None)
|
|
if session is not None and hasattr(session, "get_inputs"):
|
|
input_names = {i.name for i in session.get_inputs()}
|
|
except Exception:
|
|
input_names = None
|
|
self._model_input_names = input_names
|
|
|
|
@staticmethod
|
|
def _sigmoid(x: "Any") -> "Any":
|
|
import numpy as np
|
|
|
|
x = np.clip(x, -50.0, 50.0)
|
|
return 1.0 / (1.0 + np.exp(-x))
|
|
|
|
@staticmethod
|
|
def _select_relevance_logit(logits: "Any") -> "Any":
|
|
import numpy as np
|
|
|
|
arr = np.asarray(logits)
|
|
if arr.ndim == 0:
|
|
return arr.reshape(1)
|
|
if arr.ndim == 1:
|
|
return arr
|
|
if arr.ndim >= 2:
|
|
# Common cases:
|
|
# - Regression: (batch, 1)
|
|
# - Binary classification: (batch, 2)
|
|
if arr.shape[-1] == 1:
|
|
return arr[..., 0]
|
|
if arr.shape[-1] == 2:
|
|
# Convert 2-logit softmax into a single logit via difference.
|
|
return arr[..., 1] - arr[..., 0]
|
|
return arr.max(axis=-1)
|
|
return arr.reshape(-1)
|
|
|
|
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
|
|
if self._tokenizer is None:
|
|
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
|
|
|
|
queries = [q for q, _ in batch]
|
|
docs = [d for _, d in batch]
|
|
|
|
tokenizer_kwargs: dict[str, Any] = {
|
|
"text": queries,
|
|
"text_pair": docs,
|
|
"padding": True,
|
|
"truncation": True,
|
|
"return_tensors": "np",
|
|
}
|
|
|
|
max_len = self.max_length
|
|
if max_len is None:
|
|
try:
|
|
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
|
|
if 0 < model_max < 10_000:
|
|
max_len = model_max
|
|
else:
|
|
max_len = 512
|
|
except Exception:
|
|
max_len = 512
|
|
if max_len is not None and max_len > 0:
|
|
tokenizer_kwargs["max_length"] = int(max_len)
|
|
|
|
encoded = self._tokenizer(**tokenizer_kwargs)
|
|
inputs = dict(encoded)
|
|
|
|
# Some models do not accept token_type_ids; filter to known input names if available.
|
|
if self._model_input_names:
|
|
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
|
|
|
|
return inputs
|
|
|
|
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
|
|
if self._model is None:
|
|
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
|
|
|
|
outputs = self._model(**inputs)
|
|
if hasattr(outputs, "logits"):
|
|
return outputs.logits
|
|
if isinstance(outputs, dict) and "logits" in outputs:
|
|
return outputs["logits"]
|
|
if isinstance(outputs, (list, tuple)) and outputs:
|
|
return outputs[0]
|
|
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
|
|
|
|
def score_pairs(
|
|
self,
|
|
pairs: Sequence[tuple[str, str]],
|
|
*,
|
|
batch_size: int = 32,
|
|
) -> list[float]:
|
|
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
|
|
if not pairs:
|
|
return []
|
|
|
|
self._load_model()
|
|
|
|
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
|
|
return []
|
|
|
|
import numpy as np
|
|
|
|
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
|
scores: list[float] = []
|
|
|
|
for batch in _iter_batches(list(pairs), bs):
|
|
inputs = self._tokenize_batch(batch)
|
|
logits = self._forward_logits(inputs)
|
|
rel_logits = self._select_relevance_logit(logits)
|
|
probs = self._sigmoid(rel_logits)
|
|
probs = np.clip(probs, 0.0, 1.0)
|
|
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
|
|
|
|
if len(scores) != len(pairs):
|
|
logger.debug(
|
|
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
|
|
)
|
|
return scores[: len(pairs)]
|
|
|
|
return scores
|