Refactor code structure and remove redundant changes

This commit is contained in:
catlog22
2026-01-24 14:47:47 +08:00
parent cf5fecd66d
commit f2b0a5bbc9
113 changed files with 43217 additions and 235 deletions

View File

@@ -0,0 +1,25 @@
"""Reranker backends for second-stage search ranking.
This subpackage provides a unified interface and factory for different reranking
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
"""
from __future__ import annotations
from .base import BaseReranker
from .factory import check_reranker_available, get_reranker
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
__all__ = [
"BaseReranker",
"check_reranker_available",
"get_reranker",
"CrossEncoderReranker",
"check_cross_encoder_available",
"FastEmbedReranker",
"check_fastembed_reranker_available",
"ONNXReranker",
"check_onnx_reranker_available",
]

View File

@@ -0,0 +1,403 @@
"""API-based reranker using a remote HTTP provider.
Supported providers:
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
- Cohere: https://api.cohere.ai/v1/rerank
- Jina: https://api.jina.ai/v1/rerank
"""
from __future__ import annotations
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Mapping, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback."""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Try loading from .env files
try:
from codexlens.env_config import get_env
return get_env(key, workspace_root=workspace_root)
except ImportError:
return None
def check_httpx_available() -> tuple[bool, str | None]:
try:
import httpx # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"httpx not available: {exc}. Install with: pip install httpx"
return True, None
class APIReranker(BaseReranker):
"""Reranker backed by a remote reranking HTTP API."""
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
"siliconflow": {
"api_base": "https://api.siliconflow.cn",
"endpoint": "/v1/rerank",
"default_model": "BAAI/bge-reranker-v2-m3",
},
"cohere": {
"api_base": "https://api.cohere.ai",
"endpoint": "/v1/rerank",
"default_model": "rerank-english-v3.0",
},
"jina": {
"api_base": "https://api.jina.ai",
"endpoint": "/v1/rerank",
"default_model": "jina-reranker-v2-base-multilingual",
},
}
def __init__(
self,
*,
provider: str = "siliconflow",
model_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
timeout: float = 30.0,
max_retries: int = 3,
backoff_base_s: float = 0.5,
backoff_max_s: float = 8.0,
env_api_key: str = _DEFAULT_ENV_API_KEY,
workspace_root: Path | str | None = None,
max_input_tokens: int | None = None,
) -> None:
ok, err = check_httpx_available()
if not ok: # pragma: no cover - exercised via factory availability tests
raise ImportError(err)
import httpx
self._workspace_root = Path(workspace_root) if workspace_root else None
self.provider = (provider or "").strip().lower()
if self.provider not in self._PROVIDER_DEFAULTS:
raise ValueError(
f"Unknown reranker provider: {provider}. "
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
)
defaults = self._PROVIDER_DEFAULTS[self.provider]
# Load api_base from env with .env fallback
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
self.endpoint = defaults["endpoint"]
# Load model from env with .env fallback
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
# Load API key from env with .env fallback
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
resolved_key = resolved_key.strip()
if not resolved_key:
raise ValueError(
f"Missing API key for reranker provider '{self.provider}'. "
f"Pass api_key=... or set ${env_api_key}."
)
self._api_key = resolved_key
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
if self.provider == "cohere":
headers.setdefault("Cohere-Version", "2022-12-06")
self._client = httpx.Client(
base_url=self.api_base,
headers=headers,
timeout=self.timeout_s,
)
# Store max_input_tokens with model-aware defaults
if max_input_tokens is not None:
self._max_input_tokens = max_input_tokens
else:
# Infer from model name
model_lower = self.model_name.lower()
if '8b' in model_lower or 'large' in model_lower:
self._max_input_tokens = 32768
else:
self._max_input_tokens = 8192
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking."""
return self._max_input_tokens
def close(self) -> None:
try:
self._client.close()
except Exception: # pragma: no cover - defensive
return
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
if retry_after_s is not None and retry_after_s > 0:
time.sleep(min(float(retry_after_s), self.backoff_max_s))
return
exp = self.backoff_base_s * (2**attempt)
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
time.sleep(min(self.backoff_max_s, exp + jitter))
@staticmethod
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
value = (headers.get("Retry-After") or "").strip()
if not value:
return None
try:
return float(value)
except ValueError:
return None
@staticmethod
def _should_retry_status(status_code: int) -> bool:
return status_code == 429 or 500 <= status_code <= 599
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
response = self._client.post(self.endpoint, json=dict(payload))
except Exception as exc: # httpx is optional at import-time
last_exc = exc
if attempt < self.max_retries:
self._sleep_backoff(attempt)
continue
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' after "
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
) from exc
status = int(getattr(response, "status_code", 0) or 0)
if status >= 400:
body_preview = ""
try:
body_preview = (response.text or "").strip()
except Exception:
body_preview = ""
if len(body_preview) > 300:
body_preview = body_preview[:300] + ""
if self._should_retry_status(status) and attempt < self.max_retries:
retry_after = self._parse_retry_after_seconds(response.headers)
logger.warning(
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
self.api_base,
self.endpoint,
status,
attempt + 1,
self.max_retries + 1,
)
self._sleep_backoff(attempt, retry_after_s=retry_after)
continue
if status in {401, 403}:
raise RuntimeError(
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
"Check your API key."
)
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
f"Response: {body_preview or '<empty>'}"
)
try:
data = response.json()
except Exception as exc:
raise RuntimeError(
f"Rerank response from provider '{self.provider}' is not valid JSON: "
f"{type(exc).__name__}: {exc}"
) from exc
if not isinstance(data, dict):
raise RuntimeError(
f"Rerank response from provider '{self.provider}' must be a JSON object; "
f"got {type(data).__name__}"
)
return data
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
)
@staticmethod
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
if not isinstance(results, list):
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
scores: list[float] = [0.0 for _ in range(expected)]
filled = 0
for item in results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score"))
if idx is None or score is None:
continue
try:
idx_int = int(idx)
score_f = float(score)
except (TypeError, ValueError):
continue
if 0 <= idx_int < expected:
scores[idx_int] = score_f
filled += 1
if filled != expected:
raise RuntimeError(
f"Rerank response contained {filled}/{expected} scored documents; "
"ensure top_n matches the number of documents."
)
return scores
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
payload: dict[str, Any] = {
"model": self.model_name,
"query": query,
"documents": list(documents),
"top_n": len(documents),
"return_documents": False,
}
return payload
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count using fast heuristic.
Uses len(text) // 4 as approximation (~4 chars per token for English).
Not perfectly accurate for all models/languages but sufficient for
batch sizing decisions where exact counts aren't critical.
"""
return len(text) // 4
def _create_token_aware_batches(
self,
query: str,
documents: Sequence[str],
) -> list[list[tuple[int, str]]]:
"""Split documents into batches that fit within token limits.
Uses 90% of max_input_tokens as safety margin.
Each batch includes the query tokens overhead.
"""
max_tokens = int(self._max_input_tokens * 0.9)
query_tokens = self._estimate_tokens(query)
batches: list[list[tuple[int, str]]] = []
current_batch: list[tuple[int, str]] = []
current_tokens = query_tokens # Start with query overhead
for idx, doc in enumerate(documents):
doc_tokens = self._estimate_tokens(doc)
# Warn if single document exceeds token limit (will be truncated by API)
if doc_tokens > max_tokens - query_tokens:
logger.warning(
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
f"(limit: {max_tokens - query_tokens} after query overhead). "
"Document will likely be truncated by the API."
)
# If batch would exceed limit, start new batch
if current_tokens + doc_tokens > max_tokens and current_batch:
batches.append(current_batch)
current_batch = []
current_tokens = query_tokens
current_batch.append((idx, doc))
current_tokens += doc_tokens
if current_batch:
batches.append(current_batch)
return batches
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
if not documents:
return []
# Create token-aware batches
batches = self._create_token_aware_batches(query, documents)
if len(batches) == 1:
# Single batch - original behavior
payload = self._build_payload(query=query, documents=documents)
data = self._request_json(payload)
results = data.get("results")
return self._extract_scores_from_results(results, expected=len(documents))
# Multiple batches - process each and merge results
logger.info(
f"Splitting {len(documents)} documents into {len(batches)} batches "
f"(max_input_tokens: {self._max_input_tokens})"
)
all_scores: list[float] = [0.0] * len(documents)
for batch in batches:
batch_docs = [doc for _, doc in batch]
payload = self._build_payload(query=query, documents=batch_docs)
data = self._request_json(payload)
results = data.get("results")
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
# Map scores back to original indices
for (orig_idx, _), score in zip(batch, batch_scores):
all_scores[orig_idx] = score
return all_scores
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
) -> list[float]:
if not pairs:
return []
grouped: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
grouped.setdefault(str(query), []).append((idx, str(doc)))
scores: list[float] = [0.0 for _ in range(len(pairs))]
for query, items in grouped.items():
documents = [doc for _, doc in items]
query_scores = self._rerank_one_query(query=query, documents=documents)
for (orig_idx, _), score in zip(items, query_scores):
scores[orig_idx] = float(score)
return scores

View File

@@ -0,0 +1,46 @@
"""Base class for rerankers.
Defines the interface that all rerankers must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence
class BaseReranker(ABC):
"""Base class for all rerankers.
All reranker implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking.
Returns:
int: Maximum number of tokens that can be processed at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair).
"""
...

