mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-13 02:41:50 +08:00
Refactor code structure and remove redundant changes
This commit is contained in:
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Reranker backends for second-stage search ranking.
|
||||
|
||||
This subpackage provides a unified interface and factory for different reranking
|
||||
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import BaseReranker
|
||||
from .factory import check_reranker_available, get_reranker
|
||||
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
|
||||
from .legacy import CrossEncoderReranker, check_cross_encoder_available
|
||||
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
|
||||
|
||||
__all__ = [
|
||||
"BaseReranker",
|
||||
"check_reranker_available",
|
||||
"get_reranker",
|
||||
"CrossEncoderReranker",
|
||||
"check_cross_encoder_available",
|
||||
"FastEmbedReranker",
|
||||
"check_fastembed_reranker_available",
|
||||
"ONNXReranker",
|
||||
"check_onnx_reranker_available",
|
||||
]
|
||||
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""API-based reranker using a remote HTTP provider.
|
||||
|
||||
Supported providers:
|
||||
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
|
||||
- Cohere: https://api.cohere.ai/v1/rerank
|
||||
- Jina: https://api.jina.ai/v1/rerank
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
|
||||
|
||||
|
||||
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback."""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Try loading from .env files
|
||||
try:
|
||||
from codexlens.env_config import get_env
|
||||
return get_env(key, workspace_root=workspace_root)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def check_httpx_available() -> tuple[bool, str | None]:
|
||||
try:
|
||||
import httpx # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"httpx not available: {exc}. Install with: pip install httpx"
|
||||
return True, None
|
||||
|
||||
|
||||
class APIReranker(BaseReranker):
|
||||
"""Reranker backed by a remote reranking HTTP API."""
|
||||
|
||||
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
|
||||
"siliconflow": {
|
||||
"api_base": "https://api.siliconflow.cn",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "BAAI/bge-reranker-v2-m3",
|
||||
},
|
||||
"cohere": {
|
||||
"api_base": "https://api.cohere.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "rerank-english-v3.0",
|
||||
},
|
||||
"jina": {
|
||||
"api_base": "https://api.jina.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "jina-reranker-v2-base-multilingual",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: str = "siliconflow",
|
||||
model_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
backoff_base_s: float = 0.5,
|
||||
backoff_max_s: float = 8.0,
|
||||
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||
workspace_root: Path | str | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> None:
|
||||
ok, err = check_httpx_available()
|
||||
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||
raise ImportError(err)
|
||||
|
||||
import httpx
|
||||
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
|
||||
self.provider = (provider or "").strip().lower()
|
||||
if self.provider not in self._PROVIDER_DEFAULTS:
|
||||
raise ValueError(
|
||||
f"Unknown reranker provider: {provider}. "
|
||||
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
|
||||
)
|
||||
|
||||
defaults = self._PROVIDER_DEFAULTS[self.provider]
|
||||
|
||||
# Load api_base from env with .env fallback
|
||||
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
|
||||
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
|
||||
self.endpoint = defaults["endpoint"]
|
||||
|
||||
# Load model from env with .env fallback
|
||||
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
|
||||
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
# Load API key from env with .env fallback
|
||||
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
|
||||
resolved_key = resolved_key.strip()
|
||||
if not resolved_key:
|
||||
raise ValueError(
|
||||
f"Missing API key for reranker provider '{self.provider}'. "
|
||||
f"Pass api_key=... or set ${env_api_key}."
|
||||
)
|
||||
self._api_key = resolved_key
|
||||
|
||||
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
|
||||
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
|
||||
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
|
||||
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.provider == "cohere":
|
||||
headers.setdefault("Cohere-Version", "2022-12-06")
|
||||
|
||||
self._client = httpx.Client(
|
||||
base_url=self.api_base,
|
||||
headers=headers,
|
||||
timeout=self.timeout_s,
|
||||
)
|
||||
|
||||
# Store max_input_tokens with model-aware defaults
|
||||
if max_input_tokens is not None:
|
||||
self._max_input_tokens = max_input_tokens
|
||||
else:
|
||||
# Infer from model name
|
||||
model_lower = self.model_name.lower()
|
||||
if '8b' in model_lower or 'large' in model_lower:
|
||||
self._max_input_tokens = 32768
|
||||
else:
|
||||
self._max_input_tokens = 8192
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking."""
|
||||
return self._max_input_tokens
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return
|
||||
|
||||
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
|
||||
if retry_after_s is not None and retry_after_s > 0:
|
||||
time.sleep(min(float(retry_after_s), self.backoff_max_s))
|
||||
return
|
||||
|
||||
exp = self.backoff_base_s * (2**attempt)
|
||||
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
|
||||
time.sleep(min(self.backoff_max_s, exp + jitter))
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
|
||||
value = (headers.get("Retry-After") or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_retry_status(status_code: int) -> bool:
|
||||
return status_code == 429 or 500 <= status_code <= 599
|
||||
|
||||
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
response = self._client.post(self.endpoint, json=dict(payload))
|
||||
except Exception as exc: # httpx is optional at import-time
|
||||
last_exc = exc
|
||||
if attempt < self.max_retries:
|
||||
self._sleep_backoff(attempt)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' after "
|
||||
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
status = int(getattr(response, "status_code", 0) or 0)
|
||||
if status >= 400:
|
||||
body_preview = ""
|
||||
try:
|
||||
body_preview = (response.text or "").strip()
|
||||
except Exception:
|
||||
body_preview = ""
|
||||
if len(body_preview) > 300:
|
||||
body_preview = body_preview[:300] + "…"
|
||||
|
||||
if self._should_retry_status(status) and attempt < self.max_retries:
|
||||
retry_after = self._parse_retry_after_seconds(response.headers)
|
||||
logger.warning(
|
||||
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
|
||||
self.api_base,
|
||||
self.endpoint,
|
||||
status,
|
||||
attempt + 1,
|
||||
self.max_retries + 1,
|
||||
)
|
||||
self._sleep_backoff(attempt, retry_after_s=retry_after)
|
||||
continue
|
||||
|
||||
if status in {401, 403}:
|
||||
raise RuntimeError(
|
||||
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
|
||||
"Check your API key."
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
|
||||
f"Response: {body_preview or '<empty>'}"
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' is not valid JSON: "
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' must be a JSON object; "
|
||||
f"got {type(data).__name__}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
|
||||
if not isinstance(results, list):
|
||||
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(expected)]
|
||||
filled = 0
|
||||
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score"))
|
||||
if idx is None or score is None:
|
||||
continue
|
||||
try:
|
||||
idx_int = int(idx)
|
||||
score_f = float(score)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if 0 <= idx_int < expected:
|
||||
scores[idx_int] = score_f
|
||||
filled += 1
|
||||
|
||||
if filled != expected:
|
||||
raise RuntimeError(
|
||||
f"Rerank response contained {filled}/{expected} scored documents; "
|
||||
"ensure top_n matches the number of documents."
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": list(documents),
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
return payload
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count using fast heuristic.
|
||||
|
||||
Uses len(text) // 4 as approximation (~4 chars per token for English).
|
||||
Not perfectly accurate for all models/languages but sufficient for
|
||||
batch sizing decisions where exact counts aren't critical.
|
||||
"""
|
||||
return len(text) // 4
|
||||
|
||||
def _create_token_aware_batches(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> list[list[tuple[int, str]]]:
|
||||
"""Split documents into batches that fit within token limits.
|
||||
|
||||
Uses 90% of max_input_tokens as safety margin.
|
||||
Each batch includes the query tokens overhead.
|
||||
"""
|
||||
max_tokens = int(self._max_input_tokens * 0.9)
|
||||
query_tokens = self._estimate_tokens(query)
|
||||
|
||||
batches: list[list[tuple[int, str]]] = []
|
||||
current_batch: list[tuple[int, str]] = []
|
||||
current_tokens = query_tokens # Start with query overhead
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_tokens = self._estimate_tokens(doc)
|
||||
|
||||
# Warn if single document exceeds token limit (will be truncated by API)
|
||||
if doc_tokens > max_tokens - query_tokens:
|
||||
logger.warning(
|
||||
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
|
||||
f"(limit: {max_tokens - query_tokens} after query overhead). "
|
||||
"Document will likely be truncated by the API."
|
||||
)
|
||||
|
||||
# If batch would exceed limit, start new batch
|
||||
if current_tokens + doc_tokens > max_tokens and current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
current_tokens = query_tokens
|
||||
|
||||
current_batch.append((idx, doc))
|
||||
current_tokens += doc_tokens
|
||||
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# Create token-aware batches
|
||||
batches = self._create_token_aware_batches(query, documents)
|
||||
|
||||
if len(batches) == 1:
|
||||
# Single batch - original behavior
|
||||
payload = self._build_payload(query=query, documents=documents)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
return self._extract_scores_from_results(results, expected=len(documents))
|
||||
|
||||
# Multiple batches - process each and merge results
|
||||
logger.info(
|
||||
f"Splitting {len(documents)} documents into {len(batches)} batches "
|
||||
f"(max_input_tokens: {self._max_input_tokens})"
|
||||
)
|
||||
|
||||
all_scores: list[float] = [0.0] * len(documents)
|
||||
|
||||
for batch in batches:
|
||||
batch_docs = [doc for _, doc in batch]
|
||||
payload = self._build_payload(query=query, documents=batch_docs)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
|
||||
|
||||
# Map scores back to original indices
|
||||
for (orig_idx, _), score in zip(batch, batch_scores):
|
||||
all_scores[orig_idx] = score
|
||||
|
||||
return all_scores
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
|
||||
) -> list[float]:
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
grouped: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
grouped.setdefault(str(query), []).append((idx, str(doc)))
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(len(pairs))]
|
||||
|
||||
for query, items in grouped.items():
|
||||
documents = [doc for _, doc in items]
|
||||
query_scores = self._rerank_one_query(query=query, documents=documents)
|
||||
for (orig_idx, _), score in zip(items, query_scores):
|
||||
scores[orig_idx] = float(score)
|
||||
|
||||
return scores
|
||||
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Base class for rerankers.
|
||||
|
||||
Defines the interface that all rerankers must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
"""Base class for all rerankers.
|
||||
|
||||
All reranker implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be processed at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair).
|
||||
"""
|
||||
...
|
||||
|
||||
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Factory for creating rerankers.
|
||||
|
||||
Provides a unified interface for instantiating different reranker backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
|
||||
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific reranker backend can be used.
|
||||
|
||||
Notes:
|
||||
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
|
||||
- "onnx" redirects to "fastembed" for backward compatibility.
|
||||
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
|
||||
- "api" uses a remote reranking HTTP API (requires httpx).
|
||||
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "legacy":
|
||||
from .legacy import check_cross_encoder_available
|
||||
|
||||
return check_cross_encoder_available()
|
||||
|
||||
if backend == "fastembed":
|
||||
from .fastembed_reranker import check_fastembed_reranker_available
|
||||
|
||||
return check_fastembed_reranker_available()
|
||||
|
||||
if backend == "onnx":
|
||||
# Redirect to fastembed for backward compatibility
|
||||
from .fastembed_reranker import check_fastembed_reranker_available
|
||||
|
||||
return check_fastembed_reranker_available()
|
||||
|
||||
if backend == "litellm":
|
||||
try:
|
||||
import ccw_litellm # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
|
||||
)
|
||||
|
||||
try:
|
||||
from .litellm_reranker import LiteLLMReranker # noqa: F401
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return False, f"LiteLLM reranker backend not available: {exc}"
|
||||
|
||||
return True, None
|
||||
|
||||
if backend == "api":
|
||||
from .api_reranker import check_httpx_available
|
||||
|
||||
return check_httpx_available()
|
||||
|
||||
return False, (
|
||||
f"Invalid reranker backend: {backend}. "
|
||||
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
|
||||
)
|
||||
|
||||
|
||||
def get_reranker(
|
||||
backend: str = "fastembed",
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
device: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseReranker:
|
||||
"""Factory function to create reranker based on backend.
|
||||
|
||||
Args:
|
||||
backend: Reranker backend to use. Options:
|
||||
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
|
||||
- "onnx": Redirects to fastembed for backward compatibility
|
||||
- "api": HTTP API backend (remote providers)
|
||||
- "litellm": LiteLLM backend (LLM-based, for API mode)
|
||||
- "legacy": sentence-transformers CrossEncoder backend (optional)
|
||||
model_name: Model identifier for model-based backends. Defaults depend on backend:
|
||||
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
|
||||
- onnx: (redirects to fastembed)
|
||||
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
|
||||
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
- litellm: default
|
||||
device: Optional device string for backends that support it (legacy only).
|
||||
**kwargs: Additional backend-specific arguments.
|
||||
|
||||
Returns:
|
||||
BaseReranker: Configured reranker instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If backend is not recognized.
|
||||
ImportError: If required backend dependencies are not installed or backend is unavailable.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "fastembed":
|
||||
ok, err = check_reranker_available("fastembed")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .fastembed_reranker import FastEmbedReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||
_ = device # Device selection is managed via fastembed providers.
|
||||
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "onnx":
|
||||
# Redirect to fastembed for backward compatibility
|
||||
ok, err = check_reranker_available("fastembed")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .fastembed_reranker import FastEmbedReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||
_ = device # Device selection is managed via fastembed providers.
|
||||
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "legacy":
|
||||
ok, err = check_reranker_available("legacy")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .legacy import CrossEncoderReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
|
||||
|
||||
if backend == "litellm":
|
||||
ok, err = check_reranker_available("litellm")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .litellm_reranker import LiteLLMReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote LLM backends.
|
||||
resolved_model_name = (model_name or "").strip() or "default"
|
||||
return LiteLLMReranker(model=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "api":
|
||||
ok, err = check_reranker_available("api")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .api_reranker import APIReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote HTTP backends.
|
||||
resolved_model_name = (model_name or "").strip() or None
|
||||
return APIReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
"""FastEmbed-based reranker backend.
|
||||
|
||||
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
|
||||
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
|
||||
|
||||
Install:
|
||||
pip install fastembed>=0.4.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether fastembed reranker dependencies are available."""
|
||||
try:
|
||||
import fastembed # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
try:
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed TextCrossEncoder not available: {exc}. "
|
||||
"Upgrade with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
class FastEmbedReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
# Alternative models supported by fastembed:
|
||||
# - "BAAI/bge-reranker-base"
|
||||
# - "BAAI/bge-reranker-large"
|
||||
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize FastEmbed reranker.
|
||||
|
||||
Args:
|
||||
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
|
||||
use_gpu: Whether to use GPU acceleration when available.
|
||||
cache_dir: Optional directory for caching downloaded models.
|
||||
threads: Optional number of threads for ONNX Runtime.
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.cache_dir = cache_dir
|
||||
self.threads = threads
|
||||
|
||||
self._encoder: Any | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy-load the TextCrossEncoder model."""
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
ok, err = check_fastembed_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder
|
||||
|
||||
# Determine providers based on GPU preference
|
||||
providers: list[str] | None = None
|
||||
if self.use_gpu:
|
||||
try:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
|
||||
except Exception:
|
||||
# Fallback: let fastembed decide
|
||||
providers = None
|
||||
|
||||
# Build initialization kwargs
|
||||
init_kwargs: dict[str, Any] = {}
|
||||
if self.cache_dir:
|
||||
init_kwargs["cache_dir"] = self.cache_dir
|
||||
if self.threads is not None:
|
||||
init_kwargs["threads"] = self.threads
|
||||
if providers:
|
||||
init_kwargs["providers"] = providers
|
||||
|
||||
logger.debug(
|
||||
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
|
||||
self.model_name,
|
||||
self.use_gpu,
|
||||
)
|
||||
|
||||
self._encoder = TextCrossEncoder(
|
||||
model_name=self.model_name,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
logger.debug("FastEmbed reranker model loaded successfully")
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: float) -> float:
|
||||
"""Numerically stable sigmoid function."""
|
||||
if x < -709:
|
||||
return 0.0
|
||||
if x > 709:
|
||||
return 1.0
|
||||
import math
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair), normalized to [0, 1] range.
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
|
||||
# For batch scoring of multiple query-doc pairs, we need to process them.
|
||||
# Group by query for efficiency when same query appears multiple times.
|
||||
query_to_docs: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
if query not in query_to_docs:
|
||||
query_to_docs[query] = []
|
||||
query_to_docs[query].append((idx, doc))
|
||||
|
||||
# Score each query group
|
||||
scores: list[float] = [0.0] * len(pairs)
|
||||
|
||||
for query, indexed_docs in query_to_docs.items():
|
||||
docs = [doc for _, doc in indexed_docs]
|
||||
indices = [idx for idx, _ in indexed_docs]
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Map scores back to original positions and normalize with sigmoid
|
||||
for i, raw_score in enumerate(raw_scores):
|
||||
if i < len(indices):
|
||||
original_idx = indices[i]
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
scores[original_idx] = self._sigmoid(float(raw_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
|
||||
# Leave scores as 0.0 for failed queries
|
||||
|
||||
return scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
*,
|
||||
top_k: int | None = None,
|
||||
batch_size: int = 32,
|
||||
) -> list[tuple[float, str, int]]:
|
||||
"""Rerank documents for a single query.
|
||||
|
||||
This is a convenience method that provides results in ranked order.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
documents: List of documents to rerank.
|
||||
top_k: Return only top K results. None returns all.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of (score, document, original_index) tuples, sorted by score descending.
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=list(documents),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to our format: (normalized_score, document, original_index)
|
||||
ranked = []
|
||||
for idx, raw_score in enumerate(raw_scores):
|
||||
if idx < len(documents):
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
normalized = self._sigmoid(float(raw_score))
|
||||
ranked.append((normalized, documents[idx], idx))
|
||||
|
||||
# Sort by score descending
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
ranked = ranked[:top_k]
|
||||
|
||||
return ranked
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
|
||||
return []
|
||||
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Legacy sentence-transformers cross-encoder reranker.
|
||||
|
||||
Install with: pip install codexlens[reranker-legacy]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder as _CrossEncoder
|
||||
|
||||
CROSS_ENCODER_AVAILABLE = True
|
||||
_import_error: str | None = None
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
_CrossEncoder = None # type: ignore[assignment]
|
||||
CROSS_ENCODER_AVAILABLE = False
|
||||
_import_error = str(exc)
|
||||
|
||||
|
||||
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
||||
if CROSS_ENCODER_AVAILABLE:
|
||||
return True, None
|
||||
return (
|
||||
False,
|
||||
_import_error
|
||||
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
|
||||
)
|
||||
|
||||
|
||||
class CrossEncoderReranker(BaseReranker):
|
||||
"""Cross-encoder reranker with lazy model loading."""
|
||||
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
self.model_name = (model_name or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.device = (device or "").strip() or None
|
||||
self._model = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
ok, err = check_cross_encoder_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.device:
|
||||
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
|
||||
else:
|
||||
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
|
||||
raise
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[Tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> List[float]:
|
||||
"""Score (query, doc) pairs using the cross-encoder.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair) in the model's native scale (usually logits).
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
||||
return [float(s) for s in scores]
|
||||
@@ -0,0 +1,214 @@
|
||||
"""Experimental LiteLLM reranker backend.
|
||||
|
||||
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
|
||||
relevance of a single (query, document) pair per request.
|
||||
|
||||
Notes:
|
||||
- This backend is experimental and may be slow/expensive compared to local
|
||||
rerankers.
|
||||
- It relies on `ccw-litellm` for a unified LLM API across providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
|
||||
|
||||
|
||||
def _coerce_score_to_unit_interval(score: float) -> float:
|
||||
"""Coerce a numeric score into [0, 1].
|
||||
|
||||
The prompt asks for a float in [0, 1], but some models may respond with 0-10
|
||||
or 0-100 scales. This function attempts a conservative normalization.
|
||||
"""
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
if 0.0 <= score <= 10.0:
|
||||
return score / 10.0
|
||||
if 0.0 <= score <= 100.0:
|
||||
return score / 100.0
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
def _extract_score(text: str) -> float | None:
|
||||
"""Extract a numeric relevance score from an LLM response."""
|
||||
content = (text or "").strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# Prefer JSON if present.
|
||||
if "{" in content and "}" in content:
|
||||
try:
|
||||
start = content.index("{")
|
||||
end = content.rindex("}") + 1
|
||||
payload = json.loads(content[start:end])
|
||||
if isinstance(payload, dict) and "score" in payload:
|
||||
return float(payload["score"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
match = _NUMBER_RE.search(content)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
return float(match.group(0))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class LiteLLMReranker(BaseReranker):
|
||||
"""Experimental reranker that uses a LiteLLM-compatible model.
|
||||
|
||||
This reranker scores each (query, doc) pair in isolation (single-pair mode)
|
||||
to improve prompt reliability across providers.
|
||||
"""
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a relevance scoring assistant.\n"
|
||||
"Given a search query and a document snippet, output a single numeric "
|
||||
"relevance score between 0 and 1.\n\n"
|
||||
"Scoring guidance:\n"
|
||||
"- 1.0: The document directly answers the query.\n"
|
||||
"- 0.5: The document is partially relevant.\n"
|
||||
"- 0.0: The document is unrelated.\n\n"
|
||||
"Output requirements:\n"
|
||||
"- Output ONLY the number (e.g., 0.73).\n"
|
||||
"- Do not include any other text."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "default",
|
||||
*,
|
||||
requests_per_minute: float | None = None,
|
||||
min_interval_seconds: float | None = None,
|
||||
default_score: float = 0.0,
|
||||
max_doc_chars: int = 8000,
|
||||
**litellm_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the reranker.
|
||||
|
||||
Args:
|
||||
model: Model name from ccw-litellm configuration (default: "default").
|
||||
requests_per_minute: Optional rate limit in requests per minute.
|
||||
min_interval_seconds: Optional minimum interval between requests. If set,
|
||||
it takes precedence over requests_per_minute.
|
||||
default_score: Score to use when an API call fails or parsing fails.
|
||||
max_doc_chars: Maximum number of document characters to include in the prompt.
|
||||
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm is not installed.
|
||||
ValueError: If model is blank.
|
||||
"""
|
||||
self.model_name = (model or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model cannot be blank")
|
||||
|
||||
self.default_score = float(default_score)
|
||||
|
||||
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
|
||||
|
||||
if min_interval_seconds is not None:
|
||||
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
|
||||
elif requests_per_minute is not None and float(requests_per_minute) > 0:
|
||||
self._min_interval_seconds = 60.0 / float(requests_per_minute)
|
||||
else:
|
||||
self._min_interval_seconds = 0.0
|
||||
|
||||
# Prefer deterministic output by default; allow overrides via kwargs.
|
||||
litellm_kwargs = dict(litellm_kwargs)
|
||||
litellm_kwargs.setdefault("temperature", 0.0)
|
||||
litellm_kwargs.setdefault("max_tokens", 16)
|
||||
|
||||
try:
|
||||
from ccw_litellm import ChatMessage, LiteLLMClient
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
raise ImportError(
|
||||
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||
) from exc
|
||||
|
||||
self._ChatMessage = ChatMessage
|
||||
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
|
||||
|
||||
self._lock = threading.RLock()
|
||||
self._last_request_at = 0.0
|
||||
|
||||
def _sanitize_text(self, text: str) -> str:
|
||||
# Keep consistent with LiteLLMEmbedderWrapper workaround.
|
||||
if text.startswith("import"):
|
||||
return " " + text
|
||||
return text
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
if self._min_interval_seconds <= 0:
|
||||
return
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_at
|
||||
if elapsed < self._min_interval_seconds:
|
||||
time.sleep(self._min_interval_seconds - elapsed)
|
||||
self._last_request_at = time.monotonic()
|
||||
|
||||
def _build_user_prompt(self, query: str, doc: str) -> str:
|
||||
sanitized_query = self._sanitize_text(query or "")
|
||||
sanitized_doc = self._sanitize_text(doc or "")
|
||||
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
|
||||
sanitized_doc = sanitized_doc[: self.max_doc_chars]
|
||||
|
||||
return (
|
||||
"Query:\n"
|
||||
f"{sanitized_query}\n\n"
|
||||
"Document:\n"
|
||||
f"{sanitized_doc}\n\n"
|
||||
"Return the relevance score (0 to 1) as a single number:"
|
||||
)
|
||||
|
||||
def _score_single_pair(self, query: str, doc: str) -> float:
|
||||
messages = [
|
||||
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
|
||||
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
|
||||
]
|
||||
|
||||
try:
|
||||
self._rate_limit()
|
||||
response = self._client.chat(messages)
|
||||
except Exception as exc:
|
||||
logger.debug("LiteLLM reranker request failed: %s", exc)
|
||||
return self.default_score
|
||||
|
||||
raw = getattr(response, "content", "") or ""
|
||||
score = _extract_score(raw)
|
||||
if score is None:
|
||||
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
|
||||
return self.default_score
|
||||
return _coerce_score_to_unit_interval(float(score))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with per-pair LLM calls."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
|
||||
scores: list[float] = []
|
||||
for i in range(0, len(pairs), bs):
|
||||
batch = pairs[i : i + bs]
|
||||
for query, doc in batch:
|
||||
scores.append(self._score_single_pair(query, doc))
|
||||
return scores
|
||||
@@ -0,0 +1,268 @@
|
||||
"""Optimum + ONNX Runtime reranker backend.
|
||||
|
||||
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
|
||||
classification models. It is designed to run without requiring PyTorch at
|
||||
runtime by using numpy tensors and ONNX Runtime execution providers.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_onnx_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
|
||||
class ONNXReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
providers: list[Any] | None = None,
|
||||
max_length: int | None = None,
|
||||
) -> None:
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.providers = providers
|
||||
|
||||
self.max_length = int(max_length) if max_length is not None else None
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._model_input_names: set[str] | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_onnx_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
# Include device_id options for DirectML/CUDA selection when available.
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=True
|
||||
)
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`.
|
||||
# Prefer passing the full providers list, with a conservative fallback.
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception:
|
||||
model_kwargs = {}
|
||||
|
||||
try:
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(
|
||||
self.model_name,
|
||||
**model_kwargs,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments.
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache model input names to filter tokenizer outputs defensively.
|
||||
input_names: set[str] | None = None
|
||||
for attr in ("input_names", "model_input_names"):
|
||||
names = getattr(self._model, attr, None)
|
||||
if isinstance(names, (list, tuple)) and names:
|
||||
input_names = {str(n) for n in names}
|
||||
break
|
||||
if input_names is None:
|
||||
try:
|
||||
session = getattr(self._model, "model", None)
|
||||
if session is not None and hasattr(session, "get_inputs"):
|
||||
input_names = {i.name for i in session.get_inputs()}
|
||||
except Exception:
|
||||
input_names = None
|
||||
self._model_input_names = input_names
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
x = np.clip(x, -50.0, 50.0)
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
@staticmethod
|
||||
def _select_relevance_logit(logits: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
arr = np.asarray(logits)
|
||||
if arr.ndim == 0:
|
||||
return arr.reshape(1)
|
||||
if arr.ndim == 1:
|
||||
return arr
|
||||
if arr.ndim >= 2:
|
||||
# Common cases:
|
||||
# - Regression: (batch, 1)
|
||||
# - Binary classification: (batch, 2)
|
||||
if arr.shape[-1] == 1:
|
||||
return arr[..., 0]
|
||||
if arr.shape[-1] == 2:
|
||||
# Convert 2-logit softmax into a single logit via difference.
|
||||
return arr[..., 1] - arr[..., 0]
|
||||
return arr.max(axis=-1)
|
||||
return arr.reshape(-1)
|
||||
|
||||
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
|
||||
|
||||
queries = [q for q, _ in batch]
|
||||
docs = [d for _, d in batch]
|
||||
|
||||
tokenizer_kwargs: dict[str, Any] = {
|
||||
"text": queries,
|
||||
"text_pair": docs,
|
||||
"padding": True,
|
||||
"truncation": True,
|
||||
"return_tensors": "np",
|
||||
}
|
||||
|
||||
max_len = self.max_length
|
||||
if max_len is None:
|
||||
try:
|
||||
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
|
||||
if 0 < model_max < 10_000:
|
||||
max_len = model_max
|
||||
else:
|
||||
max_len = 512
|
||||
except Exception:
|
||||
max_len = 512
|
||||
if max_len is not None and max_len > 0:
|
||||
tokenizer_kwargs["max_length"] = int(max_len)
|
||||
|
||||
encoded = self._tokenizer(**tokenizer_kwargs)
|
||||
inputs = dict(encoded)
|
||||
|
||||
# Some models do not accept token_type_ids; filter to known input names if available.
|
||||
if self._model_input_names:
|
||||
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
|
||||
if self._model is None:
|
||||
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
|
||||
|
||||
outputs = self._model(**inputs)
|
||||
if hasattr(outputs, "logits"):
|
||||
return outputs.logits
|
||||
if isinstance(outputs, dict) and "logits" in outputs:
|
||||
return outputs["logits"]
|
||||
if isinstance(outputs, (list, tuple)) and outputs:
|
||||
return outputs[0]
|
||||
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
import numpy as np
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores: list[float] = []
|
||||
|
||||
for batch in _iter_batches(list(pairs), bs):
|
||||
inputs = self._tokenize_batch(batch)
|
||||
logits = self._forward_logits(inputs)
|
||||
rel_logits = self._select_relevance_logit(logits)
|
||||
probs = self._sigmoid(rel_logits)
|
||||
probs = np.clip(probs, 0.0, 1.0)
|
||||
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
|
||||
|
||||
if len(scores) != len(pairs):
|
||||
logger.debug(
|
||||
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
|
||||
)
|
||||
return scores[: len(pairs)]
|
||||
|
||||
return scores
|
||||
Reference in New Issue
Block a user