diff --git a/codex-lens/src/codexlens/search/hybrid_search.py b/codex-lens/src/codexlens/search/hybrid_search.py index 89dac0f9..181c7b72 100644 --- a/codex-lens/src/codexlens/search/hybrid_search.py +++ b/codex-lens/src/codexlens/search/hybrid_search.py @@ -85,8 +85,19 @@ class HybridSearchEngine: weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS) config: Optional runtime config (enables optional reranking features) embedder: Optional embedder instance for embedding-based reranking + + Raises: + TypeError: If weights is not a dict (e.g., if a Path is passed) """ self.logger = logging.getLogger(__name__) + + # Validate weights type to catch common usage errors + if weights is not None and not isinstance(weights, dict): + raise TypeError( + f"weights must be a dict, got {type(weights).__name__}. " + f"Did you mean to pass index_path to search() instead of __init__()?" + ) + self.weights = weights or DEFAULT_WEIGHTS.copy() self._config = config self.embedder = embedder diff --git a/codex-lens/src/codexlens/semantic/factory.py b/codex-lens/src/codexlens/semantic/factory.py index fe360539..3295eba8 100644 --- a/codex-lens/src/codexlens/semantic/factory.py +++ b/codex-lens/src/codexlens/semantic/factory.py @@ -1,14 +1,23 @@ """Factory for creating embedders. Provides a unified interface for instantiating different embedder backends. +Includes caching to avoid repeated model loading overhead. """ from __future__ import annotations +import logging +import threading from typing import Any, Dict, List, Optional from .base import BaseEmbedder +# Module-level cache for embedder instances +# Key: (backend, profile, model, use_gpu) -> embedder instance +_embedder_cache: Dict[tuple, BaseEmbedder] = {} +_cache_lock = threading.Lock() +_logger = logging.getLogger(__name__) + def get_embedder( backend: str = "fastembed", @@ -65,13 +74,38 @@ def get_embedder( ... ] >>> embedder = get_embedder(backend="litellm", endpoints=endpoints) """ + # Build cache key from immutable configuration + if backend == "fastembed": + cache_key = ("fastembed", profile, None, use_gpu) + elif backend == "litellm": + # For litellm, use model as part of cache key + # Multi-endpoint mode is not cached as it's more complex + if endpoints and len(endpoints) > 1: + cache_key = None # Skip cache for multi-endpoint + else: + effective_model = endpoints[0]["model"] if endpoints else model + cache_key = ("litellm", None, effective_model, None) + else: + cache_key = None + + # Check cache first (thread-safe) + if cache_key is not None: + with _cache_lock: + if cache_key in _embedder_cache: + _logger.debug("Returning cached embedder for %s", cache_key) + return _embedder_cache[cache_key] + + # Create new embedder instance + embedder: Optional[BaseEmbedder] = None + if backend == "fastembed": from .embedder import Embedder - return Embedder(profile=profile, use_gpu=use_gpu, **kwargs) + embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs) elif backend == "litellm": # Check if multi-endpoint mode is requested if endpoints and len(endpoints) > 1: from .rotational_embedder import create_rotational_embedder + # Multi-endpoint is not cached return create_rotational_embedder( endpoints_config=endpoints, strategy=strategy, @@ -86,13 +120,39 @@ def get_embedder( if "api_base" in ep: ep_kwargs["api_base"] = ep["api_base"] from .litellm_embedder import LiteLLMEmbedderWrapper - return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs) + embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs) else: # No endpoints list - use model parameter from .litellm_embedder import LiteLLMEmbedderWrapper - return LiteLLMEmbedderWrapper(model=model, **kwargs) + embedder = LiteLLMEmbedderWrapper(model=model, **kwargs) else: raise ValueError( f"Unknown backend: {backend}. " f"Supported backends: 'fastembed', 'litellm'" ) + + # Cache the embedder for future use (thread-safe) + if cache_key is not None and embedder is not None: + with _cache_lock: + # Double-check to avoid race condition + if cache_key not in _embedder_cache: + _embedder_cache[cache_key] = embedder + _logger.debug("Cached new embedder for %s", cache_key) + else: + # Another thread created it already, use that one + embedder = _embedder_cache[cache_key] + + return embedder # type: ignore + + +def clear_embedder_cache() -> int: + """Clear the embedder cache. + + Returns: + Number of embedders cleared from cache + """ + with _cache_lock: + count = len(_embedder_cache) + _embedder_cache.clear() + _logger.debug("Cleared %d embedders from cache", count) + return count