View File

@@ -0,0 +1,159 @@
"""Factory for creating rerankers.
Provides a unified interface for instantiating different reranker backends.
"""
from __future__ import annotations
from typing import Any
from .base import BaseReranker
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific reranker backend can be used.
Notes:
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
- "onnx" redirects to "fastembed" for backward compatibility.
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
- "api" uses a remote reranking HTTP API (requires httpx).
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
"""
backend = (backend or "").strip().lower()
if backend == "legacy":
from .legacy import check_cross_encoder_available
return check_cross_encoder_available()
if backend == "fastembed":
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "onnx":
# Redirect to fastembed for backward compatibility
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "litellm":
try:
import ccw_litellm # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
)
try:
from .litellm_reranker import LiteLLMReranker # noqa: F401
except Exception as exc: # pragma: no cover - defensive
return False, f"LiteLLM reranker backend not available: {exc}"
return True, None
if backend == "api":
from .api_reranker import check_httpx_available
return check_httpx_available()
return False, (
f"Invalid reranker backend: {backend}. "
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
)
def get_reranker(
backend: str = "fastembed",
model_name: str | None = None,
*,
device: str | None = None,
**kwargs: Any,
) -> BaseReranker:
"""Factory function to create reranker based on backend.
Args:
backend: Reranker backend to use. Options:
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
- "onnx": Redirects to fastembed for backward compatibility
- "api": HTTP API backend (remote providers)
- "litellm": LiteLLM backend (LLM-based, for API mode)
- "legacy": sentence-transformers CrossEncoder backend (optional)
model_name: Model identifier for model-based backends. Defaults depend on backend:
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
- onnx: (redirects to fastembed)
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
- litellm: default
device: Optional device string for backends that support it (legacy only).
**kwargs: Additional backend-specific arguments.
Returns:
BaseReranker: Configured reranker instance.
Raises:
ValueError: If backend is not recognized.
ImportError: If required backend dependencies are not installed or backend is unavailable.
"""
backend = (backend or "").strip().lower()
if backend == "fastembed":
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "onnx":
# Redirect to fastembed for backward compatibility
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "legacy":
ok, err = check_reranker_available("legacy")
if not ok:
raise ImportError(err)
from .legacy import CrossEncoderReranker
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
if backend == "litellm":
ok, err = check_reranker_available("litellm")
if not ok:
raise ImportError(err)
from .litellm_reranker import LiteLLMReranker
_ = device # Device selection is not applicable to remote LLM backends.
resolved_model_name = (model_name or "").strip() or "default"
return LiteLLMReranker(model=resolved_model_name, **kwargs)
if backend == "api":
ok, err = check_reranker_available("api")
if not ok:
raise ImportError(err)
from .api_reranker import APIReranker
_ = device # Device selection is not applicable to remote HTTP backends.
resolved_model_name = (model_name or "").strip() or None
return APIReranker(model_name=resolved_model_name, **kwargs)
raise ValueError(
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
)

View File

@@ -0,0 +1,257 @@
"""FastEmbed-based reranker backend.
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
Install:
pip install fastembed>=0.4.0
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
"""Check whether fastembed reranker dependencies are available."""
try:
import fastembed # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
)
try:
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed TextCrossEncoder not available: {exc}. "
"Upgrade with: pip install fastembed>=0.4.0",
)
return True, None
class FastEmbedReranker(BaseReranker):
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
# Alternative models supported by fastembed:
# - "BAAI/bge-reranker-base"
# - "BAAI/bge-reranker-large"
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
cache_dir: str | None = None,
threads: int | None = None,
) -> None:
"""Initialize FastEmbed reranker.
Args:
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
use_gpu: Whether to use GPU acceleration when available.
cache_dir: Optional directory for caching downloaded models.
threads: Optional number of threads for ONNX Runtime.
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.cache_dir = cache_dir
self.threads = threads
self._encoder: Any | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
"""Lazy-load the TextCrossEncoder model."""
if self._encoder is not None:
return
ok, err = check_fastembed_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._encoder is not None:
return
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Determine providers based on GPU preference
providers: list[str] | None = None
if self.use_gpu:
try:
from ..gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
except Exception:
# Fallback: let fastembed decide
providers = None
# Build initialization kwargs
init_kwargs: dict[str, Any] = {}
if self.cache_dir:
init_kwargs["cache_dir"] = self.cache_dir
if self.threads is not None:
init_kwargs["threads"] = self.threads
if providers:
init_kwargs["providers"] = providers
logger.debug(
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
self.model_name,
self.use_gpu,
)
self._encoder = TextCrossEncoder(
model_name=self.model_name,
**init_kwargs,
)
logger.debug("FastEmbed reranker model loaded successfully")
@staticmethod
def _sigmoid(x: float) -> float:
"""Numerically stable sigmoid function."""
if x < -709:
return 0.0
if x > 709:
return 1.0
import math
return 1.0 / (1.0 + math.exp(-x))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair), normalized to [0, 1] range.
"""
if not pairs:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
# For batch scoring of multiple query-doc pairs, we need to process them.
# Group by query for efficiency when same query appears multiple times.
query_to_docs: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
if query not in query_to_docs:
query_to_docs[query] = []
query_to_docs[query].append((idx, doc))
# Score each query group
scores: list[float] = [0.0] * len(pairs)
for query, indexed_docs in query_to_docs.items():
docs = [doc for _, doc in indexed_docs]
indices = [idx for idx, _ in indexed_docs]
try:
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=docs,
batch_size=batch_size,
)
)
# Map scores back to original positions and normalize with sigmoid
for i, raw_score in enumerate(raw_scores):
if i < len(indices):
original_idx = indices[i]
# Normalize score to [0, 1] using stable sigmoid
scores[original_idx] = self._sigmoid(float(raw_score))
except Exception as e:
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
# Leave scores as 0.0 for failed queries
return scores
def rerank(
self,
query: str,
documents: Sequence[str],
*,
top_k: int | None = None,
batch_size: int = 32,
) -> list[tuple[float, str, int]]:
"""Rerank documents for a single query.
This is a convenience method that provides results in ranked order.
Args:
query: The query string.
documents: List of documents to rerank.
top_k: Return only top K results. None returns all.
batch_size: Batch size for scoring.
Returns:
List of (score, document, original_index) tuples, sorted by score descending.
"""
if not documents:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
try:
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=list(documents),
batch_size=batch_size,
)
)
# Convert to our format: (normalized_score, document, original_index)
ranked = []
for idx, raw_score in enumerate(raw_scores):
if idx < len(documents):
# Normalize score to [0, 1] using stable sigmoid
normalized = self._sigmoid(float(raw_score))
ranked.append((normalized, documents[idx], idx))
# Sort by score descending
ranked.sort(key=lambda x: x[0], reverse=True)
if top_k is not None and top_k > 0:
ranked = ranked[:top_k]
return ranked
except Exception as e:
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
return []

