"""ONNX-optimized SPLADE sparse encoder for code search. This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors that combine the interpretability of BM25 with neural relevance modeling. Install (CPU): pip install onnxruntime optimum[onnxruntime] transformers Install (GPU): pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers """ from __future__ import annotations import logging import threading from pathlib import Path from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) def check_splade_available() -> Tuple[bool, Optional[str]]: """Check whether SPLADE dependencies are available. Returns: Tuple of (available: bool, error_message: Optional[str]) """ try: import numpy # noqa: F401 except ImportError as exc: return False, f"numpy not available: {exc}. Install with: pip install numpy" try: import onnxruntime # noqa: F401 except ImportError as exc: return ( False, f"onnxruntime not available: {exc}. Install with: pip install onnxruntime", ) try: from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401 except ImportError as exc: return ( False, f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]", ) try: from transformers import AutoTokenizer # noqa: F401 except ImportError as exc: return ( False, f"transformers not available: {exc}. Install with: pip install transformers", ) return True, None # Global cache for SPLADE encoders (singleton pattern) _splade_cache: Dict[str, "SpladeEncoder"] = {} _cache_lock = threading.RLock() def get_splade_encoder( model_name: str = "naver/splade-cocondenser-ensembledistil", use_gpu: bool = True, max_length: int = 512, sparsity_threshold: float = 0.01, cache_dir: Optional[str] = None, ) -> "SpladeEncoder": """Get or create cached SPLADE encoder (thread-safe singleton). This function provides significant performance improvement by reusing SpladeEncoder instances across multiple searches, avoiding repeated model loading overhead. Args: model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil) use_gpu: If True, use GPU acceleration when available max_length: Maximum sequence length for tokenization sparsity_threshold: Minimum weight to include in sparse vector cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade) Returns: Cached SpladeEncoder instance for the given configuration """ global _splade_cache # Cache key includes all configuration parameters cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}" with _cache_lock: encoder = _splade_cache.get(cache_key) if encoder is not None: return encoder # Create new encoder and cache it encoder = SpladeEncoder( model_name=model_name, use_gpu=use_gpu, max_length=max_length, sparsity_threshold=sparsity_threshold, cache_dir=cache_dir, ) # Pre-load model to ensure it's ready encoder._load_model() _splade_cache[cache_key] = encoder return encoder def clear_splade_cache() -> None: """Clear the SPLADE encoder cache and release ONNX resources. This method ensures proper cleanup of ONNX model resources to prevent memory leaks when encoders are no longer needed. """ global _splade_cache with _cache_lock: # Release ONNX resources before clearing cache for encoder in _splade_cache.values(): if encoder._model is not None: del encoder._model encoder._model = None if encoder._tokenizer is not None: del encoder._tokenizer encoder._tokenizer = None _splade_cache.clear() class SpladeEncoder: """ONNX-optimized SPLADE sparse encoder. Produces sparse vectors with vocabulary-aligned dimensions. Output: Dict[int, float] mapping token_id to weight. SPLADE activation formula: splade_repr = log(1 + ReLU(logits)) * attention_mask splade_vec = max_pooling(splade_repr, axis=sequence_length) References: - SPLADE: https://arxiv.org/abs/2107.05720 - SPLADE v2: https://arxiv.org/abs/2109.10086 """ DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil" def __init__( self, model_name: str = DEFAULT_MODEL, use_gpu: bool = True, max_length: int = 512, sparsity_threshold: float = 0.01, providers: Optional[List[Any]] = None, cache_dir: Optional[str] = None, ) -> None: """Initialize SPLADE encoder. Args: model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil) use_gpu: If True, use GPU acceleration when available max_length: Maximum sequence length for tokenization sparsity_threshold: Minimum weight to include in sparse vector providers: Explicit ONNX providers list (overrides use_gpu) cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade) """ self.model_name = (model_name or self.DEFAULT_MODEL).strip() if not self.model_name: raise ValueError("model_name cannot be blank") self.use_gpu = bool(use_gpu) self.max_length = int(max_length) if max_length > 0 else 512 self.sparsity_threshold = float(sparsity_threshold) self.providers = providers # Setup ONNX cache directory if cache_dir: self._cache_dir = Path(cache_dir) else: self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade" self._tokenizer: Any | None = None self._model: Any | None = None self._vocab_size: int | None = None self._lock = threading.RLock() def _get_local_cache_path(self) -> Path: """Get local cache path for this model's ONNX files. Returns: Path to the local ONNX cache directory for this model """ # Replace / with -- for filesystem-safe naming safe_name = self.model_name.replace("/", "--") return self._cache_dir / safe_name def _load_model(self) -> None: """Lazy load ONNX model and tokenizer. First checks local cache for ONNX model, falling back to HuggingFace download and conversion if not cached. """ if self._model is not None and self._tokenizer is not None: return ok, err = check_splade_available() if not ok: raise ImportError(err) with self._lock: if self._model is not None and self._tokenizer is not None: return from inspect import signature from optimum.onnxruntime import ORTModelForMaskedLM from transformers import AutoTokenizer if self.providers is None: from .gpu_support import get_optimal_providers, get_selected_device_id # Get providers as pure string list (cache-friendly) # NOTE: with_device_options=False to avoid tuple-based providers # which break optimum's caching mechanism self.providers = get_optimal_providers( use_gpu=self.use_gpu, with_device_options=False ) # Get device_id separately for provider_options self._device_id = get_selected_device_id() if self.use_gpu else None # Some Optimum versions accept `providers`, others accept a single `provider` # Prefer passing the full providers list, with a conservative fallback model_kwargs: dict[str, Any] = {} try: params = signature(ORTModelForMaskedLM.from_pretrained).parameters if "providers" in params: model_kwargs["providers"] = self.providers # Pass device_id via provider_options for GPU selection if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None: # Build provider_options dict for each GPU provider provider_options = {} for p in self.providers: if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"): provider_options[p] = {"device_id": self._device_id} if provider_options: model_kwargs["provider_options"] = provider_options elif "provider" in params: provider_name = "CPUExecutionProvider" if self.providers: first = self.providers[0] provider_name = first[0] if isinstance(first, tuple) else str(first) model_kwargs["provider"] = provider_name except Exception as e: logger.debug(f"Failed to inspect ORTModel signature: {e}") model_kwargs = {} # Check for local ONNX cache first local_cache = self._get_local_cache_path() onnx_model_path = local_cache / "model.onnx" if onnx_model_path.exists(): # Load from local cache logger.info(f"Loading SPLADE from local cache: {local_cache}") try: self._model = ORTModelForMaskedLM.from_pretrained( str(local_cache), **model_kwargs, ) self._tokenizer = AutoTokenizer.from_pretrained( str(local_cache), use_fast=True ) self._vocab_size = len(self._tokenizer) logger.info( f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}" ) return except Exception as e: logger.warning(f"Failed to load from cache, redownloading: {e}") # Download and convert from HuggingFace logger.info(f"Downloading SPLADE model: {self.model_name}") try: self._model = ORTModelForMaskedLM.from_pretrained( self.model_name, export=True, # Export to ONNX **model_kwargs, ) logger.debug(f"SPLADE model loaded: {self.model_name}") except TypeError: # Fallback for older Optimum versions: retry without provider arguments self._model = ORTModelForMaskedLM.from_pretrained( self.model_name, export=True, ) logger.warning( "Optimum version doesn't support provider parameters. " "Upgrade optimum for GPU acceleration: pip install --upgrade optimum" ) self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) # Cache vocabulary size self._vocab_size = len(self._tokenizer) logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}") # Save to local cache for future use try: local_cache.mkdir(parents=True, exist_ok=True) self._model.save_pretrained(str(local_cache)) self._tokenizer.save_pretrained(str(local_cache)) logger.info(f"SPLADE model cached to: {local_cache}") except Exception as e: logger.warning(f"Failed to cache SPLADE model: {e}") @staticmethod def _splade_activation(logits: Any, attention_mask: Any) -> Any: """Apply SPLADE activation function to model outputs. Formula: log(1 + ReLU(logits)) * attention_mask Args: logits: Model output logits (batch, seq_len, vocab_size) attention_mask: Attention mask (batch, seq_len) Returns: SPLADE representations (batch, seq_len, vocab_size) """ import numpy as np # ReLU activation relu_logits = np.maximum(0, logits) # Log(1 + x) transformation log_relu = np.log1p(relu_logits) # Apply attention mask (expand to match vocab dimension) # attention_mask: (batch, seq_len) -> (batch, seq_len, 1) mask_expanded = np.expand_dims(attention_mask, axis=-1) # Element-wise multiplication splade_repr = log_relu * mask_expanded return splade_repr @staticmethod def _max_pooling(splade_repr: Any) -> Any: """Max pooling over sequence length dimension. Args: splade_repr: SPLADE representations (batch, seq_len, vocab_size) Returns: Pooled sparse vectors (batch, vocab_size) """ import numpy as np # Max pooling over sequence dimension (axis=1) return np.max(splade_repr, axis=1) def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]: """Convert dense vector to sparse dictionary. Args: dense_vec: Dense vector (vocab_size,) Returns: Sparse dictionary {token_id: weight} with weights above threshold """ import numpy as np # Find non-zero indices above threshold nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0] # Create sparse dictionary sparse_dict = { int(idx): float(dense_vec[idx]) for idx in nonzero_indices } return sparse_dict def warmup(self, text: str = "warmup query") -> None: """Warmup the encoder by running a dummy inference. First-time model inference includes initialization overhead. Call this method once before the first real search to avoid latency spikes. Args: text: Dummy text for warmup (default: "warmup query") """ logger.info("Warming up SPLADE encoder...") # Trigger model loading and first inference _ = self.encode_text(text) logger.info("SPLADE encoder warmup complete") def encode_text(self, text: str) -> Dict[int, float]: """Encode text to sparse vector {token_id: weight}. Args: text: Input text to encode Returns: Sparse vector as dictionary mapping token_id to weight """ self._load_model() if self._model is None or self._tokenizer is None: raise RuntimeError("Model not loaded") import numpy as np # Tokenize input encoded = self._tokenizer( text, padding=True, truncation=True, max_length=self.max_length, return_tensors="np", ) # Forward pass through model outputs = self._model(**encoded) # Extract logits if hasattr(outputs, "logits"): logits = outputs.logits elif isinstance(outputs, dict) and "logits" in outputs: logits = outputs["logits"] elif isinstance(outputs, (list, tuple)) and outputs: logits = outputs[0] else: raise RuntimeError("Unexpected model output format") # Apply SPLADE activation attention_mask = encoded["attention_mask"] splade_repr = self._splade_activation(logits, attention_mask) # Max pooling over sequence length splade_vec = self._max_pooling(splade_repr) # Convert to sparse dictionary (single item batch) sparse_dict = self._to_sparse_dict(splade_vec[0]) return sparse_dict def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]: """Batch encode texts to sparse vectors. Args: texts: List of input texts to encode batch_size: Batch size for encoding (default: 32) Returns: List of sparse vectors as dictionaries """ if not texts: return [] self._load_model() if self._model is None or self._tokenizer is None: raise RuntimeError("Model not loaded") import numpy as np results: List[Dict[int, float]] = [] # Process in batches for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] # Tokenize batch encoded = self._tokenizer( batch_texts, padding=True, truncation=True, max_length=self.max_length, return_tensors="np", ) # Forward pass through model outputs = self._model(**encoded) # Extract logits if hasattr(outputs, "logits"): logits = outputs.logits elif isinstance(outputs, dict) and "logits" in outputs: logits = outputs["logits"] elif isinstance(outputs, (list, tuple)) and outputs: logits = outputs[0] else: raise RuntimeError("Unexpected model output format") # Apply SPLADE activation attention_mask = encoded["attention_mask"] splade_repr = self._splade_activation(logits, attention_mask) # Max pooling over sequence length splade_vecs = self._max_pooling(splade_repr) # Convert each vector to sparse dictionary for vec in splade_vecs: sparse_dict = self._to_sparse_dict(vec) results.append(sparse_dict) return results @property def vocab_size(self) -> int: """Return vocabulary size (~30k for BERT-based models). Returns: Vocabulary size (number of tokens in tokenizer) """ if self._vocab_size is not None: return self._vocab_size self._load_model() return self._vocab_size or 0 def get_token(self, token_id: int) -> str: """Convert token_id to string (for debugging). Args: token_id: Token ID to convert Returns: Token string """ self._load_model() if self._tokenizer is None: raise RuntimeError("Tokenizer not loaded") return self._tokenizer.decode([token_id]) def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]: """Get top-k tokens with highest weights from sparse vector. Useful for debugging and understanding what the model is focusing on. Args: sparse_vec: Sparse vector as {token_id: weight} top_k: Number of top tokens to return Returns: List of (token_string, weight) tuples, sorted by weight descending """ self._load_model() if not sparse_vec: return [] # Sort by weight descending sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True) # Take top-k and convert token_ids to strings top_items = sorted_items[:top_k] return [ (self.get_token(token_id), weight) for token_id, weight in top_items ]