mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-09 02:24:11 +08:00
92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
"""Legacy sentence-transformers cross-encoder reranker.
|
|
|
|
Install with: pip install codexlens[reranker-legacy]
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
from typing import List, Sequence, Tuple
|
|
|
|
from .base import BaseReranker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
from sentence_transformers import CrossEncoder as _CrossEncoder
|
|
|
|
CROSS_ENCODER_AVAILABLE = True
|
|
_import_error: str | None = None
|
|
except ImportError as exc: # pragma: no cover - optional dependency
|
|
_CrossEncoder = None # type: ignore[assignment]
|
|
CROSS_ENCODER_AVAILABLE = False
|
|
_import_error = str(exc)
|
|
|
|
|
|
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
|
if CROSS_ENCODER_AVAILABLE:
|
|
return True, None
|
|
return (
|
|
False,
|
|
_import_error
|
|
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
|
|
)
|
|
|
|
|
|
class CrossEncoderReranker(BaseReranker):
|
|
"""Cross-encoder reranker with lazy model loading."""
|
|
|
|
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
|
self.model_name = (model_name or "").strip()
|
|
if not self.model_name:
|
|
raise ValueError("model_name cannot be blank")
|
|
|
|
self.device = (device or "").strip() or None
|
|
self._model = None
|
|
self._lock = threading.RLock()
|
|
|
|
def _load_model(self) -> None:
|
|
if self._model is not None:
|
|
return
|
|
|
|
ok, err = check_cross_encoder_available()
|
|
if not ok:
|
|
raise ImportError(err)
|
|
|
|
with self._lock:
|
|
if self._model is not None:
|
|
return
|
|
|
|
try:
|
|
if self.device:
|
|
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
|
|
else:
|
|
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
|
|
except Exception as exc:
|
|
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
|
|
raise
|
|
|
|
def score_pairs(
|
|
self,
|
|
pairs: Sequence[Tuple[str, str]],
|
|
*,
|
|
batch_size: int = 32,
|
|
) -> List[float]:
|
|
"""Score (query, doc) pairs using the cross-encoder.
|
|
|
|
Returns:
|
|
List of scores (one per pair) in the model's native scale (usually logits).
|
|
"""
|
|
if not pairs:
|
|
return []
|
|
|
|
self._load_model()
|
|
|
|
if self._model is None: # pragma: no cover - defensive
|
|
return []
|
|
|
|
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
|
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
|
return [float(s) for s in scores]
|