View File

@@ -0,0 +1,91 @@
"""Legacy sentence-transformers cross-encoder reranker.
Install with: pip install codexlens[reranker-legacy]
"""
from __future__ import annotations
import logging
import threading
from typing import List, Sequence, Tuple
from .base import BaseReranker
logger = logging.getLogger(__name__)
try:
from sentence_transformers import CrossEncoder as _CrossEncoder
CROSS_ENCODER_AVAILABLE = True
_import_error: str | None = None
except ImportError as exc: # pragma: no cover - optional dependency
_CrossEncoder = None # type: ignore[assignment]
CROSS_ENCODER_AVAILABLE = False
_import_error = str(exc)
def check_cross_encoder_available() -> tuple[bool, str | None]:
if CROSS_ENCODER_AVAILABLE:
return True, None
return (
False,
_import_error
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
)
class CrossEncoderReranker(BaseReranker):
"""Cross-encoder reranker with lazy model loading."""
def __init__(self, model_name: str, *, device: str | None = None) -> None:
self.model_name = (model_name or "").strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.device = (device or "").strip() or None
self._model = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None:
return
ok, err = check_cross_encoder_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None:
return
try:
if self.device:
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
else:
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
except Exception as exc:
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
raise
def score_pairs(
self,
pairs: Sequence[Tuple[str, str]],
*,
batch_size: int = 32,
) -> List[float]:
"""Score (query, doc) pairs using the cross-encoder.
Returns:
List of scores (one per pair) in the model's native scale (usually logits).
"""
if not pairs:
return []
self._load_model()
if self._model is None: # pragma: no cover - defensive
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
return [float(s) for s in scores]

