feat(splade): add cache directory support for ONNX models and improve thread-local database connection handling

This commit is contained in:
catlog22
2026-01-01 22:40:00 +08:00
parent 5bb01755bc
commit 195438d26a
2 changed files with 81 additions and 6 deletions

View File

@@ -15,6 +15,7 @@ from __future__ import annotations
import logging
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
@@ -68,6 +69,7 @@ def get_splade_encoder(
use_gpu: bool = True,
max_length: int = 512,
sparsity_threshold: float = 0.01,
cache_dir: Optional[str] = None,
) -> "SpladeEncoder":
"""Get or create cached SPLADE encoder (thread-safe singleton).
@@ -80,6 +82,7 @@ def get_splade_encoder(
use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
Returns:
Cached SpladeEncoder instance for the given configuration
@@ -100,6 +103,7 @@ def get_splade_encoder(
use_gpu=use_gpu,
max_length=max_length,
sparsity_threshold=sparsity_threshold,
cache_dir=cache_dir,
)
# Pre-load model to ensure it's ready
encoder._load_model()
@@ -151,6 +155,7 @@ class SpladeEncoder:
max_length: int = 512,
sparsity_threshold: float = 0.01,
providers: Optional[List[Any]] = None,
cache_dir: Optional[str] = None,
) -> None:
"""Initialize SPLADE encoder.
@@ -160,6 +165,7 @@ class SpladeEncoder:
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
providers: Explicit ONNX providers list (overrides use_gpu)
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
@@ -170,13 +176,33 @@ class SpladeEncoder:
self.sparsity_threshold = float(sparsity_threshold)
self.providers = providers
# Setup ONNX cache directory
if cache_dir:
self._cache_dir = Path(cache_dir)
else:
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
self._tokenizer: Any | None = None
self._model: Any | None = None
self._vocab_size: int | None = None
self._lock = threading.RLock()
def _get_local_cache_path(self) -> Path:
"""Get local cache path for this model's ONNX files.
Returns:
Path to the local ONNX cache directory for this model
"""
# Replace / with -- for filesystem-safe naming
safe_name = self.model_name.replace("/", "--")
return self._cache_dir / safe_name
def _load_model(self) -> None:
"""Lazy load ONNX model and tokenizer."""
"""Lazy load ONNX model and tokenizer.
First checks local cache for ONNX model, falling back to
HuggingFace download and conversion if not cached.
"""
if self._model is not None and self._tokenizer is not None:
return
@@ -214,18 +240,48 @@ class SpladeEncoder:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception:
except Exception as e:
logger.debug(f"Failed to inspect ORTModel signature: {e}")
model_kwargs = {}
# Check for local ONNX cache first
local_cache = self._get_local_cache_path()
onnx_model_path = local_cache / "model.onnx"
if onnx_model_path.exists():
# Load from local cache
logger.info(f"Loading SPLADE from local cache: {local_cache}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
str(local_cache),
**model_kwargs,
)
self._tokenizer = AutoTokenizer.from_pretrained(
str(local_cache), use_fast=True
)
self._vocab_size = len(self._tokenizer)
logger.info(
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
)
return
except Exception as e:
logger.warning(f"Failed to load from cache, redownloading: {e}")
# Download and convert from HuggingFace
logger.info(f"Downloading SPLADE model: {self.model_name}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True, # Export to ONNX
**model_kwargs,
)
logger.debug(f"SPLADE model loaded: {self.model_name}")
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments
self._model = ORTModelForMaskedLM.from_pretrained(self.model_name)
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True,
)
logger.warning(
"Optimum version doesn't support provider parameters. "
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
@@ -237,6 +293,15 @@ class SpladeEncoder:
self._vocab_size = len(self._tokenizer)
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
# Save to local cache for future use
try:
local_cache.mkdir(parents=True, exist_ok=True)
self._model.save_pretrained(str(local_cache))
self._tokenizer.save_pretrained(str(local_cache))
logger.info(f"SPLADE model cached to: {local_cache}")
except Exception as e:
logger.warning(f"Failed to cache SPLADE model: {e}")
@staticmethod
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
"""Apply SPLADE activation function to model outputs.

View File

@@ -40,15 +40,25 @@ class SpladeIndex:
self._local = threading.local()
def _get_connection(self) -> sqlite3.Connection:
"""Get or create a thread-local database connection."""
"""Get or create a thread-local database connection.
Each thread gets its own connection to ensure thread safety.
Connections are stored in thread-local storage.
"""
conn = getattr(self._local, "conn", None)
if conn is None:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
# Thread-local connection - each thread has its own
conn = sqlite3.connect(
self.db_path,
timeout=30.0, # Wait up to 30s for locks
check_same_thread=True, # Enforce thread safety
)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit
# Limit mmap to 1GB to avoid OOM on smaller systems
conn.execute("PRAGMA mmap_size=1073741824")
self._local.conn = conn
return conn