Files
Claude-Code-Workflow/codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py

404 lines
15 KiB
Python

"""API-based reranker using a remote HTTP provider.
Supported providers:
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
- Cohere: https://api.cohere.ai/v1/rerank
- Jina: https://api.jina.ai/v1/rerank
"""
from __future__ import annotations
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Mapping, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback."""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Try loading from .env files
try:
from codexlens.env_config import get_env
return get_env(key, workspace_root=workspace_root)
except ImportError:
return None
def check_httpx_available() -> tuple[bool, str | None]:
try:
import httpx # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"httpx not available: {exc}. Install with: pip install httpx"
return True, None
class APIReranker(BaseReranker):
"""Reranker backed by a remote reranking HTTP API."""
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
"siliconflow": {
"api_base": "https://api.siliconflow.cn",
"endpoint": "/v1/rerank",
"default_model": "BAAI/bge-reranker-v2-m3",
},
"cohere": {
"api_base": "https://api.cohere.ai",
"endpoint": "/v1/rerank",
"default_model": "rerank-english-v3.0",
},
"jina": {
"api_base": "https://api.jina.ai",
"endpoint": "/v1/rerank",
"default_model": "jina-reranker-v2-base-multilingual",
},
}
def __init__(
self,
*,
provider: str = "siliconflow",
model_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
timeout: float = 30.0,
max_retries: int = 3,
backoff_base_s: float = 0.5,
backoff_max_s: float = 8.0,
env_api_key: str = _DEFAULT_ENV_API_KEY,
workspace_root: Path | str | None = None,
max_input_tokens: int | None = None,
) -> None:
ok, err = check_httpx_available()
if not ok: # pragma: no cover - exercised via factory availability tests
raise ImportError(err)
import httpx
self._workspace_root = Path(workspace_root) if workspace_root else None
self.provider = (provider or "").strip().lower()
if self.provider not in self._PROVIDER_DEFAULTS:
raise ValueError(
f"Unknown reranker provider: {provider}. "
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
)
defaults = self._PROVIDER_DEFAULTS[self.provider]
# Load api_base from env with .env fallback
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
self.endpoint = defaults["endpoint"]
# Load model from env with .env fallback
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
# Load API key from env with .env fallback
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
resolved_key = resolved_key.strip()
if not resolved_key:
raise ValueError(
f"Missing API key for reranker provider '{self.provider}'. "
f"Pass api_key=... or set ${env_api_key}."
)
self._api_key = resolved_key
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
if self.provider == "cohere":
headers.setdefault("Cohere-Version", "2022-12-06")
self._client = httpx.Client(
base_url=self.api_base,
headers=headers,
timeout=self.timeout_s,
)
# Store max_input_tokens with model-aware defaults
if max_input_tokens is not None:
self._max_input_tokens = max_input_tokens
else:
# Infer from model name
model_lower = self.model_name.lower()
if '8b' in model_lower or 'large' in model_lower:
self._max_input_tokens = 32768
else:
self._max_input_tokens = 8192
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking."""
return self._max_input_tokens
def close(self) -> None:
try:
self._client.close()
except Exception: # pragma: no cover - defensive
return
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
if retry_after_s is not None and retry_after_s > 0:
time.sleep(min(float(retry_after_s), self.backoff_max_s))
return
exp = self.backoff_base_s * (2**attempt)
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
time.sleep(min(self.backoff_max_s, exp + jitter))
@staticmethod
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
value = (headers.get("Retry-After") or "").strip()
if not value:
return None
try:
return float(value)
except ValueError:
return None
@staticmethod
def _should_retry_status(status_code: int) -> bool:
return status_code == 429 or 500 <= status_code <= 599
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
response = self._client.post(self.endpoint, json=dict(payload))
except Exception as exc: # httpx is optional at import-time
last_exc = exc
if attempt < self.max_retries:
self._sleep_backoff(attempt)
continue
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' after "
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
) from exc
status = int(getattr(response, "status_code", 0) or 0)
if status >= 400:
body_preview = ""
try:
body_preview = (response.text or "").strip()
except Exception:
body_preview = ""
if len(body_preview) > 300:
body_preview = body_preview[:300] + ""
if self._should_retry_status(status) and attempt < self.max_retries:
retry_after = self._parse_retry_after_seconds(response.headers)
logger.warning(
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
self.api_base,
self.endpoint,
status,
attempt + 1,
self.max_retries + 1,
)
self._sleep_backoff(attempt, retry_after_s=retry_after)
continue
if status in {401, 403}:
raise RuntimeError(
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
"Check your API key."
)
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
f"Response: {body_preview or '<empty>'}"
)
try:
data = response.json()
except Exception as exc:
raise RuntimeError(
f"Rerank response from provider '{self.provider}' is not valid JSON: "
f"{type(exc).__name__}: {exc}"
) from exc
if not isinstance(data, dict):
raise RuntimeError(
f"Rerank response from provider '{self.provider}' must be a JSON object; "
f"got {type(data).__name__}"
)
return data
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
)
@staticmethod
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
if not isinstance(results, list):
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
scores: list[float] = [0.0 for _ in range(expected)]
filled = 0
for item in results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score"))
if idx is None or score is None:
continue
try:
idx_int = int(idx)
score_f = float(score)
except (TypeError, ValueError):
continue
if 0 <= idx_int < expected:
scores[idx_int] = score_f
filled += 1
if filled != expected:
raise RuntimeError(
f"Rerank response contained {filled}/{expected} scored documents; "
"ensure top_n matches the number of documents."
)
return scores
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
payload: dict[str, Any] = {
"model": self.model_name,
"query": query,
"documents": list(documents),
"top_n": len(documents),
"return_documents": False,
}
return payload
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count using fast heuristic.
Uses len(text) // 4 as approximation (~4 chars per token for English).
Not perfectly accurate for all models/languages but sufficient for
batch sizing decisions where exact counts aren't critical.
"""
return len(text) // 4
def _create_token_aware_batches(
self,
query: str,
documents: Sequence[str],
) -> list[list[tuple[int, str]]]:
"""Split documents into batches that fit within token limits.
Uses 90% of max_input_tokens as safety margin.
Each batch includes the query tokens overhead.
"""
max_tokens = int(self._max_input_tokens * 0.9)
query_tokens = self._estimate_tokens(query)
batches: list[list[tuple[int, str]]] = []
current_batch: list[tuple[int, str]] = []
current_tokens = query_tokens # Start with query overhead
for idx, doc in enumerate(documents):
doc_tokens = self._estimate_tokens(doc)
# Warn if single document exceeds token limit (will be truncated by API)
if doc_tokens > max_tokens - query_tokens:
logger.warning(
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
f"(limit: {max_tokens - query_tokens} after query overhead). "
"Document will likely be truncated by the API."
)
# If batch would exceed limit, start new batch
if current_tokens + doc_tokens > max_tokens and current_batch:
batches.append(current_batch)
current_batch = []
current_tokens = query_tokens
current_batch.append((idx, doc))
current_tokens += doc_tokens
if current_batch:
batches.append(current_batch)
return batches
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
if not documents:
return []
# Create token-aware batches
batches = self._create_token_aware_batches(query, documents)
if len(batches) == 1:
# Single batch - original behavior
payload = self._build_payload(query=query, documents=documents)
data = self._request_json(payload)
results = data.get("results")
return self._extract_scores_from_results(results, expected=len(documents))
# Multiple batches - process each and merge results
logger.info(
f"Splitting {len(documents)} documents into {len(batches)} batches "
f"(max_input_tokens: {self._max_input_tokens})"
)
all_scores: list[float] = [0.0] * len(documents)
for batch in batches:
batch_docs = [doc for _, doc in batch]
payload = self._build_payload(query=query, documents=batch_docs)
data = self._request_json(payload)
results = data.get("results")
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
# Map scores back to original indices
for (orig_idx, _), score in zip(batch, batch_scores):
all_scores[orig_idx] = score
return all_scores
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
) -> list[float]:
if not pairs:
return []
grouped: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
grouped.setdefault(str(query), []).append((idx, str(doc)))
scores: list[float] = [0.0 for _ in range(len(pairs))]
for query, items in grouped.items():
documents = [doc for _, doc in items]
query_scores = self._rerank_one_query(query=query, documents=documents)
for (orig_idx, _), score in zip(items, query_scores):
scores[orig_idx] = float(score)
return scores