View File

@@ -0,0 +1,214 @@
"""Experimental LiteLLM reranker backend.
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
relevance of a single (query, document) pair per request.
Notes:
- This backend is experimental and may be slow/expensive compared to local
rerankers.
- It relies on `ccw-litellm` for a unified LLM API across providers.
"""
from __future__ import annotations
import json
import logging
import re
import threading
import time
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
def _coerce_score_to_unit_interval(score: float) -> float:
"""Coerce a numeric score into [0, 1].
The prompt asks for a float in [0, 1], but some models may respond with 0-10
or 0-100 scales. This function attempts a conservative normalization.
"""
if 0.0 <= score <= 1.0:
return score
if 0.0 <= score <= 10.0:
return score / 10.0
if 0.0 <= score <= 100.0:
return score / 100.0
return max(0.0, min(1.0, score))
def _extract_score(text: str) -> float | None:
"""Extract a numeric relevance score from an LLM response."""
content = (text or "").strip()
if not content:
return None
# Prefer JSON if present.
if "{" in content and "}" in content:
try:
start = content.index("{")
end = content.rindex("}") + 1
payload = json.loads(content[start:end])
if isinstance(payload, dict) and "score" in payload:
return float(payload["score"])
except Exception:
pass
match = _NUMBER_RE.search(content)
if not match:
return None
try:
return float(match.group(0))
except ValueError:
return None
class LiteLLMReranker(BaseReranker):
"""Experimental reranker that uses a LiteLLM-compatible model.
This reranker scores each (query, doc) pair in isolation (single-pair mode)
to improve prompt reliability across providers.
"""
_SYSTEM_PROMPT = (
"You are a relevance scoring assistant.\n"
"Given a search query and a document snippet, output a single numeric "
"relevance score between 0 and 1.\n\n"
"Scoring guidance:\n"
"- 1.0: The document directly answers the query.\n"
"- 0.5: The document is partially relevant.\n"
"- 0.0: The document is unrelated.\n\n"
"Output requirements:\n"
"- Output ONLY the number (e.g., 0.73).\n"
"- Do not include any other text."
)
def __init__(
self,
model: str = "default",
*,
requests_per_minute: float | None = None,
min_interval_seconds: float | None = None,
default_score: float = 0.0,
max_doc_chars: int = 8000,
**litellm_kwargs: Any,
) -> None:
"""Initialize the reranker.
Args:
model: Model name from ccw-litellm configuration (default: "default").
requests_per_minute: Optional rate limit in requests per minute.
min_interval_seconds: Optional minimum interval between requests. If set,
it takes precedence over requests_per_minute.
default_score: Score to use when an API call fails or parsing fails.
max_doc_chars: Maximum number of document characters to include in the prompt.
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
Raises:
ImportError: If ccw-litellm is not installed.
ValueError: If model is blank.
"""
self.model_name = (model or "").strip()
if not self.model_name:
raise ValueError("model cannot be blank")
self.default_score = float(default_score)
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
if min_interval_seconds is not None:
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
elif requests_per_minute is not None and float(requests_per_minute) > 0:
self._min_interval_seconds = 60.0 / float(requests_per_minute)
else:
self._min_interval_seconds = 0.0
# Prefer deterministic output by default; allow overrides via kwargs.
litellm_kwargs = dict(litellm_kwargs)
litellm_kwargs.setdefault("temperature", 0.0)
litellm_kwargs.setdefault("max_tokens", 16)
try:
from ccw_litellm import ChatMessage, LiteLLMClient
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from exc
self._ChatMessage = ChatMessage
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
self._lock = threading.RLock()
self._last_request_at = 0.0
def _sanitize_text(self, text: str) -> str:
# Keep consistent with LiteLLMEmbedderWrapper workaround.
if text.startswith("import"):
return " " + text
return text
def _rate_limit(self) -> None:
if self._min_interval_seconds <= 0:
return
with self._lock:
now = time.monotonic()
elapsed = now - self._last_request_at
if elapsed < self._min_interval_seconds:
time.sleep(self._min_interval_seconds - elapsed)
self._last_request_at = time.monotonic()
def _build_user_prompt(self, query: str, doc: str) -> str:
sanitized_query = self._sanitize_text(query or "")
sanitized_doc = self._sanitize_text(doc or "")
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
sanitized_doc = sanitized_doc[: self.max_doc_chars]
return (
"Query:\n"
f"{sanitized_query}\n\n"
"Document:\n"
f"{sanitized_doc}\n\n"
"Return the relevance score (0 to 1) as a single number:"
)
def _score_single_pair(self, query: str, doc: str) -> float:
messages = [
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
]
try:
self._rate_limit()
response = self._client.chat(messages)
except Exception as exc:
logger.debug("LiteLLM reranker request failed: %s", exc)
return self.default_score
raw = getattr(response, "content", "") or ""
score = _extract_score(raw)
if score is None:
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
return self.default_score
return _coerce_score_to_unit_interval(float(score))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with per-pair LLM calls."""
if not pairs:
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for i in range(0, len(pairs), bs):
batch = pairs[i : i + bs]
for query, doc in batch:
scores.append(self._score_single_pair(query, doc))
return scores

View File

@@ -0,0 +1,268 @@
"""Optimum + ONNX Runtime reranker backend.
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
classification models. It is designed to run without requiring PyTorch at
runtime by using numpy tensors and ONNX Runtime execution providers.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Iterable, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_onnx_reranker_available() -> tuple[bool, str | None]:
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
try:
import numpy # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
class ONNXReranker(BaseReranker):
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
providers: list[Any] | None = None,
max_length: int | None = None,
) -> None:
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.providers = providers
self.max_length = int(max_length) if max_length is not None else None
self._tokenizer: Any | None = None
self._model: Any | None = None
self._model_input_names: set[str] | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_onnx_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
if self.providers is None:
from ..gpu_support import get_optimal_providers
# Include device_id options for DirectML/CUDA selection when available.
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=True
)
# Some Optimum versions accept `providers`, others accept a single `provider`.
# Prefer passing the full providers list, with a conservative fallback.
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception:
model_kwargs = {}
try:
self._model = ORTModelForSequenceClassification.from_pretrained(
self.model_name,
**model_kwargs,
)
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments.
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache model input names to filter tokenizer outputs defensively.
input_names: set[str] | None = None
for attr in ("input_names", "model_input_names"):
names = getattr(self._model, attr, None)
if isinstance(names, (list, tuple)) and names:
input_names = {str(n) for n in names}
break
if input_names is None:
try:
session = getattr(self._model, "model", None)
if session is not None and hasattr(session, "get_inputs"):
input_names = {i.name for i in session.get_inputs()}
except Exception:
input_names = None
self._model_input_names = input_names
@staticmethod
def _sigmoid(x: "Any") -> "Any":
import numpy as np
x = np.clip(x, -50.0, 50.0)
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _select_relevance_logit(logits: "Any") -> "Any":
import numpy as np
arr = np.asarray(logits)
if arr.ndim == 0:
return arr.reshape(1)
if arr.ndim == 1:
return arr
if arr.ndim >= 2:
# Common cases:
# - Regression: (batch, 1)
# - Binary classification: (batch, 2)
if arr.shape[-1] == 1:
return arr[..., 0]
if arr.shape[-1] == 2:
# Convert 2-logit softmax into a single logit via difference.
return arr[..., 1] - arr[..., 0]
return arr.max(axis=-1)
return arr.reshape(-1)
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
queries = [q for q, _ in batch]
docs = [d for _, d in batch]
tokenizer_kwargs: dict[str, Any] = {
"text": queries,
"text_pair": docs,
"padding": True,
"truncation": True,
"return_tensors": "np",
}
max_len = self.max_length
if max_len is None:
try:
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
if 0 < model_max < 10_000:
max_len = model_max
else:
max_len = 512
except Exception:
max_len = 512
if max_len is not None and max_len > 0:
tokenizer_kwargs["max_length"] = int(max_len)
encoded = self._tokenizer(**tokenizer_kwargs)
inputs = dict(encoded)
# Some models do not accept token_type_ids; filter to known input names if available.
if self._model_input_names:
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
return inputs
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
if self._model is None:
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
outputs = self._model(**inputs)
if hasattr(outputs, "logits"):
return outputs.logits
if isinstance(outputs, dict) and "logits" in outputs:
return outputs["logits"]
if isinstance(outputs, (list, tuple)) and outputs:
return outputs[0]
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
if not pairs:
return []
self._load_model()
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
return []
import numpy as np
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for batch in _iter_batches(list(pairs), bs):
inputs = self._tokenize_batch(batch)
logits = self._forward_logits(inputs)
rel_logits = self._select_relevance_logit(logits)
probs = self._sigmoid(rel_logits)
probs = np.clip(probs, 0.0, 1.0)
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
if len(scores) != len(pairs):
logger.debug(
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
)
return scores[: len(pairs)]
return scores