feat(model-lock): implement model lock management with localStorage support

This commit is contained in:
catlog22
2026-01-03 19:48:07 +08:00
parent 6043e6aa3b
commit 0af84be775
4 changed files with 570 additions and 99 deletions

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
from .base import BaseReranker
from .factory import check_reranker_available, get_reranker
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
@@ -17,6 +18,8 @@ __all__ = [
"get_reranker",
"CrossEncoderReranker",
"check_cross_encoder_available",
"FastEmbedReranker",
"check_fastembed_reranker_available",
"ONNXReranker",
"check_onnx_reranker_available",
]

View File

@@ -14,8 +14,9 @@ def check_reranker_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific reranker backend can be used.
Notes:
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
- "onnx" redirects to "fastembed" for backward compatibility.
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
- "onnx" uses Optimum + ONNX Runtime (pip install codexlens[reranker] or codexlens[reranker-onnx]).
- "api" uses a remote reranking HTTP API (requires httpx).
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
"""
@@ -26,10 +27,16 @@ def check_reranker_available(backend: str) -> tuple[bool, str | None]:
return check_cross_encoder_available()
if backend == "onnx":
from .onnx_reranker import check_onnx_reranker_available
if backend == "fastembed":
from .fastembed_reranker import check_fastembed_reranker_available
return check_onnx_reranker_available()
return check_fastembed_reranker_available()
if backend == "onnx":
# Redirect to fastembed for backward compatibility
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "litellm":
try:
@@ -54,12 +61,12 @@ def check_reranker_available(backend: str) -> tuple[bool, str | None]:
return False, (
f"Invalid reranker backend: {backend}. "
"Must be 'onnx', 'api', 'litellm', or 'legacy'."
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
)
def get_reranker(
backend: str = "onnx",
backend: str = "fastembed",
model_name: str | None = None,
*,
device: str | None = None,
@@ -69,12 +76,14 @@ def get_reranker(
Args:
backend: Reranker backend to use. Options:
- "onnx": Optimum + onnxruntime backend (default)
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
- "onnx": Redirects to fastembed for backward compatibility
- "api": HTTP API backend (remote providers)
- "litellm": LiteLLM backend (LLM-based, experimental)
- "litellm": LiteLLM backend (LLM-based, for API mode)
- "legacy": sentence-transformers CrossEncoder backend (optional)
model_name: Model identifier for model-based backends. Defaults depend on backend:
- onnx: Xenova/ms-marco-MiniLM-L-6-v2
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
- onnx: (redirects to fastembed)
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
- litellm: default
@@ -90,16 +99,28 @@ def get_reranker(
"""
backend = (backend or "").strip().lower()
if backend == "onnx":
ok, err = check_reranker_available("onnx")
if backend == "fastembed":
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .onnx_reranker import ONNXReranker
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or ONNXReranker.DEFAULT_MODEL
_ = device # Device selection is managed via ONNX Runtime providers.
return ONNXReranker(model_name=resolved_model_name, **kwargs)
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "onnx":
# Redirect to fastembed for backward compatibility
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "legacy":
ok, err = check_reranker_available("legacy")
@@ -134,5 +155,5 @@ def get_reranker(
return APIReranker(model_name=resolved_model_name, **kwargs)
raise ValueError(
f"Unknown backend: {backend}. Supported backends: 'onnx', 'api', 'litellm', 'legacy'"
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
)

View File

@@ -0,0 +1,256 @@
"""FastEmbed-based reranker backend.
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
Install:
pip install fastembed>=0.4.0
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
"""Check whether fastembed reranker dependencies are available."""
try:
import fastembed # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
)
try:
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed TextCrossEncoder not available: {exc}. "
"Upgrade with: pip install fastembed>=0.4.0",
)
return True, None
class FastEmbedReranker(BaseReranker):
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
# Alternative models supported by fastembed:
# - "BAAI/bge-reranker-base"
# - "BAAI/bge-reranker-large"
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
cache_dir: str | None = None,
threads: int | None = None,
) -> None:
"""Initialize FastEmbed reranker.
Args:
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
use_gpu: Whether to use GPU acceleration when available.
cache_dir: Optional directory for caching downloaded models.
threads: Optional number of threads for ONNX Runtime.
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.cache_dir = cache_dir
self.threads = threads
self._encoder: Any | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
"""Lazy-load the TextCrossEncoder model."""
if self._encoder is not None:
return
ok, err = check_fastembed_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._encoder is not None:
return
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Determine providers based on GPU preference
providers: list[str] | None = None
if self.use_gpu:
try:
from ..gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
except Exception:
# Fallback: let fastembed decide
providers = None
# Build initialization kwargs
init_kwargs: dict[str, Any] = {}
if self.cache_dir:
init_kwargs["cache_dir"] = self.cache_dir
if self.threads is not None:
init_kwargs["threads"] = self.threads
if providers:
init_kwargs["providers"] = providers
logger.debug(
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
self.model_name,
self.use_gpu,
)
self._encoder = TextCrossEncoder(
model_name=self.model_name,
**init_kwargs,
)
logger.debug("FastEmbed reranker model loaded successfully")
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair), normalized to [0, 1] range.
"""
if not pairs:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
# For batch scoring of multiple query-doc pairs, we need to process them.
# Group by query for efficiency when same query appears multiple times.
query_to_docs: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
if query not in query_to_docs:
query_to_docs[query] = []
query_to_docs[query].append((idx, doc))
# Score each query group
scores: list[float] = [0.0] * len(pairs)
for query, indexed_docs in query_to_docs.items():
docs = [doc for _, doc in indexed_docs]
indices = [idx for idx, _ in indexed_docs]
try:
# TextCrossEncoder.rerank returns list of RerankResult with score attribute
results = list(
self._encoder.rerank(
query=query,
documents=docs,
batch_size=batch_size,
)
)
# Map scores back to original positions
# Results are returned in descending score order, but we need original order
for result in results:
# Each result has 'index' (position in input docs) and 'score'
doc_idx = result.index if hasattr(result, "index") else 0
score = result.score if hasattr(result, "score") else 0.0
if doc_idx < len(indices):
original_idx = indices[doc_idx]
# Normalize score to [0, 1] using sigmoid if needed
# FastEmbed typically returns scores in [0, 1] already
if score < 0 or score > 1:
import math
score = 1.0 / (1.0 + math.exp(-score))
scores[original_idx] = float(score)
except Exception as e:
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
# Leave scores as 0.0 for failed queries
return scores
def rerank(
self,
query: str,
documents: Sequence[str],
*,
top_k: int | None = None,
batch_size: int = 32,
) -> list[tuple[float, str, int]]:
"""Rerank documents for a single query.
This is a convenience method that provides results in ranked order.
Args:
query: The query string.
documents: List of documents to rerank.
top_k: Return only top K results. None returns all.
batch_size: Batch size for scoring.
Returns:
List of (score, document, original_index) tuples, sorted by score descending.
"""
if not documents:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
try:
results = list(
self._encoder.rerank(
query=query,
documents=list(documents),
batch_size=batch_size,
)
)
# Convert to our format: (score, document, original_index)
ranked = []
for result in results:
idx = result.index if hasattr(result, "index") else 0
score = result.score if hasattr(result, "score") else 0.0
doc = documents[idx] if idx < len(documents) else ""
ranked.append((float(score), doc, idx))
# Sort by score descending
ranked.sort(key=lambda x: x[0], reverse=True)
if top_k is not None and top_k > 0:
ranked = ranked[:top_k]
return ranked
except Exception as e:
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
return []