Implement SPLADE sparse encoder and associated database migrations

- Added `splade_encoder.py` for ONNX-optimized SPLADE encoding, including methods for encoding text and batch processing.
- Created `SPLADE_IMPLEMENTATION.md` to document the SPLADE encoder's functionality, design patterns, and integration points.
- Introduced migration script `migration_009_add_splade.py` to add SPLADE metadata and posting list tables to the database.
- Developed `splade_index.py` for managing the SPLADE inverted index, supporting efficient sparse vector retrieval.
- Added verification script `verify_watcher.py` to test FileWatcher event filtering and debouncing functionality.
This commit is contained in:
catlog22
2026-01-01 17:41:22 +08:00
parent 520f2d26f2
commit 5bb01755bc
16 changed files with 3122 additions and 2792 deletions

View File

@@ -0,0 +1,103 @@
"""
Migration 009: Add SPLADE sparse retrieval tables.
This migration introduces SPLADE (Sparse Lexical AnD Expansion) support:
- splade_metadata: Model configuration (model name, vocab size, ONNX path)
- splade_posting_list: Inverted index mapping token_id -> (chunk_id, weight)
The SPLADE tables are designed for efficient sparse vector retrieval:
- Token-based lookup for query expansion
- Chunk-based deletion for index maintenance
- Maintains backward compatibility with existing FTS tables
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection) -> None:
"""
Adds SPLADE tables for sparse retrieval.
Creates:
- splade_metadata: Stores model configuration and ONNX path
- splade_posting_list: Inverted index with token_id -> (chunk_id, weight) mappings
- Indexes for efficient token-based and chunk-based lookups
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Creating splade_metadata table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS splade_metadata (
id INTEGER PRIMARY KEY DEFAULT 1,
model_name TEXT NOT NULL,
vocab_size INTEGER NOT NULL,
onnx_path TEXT,
created_at REAL
)
"""
)
log.info("Creating splade_posting_list table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS splade_posting_list (
token_id INTEGER NOT NULL,
chunk_id INTEGER NOT NULL,
weight REAL NOT NULL,
PRIMARY KEY (token_id, chunk_id),
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
)
"""
)
log.info("Creating indexes for splade_posting_list...")
# Index for efficient chunk-based lookups (deletion, updates)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
ON splade_posting_list(chunk_id)
"""
)
# Index for efficient term-based retrieval
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_splade_by_token
ON splade_posting_list(token_id)
"""
)
log.info("Migration 009 completed successfully")
def downgrade(db_conn: Connection) -> None:
"""
Removes SPLADE tables.
Drops:
- splade_posting_list (and associated indexes)
- splade_metadata
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Dropping SPLADE indexes...")
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_chunk")
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_token")
log.info("Dropping splade_posting_list table...")
cursor.execute("DROP TABLE IF EXISTS splade_posting_list")
log.info("Dropping splade_metadata table...")
cursor.execute("DROP TABLE IF EXISTS splade_metadata")
log.info("Migration 009 downgrade completed successfully")

View File

@@ -0,0 +1,432 @@
"""SPLADE inverted index storage for sparse vector retrieval.
This module implements SQLite-based inverted index for SPLADE sparse vectors,
enabling efficient sparse retrieval using dot-product scoring.
"""
from __future__ import annotations
import logging
import sqlite3
import threading
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from codexlens.entities import SearchResult
from codexlens.errors import StorageError
logger = logging.getLogger(__name__)
class SpladeIndex:
"""SQLite-based inverted index for SPLADE sparse vectors.
Stores sparse vectors as posting lists mapping token_id -> (chunk_id, weight).
Supports efficient dot-product retrieval using SQL joins.
"""
def __init__(self, db_path: Path | str) -> None:
"""Initialize SPLADE index.
Args:
db_path: Path to SQLite database file.
"""
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
# Thread-safe connection management
self._lock = threading.RLock()
self._local = threading.local()
def _get_connection(self) -> sqlite3.Connection:
"""Get or create a thread-local database connection."""
conn = getattr(self._local, "conn", None)
if conn is None:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit
self._local.conn = conn
return conn
def close(self) -> None:
"""Close thread-local database connection."""
with self._lock:
conn = getattr(self._local, "conn", None)
if conn is not None:
conn.close()
self._local.conn = None
def __enter__(self) -> SpladeIndex:
"""Context manager entry."""
self.create_tables()
return self
def __exit__(self, exc_type, exc, tb) -> None:
"""Context manager exit."""
self.close()
def has_index(self) -> bool:
"""Check if SPLADE tables exist in database.
Returns:
True if tables exist, False otherwise.
"""
with self._lock:
conn = self._get_connection()
try:
cursor = conn.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='splade_posting_list'
"""
)
return cursor.fetchone() is not None
except sqlite3.Error as e:
logger.error("Failed to check index existence: %s", e)
return False
def create_tables(self) -> None:
"""Create SPLADE schema if not exists.
Note: The splade_posting_list table has a FOREIGN KEY constraint
referencing semantic_chunks(id). Ensure VectorStore.create_tables()
is called first to create the semantic_chunks table.
"""
with self._lock:
conn = self._get_connection()
try:
# Inverted index for sparse vectors
conn.execute("""
CREATE TABLE IF NOT EXISTS splade_posting_list (
token_id INTEGER NOT NULL,
chunk_id INTEGER NOT NULL,
weight REAL NOT NULL,
PRIMARY KEY (token_id, chunk_id),
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
)
""")
# Indexes for efficient lookups
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
ON splade_posting_list(chunk_id)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_splade_by_token
ON splade_posting_list(token_id)
""")
# Model metadata
conn.execute("""
CREATE TABLE IF NOT EXISTS splade_metadata (
id INTEGER PRIMARY KEY DEFAULT 1,
model_name TEXT NOT NULL,
vocab_size INTEGER NOT NULL,
onnx_path TEXT,
created_at REAL
)
""")
conn.commit()
logger.debug("SPLADE schema created successfully")
except sqlite3.Error as e:
raise StorageError(
f"Failed to create SPLADE schema: {e}",
db_path=str(self.db_path),
operation="create_tables"
) from e
def add_posting(self, chunk_id: int, sparse_vec: Dict[int, float]) -> None:
"""Add a single document to inverted index.
Args:
chunk_id: Chunk ID (foreign key to semantic_chunks.id).
sparse_vec: Sparse vector as {token_id: weight} mapping.
"""
if not sparse_vec:
logger.warning("Empty sparse vector for chunk_id=%d, skipping", chunk_id)
return
with self._lock:
conn = self._get_connection()
try:
# Insert all non-zero weights for this chunk
postings = [
(token_id, chunk_id, weight)
for token_id, weight in sparse_vec.items()
if weight > 0 # Only store non-zero weights
]
if postings:
conn.executemany(
"""
INSERT OR REPLACE INTO splade_posting_list
(token_id, chunk_id, weight)
VALUES (?, ?, ?)
""",
postings
)
conn.commit()
logger.debug(
"Added %d postings for chunk_id=%d", len(postings), chunk_id
)
except sqlite3.Error as e:
raise StorageError(
f"Failed to add posting for chunk_id={chunk_id}: {e}",
db_path=str(self.db_path),
operation="add_posting"
) from e
def add_postings_batch(
self, postings: List[Tuple[int, Dict[int, float]]]
) -> None:
"""Batch insert postings for multiple chunks.
Args:
postings: List of (chunk_id, sparse_vec) tuples.
"""
if not postings:
return
with self._lock:
conn = self._get_connection()
try:
# Flatten all postings into single batch
batch_data = []
for chunk_id, sparse_vec in postings:
for token_id, weight in sparse_vec.items():
if weight > 0: # Only store non-zero weights
batch_data.append((token_id, chunk_id, weight))
if batch_data:
conn.executemany(
"""
INSERT OR REPLACE INTO splade_posting_list
(token_id, chunk_id, weight)
VALUES (?, ?, ?)
""",
batch_data
)
conn.commit()
logger.debug(
"Batch inserted %d postings for %d chunks",
len(batch_data),
len(postings)
)
except sqlite3.Error as e:
raise StorageError(
f"Failed to batch insert postings: {e}",
db_path=str(self.db_path),
operation="add_postings_batch"
) from e
def remove_chunk(self, chunk_id: int) -> int:
"""Remove all postings for a chunk.
Args:
chunk_id: Chunk ID to remove.
Returns:
Number of deleted postings.
"""
with self._lock:
conn = self._get_connection()
try:
cursor = conn.execute(
"DELETE FROM splade_posting_list WHERE chunk_id = ?",
(chunk_id,)
)
conn.commit()
deleted = cursor.rowcount
logger.debug("Removed %d postings for chunk_id=%d", deleted, chunk_id)
return deleted
except sqlite3.Error as e:
raise StorageError(
f"Failed to remove chunk_id={chunk_id}: {e}",
db_path=str(self.db_path),
operation="remove_chunk"
) from e
def search(
self,
query_sparse: Dict[int, float],
limit: int = 50,
min_score: float = 0.0
) -> List[Tuple[int, float]]:
"""Search for similar chunks using dot-product scoring.
Implements efficient sparse dot-product via SQL JOIN:
score(q, d) = sum(q[t] * d[t]) for all tokens t
Args:
query_sparse: Query sparse vector as {token_id: weight}.
limit: Maximum number of results.
min_score: Minimum score threshold.
Returns:
List of (chunk_id, score) tuples, ordered by score descending.
"""
if not query_sparse:
logger.warning("Empty query sparse vector")
return []
with self._lock:
conn = self._get_connection()
try:
# Build VALUES clause for query terms
# Each term: (token_id, weight)
query_terms = [
(token_id, weight)
for token_id, weight in query_sparse.items()
if weight > 0
]
if not query_terms:
logger.warning("No non-zero query terms")
return []
# Create CTE for query terms using parameterized VALUES
# Build placeholders and params to prevent SQL injection
params = []
placeholders = []
for token_id, weight in query_terms:
placeholders.append("(?, ?)")
params.extend([token_id, weight])
values_placeholders = ", ".join(placeholders)
sql = f"""
WITH query_terms(token_id, weight) AS (
VALUES {values_placeholders}
)
SELECT
p.chunk_id,
SUM(p.weight * q.weight) as score
FROM splade_posting_list p
INNER JOIN query_terms q ON p.token_id = q.token_id
GROUP BY p.chunk_id
HAVING score >= ?
ORDER BY score DESC
LIMIT ?
"""
# Append min_score and limit to params
params.extend([min_score, limit])
rows = conn.execute(sql, params).fetchall()
results = [(row["chunk_id"], float(row["score"])) for row in rows]
logger.debug(
"SPLADE search: %d query terms, %d results",
len(query_terms),
len(results)
)
return results
except sqlite3.Error as e:
raise StorageError(
f"SPLADE search failed: {e}",
db_path=str(self.db_path),
operation="search"
) from e
def get_metadata(self) -> Optional[Dict]:
"""Get SPLADE model metadata.
Returns:
Dictionary with model_name, vocab_size, onnx_path, created_at,
or None if not set.
"""
with self._lock:
conn = self._get_connection()
try:
row = conn.execute(
"""
SELECT model_name, vocab_size, onnx_path, created_at
FROM splade_metadata
WHERE id = 1
"""
).fetchone()
if row:
return {
"model_name": row["model_name"],
"vocab_size": row["vocab_size"],
"onnx_path": row["onnx_path"],
"created_at": row["created_at"]
}
return None
except sqlite3.Error as e:
logger.error("Failed to get metadata: %s", e)
return None
def set_metadata(
self,
model_name: str,
vocab_size: int,
onnx_path: Optional[str] = None
) -> None:
"""Set SPLADE model metadata.
Args:
model_name: SPLADE model name.
vocab_size: Vocabulary size (typically ~30k for BERT vocab).
onnx_path: Optional path to ONNX model file.
"""
with self._lock:
conn = self._get_connection()
try:
current_time = time.time()
conn.execute(
"""
INSERT OR REPLACE INTO splade_metadata
(id, model_name, vocab_size, onnx_path, created_at)
VALUES (1, ?, ?, ?, ?)
""",
(model_name, vocab_size, onnx_path, current_time)
)
conn.commit()
logger.info(
"Set SPLADE metadata: model=%s, vocab_size=%d",
model_name,
vocab_size
)
except sqlite3.Error as e:
raise StorageError(
f"Failed to set metadata: {e}",
db_path=str(self.db_path),
operation="set_metadata"
) from e
def get_stats(self) -> Dict:
"""Get index statistics.
Returns:
Dictionary with total_postings, unique_tokens, unique_chunks.
"""
with self._lock:
conn = self._get_connection()
try:
row = conn.execute("""
SELECT
COUNT(*) as total_postings,
COUNT(DISTINCT token_id) as unique_tokens,
COUNT(DISTINCT chunk_id) as unique_chunks
FROM splade_posting_list
""").fetchone()
return {
"total_postings": row["total_postings"],
"unique_tokens": row["unique_tokens"],
"unique_chunks": row["unique_chunks"]
}
except sqlite3.Error as e:
logger.error("Failed to get stats: %s", e)
return {
"total_postings": 0,
"unique_tokens": 0,
"unique_chunks": 0
}