feat(splade): add cache directory support for ONNX models and improve thread-local database connection handling

This commit is contained in:
catlog22
2026-01-01 22:40:00 +08:00
parent 5bb01755bc
commit 195438d26a
2 changed files with 81 additions and 6 deletions

View File

@@ -15,6 +15,7 @@ from __future__ import annotations
import logging import logging
import threading import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -68,6 +69,7 @@ def get_splade_encoder(
use_gpu: bool = True, use_gpu: bool = True,
max_length: int = 512, max_length: int = 512,
sparsity_threshold: float = 0.01, sparsity_threshold: float = 0.01,
cache_dir: Optional[str] = None,
) -> "SpladeEncoder": ) -> "SpladeEncoder":
"""Get or create cached SPLADE encoder (thread-safe singleton). """Get or create cached SPLADE encoder (thread-safe singleton).
@@ -80,6 +82,7 @@ def get_splade_encoder(
use_gpu: If True, use GPU acceleration when available use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector sparsity_threshold: Minimum weight to include in sparse vector
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
Returns: Returns:
Cached SpladeEncoder instance for the given configuration Cached SpladeEncoder instance for the given configuration
@@ -100,6 +103,7 @@ def get_splade_encoder(
use_gpu=use_gpu, use_gpu=use_gpu,
max_length=max_length, max_length=max_length,
sparsity_threshold=sparsity_threshold, sparsity_threshold=sparsity_threshold,
cache_dir=cache_dir,
) )
# Pre-load model to ensure it's ready # Pre-load model to ensure it's ready
encoder._load_model() encoder._load_model()
@@ -151,6 +155,7 @@ class SpladeEncoder:
max_length: int = 512, max_length: int = 512,
sparsity_threshold: float = 0.01, sparsity_threshold: float = 0.01,
providers: Optional[List[Any]] = None, providers: Optional[List[Any]] = None,
cache_dir: Optional[str] = None,
) -> None: ) -> None:
"""Initialize SPLADE encoder. """Initialize SPLADE encoder.
@@ -160,6 +165,7 @@ class SpladeEncoder:
max_length: Maximum sequence length for tokenization max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector sparsity_threshold: Minimum weight to include in sparse vector
providers: Explicit ONNX providers list (overrides use_gpu) providers: Explicit ONNX providers list (overrides use_gpu)
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
""" """
self.model_name = (model_name or self.DEFAULT_MODEL).strip() self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name: if not self.model_name:
@@ -170,13 +176,33 @@ class SpladeEncoder:
self.sparsity_threshold = float(sparsity_threshold) self.sparsity_threshold = float(sparsity_threshold)
self.providers = providers self.providers = providers
# Setup ONNX cache directory
if cache_dir:
self._cache_dir = Path(cache_dir)
else:
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
self._tokenizer: Any | None = None self._tokenizer: Any | None = None
self._model: Any | None = None self._model: Any | None = None
self._vocab_size: int | None = None self._vocab_size: int | None = None
self._lock = threading.RLock() self._lock = threading.RLock()
def _get_local_cache_path(self) -> Path:
"""Get local cache path for this model's ONNX files.
Returns:
Path to the local ONNX cache directory for this model
"""
# Replace / with -- for filesystem-safe naming
safe_name = self.model_name.replace("/", "--")
return self._cache_dir / safe_name
def _load_model(self) -> None: def _load_model(self) -> None:
"""Lazy load ONNX model and tokenizer.""" """Lazy load ONNX model and tokenizer.
First checks local cache for ONNX model, falling back to
HuggingFace download and conversion if not cached.
"""
if self._model is not None and self._tokenizer is not None: if self._model is not None and self._tokenizer is not None:
return return
@@ -214,18 +240,48 @@ class SpladeEncoder:
first = self.providers[0] first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first) provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name model_kwargs["provider"] = provider_name
except Exception: except Exception as e:
logger.debug(f"Failed to inspect ORTModel signature: {e}")
model_kwargs = {} model_kwargs = {}
# Check for local ONNX cache first
local_cache = self._get_local_cache_path()
onnx_model_path = local_cache / "model.onnx"
if onnx_model_path.exists():
# Load from local cache
logger.info(f"Loading SPLADE from local cache: {local_cache}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
str(local_cache),
**model_kwargs,
)
self._tokenizer = AutoTokenizer.from_pretrained(
str(local_cache), use_fast=True
)
self._vocab_size = len(self._tokenizer)
logger.info(
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
)
return
except Exception as e:
logger.warning(f"Failed to load from cache, redownloading: {e}")
# Download and convert from HuggingFace
logger.info(f"Downloading SPLADE model: {self.model_name}")
try: try:
self._model = ORTModelForMaskedLM.from_pretrained( self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name, self.model_name,
export=True, # Export to ONNX
**model_kwargs, **model_kwargs,
) )
logger.debug(f"SPLADE model loaded: {self.model_name}") logger.debug(f"SPLADE model loaded: {self.model_name}")
except TypeError: except TypeError:
# Fallback for older Optimum versions: retry without provider arguments # Fallback for older Optimum versions: retry without provider arguments
self._model = ORTModelForMaskedLM.from_pretrained(self.model_name) self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True,
)
logger.warning( logger.warning(
"Optimum version doesn't support provider parameters. " "Optimum version doesn't support provider parameters. "
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum" "Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
@@ -237,6 +293,15 @@ class SpladeEncoder:
self._vocab_size = len(self._tokenizer) self._vocab_size = len(self._tokenizer)
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}") logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
# Save to local cache for future use
try:
local_cache.mkdir(parents=True, exist_ok=True)
self._model.save_pretrained(str(local_cache))
self._tokenizer.save_pretrained(str(local_cache))
logger.info(f"SPLADE model cached to: {local_cache}")
except Exception as e:
logger.warning(f"Failed to cache SPLADE model: {e}")
@staticmethod @staticmethod
def _splade_activation(logits: Any, attention_mask: Any) -> Any: def _splade_activation(logits: Any, attention_mask: Any) -> Any:
"""Apply SPLADE activation function to model outputs. """Apply SPLADE activation function to model outputs.

View File

@@ -40,15 +40,25 @@ class SpladeIndex:
self._local = threading.local() self._local = threading.local()
def _get_connection(self) -> sqlite3.Connection: def _get_connection(self) -> sqlite3.Connection:
"""Get or create a thread-local database connection.""" """Get or create a thread-local database connection.
Each thread gets its own connection to ensure thread safety.
Connections are stored in thread-local storage.
"""
conn = getattr(self._local, "conn", None) conn = getattr(self._local, "conn", None)
if conn is None: if conn is None:
conn = sqlite3.connect(self.db_path, check_same_thread=False) # Thread-local connection - each thread has its own
conn = sqlite3.connect(
self.db_path,
timeout=30.0, # Wait up to 30s for locks
check_same_thread=True, # Enforce thread safety
)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL") conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA foreign_keys=ON") conn.execute("PRAGMA foreign_keys=ON")
conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit # Limit mmap to 1GB to avoid OOM on smaller systems
conn.execute("PRAGMA mmap_size=1073741824")
self._local.conn = conn self._local.conn = conn
return conn return conn