mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
Implement SPLADE sparse encoder and associated database migrations
- Added `splade_encoder.py` for ONNX-optimized SPLADE encoding, including methods for encoding text and batch processing. - Created `SPLADE_IMPLEMENTATION.md` to document the SPLADE encoder's functionality, design patterns, and integration points. - Introduced migration script `migration_009_add_splade.py` to add SPLADE metadata and posting list tables to the database. - Developed `splade_index.py` for managing the SPLADE inverted index, supporting efficient sparse vector retrieval. - Added verification script `verify_watcher.py` to test FileWatcher event filtering and debouncing functionality.
This commit is contained in:
@@ -415,11 +415,20 @@ def search(
|
||||
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited, 0 = current only)."),
|
||||
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
|
||||
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
|
||||
weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."),
|
||||
weights: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--weights", "-w",
|
||||
help="RRF weights as key=value pairs (e.g., 'splade=0.4,vector=0.6' or 'exact=0.3,fuzzy=0.1,vector=0.6'). Default: auto-detect based on available backends."
|
||||
),
|
||||
use_fts: bool = typer.Option(
|
||||
False,
|
||||
"--use-fts",
|
||||
help="Use FTS (exact+fuzzy) instead of SPLADE for sparse retrieval"
|
||||
),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
"""Search indexed file contents using SQLite FTS5 or semantic vectors.
|
||||
"""Search indexed file contents using hybrid retrieval.
|
||||
|
||||
Uses chain search across directory indexes.
|
||||
Use --depth to limit search recursion (0 = current dir only).
|
||||
@@ -428,17 +437,27 @@ def search(
|
||||
- auto: Auto-detect (hybrid if embeddings exist, exact otherwise) [default]
|
||||
- exact: Exact FTS using unicode61 tokenizer - for code identifiers
|
||||
- fuzzy: Fuzzy FTS using trigram tokenizer - for typo-tolerant search
|
||||
- hybrid: RRF fusion of exact + fuzzy + vector (recommended) - best recall
|
||||
- vector: Vector search with exact FTS fallback - semantic + keyword
|
||||
- hybrid: RRF fusion of sparse + dense search (recommended) - best recall
|
||||
- vector: Vector search with sparse fallback - semantic + keyword
|
||||
- pure-vector: Pure semantic vector search only - natural language queries
|
||||
|
||||
SPLADE Mode:
|
||||
When SPLADE is available (pip install codex-lens[splade]), it automatically
|
||||
replaces FTS (exact+fuzzy) as the sparse retrieval backend. SPLADE provides
|
||||
semantic term expansion for better synonym handling.
|
||||
|
||||
Use --use-fts to force FTS mode instead of SPLADE.
|
||||
|
||||
Vector Search Requirements:
|
||||
Vector search modes require pre-generated embeddings.
|
||||
Use 'codexlens embeddings-generate' to create embeddings first.
|
||||
|
||||
Hybrid Mode:
|
||||
Default weights: exact=0.3, fuzzy=0.1, vector=0.6
|
||||
Use --weights to customize (e.g., --weights 0.5,0.3,0.2)
|
||||
Hybrid Mode Weights:
|
||||
Use --weights to adjust RRF fusion weights:
|
||||
- SPLADE mode: 'splade=0.4,vector=0.6' (default)
|
||||
- FTS mode: 'exact=0.3,fuzzy=0.1,vector=0.6' (default)
|
||||
|
||||
Legacy format also supported: '0.3,0.1,0.6' (exact,fuzzy,vector)
|
||||
|
||||
Examples:
|
||||
# Auto-detect mode (uses hybrid if embeddings available)
|
||||
@@ -450,11 +469,19 @@ def search(
|
||||
# Semantic search (requires embeddings)
|
||||
codexlens search "how to verify user credentials" --mode pure-vector
|
||||
|
||||
# Force hybrid mode
|
||||
codexlens search "authentication" --mode hybrid
|
||||
# Force hybrid mode with custom weights
|
||||
codexlens search "authentication" --mode hybrid --weights splade=0.5,vector=0.5
|
||||
|
||||
# Force FTS instead of SPLADE
|
||||
codexlens search "authentication" --use-fts
|
||||
"""
|
||||
_configure_logging(verbose, json_mode)
|
||||
search_path = path.expanduser().resolve()
|
||||
|
||||
# Configure search with FTS fallback if requested
|
||||
config = Config()
|
||||
if use_fts:
|
||||
config.use_fts_fallback = True
|
||||
|
||||
# Validate mode
|
||||
valid_modes = ["auto", "exact", "fuzzy", "hybrid", "vector", "pure-vector"]
|
||||
@@ -470,22 +497,56 @@ def search(
|
||||
hybrid_weights = None
|
||||
if weights:
|
||||
try:
|
||||
weight_parts = [float(w.strip()) for w in weights.split(",")]
|
||||
if len(weight_parts) == 3:
|
||||
weight_sum = sum(weight_parts)
|
||||
# Check if using key=value format (new) or legacy comma-separated format
|
||||
if "=" in weights:
|
||||
# New format: splade=0.4,vector=0.6 or exact=0.3,fuzzy=0.1,vector=0.6
|
||||
weight_dict = {}
|
||||
for pair in weights.split(","):
|
||||
if "=" in pair:
|
||||
key, val = pair.split("=", 1)
|
||||
weight_dict[key.strip()] = float(val.strip())
|
||||
else:
|
||||
raise ValueError("Mixed format not supported - use all key=value pairs")
|
||||
|
||||
# Validate and normalize weights
|
||||
weight_sum = sum(weight_dict.values())
|
||||
if abs(weight_sum - 1.0) > 0.01:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
# Normalize weights
|
||||
weight_parts = [w / weight_sum for w in weight_parts]
|
||||
hybrid_weights = {
|
||||
"exact": weight_parts[0],
|
||||
"fuzzy": weight_parts[1],
|
||||
"vector": weight_parts[2],
|
||||
}
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
weight_dict = {k: v / weight_sum for k, v in weight_dict.items()}
|
||||
|
||||
hybrid_weights = weight_dict
|
||||
else:
|
||||
console.print("[yellow]Warning: Invalid weights format (need 3 values). Using defaults.[/yellow]")
|
||||
except ValueError:
|
||||
console.print("[yellow]Warning: Invalid weights format. Using defaults.[/yellow]")
|
||||
# Legacy format: 0.3,0.1,0.6 (exact,fuzzy,vector)
|
||||
weight_parts = [float(w.strip()) for w in weights.split(",")]
|
||||
if len(weight_parts) == 3:
|
||||
weight_sum = sum(weight_parts)
|
||||
if abs(weight_sum - 1.0) > 0.01:
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
weight_parts = [w / weight_sum for w in weight_parts]
|
||||
hybrid_weights = {
|
||||
"exact": weight_parts[0],
|
||||
"fuzzy": weight_parts[1],
|
||||
"vector": weight_parts[2],
|
||||
}
|
||||
elif len(weight_parts) == 2:
|
||||
# Two values: assume splade,vector
|
||||
weight_sum = sum(weight_parts)
|
||||
if abs(weight_sum - 1.0) > 0.01:
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
weight_parts = [w / weight_sum for w in weight_parts]
|
||||
hybrid_weights = {
|
||||
"splade": weight_parts[0],
|
||||
"vector": weight_parts[1],
|
||||
}
|
||||
else:
|
||||
if not json_mode:
|
||||
console.print("[yellow]Warning: Invalid weights format. Using defaults.[/yellow]")
|
||||
except ValueError as e:
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Invalid weights format ({e}). Using defaults.[/yellow]")
|
||||
|
||||
registry: RegistryStore | None = None
|
||||
try:
|
||||
@@ -2381,6 +2442,188 @@ def gpu_reset(
|
||||
console.print(f" Device: [cyan]{gpu_info.gpu_name}[/cyan]")
|
||||
|
||||
|
||||
|
||||
# ==================== SPLADE Commands ====================
|
||||
|
||||
@app.command("splade-index")
|
||||
def splade_index_command(
|
||||
path: Path = typer.Argument(..., help="Project path to index"),
|
||||
rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""Generate SPLADE sparse index for existing codebase.
|
||||
|
||||
Encodes all semantic chunks with SPLADE model and builds inverted index
|
||||
for efficient sparse retrieval.
|
||||
|
||||
Examples:
|
||||
codexlens splade-index ~/projects/my-app
|
||||
codexlens splade-index . --rebuild
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
|
||||
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
|
||||
# Check SPLADE availability
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
console.print(f"[red]SPLADE not available: {err}[/red]")
|
||||
console.print("[dim]Install with: pip install transformers torch[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Find index database
|
||||
target_path = path.expanduser().resolve()
|
||||
|
||||
# Try to find _index.db
|
||||
if target_path.is_file() and target_path.name == "_index.db":
|
||||
index_db = target_path
|
||||
elif target_path.is_dir():
|
||||
# Check for local .codexlens/_index.db
|
||||
local_index = target_path / ".codexlens" / "_index.db"
|
||||
if local_index.exists():
|
||||
index_db = local_index
|
||||
else:
|
||||
# Try to find via registry
|
||||
registry = RegistryStore()
|
||||
try:
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
index_db = mapper.source_to_index_db(target_path)
|
||||
if not index_db.exists():
|
||||
console.print(f"[red]Error:[/red] No index found for {target_path}")
|
||||
console.print("Run 'codexlens init' first to create an index")
|
||||
raise typer.Exit(1)
|
||||
finally:
|
||||
registry.close()
|
||||
else:
|
||||
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
|
||||
raise typer.Exit(1)
|
||||
|
||||
splade_db = index_db.parent / "_splade.db"
|
||||
|
||||
if splade_db.exists() and not rebuild:
|
||||
console.print("[yellow]SPLADE index exists. Use --rebuild to regenerate.[/yellow]")
|
||||
return
|
||||
|
||||
# Load chunks from vector store
|
||||
console.print(f"[blue]Loading chunks from {index_db.name}...[/blue]")
|
||||
vector_store = VectorStore(index_db)
|
||||
chunks = vector_store.get_all_chunks()
|
||||
|
||||
if not chunks:
|
||||
console.print("[yellow]No chunks found in vector store[/yellow]")
|
||||
console.print("[dim]Generate embeddings first with 'codexlens embeddings-generate'[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[blue]Encoding {len(chunks)} chunks with SPLADE...[/blue]")
|
||||
|
||||
# Initialize SPLADE
|
||||
encoder = get_splade_encoder()
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
splade_index.create_tables()
|
||||
|
||||
# Encode in batches with progress bar
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Encoding...", total=len(chunks))
|
||||
for chunk in chunks:
|
||||
sparse_vec = encoder.encode_text(chunk.content)
|
||||
splade_index.add_posting(chunk.id, sparse_vec)
|
||||
progress.advance(task)
|
||||
|
||||
# Set metadata
|
||||
splade_index.set_metadata(
|
||||
model_name=encoder.model_name,
|
||||
vocab_size=encoder.vocab_size
|
||||
)
|
||||
|
||||
stats = splade_index.get_stats()
|
||||
console.print(f"[green]✓[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings")
|
||||
console.print(f" Database: [dim]{splade_db}[/dim]")
|
||||
|
||||
|
||||
@app.command("splade-status")
|
||||
def splade_status_command(
|
||||
path: Path = typer.Argument(..., help="Project path"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""Show SPLADE index status and statistics.
|
||||
|
||||
Examples:
|
||||
codexlens splade-status ~/projects/my-app
|
||||
codexlens splade-status .
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.splade_encoder import check_splade_available
|
||||
|
||||
# Find index database
|
||||
target_path = path.expanduser().resolve()
|
||||
|
||||
if target_path.is_file() and target_path.name == "_index.db":
|
||||
splade_db = target_path.parent / "_splade.db"
|
||||
elif target_path.is_dir():
|
||||
# Check for local .codexlens/_splade.db
|
||||
local_splade = target_path / ".codexlens" / "_splade.db"
|
||||
if local_splade.exists():
|
||||
splade_db = local_splade
|
||||
else:
|
||||
# Try to find via registry
|
||||
registry = RegistryStore()
|
||||
try:
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
index_db = mapper.source_to_index_db(target_path)
|
||||
splade_db = index_db.parent / "_splade.db"
|
||||
finally:
|
||||
registry.close()
|
||||
else:
|
||||
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not splade_db.exists():
|
||||
console.print("[yellow]No SPLADE index found[/yellow]")
|
||||
console.print(f"[dim]Run 'codexlens splade-index {path}' to create one[/dim]")
|
||||
return
|
||||
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
|
||||
if not splade_index.has_index():
|
||||
console.print("[yellow]SPLADE tables not initialized[/yellow]")
|
||||
return
|
||||
|
||||
metadata = splade_index.get_metadata()
|
||||
stats = splade_index.get_stats()
|
||||
|
||||
# Create status table
|
||||
table = Table(title="SPLADE Index Status", show_header=False)
|
||||
table.add_column("Property", style="cyan")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Database", str(splade_db))
|
||||
if metadata:
|
||||
table.add_row("Model", metadata['model_name'])
|
||||
table.add_row("Vocab Size", str(metadata['vocab_size']))
|
||||
table.add_row("Chunks", str(stats['unique_chunks']))
|
||||
table.add_row("Unique Tokens", str(stats['unique_tokens']))
|
||||
table.add_row("Total Postings", str(stats['total_postings']))
|
||||
|
||||
ok, err = check_splade_available()
|
||||
status_text = "[green]Yes[/green]" if ok else f"[red]No[/red] - {err}"
|
||||
table.add_row("SPLADE Available", status_text)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
# ==================== Watch Command ====================
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -33,6 +33,15 @@ def _cleanup_fastembed_resources() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _cleanup_splade_resources() -> None:
|
||||
"""Release SPLADE encoder ONNX resources."""
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import clear_splade_cache
|
||||
clear_splade_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _generate_chunks_from_cursor(
|
||||
cursor,
|
||||
chunker,
|
||||
@@ -675,10 +684,96 @@ def generate_embeddings(
|
||||
if progress_callback:
|
||||
progress_callback(f"Finalizing index... Building ANN index for {total_chunks_created} chunks")
|
||||
|
||||
# --- SPLADE SPARSE ENCODING (after dense embeddings) ---
|
||||
# Add SPLADE encoding if enabled in config
|
||||
splade_success = False
|
||||
splade_error = None
|
||||
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
config = Config.load()
|
||||
|
||||
if config.enable_splade:
|
||||
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if ok:
|
||||
if progress_callback:
|
||||
progress_callback(f"Generating SPLADE sparse vectors for {total_chunks_created} chunks...")
|
||||
|
||||
# Initialize SPLADE encoder and index
|
||||
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
|
||||
# Use main index database for SPLADE (not separate _splade.db)
|
||||
splade_index = SpladeIndex(index_path)
|
||||
splade_index.create_tables()
|
||||
|
||||
# Retrieve all chunks from database for SPLADE encoding
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute("SELECT id, content FROM semantic_chunks ORDER BY id")
|
||||
|
||||
# Batch encode for efficiency
|
||||
SPLADE_BATCH_SIZE = 32
|
||||
batch_postings = []
|
||||
chunk_batch = []
|
||||
chunk_ids = []
|
||||
|
||||
for row in cursor:
|
||||
chunk_id = row["id"]
|
||||
content = row["content"]
|
||||
|
||||
chunk_ids.append(chunk_id)
|
||||
chunk_batch.append(content)
|
||||
|
||||
# Process batch when full
|
||||
if len(chunk_batch) >= SPLADE_BATCH_SIZE:
|
||||
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
|
||||
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
|
||||
batch_postings.append((cid, sparse_vec))
|
||||
|
||||
chunk_batch = []
|
||||
chunk_ids = []
|
||||
|
||||
# Process remaining chunks
|
||||
if chunk_batch:
|
||||
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
|
||||
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
|
||||
batch_postings.append((cid, sparse_vec))
|
||||
|
||||
# Batch insert all postings
|
||||
if batch_postings:
|
||||
splade_index.add_postings_batch(batch_postings)
|
||||
|
||||
# Set metadata
|
||||
splade_index.set_metadata(
|
||||
model_name=splade_encoder.model_name,
|
||||
vocab_size=splade_encoder.vocab_size
|
||||
)
|
||||
|
||||
splade_success = True
|
||||
if progress_callback:
|
||||
stats = splade_index.get_stats()
|
||||
progress_callback(
|
||||
f"SPLADE index created: {stats['total_postings']} postings, "
|
||||
f"{stats['unique_tokens']} unique tokens"
|
||||
)
|
||||
else:
|
||||
logger.debug("SPLADE not available: %s", err)
|
||||
splade_error = f"SPLADE not available: {err}"
|
||||
except Exception as e:
|
||||
splade_error = str(e)
|
||||
logger.warning("SPLADE encoding failed: %s", e)
|
||||
|
||||
# Report SPLADE status after processing
|
||||
if progress_callback and not splade_success and splade_error:
|
||||
progress_callback(f"SPLADE index: FAILED - {splade_error}")
|
||||
|
||||
except Exception as e:
|
||||
# Cleanup on error to prevent process hanging
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -690,6 +785,7 @@ def generate_embeddings(
|
||||
# This is critical - without it, ONNX Runtime threads prevent Python from exiting
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -874,6 +970,7 @@ def generate_embeddings_recursive(
|
||||
# Each generate_embeddings() call does its own cleanup, but do a final one to be safe
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -103,6 +103,15 @@ class Config:
|
||||
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||
|
||||
# SPLADE sparse retrieval configuration
|
||||
enable_splade: bool = True # Enable SPLADE as default sparse backend
|
||||
splade_model: str = "naver/splade-cocondenser-ensembledistil"
|
||||
splade_threshold: float = 0.01 # Min weight to store in index
|
||||
splade_onnx_path: Optional[str] = None # Custom ONNX model path
|
||||
|
||||
# FTS fallback (disabled by default, available via --use-fts)
|
||||
use_fts_fallback: bool = False # Use FTS instead of SPLADE
|
||||
|
||||
# Indexing/search optimizations
|
||||
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError, as_completed
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -33,6 +33,8 @@ def timer(name: str, logger: logging.Logger, level: int = logging.DEBUG):
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import (
|
||||
DEFAULT_WEIGHTS,
|
||||
FTS_FALLBACK_WEIGHTS,
|
||||
apply_symbol_boost,
|
||||
cross_encoder_rerank,
|
||||
get_rrf_weights,
|
||||
@@ -54,12 +56,9 @@ class HybridSearchEngine:
|
||||
default_weights: Default RRF weights for each source
|
||||
"""
|
||||
|
||||
# Default RRF weights (vector: 60%, exact: 30%, fuzzy: 10%)
|
||||
DEFAULT_WEIGHTS = {
|
||||
"exact": 0.3,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.6,
|
||||
}
|
||||
# NOTE: DEFAULT_WEIGHTS imported from ranking.py - single source of truth
|
||||
# Default RRF weights: SPLADE-based hybrid (splade: 0.4, vector: 0.6)
|
||||
# FTS fallback mode uses FTS_FALLBACK_WEIGHTS (exact: 0.3, fuzzy: 0.1, vector: 0.6)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -75,10 +74,11 @@ class HybridSearchEngine:
|
||||
embedder: Optional embedder instance for embedding-based reranking
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.weights = weights or self.DEFAULT_WEIGHTS.copy()
|
||||
self.weights = weights or DEFAULT_WEIGHTS.copy()
|
||||
self._config = config
|
||||
self.embedder = embedder
|
||||
self.reranker: Any = None
|
||||
self._use_gpu = config.embedding_use_gpu if config else True
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -124,6 +124,26 @@ class HybridSearchEngine:
|
||||
|
||||
# Determine which backends to use
|
||||
backends = {}
|
||||
|
||||
# Check if SPLADE is available
|
||||
splade_available = False
|
||||
# Respect config.enable_splade flag and use_fts_fallback flag
|
||||
if self._config and getattr(self._config, 'use_fts_fallback', False):
|
||||
# Config explicitly requests FTS fallback - disable SPLADE
|
||||
splade_available = False
|
||||
elif self._config and not getattr(self._config, 'enable_splade', True):
|
||||
# Config explicitly disabled SPLADE
|
||||
splade_available = False
|
||||
else:
|
||||
# Check if SPLADE dependencies are available
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import check_splade_available
|
||||
ok, _ = check_splade_available()
|
||||
if ok:
|
||||
# SPLADE tables are in main index database, will check table existence in _search_splade
|
||||
splade_available = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if pure_vector:
|
||||
# Pure vector mode: only use vector search, no FTS fallback
|
||||
@@ -138,12 +158,19 @@ class HybridSearchEngine:
|
||||
)
|
||||
backends["exact"] = True
|
||||
else:
|
||||
# Hybrid mode: always include exact search as baseline
|
||||
backends["exact"] = True
|
||||
if enable_fuzzy:
|
||||
backends["fuzzy"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
# Hybrid mode: default to SPLADE if available, otherwise use FTS
|
||||
if splade_available:
|
||||
# Default: enable SPLADE, disable exact and fuzzy
|
||||
backends["splade"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
else:
|
||||
# Fallback mode: enable exact+fuzzy when SPLADE unavailable
|
||||
backends["exact"] = True
|
||||
if enable_fuzzy:
|
||||
backends["fuzzy"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
|
||||
# Execute parallel searches
|
||||
with timer("parallel_search_total", self.logger):
|
||||
@@ -354,23 +381,40 @@ class HybridSearchEngine:
|
||||
)
|
||||
future_to_source[future] = "vector"
|
||||
|
||||
# Collect results as they complete
|
||||
for future in as_completed(future_to_source):
|
||||
source = future_to_source[future]
|
||||
elapsed_ms = (time.perf_counter() - submit_times[source]) * 1000
|
||||
timing_data[source] = elapsed_ms
|
||||
try:
|
||||
results = future.result()
|
||||
# Tag results with source for debugging
|
||||
tagged_results = tag_search_source(results, source)
|
||||
results_map[source] = tagged_results
|
||||
self.logger.debug(
|
||||
"[TIMING] %s_search: %.2fms (%d results)",
|
||||
source, elapsed_ms, len(results)
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.error("Search failed for %s: %s", source, exc)
|
||||
results_map[source] = []
|
||||
if backends.get("splade"):
|
||||
submit_times["splade"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
self._search_splade, index_path, query, limit
|
||||
)
|
||||
future_to_source[future] = "splade"
|
||||
|
||||
# Collect results as they complete with timeout protection
|
||||
try:
|
||||
for future in as_completed(future_to_source, timeout=30.0):
|
||||
source = future_to_source[future]
|
||||
elapsed_ms = (time.perf_counter() - submit_times[source]) * 1000
|
||||
timing_data[source] = elapsed_ms
|
||||
try:
|
||||
results = future.result(timeout=10.0)
|
||||
# Tag results with source for debugging
|
||||
tagged_results = tag_search_source(results, source)
|
||||
results_map[source] = tagged_results
|
||||
self.logger.debug(
|
||||
"[TIMING] %s_search: %.2fms (%d results)",
|
||||
source, elapsed_ms, len(results)
|
||||
)
|
||||
except (Exception, FuturesTimeoutError) as exc:
|
||||
self.logger.error("Search failed for %s: %s", source, exc)
|
||||
results_map[source] = []
|
||||
except FuturesTimeoutError:
|
||||
self.logger.warning("Search timeout: some backends did not respond in time")
|
||||
# Cancel remaining futures
|
||||
for future in future_to_source:
|
||||
future.cancel()
|
||||
# Set empty results for sources that didn't complete
|
||||
for source in backends:
|
||||
if source not in results_map:
|
||||
results_map[source] = []
|
||||
|
||||
# Log timing summary
|
||||
if timing_data:
|
||||
@@ -564,3 +608,113 @@ class HybridSearchEngine:
|
||||
except Exception as exc:
|
||||
self.logger.error("Vector search error: %s", exc)
|
||||
return []
|
||||
|
||||
def _search_splade(
|
||||
self, index_path: Path, query: str, limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""SPLADE sparse retrieval via inverted index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
query: Natural language query string
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of SearchResult ordered by SPLADE score
|
||||
"""
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
import sqlite3
|
||||
import json
|
||||
|
||||
# Check dependencies
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
self.logger.debug("SPLADE not available: %s", err)
|
||||
return []
|
||||
|
||||
# Use main index database (SPLADE tables are in _index.db, not separate _splade.db)
|
||||
splade_index = SpladeIndex(index_path)
|
||||
if not splade_index.has_index():
|
||||
self.logger.debug("SPLADE index not initialized")
|
||||
return []
|
||||
|
||||
# Encode query to sparse vector
|
||||
encoder = get_splade_encoder(use_gpu=self._use_gpu)
|
||||
query_sparse = encoder.encode_text(query)
|
||||
|
||||
# Search inverted index for top matches
|
||||
raw_results = splade_index.search(query_sparse, limit=limit, min_score=0.0)
|
||||
|
||||
if not raw_results:
|
||||
return []
|
||||
|
||||
# Fetch chunk details from main index database
|
||||
chunk_ids = [chunk_id for chunk_id, _ in raw_results]
|
||||
score_map = {chunk_id: score for chunk_id, score in raw_results}
|
||||
|
||||
# Query semantic_chunks table for full details
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT id, file_path, content, metadata
|
||||
FROM semantic_chunks
|
||||
WHERE id IN ({placeholders})
|
||||
""",
|
||||
chunk_ids
|
||||
).fetchall()
|
||||
|
||||
# Build SearchResult objects
|
||||
results = []
|
||||
for row in rows:
|
||||
chunk_id = row["id"]
|
||||
file_path = row["file_path"]
|
||||
content = row["content"]
|
||||
metadata_json = row["metadata"]
|
||||
metadata = json.loads(metadata_json) if metadata_json else {}
|
||||
|
||||
score = score_map.get(chunk_id, 0.0)
|
||||
|
||||
# Build excerpt (short preview)
|
||||
excerpt = content[:200] + "..." if len(content) > 200 else content
|
||||
|
||||
# Extract symbol information from metadata
|
||||
symbol_name = metadata.get("symbol_name")
|
||||
symbol_kind = metadata.get("symbol_kind")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
|
||||
# Build Symbol object if we have symbol info
|
||||
symbol = None
|
||||
if symbol_name and symbol_kind and start_line and end_line:
|
||||
try:
|
||||
from codexlens.entities import Symbol
|
||||
symbol = Symbol(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind,
|
||||
range=(start_line, end_line)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
results.append(SearchResult(
|
||||
path=file_path,
|
||||
score=score,
|
||||
excerpt=excerpt,
|
||||
content=content,
|
||||
symbol=symbol,
|
||||
metadata=metadata,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.debug("SPLADE search error: %s", exc)
|
||||
return []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Ranking algorithms for hybrid search result fusion.
|
||||
|
||||
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
|
||||
for combining results from heterogeneous search backends (exact FTS, fuzzy FTS, vector search).
|
||||
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,6 +14,20 @@ from typing import Any, Dict, List
|
||||
from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
|
||||
# Default RRF weights for SPLADE-based hybrid search
|
||||
DEFAULT_WEIGHTS = {
|
||||
"splade": 0.4, # Replaces exact(0.3) + fuzzy(0.1)
|
||||
"vector": 0.6,
|
||||
}
|
||||
|
||||
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
||||
FTS_FALLBACK_WEIGHTS = {
|
||||
"exact": 0.3,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.6,
|
||||
}
|
||||
|
||||
|
||||
class QueryIntent(str, Enum):
|
||||
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
||||
|
||||
@@ -87,15 +101,24 @@ def adjust_weights_by_intent(
|
||||
intent: QueryIntent,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Map intent → weights (kept aligned with TypeScript mapping)."""
|
||||
"""Adjust RRF weights based on query intent."""
|
||||
# Check if using SPLADE or FTS mode
|
||||
use_splade = "splade" in base_weights
|
||||
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
if use_splade:
|
||||
target = {"splade": 0.6, "vector": 0.4}
|
||||
else:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
if use_splade:
|
||||
target = {"splade": 0.3, "vector": 0.7}
|
||||
else:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
else:
|
||||
target = dict(base_weights)
|
||||
|
||||
# Preserve only keys that are present in base_weights (active backends).
|
||||
|
||||
# Filter to active backends
|
||||
keys = list(base_weights.keys())
|
||||
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||
return normalize_weights(filtered)
|
||||
|
||||
225
codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md
Normal file
225
codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,225 @@
|
||||
# SPLADE Encoder Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
Created `splade_encoder.py` - A complete ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
## File Location
|
||||
|
||||
`src/codexlens/semantic/splade_encoder.py` (474 lines)
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Dependency Checking
|
||||
|
||||
**Function**: `check_splade_available() -> Tuple[bool, Optional[str]]`
|
||||
- Validates numpy, onnxruntime, optimum, transformers availability
|
||||
- Returns (True, None) if all dependencies present
|
||||
- Returns (False, error_message) with install instructions if missing
|
||||
|
||||
### 2. Caching System
|
||||
|
||||
**Global Cache**: Thread-safe singleton pattern
|
||||
- `_splade_cache: Dict[str, SpladeEncoder]` - Global encoder cache
|
||||
- `_cache_lock: threading.RLock()` - Thread safety lock
|
||||
|
||||
**Factory Function**: `get_splade_encoder(...) -> SpladeEncoder`
|
||||
- Cache key includes: model_name, gpu/cpu, max_length, sparsity_threshold
|
||||
- Pre-loads model on first access
|
||||
- Returns cached instance on subsequent calls
|
||||
|
||||
**Cleanup Function**: `clear_splade_cache() -> None`
|
||||
- Releases ONNX resources
|
||||
- Clears model and tokenizer references
|
||||
- Prevents memory leaks
|
||||
|
||||
### 3. SpladeEncoder Class
|
||||
|
||||
#### Initialization Parameters
|
||||
- `model_name: str` - Default: "naver/splade-cocondenser-ensembledistil"
|
||||
- `use_gpu: bool` - Enable GPU acceleration (default: True)
|
||||
- `max_length: int` - Max sequence length (default: 512)
|
||||
- `sparsity_threshold: float` - Min weight threshold (default: 0.01)
|
||||
- `providers: Optional[List]` - Explicit ONNX providers (overrides use_gpu)
|
||||
|
||||
#### Core Methods
|
||||
|
||||
**`_load_model()`**: Lazy loading with GPU support
|
||||
- Uses `optimum.onnxruntime.ORTModelForMaskedLM`
|
||||
- Falls back to CPU if GPU unavailable
|
||||
- Integrates with `gpu_support.get_optimal_providers()`
|
||||
- Handles device_id options for DirectML/CUDA
|
||||
|
||||
**`_splade_activation(logits, attention_mask)`**: Static method
|
||||
- Formula: `log(1 + ReLU(logits)) * attention_mask`
|
||||
- Input: (batch, seq_len, vocab_size)
|
||||
- Output: (batch, seq_len, vocab_size)
|
||||
|
||||
**`_max_pooling(splade_repr)`**: Static method
|
||||
- Max pooling over sequence dimension
|
||||
- Input: (batch, seq_len, vocab_size)
|
||||
- Output: (batch, vocab_size)
|
||||
|
||||
**`_to_sparse_dict(dense_vec)`**: Conversion helper
|
||||
- Filters by sparsity_threshold
|
||||
- Returns: `Dict[int, float]` mapping token_id to weight
|
||||
|
||||
**`encode_text(text: str) -> Dict[int, float]`**: Single text encoding
|
||||
- Tokenizes input with truncation/padding
|
||||
- Forward pass through ONNX model
|
||||
- Applies SPLADE activation + max pooling
|
||||
- Returns sparse vector
|
||||
|
||||
**`encode_batch(texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]`**: Batch encoding
|
||||
- Processes in batches for memory efficiency
|
||||
- Returns list of sparse vectors
|
||||
|
||||
#### Properties
|
||||
|
||||
**`vocab_size: int`**: Vocabulary size (~30k for BERT)
|
||||
- Cached after first model load
|
||||
- Returns tokenizer length
|
||||
|
||||
#### Debugging Methods
|
||||
|
||||
**`get_token(token_id: int) -> str`**
|
||||
- Converts token_id to human-readable string
|
||||
- Uses tokenizer.decode()
|
||||
|
||||
**`get_top_tokens(sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]`**
|
||||
- Extracts top-k highest-weight tokens
|
||||
- Returns (token_string, weight) pairs
|
||||
- Useful for understanding model focus
|
||||
|
||||
## Design Patterns Followed
|
||||
|
||||
### 1. From `onnx_reranker.py`
|
||||
✓ ONNX loading with provider detection
|
||||
✓ Lazy model initialization
|
||||
✓ Thread-safe loading with RLock
|
||||
✓ Signature inspection for backward compatibility
|
||||
✓ Fallback for older Optimum versions
|
||||
✓ Static helper methods for numerical operations
|
||||
|
||||
### 2. From `embedder.py`
|
||||
✓ Global cache with thread safety
|
||||
✓ Factory function pattern (get_splade_encoder)
|
||||
✓ Cache cleanup function (clear_splade_cache)
|
||||
✓ GPU provider configuration
|
||||
✓ Batch processing support
|
||||
|
||||
### 3. From `gpu_support.py`
|
||||
✓ `get_optimal_providers(use_gpu, with_device_options=True)`
|
||||
✓ Device ID options for DirectML/CUDA
|
||||
✓ Provider tuple format: (provider_name, options_dict)
|
||||
|
||||
## SPLADE Algorithm
|
||||
|
||||
### Activation Formula
|
||||
```python
|
||||
# Step 1: ReLU activation
|
||||
relu_logits = max(0, logits)
|
||||
|
||||
# Step 2: Log(1 + x) transformation
|
||||
log_relu = log(1 + relu_logits)
|
||||
|
||||
# Step 3: Apply attention mask
|
||||
splade_repr = log_relu * attention_mask
|
||||
|
||||
# Step 4: Max pooling over sequence
|
||||
splade_vec = max(splade_repr, axis=sequence_length)
|
||||
|
||||
# Step 5: Sparsification by threshold
|
||||
sparse_dict = {token_id: weight for token_id, weight in enumerate(splade_vec) if weight > threshold}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
- Sparse dictionary: `{token_id: weight}`
|
||||
- Token IDs: 0 to vocab_size-1 (typically ~30,000)
|
||||
- Weights: Float values > sparsity_threshold
|
||||
- Interpretable: Can decode token_ids to strings
|
||||
|
||||
## Integration Points
|
||||
|
||||
### With `splade_index.py`
|
||||
- `SpladeIndex.add_posting(chunk_id, sparse_vec: Dict[int, float])`
|
||||
- `SpladeIndex.search(query_sparse: Dict[int, float])`
|
||||
- Encoder produces the sparse vectors consumed by index
|
||||
|
||||
### With Indexing Pipeline
|
||||
```python
|
||||
encoder = get_splade_encoder(use_gpu=True)
|
||||
|
||||
# Single document
|
||||
sparse_vec = encoder.encode_text("def main():\n print('hello')")
|
||||
index.add_posting(chunk_id=1, sparse_vec=sparse_vec)
|
||||
|
||||
# Batch indexing
|
||||
texts = ["code chunk 1", "code chunk 2", ...]
|
||||
sparse_vecs = encoder.encode_batch(texts, batch_size=64)
|
||||
postings = [(chunk_id, vec) for chunk_id, vec in enumerate(sparse_vecs)]
|
||||
index.add_postings_batch(postings)
|
||||
```
|
||||
|
||||
### With Search Pipeline
|
||||
```python
|
||||
encoder = get_splade_encoder(use_gpu=True)
|
||||
query_sparse = encoder.encode_text("authentication function")
|
||||
results = index.search(query_sparse, limit=50, min_score=0.5)
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
Required packages:
|
||||
- `numpy` - Numerical operations
|
||||
- `onnxruntime` - ONNX model execution (CPU)
|
||||
- `onnxruntime-gpu` - ONNX with GPU support (optional)
|
||||
- `optimum[onnxruntime]` - Hugging Face ONNX optimization
|
||||
- `transformers` - Tokenizer and model loading
|
||||
|
||||
Install command:
|
||||
```bash
|
||||
# CPU only
|
||||
pip install numpy onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
# With GPU support
|
||||
pip install numpy onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
```
|
||||
|
||||
## Testing Status
|
||||
|
||||
✓ Python syntax validation passed
|
||||
✓ Module import successful
|
||||
✓ Dependency checking works correctly
|
||||
✗ Full functional test pending (requires optimum installation)
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Install dependencies for functional testing
|
||||
2. Create unit tests in `tests/semantic/test_splade_encoder.py`
|
||||
3. Benchmark encoding performance (CPU vs GPU)
|
||||
4. Integrate with codex-lens indexing pipeline
|
||||
5. Add SPLADE option to semantic search configuration
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
- Model size: ~100MB (ONNX optimized)
|
||||
- Sparse vectors: ~100-500 non-zero entries per document
|
||||
- Batch size: 32 recommended (adjust based on GPU memory)
|
||||
|
||||
### Speed Benchmarks (Expected)
|
||||
- CPU encoding: ~10-20 docs/sec
|
||||
- GPU encoding (CUDA): ~100-200 docs/sec
|
||||
- GPU encoding (DirectML): ~50-100 docs/sec
|
||||
|
||||
### Sparsity Analysis
|
||||
- Threshold 0.01: ~200-400 tokens per document
|
||||
- Threshold 0.05: ~100-200 tokens per document
|
||||
- Threshold 0.10: ~50-100 tokens per document
|
||||
|
||||
## References
|
||||
|
||||
- SPLADE paper: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
- Naver model: https://huggingface.co/naver/splade-cocondenser-ensembledistil
|
||||
474
codex-lens/src/codexlens/semantic/splade_encoder.py
Normal file
474
codex-lens/src/codexlens/semantic/splade_encoder.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
|
||||
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
|
||||
that combine the interpretability of BM25 with neural relevance modeling.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
Install (GPU):
|
||||
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_splade_available() -> Tuple[bool, Optional[str]]:
|
||||
"""Check whether SPLADE dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (available: bool, error_message: Optional[str])
|
||||
"""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc:
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# Global cache for SPLADE encoders (singleton pattern)
|
||||
_splade_cache: Dict[str, "SpladeEncoder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_splade_encoder(
|
||||
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
) -> "SpladeEncoder":
|
||||
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
SpladeEncoder instances across multiple searches, avoiding repeated model
|
||||
loading overhead.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
|
||||
Returns:
|
||||
Cached SpladeEncoder instance for the given configuration
|
||||
"""
|
||||
global _splade_cache
|
||||
|
||||
# Cache key includes all configuration parameters
|
||||
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
|
||||
|
||||
with _cache_lock:
|
||||
encoder = _splade_cache.get(cache_key)
|
||||
if encoder is not None:
|
||||
return encoder
|
||||
|
||||
# Create new encoder and cache it
|
||||
encoder = SpladeEncoder(
|
||||
model_name=model_name,
|
||||
use_gpu=use_gpu,
|
||||
max_length=max_length,
|
||||
sparsity_threshold=sparsity_threshold,
|
||||
)
|
||||
# Pre-load model to ensure it's ready
|
||||
encoder._load_model()
|
||||
_splade_cache[cache_key] = encoder
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def clear_splade_cache() -> None:
|
||||
"""Clear the SPLADE encoder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when encoders are no longer needed.
|
||||
"""
|
||||
global _splade_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for encoder in _splade_cache.values():
|
||||
if encoder._model is not None:
|
||||
del encoder._model
|
||||
encoder._model = None
|
||||
if encoder._tokenizer is not None:
|
||||
del encoder._tokenizer
|
||||
encoder._tokenizer = None
|
||||
_splade_cache.clear()
|
||||
|
||||
|
||||
class SpladeEncoder:
|
||||
"""ONNX-optimized SPLADE sparse encoder.
|
||||
|
||||
Produces sparse vectors with vocabulary-aligned dimensions.
|
||||
Output: Dict[int, float] mapping token_id to weight.
|
||||
|
||||
SPLADE activation formula:
|
||||
splade_repr = log(1 + ReLU(logits)) * attention_mask
|
||||
splade_vec = max_pooling(splade_repr, axis=sequence_length)
|
||||
|
||||
References:
|
||||
- SPLADE: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
providers: Optional[List[Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize SPLADE encoder.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.max_length = int(max_length) if max_length > 0 else 512
|
||||
self.sparsity_threshold = float(sparsity_threshold)
|
||||
self.providers = providers
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._vocab_size: int | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load ONNX model and tokenizer."""
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from .gpu_support import get_optimal_providers
|
||||
|
||||
# Include device_id options for DirectML/CUDA selection when available
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=True
|
||||
)
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`
|
||||
# Prefer passing the full providers list, with a conservative fallback
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception:
|
||||
model_kwargs = {}
|
||||
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
**model_kwargs,
|
||||
)
|
||||
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(self.model_name)
|
||||
logger.warning(
|
||||
"Optimum version doesn't support provider parameters. "
|
||||
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||
)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache vocabulary size
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||
|
||||
@staticmethod
|
||||
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||
"""Apply SPLADE activation function to model outputs.
|
||||
|
||||
Formula: log(1 + ReLU(logits)) * attention_mask
|
||||
|
||||
Args:
|
||||
logits: Model output logits (batch, seq_len, vocab_size)
|
||||
attention_mask: Attention mask (batch, seq_len)
|
||||
|
||||
Returns:
|
||||
SPLADE representations (batch, seq_len, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# ReLU activation
|
||||
relu_logits = np.maximum(0, logits)
|
||||
|
||||
# Log(1 + x) transformation
|
||||
log_relu = np.log1p(relu_logits)
|
||||
|
||||
# Apply attention mask (expand to match vocab dimension)
|
||||
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
|
||||
mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
||||
|
||||
# Element-wise multiplication
|
||||
splade_repr = log_relu * mask_expanded
|
||||
|
||||
return splade_repr
|
||||
|
||||
@staticmethod
|
||||
def _max_pooling(splade_repr: Any) -> Any:
|
||||
"""Max pooling over sequence length dimension.
|
||||
|
||||
Args:
|
||||
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
|
||||
|
||||
Returns:
|
||||
Pooled sparse vectors (batch, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Max pooling over sequence dimension (axis=1)
|
||||
return np.max(splade_repr, axis=1)
|
||||
|
||||
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
|
||||
"""Convert dense vector to sparse dictionary.
|
||||
|
||||
Args:
|
||||
dense_vec: Dense vector (vocab_size,)
|
||||
|
||||
Returns:
|
||||
Sparse dictionary {token_id: weight} with weights above threshold
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Find non-zero indices above threshold
|
||||
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
|
||||
|
||||
# Create sparse dictionary
|
||||
sparse_dict = {
|
||||
int(idx): float(dense_vec[idx])
|
||||
for idx in nonzero_indices
|
||||
}
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def encode_text(self, text: str) -> Dict[int, float]:
|
||||
"""Encode text to sparse vector {token_id: weight}.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Sparse vector as dictionary mapping token_id to weight
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Tokenize input
|
||||
encoded = self._tokenizer(
|
||||
text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vec = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert to sparse dictionary (single item batch)
|
||||
sparse_dict = self._to_sparse_dict(splade_vec[0])
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
|
||||
"""Batch encode texts to sparse vectors.
|
||||
|
||||
Args:
|
||||
texts: List of input texts to encode
|
||||
batch_size: Batch size for encoding (default: 32)
|
||||
|
||||
Returns:
|
||||
List of sparse vectors as dictionaries
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
results: List[Dict[int, float]] = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Tokenize batch
|
||||
encoded = self._tokenizer(
|
||||
batch_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vecs = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert each vector to sparse dictionary
|
||||
for vec in splade_vecs:
|
||||
sparse_dict = self._to_sparse_dict(vec)
|
||||
results.append(sparse_dict)
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Return vocabulary size (~30k for BERT-based models).
|
||||
|
||||
Returns:
|
||||
Vocabulary size (number of tokens in tokenizer)
|
||||
"""
|
||||
if self._vocab_size is not None:
|
||||
return self._vocab_size
|
||||
|
||||
self._load_model()
|
||||
return self._vocab_size or 0
|
||||
|
||||
def get_token(self, token_id: int) -> str:
|
||||
"""Convert token_id to string (for debugging).
|
||||
|
||||
Args:
|
||||
token_id: Token ID to convert
|
||||
|
||||
Returns:
|
||||
Token string
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded")
|
||||
|
||||
return self._tokenizer.decode([token_id])
|
||||
|
||||
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
|
||||
"""Get top-k tokens with highest weights from sparse vector.
|
||||
|
||||
Useful for debugging and understanding what the model is focusing on.
|
||||
|
||||
Args:
|
||||
sparse_vec: Sparse vector as {token_id: weight}
|
||||
top_k: Number of top tokens to return
|
||||
|
||||
Returns:
|
||||
List of (token_string, weight) tuples, sorted by weight descending
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if not sparse_vec:
|
||||
return []
|
||||
|
||||
# Sort by weight descending
|
||||
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Take top-k and convert token_ids to strings
|
||||
top_items = sorted_items[:top_k]
|
||||
|
||||
return [
|
||||
(self.get_token(token_id), weight)
|
||||
for token_id, weight in top_items
|
||||
]
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Migration 009: Add SPLADE sparse retrieval tables.
|
||||
|
||||
This migration introduces SPLADE (Sparse Lexical AnD Expansion) support:
|
||||
- splade_metadata: Model configuration (model name, vocab size, ONNX path)
|
||||
- splade_posting_list: Inverted index mapping token_id -> (chunk_id, weight)
|
||||
|
||||
The SPLADE tables are designed for efficient sparse vector retrieval:
|
||||
- Token-based lookup for query expansion
|
||||
- Chunk-based deletion for index maintenance
|
||||
- Maintains backward compatibility with existing FTS tables
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
"""
|
||||
Adds SPLADE tables for sparse retrieval.
|
||||
|
||||
Creates:
|
||||
- splade_metadata: Stores model configuration and ONNX path
|
||||
- splade_posting_list: Inverted index with token_id -> (chunk_id, weight) mappings
|
||||
- Indexes for efficient token-based and chunk-based lookups
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating splade_metadata table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS splade_metadata (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
model_name TEXT NOT NULL,
|
||||
vocab_size INTEGER NOT NULL,
|
||||
onnx_path TEXT,
|
||||
created_at REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating splade_posting_list table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
||||
token_id INTEGER NOT NULL,
|
||||
chunk_id INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
PRIMARY KEY (token_id, chunk_id),
|
||||
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for splade_posting_list...")
|
||||
# Index for efficient chunk-based lookups (deletion, updates)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
||||
ON splade_posting_list(chunk_id)
|
||||
"""
|
||||
)
|
||||
|
||||
# Index for efficient term-based retrieval
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
||||
ON splade_posting_list(token_id)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Migration 009 completed successfully")
|
||||
|
||||
|
||||
def downgrade(db_conn: Connection) -> None:
|
||||
"""
|
||||
Removes SPLADE tables.
|
||||
|
||||
Drops:
|
||||
- splade_posting_list (and associated indexes)
|
||||
- splade_metadata
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Dropping SPLADE indexes...")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_chunk")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_token")
|
||||
|
||||
log.info("Dropping splade_posting_list table...")
|
||||
cursor.execute("DROP TABLE IF EXISTS splade_posting_list")
|
||||
|
||||
log.info("Dropping splade_metadata table...")
|
||||
cursor.execute("DROP TABLE IF EXISTS splade_metadata")
|
||||
|
||||
log.info("Migration 009 downgrade completed successfully")
|
||||
432
codex-lens/src/codexlens/storage/splade_index.py
Normal file
432
codex-lens/src/codexlens/storage/splade_index.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""SPLADE inverted index storage for sparse vector retrieval.
|
||||
|
||||
This module implements SQLite-based inverted index for SPLADE sparse vectors,
|
||||
enabling efficient sparse retrieval using dot-product scoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpladeIndex:
|
||||
"""SQLite-based inverted index for SPLADE sparse vectors.
|
||||
|
||||
Stores sparse vectors as posting lists mapping token_id -> (chunk_id, weight).
|
||||
Supports efficient dot-product retrieval using SQL joins.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path | str) -> None:
|
||||
"""Initialize SPLADE index.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file.
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Thread-safe connection management
|
||||
self._lock = threading.RLock()
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get or create a thread-local database connection."""
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is None:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit
|
||||
self._local.conn = conn
|
||||
return conn
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close thread-local database connection."""
|
||||
with self._lock:
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
self._local.conn = None
|
||||
|
||||
def __enter__(self) -> SpladeIndex:
|
||||
"""Context manager entry."""
|
||||
self.create_tables()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
def has_index(self) -> bool:
|
||||
"""Check if SPLADE tables exist in database.
|
||||
|
||||
Returns:
|
||||
True if tables exist, False otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='splade_posting_list'
|
||||
"""
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to check index existence: %s", e)
|
||||
return False
|
||||
|
||||
def create_tables(self) -> None:
|
||||
"""Create SPLADE schema if not exists.
|
||||
|
||||
Note: The splade_posting_list table has a FOREIGN KEY constraint
|
||||
referencing semantic_chunks(id). Ensure VectorStore.create_tables()
|
||||
is called first to create the semantic_chunks table.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Inverted index for sparse vectors
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
||||
token_id INTEGER NOT NULL,
|
||||
chunk_id INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
PRIMARY KEY (token_id, chunk_id),
|
||||
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Indexes for efficient lookups
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
||||
ON splade_posting_list(chunk_id)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
||||
ON splade_posting_list(token_id)
|
||||
""")
|
||||
|
||||
# Model metadata
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS splade_metadata (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
model_name TEXT NOT NULL,
|
||||
vocab_size INTEGER NOT NULL,
|
||||
onnx_path TEXT,
|
||||
created_at REAL
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
logger.debug("SPLADE schema created successfully")
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to create SPLADE schema: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="create_tables"
|
||||
) from e
|
||||
|
||||
def add_posting(self, chunk_id: int, sparse_vec: Dict[int, float]) -> None:
|
||||
"""Add a single document to inverted index.
|
||||
|
||||
Args:
|
||||
chunk_id: Chunk ID (foreign key to semantic_chunks.id).
|
||||
sparse_vec: Sparse vector as {token_id: weight} mapping.
|
||||
"""
|
||||
if not sparse_vec:
|
||||
logger.warning("Empty sparse vector for chunk_id=%d, skipping", chunk_id)
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Insert all non-zero weights for this chunk
|
||||
postings = [
|
||||
(token_id, chunk_id, weight)
|
||||
for token_id, weight in sparse_vec.items()
|
||||
if weight > 0 # Only store non-zero weights
|
||||
]
|
||||
|
||||
if postings:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_posting_list
|
||||
(token_id, chunk_id, weight)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
postings
|
||||
)
|
||||
conn.commit()
|
||||
logger.debug(
|
||||
"Added %d postings for chunk_id=%d", len(postings), chunk_id
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to add posting for chunk_id={chunk_id}: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_posting"
|
||||
) from e
|
||||
|
||||
def add_postings_batch(
|
||||
self, postings: List[Tuple[int, Dict[int, float]]]
|
||||
) -> None:
|
||||
"""Batch insert postings for multiple chunks.
|
||||
|
||||
Args:
|
||||
postings: List of (chunk_id, sparse_vec) tuples.
|
||||
"""
|
||||
if not postings:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Flatten all postings into single batch
|
||||
batch_data = []
|
||||
for chunk_id, sparse_vec in postings:
|
||||
for token_id, weight in sparse_vec.items():
|
||||
if weight > 0: # Only store non-zero weights
|
||||
batch_data.append((token_id, chunk_id, weight))
|
||||
|
||||
if batch_data:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_posting_list
|
||||
(token_id, chunk_id, weight)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
batch_data
|
||||
)
|
||||
conn.commit()
|
||||
logger.debug(
|
||||
"Batch inserted %d postings for %d chunks",
|
||||
len(batch_data),
|
||||
len(postings)
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to batch insert postings: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_postings_batch"
|
||||
) from e
|
||||
|
||||
def remove_chunk(self, chunk_id: int) -> int:
|
||||
"""Remove all postings for a chunk.
|
||||
|
||||
Args:
|
||||
chunk_id: Chunk ID to remove.
|
||||
|
||||
Returns:
|
||||
Number of deleted postings.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM splade_posting_list WHERE chunk_id = ?",
|
||||
(chunk_id,)
|
||||
)
|
||||
conn.commit()
|
||||
deleted = cursor.rowcount
|
||||
logger.debug("Removed %d postings for chunk_id=%d", deleted, chunk_id)
|
||||
return deleted
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to remove chunk_id={chunk_id}: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="remove_chunk"
|
||||
) from e
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_sparse: Dict[int, float],
|
||||
limit: int = 50,
|
||||
min_score: float = 0.0
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Search for similar chunks using dot-product scoring.
|
||||
|
||||
Implements efficient sparse dot-product via SQL JOIN:
|
||||
score(q, d) = sum(q[t] * d[t]) for all tokens t
|
||||
|
||||
Args:
|
||||
query_sparse: Query sparse vector as {token_id: weight}.
|
||||
limit: Maximum number of results.
|
||||
min_score: Minimum score threshold.
|
||||
|
||||
Returns:
|
||||
List of (chunk_id, score) tuples, ordered by score descending.
|
||||
"""
|
||||
if not query_sparse:
|
||||
logger.warning("Empty query sparse vector")
|
||||
return []
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Build VALUES clause for query terms
|
||||
# Each term: (token_id, weight)
|
||||
query_terms = [
|
||||
(token_id, weight)
|
||||
for token_id, weight in query_sparse.items()
|
||||
if weight > 0
|
||||
]
|
||||
|
||||
if not query_terms:
|
||||
logger.warning("No non-zero query terms")
|
||||
return []
|
||||
|
||||
# Create CTE for query terms using parameterized VALUES
|
||||
# Build placeholders and params to prevent SQL injection
|
||||
params = []
|
||||
placeholders = []
|
||||
for token_id, weight in query_terms:
|
||||
placeholders.append("(?, ?)")
|
||||
params.extend([token_id, weight])
|
||||
|
||||
values_placeholders = ", ".join(placeholders)
|
||||
|
||||
sql = f"""
|
||||
WITH query_terms(token_id, weight) AS (
|
||||
VALUES {values_placeholders}
|
||||
)
|
||||
SELECT
|
||||
p.chunk_id,
|
||||
SUM(p.weight * q.weight) as score
|
||||
FROM splade_posting_list p
|
||||
INNER JOIN query_terms q ON p.token_id = q.token_id
|
||||
GROUP BY p.chunk_id
|
||||
HAVING score >= ?
|
||||
ORDER BY score DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
# Append min_score and limit to params
|
||||
params.extend([min_score, limit])
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
results = [(row["chunk_id"], float(row["score"])) for row in rows]
|
||||
logger.debug(
|
||||
"SPLADE search: %d query terms, %d results",
|
||||
len(query_terms),
|
||||
len(results)
|
||||
)
|
||||
return results
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"SPLADE search failed: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="search"
|
||||
) from e
|
||||
|
||||
def get_metadata(self) -> Optional[Dict]:
|
||||
"""Get SPLADE model metadata.
|
||||
|
||||
Returns:
|
||||
Dictionary with model_name, vocab_size, onnx_path, created_at,
|
||||
or None if not set.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT model_name, vocab_size, onnx_path, created_at
|
||||
FROM splade_metadata
|
||||
WHERE id = 1
|
||||
"""
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"model_name": row["model_name"],
|
||||
"vocab_size": row["vocab_size"],
|
||||
"onnx_path": row["onnx_path"],
|
||||
"created_at": row["created_at"]
|
||||
}
|
||||
return None
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to get metadata: %s", e)
|
||||
return None
|
||||
|
||||
def set_metadata(
|
||||
self,
|
||||
model_name: str,
|
||||
vocab_size: int,
|
||||
onnx_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""Set SPLADE model metadata.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name.
|
||||
vocab_size: Vocabulary size (typically ~30k for BERT vocab).
|
||||
onnx_path: Optional path to ONNX model file.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
current_time = time.time()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_metadata
|
||||
(id, model_name, vocab_size, onnx_path, created_at)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""",
|
||||
(model_name, vocab_size, onnx_path, current_time)
|
||||
)
|
||||
conn.commit()
|
||||
logger.info(
|
||||
"Set SPLADE metadata: model=%s, vocab_size=%d",
|
||||
model_name,
|
||||
vocab_size
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to set metadata: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="set_metadata"
|
||||
) from e
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with total_postings, unique_tokens, unique_chunks.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total_postings,
|
||||
COUNT(DISTINCT token_id) as unique_tokens,
|
||||
COUNT(DISTINCT chunk_id) as unique_chunks
|
||||
FROM splade_posting_list
|
||||
""").fetchone()
|
||||
|
||||
return {
|
||||
"total_postings": row["total_postings"],
|
||||
"unique_tokens": row["unique_tokens"],
|
||||
"unique_chunks": row["unique_chunks"]
|
||||
}
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to get stats: %s", e)
|
||||
return {
|
||||
"total_postings": 0,
|
||||
"unique_tokens": 0,
|
||||
"unique_chunks": 0
|
||||
}
|
||||
Reference in New Issue
Block a user