"""Experimental LiteLLM reranker backend. This module provides :class:`LiteLLMReranker`, which uses an LLM to score the relevance of a single (query, document) pair per request. Notes: - This backend is experimental and may be slow/expensive compared to local rerankers. - It relies on `ccw-litellm` for a unified LLM API across providers. """ from __future__ import annotations import json import logging import re import threading import time from typing import Any, Sequence from .base import BaseReranker logger = logging.getLogger(__name__) _NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?") def _coerce_score_to_unit_interval(score: float) -> float: """Coerce a numeric score into [0, 1]. The prompt asks for a float in [0, 1], but some models may respond with 0-10 or 0-100 scales. This function attempts a conservative normalization. """ if 0.0 <= score <= 1.0: return score if 0.0 <= score <= 10.0: return score / 10.0 if 0.0 <= score <= 100.0: return score / 100.0 return max(0.0, min(1.0, score)) def _extract_score(text: str) -> float | None: """Extract a numeric relevance score from an LLM response.""" content = (text or "").strip() if not content: return None # Prefer JSON if present. if "{" in content and "}" in content: try: start = content.index("{") end = content.rindex("}") + 1 payload = json.loads(content[start:end]) if isinstance(payload, dict) and "score" in payload: return float(payload["score"]) except Exception: pass match = _NUMBER_RE.search(content) if not match: return None try: return float(match.group(0)) except ValueError: return None class LiteLLMReranker(BaseReranker): """Experimental reranker that uses a LiteLLM-compatible model. This reranker scores each (query, doc) pair in isolation (single-pair mode) to improve prompt reliability across providers. """ _SYSTEM_PROMPT = ( "You are a relevance scoring assistant.\n" "Given a search query and a document snippet, output a single numeric " "relevance score between 0 and 1.\n\n" "Scoring guidance:\n" "- 1.0: The document directly answers the query.\n" "- 0.5: The document is partially relevant.\n" "- 0.0: The document is unrelated.\n\n" "Output requirements:\n" "- Output ONLY the number (e.g., 0.73).\n" "- Do not include any other text." ) def __init__( self, model: str = "default", *, requests_per_minute: float | None = None, min_interval_seconds: float | None = None, default_score: float = 0.0, max_doc_chars: int = 8000, **litellm_kwargs: Any, ) -> None: """Initialize the reranker. Args: model: Model name from ccw-litellm configuration (default: "default"). requests_per_minute: Optional rate limit in requests per minute. min_interval_seconds: Optional minimum interval between requests. If set, it takes precedence over requests_per_minute. default_score: Score to use when an API call fails or parsing fails. max_doc_chars: Maximum number of document characters to include in the prompt. **litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`. Raises: ImportError: If ccw-litellm is not installed. ValueError: If model is blank. """ self.model_name = (model or "").strip() if not self.model_name: raise ValueError("model cannot be blank") self.default_score = float(default_score) self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0 if min_interval_seconds is not None: self._min_interval_seconds = max(0.0, float(min_interval_seconds)) elif requests_per_minute is not None and float(requests_per_minute) > 0: self._min_interval_seconds = 60.0 / float(requests_per_minute) else: self._min_interval_seconds = 0.0 # Prefer deterministic output by default; allow overrides via kwargs. litellm_kwargs = dict(litellm_kwargs) litellm_kwargs.setdefault("temperature", 0.0) litellm_kwargs.setdefault("max_tokens", 16) try: from ccw_litellm import ChatMessage, LiteLLMClient except ImportError as exc: # pragma: no cover - optional dependency raise ImportError( "ccw-litellm not installed. Install with: pip install ccw-litellm" ) from exc self._ChatMessage = ChatMessage self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs) self._lock = threading.RLock() self._last_request_at = 0.0 def _sanitize_text(self, text: str) -> str: # Keep consistent with LiteLLMEmbedderWrapper workaround. if text.startswith("import"): return " " + text return text def _rate_limit(self) -> None: if self._min_interval_seconds <= 0: return with self._lock: now = time.monotonic() elapsed = now - self._last_request_at if elapsed < self._min_interval_seconds: time.sleep(self._min_interval_seconds - elapsed) self._last_request_at = time.monotonic() def _build_user_prompt(self, query: str, doc: str) -> str: sanitized_query = self._sanitize_text(query or "") sanitized_doc = self._sanitize_text(doc or "") if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars: sanitized_doc = sanitized_doc[: self.max_doc_chars] return ( "Query:\n" f"{sanitized_query}\n\n" "Document:\n" f"{sanitized_doc}\n\n" "Return the relevance score (0 to 1) as a single number:" ) def _score_single_pair(self, query: str, doc: str) -> float: messages = [ self._ChatMessage(role="system", content=self._SYSTEM_PROMPT), self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)), ] try: self._rate_limit() response = self._client.chat(messages) except Exception as exc: logger.debug("LiteLLM reranker request failed: %s", exc) return self.default_score raw = getattr(response, "content", "") or "" score = _extract_score(raw) if score is None: logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw) return self.default_score return _coerce_score_to_unit_interval(float(score)) def score_pairs( self, pairs: Sequence[tuple[str, str]], *, batch_size: int = 32, ) -> list[float]: """Score (query, doc) pairs with per-pair LLM calls.""" if not pairs: return [] bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32 scores: list[float] = [] for i in range(0, len(pairs), bs): batch = pairs[i : i + bs] for query, doc in batch: scores.append(self._score_single_pair(query, doc)) return scores