mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-13 02:41:50 +08:00
Refactor code structure and remove redundant changes
This commit is contained in:
@@ -0,0 +1,257 @@
|
||||
"""FastEmbed-based reranker backend.
|
||||
|
||||
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
|
||||
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
|
||||
|
||||
Install:
|
||||
pip install fastembed>=0.4.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether fastembed reranker dependencies are available."""
|
||||
try:
|
||||
import fastembed # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
try:
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed TextCrossEncoder not available: {exc}. "
|
||||
"Upgrade with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
class FastEmbedReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
# Alternative models supported by fastembed:
|
||||
# - "BAAI/bge-reranker-base"
|
||||
# - "BAAI/bge-reranker-large"
|
||||
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize FastEmbed reranker.
|
||||
|
||||
Args:
|
||||
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
|
||||
use_gpu: Whether to use GPU acceleration when available.
|
||||
cache_dir: Optional directory for caching downloaded models.
|
||||
threads: Optional number of threads for ONNX Runtime.
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.cache_dir = cache_dir
|
||||
self.threads = threads
|
||||
|
||||
self._encoder: Any | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy-load the TextCrossEncoder model."""
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
ok, err = check_fastembed_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder
|
||||
|
||||
# Determine providers based on GPU preference
|
||||
providers: list[str] | None = None
|
||||
if self.use_gpu:
|
||||
try:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
|
||||
except Exception:
|
||||
# Fallback: let fastembed decide
|
||||
providers = None
|
||||
|
||||
# Build initialization kwargs
|
||||
init_kwargs: dict[str, Any] = {}
|
||||
if self.cache_dir:
|
||||
init_kwargs["cache_dir"] = self.cache_dir
|
||||
if self.threads is not None:
|
||||
init_kwargs["threads"] = self.threads
|
||||
if providers:
|
||||
init_kwargs["providers"] = providers
|
||||
|
||||
logger.debug(
|
||||
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
|
||||
self.model_name,
|
||||
self.use_gpu,
|
||||
)
|
||||
|
||||
self._encoder = TextCrossEncoder(
|
||||
model_name=self.model_name,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
logger.debug("FastEmbed reranker model loaded successfully")
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: float) -> float:
|
||||
"""Numerically stable sigmoid function."""
|
||||
if x < -709:
|
||||
return 0.0
|
||||
if x > 709:
|
||||
return 1.0
|
||||
import math
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair), normalized to [0, 1] range.
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
|
||||
# For batch scoring of multiple query-doc pairs, we need to process them.
|
||||
# Group by query for efficiency when same query appears multiple times.
|
||||
query_to_docs: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
if query not in query_to_docs:
|
||||
query_to_docs[query] = []
|
||||
query_to_docs[query].append((idx, doc))
|
||||
|
||||
# Score each query group
|
||||
scores: list[float] = [0.0] * len(pairs)
|
||||
|
||||
for query, indexed_docs in query_to_docs.items():
|
||||
docs = [doc for _, doc in indexed_docs]
|
||||
indices = [idx for idx, _ in indexed_docs]
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Map scores back to original positions and normalize with sigmoid
|
||||
for i, raw_score in enumerate(raw_scores):
|
||||
if i < len(indices):
|
||||
original_idx = indices[i]
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
scores[original_idx] = self._sigmoid(float(raw_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
|
||||
# Leave scores as 0.0 for failed queries
|
||||
|
||||
return scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
*,
|
||||
top_k: int | None = None,
|
||||
batch_size: int = 32,
|
||||
) -> list[tuple[float, str, int]]:
|
||||
"""Rerank documents for a single query.
|
||||
|
||||
This is a convenience method that provides results in ranked order.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
documents: List of documents to rerank.
|
||||
top_k: Return only top K results. None returns all.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of (score, document, original_index) tuples, sorted by score descending.
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=list(documents),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to our format: (normalized_score, document, original_index)
|
||||
ranked = []
|
||||
for idx, raw_score in enumerate(raw_scores):
|
||||
if idx < len(documents):
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
normalized = self._sigmoid(float(raw_score))
|
||||
ranked.append((normalized, documents[idx], idx))
|
||||
|
||||
# Sort by score descending
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
ranked = ranked[:top_k]
|
||||
|
||||
return ranked
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
|
||||
return []
|
||||
Reference in New Issue
Block a user