diff --git a/codex-lens/src/codexlens/semantic/splade_encoder.py b/codex-lens/src/codexlens/semantic/splade_encoder.py index 646c84bd..59d2ac00 100644 --- a/codex-lens/src/codexlens/semantic/splade_encoder.py +++ b/codex-lens/src/codexlens/semantic/splade_encoder.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging import threading +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @@ -68,6 +69,7 @@ def get_splade_encoder( use_gpu: bool = True, max_length: int = 512, sparsity_threshold: float = 0.01, + cache_dir: Optional[str] = None, ) -> "SpladeEncoder": """Get or create cached SPLADE encoder (thread-safe singleton). @@ -80,6 +82,7 @@ def get_splade_encoder( use_gpu: If True, use GPU acceleration when available max_length: Maximum sequence length for tokenization sparsity_threshold: Minimum weight to include in sparse vector + cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade) Returns: Cached SpladeEncoder instance for the given configuration @@ -100,6 +103,7 @@ def get_splade_encoder( use_gpu=use_gpu, max_length=max_length, sparsity_threshold=sparsity_threshold, + cache_dir=cache_dir, ) # Pre-load model to ensure it's ready encoder._load_model() @@ -151,6 +155,7 @@ class SpladeEncoder: max_length: int = 512, sparsity_threshold: float = 0.01, providers: Optional[List[Any]] = None, + cache_dir: Optional[str] = None, ) -> None: """Initialize SPLADE encoder. @@ -160,6 +165,7 @@ class SpladeEncoder: max_length: Maximum sequence length for tokenization sparsity_threshold: Minimum weight to include in sparse vector providers: Explicit ONNX providers list (overrides use_gpu) + cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade) """ self.model_name = (model_name or self.DEFAULT_MODEL).strip() if not self.model_name: @@ -170,13 +176,33 @@ class SpladeEncoder: self.sparsity_threshold = float(sparsity_threshold) self.providers = providers + # Setup ONNX cache directory + if cache_dir: + self._cache_dir = Path(cache_dir) + else: + self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade" + self._tokenizer: Any | None = None self._model: Any | None = None self._vocab_size: int | None = None self._lock = threading.RLock() + def _get_local_cache_path(self) -> Path: + """Get local cache path for this model's ONNX files. + + Returns: + Path to the local ONNX cache directory for this model + """ + # Replace / with -- for filesystem-safe naming + safe_name = self.model_name.replace("/", "--") + return self._cache_dir / safe_name + def _load_model(self) -> None: - """Lazy load ONNX model and tokenizer.""" + """Lazy load ONNX model and tokenizer. + + First checks local cache for ONNX model, falling back to + HuggingFace download and conversion if not cached. + """ if self._model is not None and self._tokenizer is not None: return @@ -214,18 +240,48 @@ class SpladeEncoder: first = self.providers[0] provider_name = first[0] if isinstance(first, tuple) else str(first) model_kwargs["provider"] = provider_name - except Exception: + except Exception as e: + logger.debug(f"Failed to inspect ORTModel signature: {e}") model_kwargs = {} + # Check for local ONNX cache first + local_cache = self._get_local_cache_path() + onnx_model_path = local_cache / "model.onnx" + + if onnx_model_path.exists(): + # Load from local cache + logger.info(f"Loading SPLADE from local cache: {local_cache}") + try: + self._model = ORTModelForMaskedLM.from_pretrained( + str(local_cache), + **model_kwargs, + ) + self._tokenizer = AutoTokenizer.from_pretrained( + str(local_cache), use_fast=True + ) + self._vocab_size = len(self._tokenizer) + logger.info( + f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}" + ) + return + except Exception as e: + logger.warning(f"Failed to load from cache, redownloading: {e}") + + # Download and convert from HuggingFace + logger.info(f"Downloading SPLADE model: {self.model_name}") try: self._model = ORTModelForMaskedLM.from_pretrained( self.model_name, + export=True, # Export to ONNX **model_kwargs, ) logger.debug(f"SPLADE model loaded: {self.model_name}") except TypeError: # Fallback for older Optimum versions: retry without provider arguments - self._model = ORTModelForMaskedLM.from_pretrained(self.model_name) + self._model = ORTModelForMaskedLM.from_pretrained( + self.model_name, + export=True, + ) logger.warning( "Optimum version doesn't support provider parameters. " "Upgrade optimum for GPU acceleration: pip install --upgrade optimum" @@ -237,6 +293,15 @@ class SpladeEncoder: self._vocab_size = len(self._tokenizer) logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}") + # Save to local cache for future use + try: + local_cache.mkdir(parents=True, exist_ok=True) + self._model.save_pretrained(str(local_cache)) + self._tokenizer.save_pretrained(str(local_cache)) + logger.info(f"SPLADE model cached to: {local_cache}") + except Exception as e: + logger.warning(f"Failed to cache SPLADE model: {e}") + @staticmethod def _splade_activation(logits: Any, attention_mask: Any) -> Any: """Apply SPLADE activation function to model outputs. diff --git a/codex-lens/src/codexlens/storage/splade_index.py b/codex-lens/src/codexlens/storage/splade_index.py index 65fa6b41..6a7c2fa1 100644 --- a/codex-lens/src/codexlens/storage/splade_index.py +++ b/codex-lens/src/codexlens/storage/splade_index.py @@ -40,15 +40,25 @@ class SpladeIndex: self._local = threading.local() def _get_connection(self) -> sqlite3.Connection: - """Get or create a thread-local database connection.""" + """Get or create a thread-local database connection. + + Each thread gets its own connection to ensure thread safety. + Connections are stored in thread-local storage. + """ conn = getattr(self._local, "conn", None) if conn is None: - conn = sqlite3.connect(self.db_path, check_same_thread=False) + # Thread-local connection - each thread has its own + conn = sqlite3.connect( + self.db_path, + timeout=30.0, # Wait up to 30s for locks + check_same_thread=True, # Enforce thread safety + ) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") conn.execute("PRAGMA foreign_keys=ON") - conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit + # Limit mmap to 1GB to avoid OOM on smaller systems + conn.execute("PRAGMA mmap_size=1073741824") self._local.conn = conn return conn