mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat(splade): add cache directory support for ONNX models and improve thread-local database connection handling
This commit is contained in:
@@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -68,6 +69,7 @@ def get_splade_encoder(
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> "SpladeEncoder":
|
||||
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||
|
||||
@@ -80,6 +82,7 @@ def get_splade_encoder(
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
|
||||
Returns:
|
||||
Cached SpladeEncoder instance for the given configuration
|
||||
@@ -100,6 +103,7 @@ def get_splade_encoder(
|
||||
use_gpu=use_gpu,
|
||||
max_length=max_length,
|
||||
sparsity_threshold=sparsity_threshold,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
# Pre-load model to ensure it's ready
|
||||
encoder._load_model()
|
||||
@@ -151,6 +155,7 @@ class SpladeEncoder:
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
providers: Optional[List[Any]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize SPLADE encoder.
|
||||
|
||||
@@ -160,6 +165,7 @@ class SpladeEncoder:
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
@@ -170,13 +176,33 @@ class SpladeEncoder:
|
||||
self.sparsity_threshold = float(sparsity_threshold)
|
||||
self.providers = providers
|
||||
|
||||
# Setup ONNX cache directory
|
||||
if cache_dir:
|
||||
self._cache_dir = Path(cache_dir)
|
||||
else:
|
||||
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._vocab_size: int | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _get_local_cache_path(self) -> Path:
|
||||
"""Get local cache path for this model's ONNX files.
|
||||
|
||||
Returns:
|
||||
Path to the local ONNX cache directory for this model
|
||||
"""
|
||||
# Replace / with -- for filesystem-safe naming
|
||||
safe_name = self.model_name.replace("/", "--")
|
||||
return self._cache_dir / safe_name
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load ONNX model and tokenizer."""
|
||||
"""Lazy load ONNX model and tokenizer.
|
||||
|
||||
First checks local cache for ONNX model, falling back to
|
||||
HuggingFace download and conversion if not cached.
|
||||
"""
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
@@ -214,18 +240,48 @@ class SpladeEncoder:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to inspect ORTModel signature: {e}")
|
||||
model_kwargs = {}
|
||||
|
||||
# Check for local ONNX cache first
|
||||
local_cache = self._get_local_cache_path()
|
||||
onnx_model_path = local_cache / "model.onnx"
|
||||
|
||||
if onnx_model_path.exists():
|
||||
# Load from local cache
|
||||
logger.info(f"Loading SPLADE from local cache: {local_cache}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
str(local_cache),
|
||||
**model_kwargs,
|
||||
)
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
str(local_cache), use_fast=True
|
||||
)
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.info(
|
||||
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load from cache, redownloading: {e}")
|
||||
|
||||
# Download and convert from HuggingFace
|
||||
logger.info(f"Downloading SPLADE model: {self.model_name}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True, # Export to ONNX
|
||||
**model_kwargs,
|
||||
)
|
||||
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(self.model_name)
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True,
|
||||
)
|
||||
logger.warning(
|
||||
"Optimum version doesn't support provider parameters. "
|
||||
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||
@@ -237,6 +293,15 @@ class SpladeEncoder:
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||
|
||||
# Save to local cache for future use
|
||||
try:
|
||||
local_cache.mkdir(parents=True, exist_ok=True)
|
||||
self._model.save_pretrained(str(local_cache))
|
||||
self._tokenizer.save_pretrained(str(local_cache))
|
||||
logger.info(f"SPLADE model cached to: {local_cache}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache SPLADE model: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||
"""Apply SPLADE activation function to model outputs.
|
||||
|
||||
@@ -40,15 +40,25 @@ class SpladeIndex:
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get or create a thread-local database connection."""
|
||||
"""Get or create a thread-local database connection.
|
||||
|
||||
Each thread gets its own connection to ensure thread safety.
|
||||
Connections are stored in thread-local storage.
|
||||
"""
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is None:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
# Thread-local connection - each thread has its own
|
||||
conn = sqlite3.connect(
|
||||
self.db_path,
|
||||
timeout=30.0, # Wait up to 30s for locks
|
||||
check_same_thread=True, # Enforce thread safety
|
||||
)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit
|
||||
# Limit mmap to 1GB to avoid OOM on smaller systems
|
||||
conn.execute("PRAGMA mmap_size=1073741824")
|
||||
self._local.conn = conn
|
||||
return conn
|
||||
|
||||
|
||||
Reference in New Issue
Block a user