feat: Add type validation for RRF weights and implement caching for embedder instances

This commit is contained in:
catlog22
2026-01-02 19:50:51 +08:00
parent c268b531aa
commit 96b44e1482
2 changed files with 74 additions and 3 deletions

View File

@@ -85,8 +85,19 @@ class HybridSearchEngine:
weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS)
config: Optional runtime config (enables optional reranking features)
embedder: Optional embedder instance for embedding-based reranking
Raises:
TypeError: If weights is not a dict (e.g., if a Path is passed)
"""
self.logger = logging.getLogger(__name__)
# Validate weights type to catch common usage errors
if weights is not None and not isinstance(weights, dict):
raise TypeError(
f"weights must be a dict, got {type(weights).__name__}. "
f"Did you mean to pass index_path to search() instead of __init__()?"
)
self.weights = weights or DEFAULT_WEIGHTS.copy()
self._config = config
self.embedder = embedder

View File

@@ -1,14 +1,23 @@
"""Factory for creating embedders.
Provides a unified interface for instantiating different embedder backends.
Includes caching to avoid repeated model loading overhead.
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Dict, List, Optional
from .base import BaseEmbedder
# Module-level cache for embedder instances
# Key: (backend, profile, model, use_gpu) -> embedder instance
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
_cache_lock = threading.Lock()
_logger = logging.getLogger(__name__)
def get_embedder(
backend: str = "fastembed",
@@ -65,13 +74,38 @@ def get_embedder(
... ]
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
"""
# Build cache key from immutable configuration
if backend == "fastembed":
cache_key = ("fastembed", profile, None, use_gpu)
elif backend == "litellm":
# For litellm, use model as part of cache key
# Multi-endpoint mode is not cached as it's more complex
if endpoints and len(endpoints) > 1:
cache_key = None # Skip cache for multi-endpoint
else:
effective_model = endpoints[0]["model"] if endpoints else model
cache_key = ("litellm", None, effective_model, None)
else:
cache_key = None
# Check cache first (thread-safe)
if cache_key is not None:
with _cache_lock:
if cache_key in _embedder_cache:
_logger.debug("Returning cached embedder for %s", cache_key)
return _embedder_cache[cache_key]
# Create new embedder instance
embedder: Optional[BaseEmbedder] = None
if backend == "fastembed":
from .embedder import Embedder
return Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
elif backend == "litellm":
# Check if multi-endpoint mode is requested
if endpoints and len(endpoints) > 1:
from .rotational_embedder import create_rotational_embedder
# Multi-endpoint is not cached
return create_rotational_embedder(
endpoints_config=endpoints,
strategy=strategy,
@@ -86,13 +120,39 @@ def get_embedder(
if "api_base" in ep:
ep_kwargs["api_base"] = ep["api_base"]
from .litellm_embedder import LiteLLMEmbedderWrapper
return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
else:
# No endpoints list - use model parameter
from .litellm_embedder import LiteLLMEmbedderWrapper
return LiteLLMEmbedderWrapper(model=model, **kwargs)
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
else:
raise ValueError(
f"Unknown backend: {backend}. "
f"Supported backends: 'fastembed', 'litellm'"
)
# Cache the embedder for future use (thread-safe)
if cache_key is not None and embedder is not None:
with _cache_lock:
# Double-check to avoid race condition
if cache_key not in _embedder_cache:
_embedder_cache[cache_key] = embedder
_logger.debug("Cached new embedder for %s", cache_key)
else:
# Another thread created it already, use that one
embedder = _embedder_cache[cache_key]
return embedder # type: ignore
def clear_embedder_cache() -> int:
"""Clear the embedder cache.
Returns:
Number of embedders cleared from cache
"""
with _cache_lock:
count = len(_embedder_cache)
_embedder_cache.clear()
_logger.debug("Cleared %d embedders from cache", count)
return count