mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-13 02:41:50 +08:00
feat: Add type validation for RRF weights and implement caching for embedder instances
This commit is contained in:
@@ -85,8 +85,19 @@ class HybridSearchEngine:
|
|||||||
weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS)
|
weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS)
|
||||||
config: Optional runtime config (enables optional reranking features)
|
config: Optional runtime config (enables optional reranking features)
|
||||||
embedder: Optional embedder instance for embedding-based reranking
|
embedder: Optional embedder instance for embedding-based reranking
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If weights is not a dict (e.g., if a Path is passed)
|
||||||
"""
|
"""
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Validate weights type to catch common usage errors
|
||||||
|
if weights is not None and not isinstance(weights, dict):
|
||||||
|
raise TypeError(
|
||||||
|
f"weights must be a dict, got {type(weights).__name__}. "
|
||||||
|
f"Did you mean to pass index_path to search() instead of __init__()?"
|
||||||
|
)
|
||||||
|
|
||||||
self.weights = weights or DEFAULT_WEIGHTS.copy()
|
self.weights = weights or DEFAULT_WEIGHTS.copy()
|
||||||
self._config = config
|
self._config = config
|
||||||
self.embedder = embedder
|
self.embedder = embedder
|
||||||
|
|||||||
@@ -1,14 +1,23 @@
|
|||||||
"""Factory for creating embedders.
|
"""Factory for creating embedders.
|
||||||
|
|
||||||
Provides a unified interface for instantiating different embedder backends.
|
Provides a unified interface for instantiating different embedder backends.
|
||||||
|
Includes caching to avoid repeated model loading overhead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from .base import BaseEmbedder
|
from .base import BaseEmbedder
|
||||||
|
|
||||||
|
# Module-level cache for embedder instances
|
||||||
|
# Key: (backend, profile, model, use_gpu) -> embedder instance
|
||||||
|
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
|
||||||
|
_cache_lock = threading.Lock()
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_embedder(
|
def get_embedder(
|
||||||
backend: str = "fastembed",
|
backend: str = "fastembed",
|
||||||
@@ -65,13 +74,38 @@ def get_embedder(
|
|||||||
... ]
|
... ]
|
||||||
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
||||||
"""
|
"""
|
||||||
|
# Build cache key from immutable configuration
|
||||||
|
if backend == "fastembed":
|
||||||
|
cache_key = ("fastembed", profile, None, use_gpu)
|
||||||
|
elif backend == "litellm":
|
||||||
|
# For litellm, use model as part of cache key
|
||||||
|
# Multi-endpoint mode is not cached as it's more complex
|
||||||
|
if endpoints and len(endpoints) > 1:
|
||||||
|
cache_key = None # Skip cache for multi-endpoint
|
||||||
|
else:
|
||||||
|
effective_model = endpoints[0]["model"] if endpoints else model
|
||||||
|
cache_key = ("litellm", None, effective_model, None)
|
||||||
|
else:
|
||||||
|
cache_key = None
|
||||||
|
|
||||||
|
# Check cache first (thread-safe)
|
||||||
|
if cache_key is not None:
|
||||||
|
with _cache_lock:
|
||||||
|
if cache_key in _embedder_cache:
|
||||||
|
_logger.debug("Returning cached embedder for %s", cache_key)
|
||||||
|
return _embedder_cache[cache_key]
|
||||||
|
|
||||||
|
# Create new embedder instance
|
||||||
|
embedder: Optional[BaseEmbedder] = None
|
||||||
|
|
||||||
if backend == "fastembed":
|
if backend == "fastembed":
|
||||||
from .embedder import Embedder
|
from .embedder import Embedder
|
||||||
return Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||||
elif backend == "litellm":
|
elif backend == "litellm":
|
||||||
# Check if multi-endpoint mode is requested
|
# Check if multi-endpoint mode is requested
|
||||||
if endpoints and len(endpoints) > 1:
|
if endpoints and len(endpoints) > 1:
|
||||||
from .rotational_embedder import create_rotational_embedder
|
from .rotational_embedder import create_rotational_embedder
|
||||||
|
# Multi-endpoint is not cached
|
||||||
return create_rotational_embedder(
|
return create_rotational_embedder(
|
||||||
endpoints_config=endpoints,
|
endpoints_config=endpoints,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
@@ -86,13 +120,39 @@ def get_embedder(
|
|||||||
if "api_base" in ep:
|
if "api_base" in ep:
|
||||||
ep_kwargs["api_base"] = ep["api_base"]
|
ep_kwargs["api_base"] = ep["api_base"]
|
||||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||||
return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||||
else:
|
else:
|
||||||
# No endpoints list - use model parameter
|
# No endpoints list - use model parameter
|
||||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||||
return LiteLLMEmbedderWrapper(model=model, **kwargs)
|
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown backend: {backend}. "
|
f"Unknown backend: {backend}. "
|
||||||
f"Supported backends: 'fastembed', 'litellm'"
|
f"Supported backends: 'fastembed', 'litellm'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Cache the embedder for future use (thread-safe)
|
||||||
|
if cache_key is not None and embedder is not None:
|
||||||
|
with _cache_lock:
|
||||||
|
# Double-check to avoid race condition
|
||||||
|
if cache_key not in _embedder_cache:
|
||||||
|
_embedder_cache[cache_key] = embedder
|
||||||
|
_logger.debug("Cached new embedder for %s", cache_key)
|
||||||
|
else:
|
||||||
|
# Another thread created it already, use that one
|
||||||
|
embedder = _embedder_cache[cache_key]
|
||||||
|
|
||||||
|
return embedder # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def clear_embedder_cache() -> int:
|
||||||
|
"""Clear the embedder cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of embedders cleared from cache
|
||||||
|
"""
|
||||||
|
with _cache_lock:
|
||||||
|
count = len(_embedder_cache)
|
||||||
|
_embedder_cache.clear()
|
||||||
|
_logger.debug("Cleared %d embedders from cache", count)
|
||||||
|
return count
|
||||||
|
|||||||
Reference in New Issue
Block a user