mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat: Add type validation for RRF weights and implement caching for embedder instances
This commit is contained in:
@@ -85,8 +85,19 @@ class HybridSearchEngine:
|
||||
weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS)
|
||||
config: Optional runtime config (enables optional reranking features)
|
||||
embedder: Optional embedder instance for embedding-based reranking
|
||||
|
||||
Raises:
|
||||
TypeError: If weights is not a dict (e.g., if a Path is passed)
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Validate weights type to catch common usage errors
|
||||
if weights is not None and not isinstance(weights, dict):
|
||||
raise TypeError(
|
||||
f"weights must be a dict, got {type(weights).__name__}. "
|
||||
f"Did you mean to pass index_path to search() instead of __init__()?"
|
||||
)
|
||||
|
||||
self.weights = weights or DEFAULT_WEIGHTS.copy()
|
||||
self._config = config
|
||||
self.embedder = embedder
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
"""Factory for creating embedders.
|
||||
|
||||
Provides a unified interface for instantiating different embedder backends.
|
||||
Includes caching to avoid repeated model loading overhead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
# Module-level cache for embedder instances
|
||||
# Key: (backend, profile, model, use_gpu) -> embedder instance
|
||||
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_embedder(
|
||||
backend: str = "fastembed",
|
||||
@@ -65,13 +74,38 @@ def get_embedder(
|
||||
... ]
|
||||
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
||||
"""
|
||||
# Build cache key from immutable configuration
|
||||
if backend == "fastembed":
|
||||
cache_key = ("fastembed", profile, None, use_gpu)
|
||||
elif backend == "litellm":
|
||||
# For litellm, use model as part of cache key
|
||||
# Multi-endpoint mode is not cached as it's more complex
|
||||
if endpoints and len(endpoints) > 1:
|
||||
cache_key = None # Skip cache for multi-endpoint
|
||||
else:
|
||||
effective_model = endpoints[0]["model"] if endpoints else model
|
||||
cache_key = ("litellm", None, effective_model, None)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
# Check cache first (thread-safe)
|
||||
if cache_key is not None:
|
||||
with _cache_lock:
|
||||
if cache_key in _embedder_cache:
|
||||
_logger.debug("Returning cached embedder for %s", cache_key)
|
||||
return _embedder_cache[cache_key]
|
||||
|
||||
# Create new embedder instance
|
||||
embedder: Optional[BaseEmbedder] = None
|
||||
|
||||
if backend == "fastembed":
|
||||
from .embedder import Embedder
|
||||
return Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||
elif backend == "litellm":
|
||||
# Check if multi-endpoint mode is requested
|
||||
if endpoints and len(endpoints) > 1:
|
||||
from .rotational_embedder import create_rotational_embedder
|
||||
# Multi-endpoint is not cached
|
||||
return create_rotational_embedder(
|
||||
endpoints_config=endpoints,
|
||||
strategy=strategy,
|
||||
@@ -86,13 +120,39 @@ def get_embedder(
|
||||
if "api_base" in ep:
|
||||
ep_kwargs["api_base"] = ep["api_base"]
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||
else:
|
||||
# No endpoints list - use model parameter
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
return LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Supported backends: 'fastembed', 'litellm'"
|
||||
)
|
||||
|
||||
# Cache the embedder for future use (thread-safe)
|
||||
if cache_key is not None and embedder is not None:
|
||||
with _cache_lock:
|
||||
# Double-check to avoid race condition
|
||||
if cache_key not in _embedder_cache:
|
||||
_embedder_cache[cache_key] = embedder
|
||||
_logger.debug("Cached new embedder for %s", cache_key)
|
||||
else:
|
||||
# Another thread created it already, use that one
|
||||
embedder = _embedder_cache[cache_key]
|
||||
|
||||
return embedder # type: ignore
|
||||
|
||||
|
||||
def clear_embedder_cache() -> int:
|
||||
"""Clear the embedder cache.
|
||||
|
||||
Returns:
|
||||
Number of embedders cleared from cache
|
||||
"""
|
||||
with _cache_lock:
|
||||
count = len(_embedder_cache)
|
||||
_embedder_cache.clear()
|
||||
_logger.debug("Cleared %d embedders from cache", count)
|
||||
return count
|
||||
|
||||
Reference in New Issue
Block a user