feat: Add type validation for RRF weights and implement caching for embedder instances

This commit is contained in:
catlog22
2026-01-02 19:50:51 +08:00
parent c268b531aa
commit 96b44e1482
2 changed files with 74 additions and 3 deletions

View File

@@ -85,8 +85,19 @@ class HybridSearchEngine:
weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS) weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS)
config: Optional runtime config (enables optional reranking features) config: Optional runtime config (enables optional reranking features)
embedder: Optional embedder instance for embedding-based reranking embedder: Optional embedder instance for embedding-based reranking
Raises:
TypeError: If weights is not a dict (e.g., if a Path is passed)
""" """
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
# Validate weights type to catch common usage errors
if weights is not None and not isinstance(weights, dict):
raise TypeError(
f"weights must be a dict, got {type(weights).__name__}. "
f"Did you mean to pass index_path to search() instead of __init__()?"
)
self.weights = weights or DEFAULT_WEIGHTS.copy() self.weights = weights or DEFAULT_WEIGHTS.copy()
self._config = config self._config = config
self.embedder = embedder self.embedder = embedder

View File

@@ -1,14 +1,23 @@
"""Factory for creating embedders. """Factory for creating embedders.
Provides a unified interface for instantiating different embedder backends. Provides a unified interface for instantiating different embedder backends.
Includes caching to avoid repeated model loading overhead.
""" """
from __future__ import annotations from __future__ import annotations
import logging
import threading
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from .base import BaseEmbedder from .base import BaseEmbedder
# Module-level cache for embedder instances
# Key: (backend, profile, model, use_gpu) -> embedder instance
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
_cache_lock = threading.Lock()
_logger = logging.getLogger(__name__)
def get_embedder( def get_embedder(
backend: str = "fastembed", backend: str = "fastembed",
@@ -65,13 +74,38 @@ def get_embedder(
... ] ... ]
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints) >>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
""" """
# Build cache key from immutable configuration
if backend == "fastembed":
cache_key = ("fastembed", profile, None, use_gpu)
elif backend == "litellm":
# For litellm, use model as part of cache key
# Multi-endpoint mode is not cached as it's more complex
if endpoints and len(endpoints) > 1:
cache_key = None # Skip cache for multi-endpoint
else:
effective_model = endpoints[0]["model"] if endpoints else model
cache_key = ("litellm", None, effective_model, None)
else:
cache_key = None
# Check cache first (thread-safe)
if cache_key is not None:
with _cache_lock:
if cache_key in _embedder_cache:
_logger.debug("Returning cached embedder for %s", cache_key)
return _embedder_cache[cache_key]
# Create new embedder instance
embedder: Optional[BaseEmbedder] = None
if backend == "fastembed": if backend == "fastembed":
from .embedder import Embedder from .embedder import Embedder
return Embedder(profile=profile, use_gpu=use_gpu, **kwargs) embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
elif backend == "litellm": elif backend == "litellm":
# Check if multi-endpoint mode is requested # Check if multi-endpoint mode is requested
if endpoints and len(endpoints) > 1: if endpoints and len(endpoints) > 1:
from .rotational_embedder import create_rotational_embedder from .rotational_embedder import create_rotational_embedder
# Multi-endpoint is not cached
return create_rotational_embedder( return create_rotational_embedder(
endpoints_config=endpoints, endpoints_config=endpoints,
strategy=strategy, strategy=strategy,
@@ -86,13 +120,39 @@ def get_embedder(
if "api_base" in ep: if "api_base" in ep:
ep_kwargs["api_base"] = ep["api_base"] ep_kwargs["api_base"] = ep["api_base"]
from .litellm_embedder import LiteLLMEmbedderWrapper from .litellm_embedder import LiteLLMEmbedderWrapper
return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs) embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
else: else:
# No endpoints list - use model parameter # No endpoints list - use model parameter
from .litellm_embedder import LiteLLMEmbedderWrapper from .litellm_embedder import LiteLLMEmbedderWrapper
return LiteLLMEmbedderWrapper(model=model, **kwargs) embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
else: else:
raise ValueError( raise ValueError(
f"Unknown backend: {backend}. " f"Unknown backend: {backend}. "
f"Supported backends: 'fastembed', 'litellm'" f"Supported backends: 'fastembed', 'litellm'"
) )
# Cache the embedder for future use (thread-safe)
if cache_key is not None and embedder is not None:
with _cache_lock:
# Double-check to avoid race condition
if cache_key not in _embedder_cache:
_embedder_cache[cache_key] = embedder
_logger.debug("Cached new embedder for %s", cache_key)
else:
# Another thread created it already, use that one
embedder = _embedder_cache[cache_key]
return embedder # type: ignore
def clear_embedder_cache() -> int:
"""Clear the embedder cache.
Returns:
Number of embedders cleared from cache
"""
with _cache_lock:
count = len(_embedder_cache)
_embedder_cache.clear()
_logger.debug("Cleared %d embedders from cache", count)
return count