mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
215 lines
7.3 KiB
Python
215 lines
7.3 KiB
Python
"""Experimental LiteLLM reranker backend.
|
|
|
|
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
|
|
relevance of a single (query, document) pair per request.
|
|
|
|
Notes:
|
|
- This backend is experimental and may be slow/expensive compared to local
|
|
rerankers.
|
|
- It relies on `ccw-litellm` for a unified LLM API across providers.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
import threading
|
|
import time
|
|
from typing import Any, Sequence
|
|
|
|
from .base import BaseReranker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
|
|
|
|
|
|
def _coerce_score_to_unit_interval(score: float) -> float:
|
|
"""Coerce a numeric score into [0, 1].
|
|
|
|
The prompt asks for a float in [0, 1], but some models may respond with 0-10
|
|
or 0-100 scales. This function attempts a conservative normalization.
|
|
"""
|
|
if 0.0 <= score <= 1.0:
|
|
return score
|
|
if 0.0 <= score <= 10.0:
|
|
return score / 10.0
|
|
if 0.0 <= score <= 100.0:
|
|
return score / 100.0
|
|
return max(0.0, min(1.0, score))
|
|
|
|
|
|
def _extract_score(text: str) -> float | None:
|
|
"""Extract a numeric relevance score from an LLM response."""
|
|
content = (text or "").strip()
|
|
if not content:
|
|
return None
|
|
|
|
# Prefer JSON if present.
|
|
if "{" in content and "}" in content:
|
|
try:
|
|
start = content.index("{")
|
|
end = content.rindex("}") + 1
|
|
payload = json.loads(content[start:end])
|
|
if isinstance(payload, dict) and "score" in payload:
|
|
return float(payload["score"])
|
|
except Exception:
|
|
pass
|
|
|
|
match = _NUMBER_RE.search(content)
|
|
if not match:
|
|
return None
|
|
try:
|
|
return float(match.group(0))
|
|
except ValueError:
|
|
return None
|
|
|
|
|
|
class LiteLLMReranker(BaseReranker):
|
|
"""Experimental reranker that uses a LiteLLM-compatible model.
|
|
|
|
This reranker scores each (query, doc) pair in isolation (single-pair mode)
|
|
to improve prompt reliability across providers.
|
|
"""
|
|
|
|
_SYSTEM_PROMPT = (
|
|
"You are a relevance scoring assistant.\n"
|
|
"Given a search query and a document snippet, output a single numeric "
|
|
"relevance score between 0 and 1.\n\n"
|
|
"Scoring guidance:\n"
|
|
"- 1.0: The document directly answers the query.\n"
|
|
"- 0.5: The document is partially relevant.\n"
|
|
"- 0.0: The document is unrelated.\n\n"
|
|
"Output requirements:\n"
|
|
"- Output ONLY the number (e.g., 0.73).\n"
|
|
"- Do not include any other text."
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "default",
|
|
*,
|
|
requests_per_minute: float | None = None,
|
|
min_interval_seconds: float | None = None,
|
|
default_score: float = 0.0,
|
|
max_doc_chars: int = 8000,
|
|
**litellm_kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the reranker.
|
|
|
|
Args:
|
|
model: Model name from ccw-litellm configuration (default: "default").
|
|
requests_per_minute: Optional rate limit in requests per minute.
|
|
min_interval_seconds: Optional minimum interval between requests. If set,
|
|
it takes precedence over requests_per_minute.
|
|
default_score: Score to use when an API call fails or parsing fails.
|
|
max_doc_chars: Maximum number of document characters to include in the prompt.
|
|
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
|
|
|
|
Raises:
|
|
ImportError: If ccw-litellm is not installed.
|
|
ValueError: If model is blank.
|
|
"""
|
|
self.model_name = (model or "").strip()
|
|
if not self.model_name:
|
|
raise ValueError("model cannot be blank")
|
|
|
|
self.default_score = float(default_score)
|
|
|
|
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
|
|
|
|
if min_interval_seconds is not None:
|
|
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
|
|
elif requests_per_minute is not None and float(requests_per_minute) > 0:
|
|
self._min_interval_seconds = 60.0 / float(requests_per_minute)
|
|
else:
|
|
self._min_interval_seconds = 0.0
|
|
|
|
# Prefer deterministic output by default; allow overrides via kwargs.
|
|
litellm_kwargs = dict(litellm_kwargs)
|
|
litellm_kwargs.setdefault("temperature", 0.0)
|
|
litellm_kwargs.setdefault("max_tokens", 16)
|
|
|
|
try:
|
|
from ccw_litellm import ChatMessage, LiteLLMClient
|
|
except ImportError as exc: # pragma: no cover - optional dependency
|
|
raise ImportError(
|
|
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
|
) from exc
|
|
|
|
self._ChatMessage = ChatMessage
|
|
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
|
|
|
|
self._lock = threading.RLock()
|
|
self._last_request_at = 0.0
|
|
|
|
def _sanitize_text(self, text: str) -> str:
|
|
# Keep consistent with LiteLLMEmbedderWrapper workaround.
|
|
if text.startswith("import"):
|
|
return " " + text
|
|
return text
|
|
|
|
def _rate_limit(self) -> None:
|
|
if self._min_interval_seconds <= 0:
|
|
return
|
|
with self._lock:
|
|
now = time.monotonic()
|
|
elapsed = now - self._last_request_at
|
|
if elapsed < self._min_interval_seconds:
|
|
time.sleep(self._min_interval_seconds - elapsed)
|
|
self._last_request_at = time.monotonic()
|
|
|
|
def _build_user_prompt(self, query: str, doc: str) -> str:
|
|
sanitized_query = self._sanitize_text(query or "")
|
|
sanitized_doc = self._sanitize_text(doc or "")
|
|
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
|
|
sanitized_doc = sanitized_doc[: self.max_doc_chars]
|
|
|
|
return (
|
|
"Query:\n"
|
|
f"{sanitized_query}\n\n"
|
|
"Document:\n"
|
|
f"{sanitized_doc}\n\n"
|
|
"Return the relevance score (0 to 1) as a single number:"
|
|
)
|
|
|
|
def _score_single_pair(self, query: str, doc: str) -> float:
|
|
messages = [
|
|
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
|
|
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
|
|
]
|
|
|
|
try:
|
|
self._rate_limit()
|
|
response = self._client.chat(messages)
|
|
except Exception as exc:
|
|
logger.debug("LiteLLM reranker request failed: %s", exc)
|
|
return self.default_score
|
|
|
|
raw = getattr(response, "content", "") or ""
|
|
score = _extract_score(raw)
|
|
if score is None:
|
|
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
|
|
return self.default_score
|
|
return _coerce_score_to_unit_interval(float(score))
|
|
|
|
def score_pairs(
|
|
self,
|
|
pairs: Sequence[tuple[str, str]],
|
|
*,
|
|
batch_size: int = 32,
|
|
) -> list[float]:
|
|
"""Score (query, doc) pairs with per-pair LLM calls."""
|
|
if not pairs:
|
|
return []
|
|
|
|
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
|
|
|
scores: list[float] = []
|
|
for i in range(0, len(pairs), bs):
|
|
batch = pairs[i : i + bs]
|
|
for query, doc in batch:
|
|
scores.append(self._score_single_pair(query, doc))
|
|
return scores
|