mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-13 02:41:50 +08:00
feat(splade): add cache directory support for ONNX models and improve thread-local database connection handling
This commit is contained in:
@@ -15,6 +15,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -68,6 +69,7 @@ def get_splade_encoder(
|
|||||||
use_gpu: bool = True,
|
use_gpu: bool = True,
|
||||||
max_length: int = 512,
|
max_length: int = 512,
|
||||||
sparsity_threshold: float = 0.01,
|
sparsity_threshold: float = 0.01,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
) -> "SpladeEncoder":
|
) -> "SpladeEncoder":
|
||||||
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||||
|
|
||||||
@@ -80,6 +82,7 @@ def get_splade_encoder(
|
|||||||
use_gpu: If True, use GPU acceleration when available
|
use_gpu: If True, use GPU acceleration when available
|
||||||
max_length: Maximum sequence length for tokenization
|
max_length: Maximum sequence length for tokenization
|
||||||
sparsity_threshold: Minimum weight to include in sparse vector
|
sparsity_threshold: Minimum weight to include in sparse vector
|
||||||
|
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Cached SpladeEncoder instance for the given configuration
|
Cached SpladeEncoder instance for the given configuration
|
||||||
@@ -100,6 +103,7 @@ def get_splade_encoder(
|
|||||||
use_gpu=use_gpu,
|
use_gpu=use_gpu,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
sparsity_threshold=sparsity_threshold,
|
sparsity_threshold=sparsity_threshold,
|
||||||
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
# Pre-load model to ensure it's ready
|
# Pre-load model to ensure it's ready
|
||||||
encoder._load_model()
|
encoder._load_model()
|
||||||
@@ -151,6 +155,7 @@ class SpladeEncoder:
|
|||||||
max_length: int = 512,
|
max_length: int = 512,
|
||||||
sparsity_threshold: float = 0.01,
|
sparsity_threshold: float = 0.01,
|
||||||
providers: Optional[List[Any]] = None,
|
providers: Optional[List[Any]] = None,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize SPLADE encoder.
|
"""Initialize SPLADE encoder.
|
||||||
|
|
||||||
@@ -160,6 +165,7 @@ class SpladeEncoder:
|
|||||||
max_length: Maximum sequence length for tokenization
|
max_length: Maximum sequence length for tokenization
|
||||||
sparsity_threshold: Minimum weight to include in sparse vector
|
sparsity_threshold: Minimum weight to include in sparse vector
|
||||||
providers: Explicit ONNX providers list (overrides use_gpu)
|
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||||
|
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||||
"""
|
"""
|
||||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||||
if not self.model_name:
|
if not self.model_name:
|
||||||
@@ -170,13 +176,33 @@ class SpladeEncoder:
|
|||||||
self.sparsity_threshold = float(sparsity_threshold)
|
self.sparsity_threshold = float(sparsity_threshold)
|
||||||
self.providers = providers
|
self.providers = providers
|
||||||
|
|
||||||
|
# Setup ONNX cache directory
|
||||||
|
if cache_dir:
|
||||||
|
self._cache_dir = Path(cache_dir)
|
||||||
|
else:
|
||||||
|
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
|
||||||
|
|
||||||
self._tokenizer: Any | None = None
|
self._tokenizer: Any | None = None
|
||||||
self._model: Any | None = None
|
self._model: Any | None = None
|
||||||
self._vocab_size: int | None = None
|
self._vocab_size: int | None = None
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
def _get_local_cache_path(self) -> Path:
|
||||||
|
"""Get local cache path for this model's ONNX files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the local ONNX cache directory for this model
|
||||||
|
"""
|
||||||
|
# Replace / with -- for filesystem-safe naming
|
||||||
|
safe_name = self.model_name.replace("/", "--")
|
||||||
|
return self._cache_dir / safe_name
|
||||||
|
|
||||||
def _load_model(self) -> None:
|
def _load_model(self) -> None:
|
||||||
"""Lazy load ONNX model and tokenizer."""
|
"""Lazy load ONNX model and tokenizer.
|
||||||
|
|
||||||
|
First checks local cache for ONNX model, falling back to
|
||||||
|
HuggingFace download and conversion if not cached.
|
||||||
|
"""
|
||||||
if self._model is not None and self._tokenizer is not None:
|
if self._model is not None and self._tokenizer is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -214,18 +240,48 @@ class SpladeEncoder:
|
|||||||
first = self.providers[0]
|
first = self.providers[0]
|
||||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||||
model_kwargs["provider"] = provider_name
|
model_kwargs["provider"] = provider_name
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to inspect ORTModel signature: {e}")
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
|
# Check for local ONNX cache first
|
||||||
|
local_cache = self._get_local_cache_path()
|
||||||
|
onnx_model_path = local_cache / "model.onnx"
|
||||||
|
|
||||||
|
if onnx_model_path.exists():
|
||||||
|
# Load from local cache
|
||||||
|
logger.info(f"Loading SPLADE from local cache: {local_cache}")
|
||||||
|
try:
|
||||||
|
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||||
|
str(local_cache),
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
str(local_cache), use_fast=True
|
||||||
|
)
|
||||||
|
self._vocab_size = len(self._tokenizer)
|
||||||
|
logger.info(
|
||||||
|
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load from cache, redownloading: {e}")
|
||||||
|
|
||||||
|
# Download and convert from HuggingFace
|
||||||
|
logger.info(f"Downloading SPLADE model: {self.model_name}")
|
||||||
try:
|
try:
|
||||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
|
export=True, # Export to ONNX
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Fallback for older Optimum versions: retry without provider arguments
|
# Fallback for older Optimum versions: retry without provider arguments
|
||||||
self._model = ORTModelForMaskedLM.from_pretrained(self.model_name)
|
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
export=True,
|
||||||
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Optimum version doesn't support provider parameters. "
|
"Optimum version doesn't support provider parameters. "
|
||||||
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||||
@@ -237,6 +293,15 @@ class SpladeEncoder:
|
|||||||
self._vocab_size = len(self._tokenizer)
|
self._vocab_size = len(self._tokenizer)
|
||||||
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||||
|
|
||||||
|
# Save to local cache for future use
|
||||||
|
try:
|
||||||
|
local_cache.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._model.save_pretrained(str(local_cache))
|
||||||
|
self._tokenizer.save_pretrained(str(local_cache))
|
||||||
|
logger.info(f"SPLADE model cached to: {local_cache}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cache SPLADE model: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||||
"""Apply SPLADE activation function to model outputs.
|
"""Apply SPLADE activation function to model outputs.
|
||||||
|
|||||||
@@ -40,15 +40,25 @@ class SpladeIndex:
|
|||||||
self._local = threading.local()
|
self._local = threading.local()
|
||||||
|
|
||||||
def _get_connection(self) -> sqlite3.Connection:
|
def _get_connection(self) -> sqlite3.Connection:
|
||||||
"""Get or create a thread-local database connection."""
|
"""Get or create a thread-local database connection.
|
||||||
|
|
||||||
|
Each thread gets its own connection to ensure thread safety.
|
||||||
|
Connections are stored in thread-local storage.
|
||||||
|
"""
|
||||||
conn = getattr(self._local, "conn", None)
|
conn = getattr(self._local, "conn", None)
|
||||||
if conn is None:
|
if conn is None:
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
# Thread-local connection - each thread has its own
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
self.db_path,
|
||||||
|
timeout=30.0, # Wait up to 30s for locks
|
||||||
|
check_same_thread=True, # Enforce thread safety
|
||||||
|
)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA synchronous=NORMAL")
|
conn.execute("PRAGMA synchronous=NORMAL")
|
||||||
conn.execute("PRAGMA foreign_keys=ON")
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit
|
# Limit mmap to 1GB to avoid OOM on smaller systems
|
||||||
|
conn.execute("PRAGMA mmap_size=1073741824")
|
||||||
self._local.conn = conn
|
self._local.conn = conn
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user