mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
579 lines
21 KiB
Python
579 lines
21 KiB
Python
"""SPLADE inverted index storage for sparse vector retrieval.
|
|
|
|
This module implements SQLite-based inverted index for SPLADE sparse vectors,
|
|
enabling efficient sparse retrieval using dot-product scoring.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from codexlens.entities import SearchResult
|
|
from codexlens.errors import StorageError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SpladeIndex:
|
|
"""SQLite-based inverted index for SPLADE sparse vectors.
|
|
|
|
Stores sparse vectors as posting lists mapping token_id -> (chunk_id, weight).
|
|
Supports efficient dot-product retrieval using SQL joins.
|
|
"""
|
|
|
|
def __init__(self, db_path: Path | str) -> None:
|
|
"""Initialize SPLADE index.
|
|
|
|
Args:
|
|
db_path: Path to SQLite database file.
|
|
"""
|
|
self.db_path = Path(db_path)
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Thread-safe connection management
|
|
self._lock = threading.RLock()
|
|
self._local = threading.local()
|
|
|
|
def _get_connection(self) -> sqlite3.Connection:
|
|
"""Get or create a thread-local database connection.
|
|
|
|
Each thread gets its own connection to ensure thread safety.
|
|
Connections are stored in thread-local storage.
|
|
"""
|
|
conn = getattr(self._local, "conn", None)
|
|
if conn is None:
|
|
# Thread-local connection - each thread has its own
|
|
conn = sqlite3.connect(
|
|
self.db_path,
|
|
timeout=30.0, # Wait up to 30s for locks
|
|
check_same_thread=True, # Enforce thread safety
|
|
)
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("PRAGMA synchronous=NORMAL")
|
|
conn.execute("PRAGMA foreign_keys=ON")
|
|
# Limit mmap to 1GB to avoid OOM on smaller systems
|
|
conn.execute("PRAGMA mmap_size=1073741824")
|
|
# Increase cache size for better query performance (20MB = -20000 pages)
|
|
conn.execute("PRAGMA cache_size=-20000")
|
|
self._local.conn = conn
|
|
return conn
|
|
|
|
def close(self) -> None:
|
|
"""Close thread-local database connection."""
|
|
with self._lock:
|
|
conn = getattr(self._local, "conn", None)
|
|
if conn is not None:
|
|
conn.close()
|
|
self._local.conn = None
|
|
|
|
def __enter__(self) -> SpladeIndex:
|
|
"""Context manager entry."""
|
|
self.create_tables()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> None:
|
|
"""Context manager exit."""
|
|
self.close()
|
|
|
|
def has_index(self) -> bool:
|
|
"""Check if SPLADE tables exist in database.
|
|
|
|
Returns:
|
|
True if tables exist, False otherwise.
|
|
"""
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
cursor = conn.execute(
|
|
"""
|
|
SELECT name FROM sqlite_master
|
|
WHERE type='table' AND name='splade_posting_list'
|
|
"""
|
|
)
|
|
return cursor.fetchone() is not None
|
|
except sqlite3.Error as e:
|
|
logger.error("Failed to check index existence: %s", e)
|
|
return False
|
|
|
|
def create_tables(self) -> None:
|
|
"""Create SPLADE schema if not exists.
|
|
|
|
Note: When used with distributed indexes (multiple _index.db files),
|
|
the SPLADE database stores chunk IDs from multiple sources. In this case,
|
|
foreign key constraints are not enforced to allow cross-database references.
|
|
"""
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
# Inverted index for sparse vectors
|
|
# Note: No FOREIGN KEY constraint to support distributed index architecture
|
|
# where chunks may come from multiple _index.db files
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
|
token_id INTEGER NOT NULL,
|
|
chunk_id INTEGER NOT NULL,
|
|
weight REAL NOT NULL,
|
|
PRIMARY KEY (token_id, chunk_id)
|
|
)
|
|
""")
|
|
|
|
# Indexes for efficient lookups
|
|
conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
|
ON splade_posting_list(chunk_id)
|
|
""")
|
|
conn.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
|
ON splade_posting_list(token_id)
|
|
""")
|
|
|
|
# Model metadata
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS splade_metadata (
|
|
id INTEGER PRIMARY KEY DEFAULT 1,
|
|
model_name TEXT NOT NULL,
|
|
vocab_size INTEGER NOT NULL,
|
|
onnx_path TEXT,
|
|
created_at REAL
|
|
)
|
|
""")
|
|
|
|
# Chunk metadata for self-contained search results
|
|
# Stores all chunk info needed to build SearchResult without querying _index.db
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS splade_chunks (
|
|
id INTEGER PRIMARY KEY,
|
|
file_path TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
metadata TEXT,
|
|
source_db TEXT
|
|
)
|
|
""")
|
|
|
|
conn.commit()
|
|
logger.debug("SPLADE schema created successfully")
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"Failed to create SPLADE schema: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="create_tables"
|
|
) from e
|
|
|
|
def add_posting(self, chunk_id: int, sparse_vec: Dict[int, float]) -> None:
|
|
"""Add a single document to inverted index.
|
|
|
|
Args:
|
|
chunk_id: Chunk ID (foreign key to semantic_chunks.id).
|
|
sparse_vec: Sparse vector as {token_id: weight} mapping.
|
|
"""
|
|
if not sparse_vec:
|
|
logger.warning("Empty sparse vector for chunk_id=%d, skipping", chunk_id)
|
|
return
|
|
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
# Insert all non-zero weights for this chunk
|
|
postings = [
|
|
(token_id, chunk_id, weight)
|
|
for token_id, weight in sparse_vec.items()
|
|
if weight > 0 # Only store non-zero weights
|
|
]
|
|
|
|
if postings:
|
|
conn.executemany(
|
|
"""
|
|
INSERT OR REPLACE INTO splade_posting_list
|
|
(token_id, chunk_id, weight)
|
|
VALUES (?, ?, ?)
|
|
""",
|
|
postings
|
|
)
|
|
conn.commit()
|
|
logger.debug(
|
|
"Added %d postings for chunk_id=%d", len(postings), chunk_id
|
|
)
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"Failed to add posting for chunk_id={chunk_id}: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="add_posting"
|
|
) from e
|
|
|
|
def add_postings_batch(
|
|
self, postings: List[Tuple[int, Dict[int, float]]]
|
|
) -> None:
|
|
"""Batch insert postings for multiple chunks.
|
|
|
|
Args:
|
|
postings: List of (chunk_id, sparse_vec) tuples.
|
|
"""
|
|
if not postings:
|
|
return
|
|
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
# Flatten all postings into single batch
|
|
batch_data = []
|
|
for chunk_id, sparse_vec in postings:
|
|
for token_id, weight in sparse_vec.items():
|
|
if weight > 0: # Only store non-zero weights
|
|
batch_data.append((token_id, chunk_id, weight))
|
|
|
|
if batch_data:
|
|
conn.executemany(
|
|
"""
|
|
INSERT OR REPLACE INTO splade_posting_list
|
|
(token_id, chunk_id, weight)
|
|
VALUES (?, ?, ?)
|
|
""",
|
|
batch_data
|
|
)
|
|
conn.commit()
|
|
logger.debug(
|
|
"Batch inserted %d postings for %d chunks",
|
|
len(batch_data),
|
|
len(postings)
|
|
)
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"Failed to batch insert postings: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="add_postings_batch"
|
|
) from e
|
|
|
|
def add_chunk_metadata(
|
|
self,
|
|
chunk_id: int,
|
|
file_path: str,
|
|
content: str,
|
|
metadata: Optional[str] = None,
|
|
source_db: Optional[str] = None
|
|
) -> None:
|
|
"""Store chunk metadata for self-contained search results.
|
|
|
|
Args:
|
|
chunk_id: Global chunk ID.
|
|
file_path: Path to source file.
|
|
content: Chunk text content.
|
|
metadata: JSON metadata string.
|
|
source_db: Path to source _index.db.
|
|
"""
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO splade_chunks
|
|
(id, file_path, content, metadata, source_db)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(chunk_id, file_path, content, metadata, source_db)
|
|
)
|
|
conn.commit()
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"Failed to add chunk metadata for chunk_id={chunk_id}: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="add_chunk_metadata"
|
|
) from e
|
|
|
|
def add_chunks_metadata_batch(
|
|
self,
|
|
chunks: List[Tuple[int, str, str, Optional[str], Optional[str]]]
|
|
) -> None:
|
|
"""Batch insert chunk metadata.
|
|
|
|
Args:
|
|
chunks: List of (chunk_id, file_path, content, metadata, source_db) tuples.
|
|
"""
|
|
if not chunks:
|
|
return
|
|
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
conn.executemany(
|
|
"""
|
|
INSERT OR REPLACE INTO splade_chunks
|
|
(id, file_path, content, metadata, source_db)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
chunks
|
|
)
|
|
conn.commit()
|
|
logger.debug("Batch inserted %d chunk metadata records", len(chunks))
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"Failed to batch insert chunk metadata: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="add_chunks_metadata_batch"
|
|
) from e
|
|
|
|
def get_chunks_by_ids(self, chunk_ids: List[int]) -> List[Dict]:
|
|
"""Get chunk metadata by IDs.
|
|
|
|
Args:
|
|
chunk_ids: List of chunk IDs to retrieve.
|
|
|
|
Returns:
|
|
List of dicts with id, file_path, content, metadata, source_db.
|
|
"""
|
|
if not chunk_ids:
|
|
return []
|
|
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
placeholders = ",".join("?" * len(chunk_ids))
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT id, file_path, content, metadata, source_db
|
|
FROM splade_chunks
|
|
WHERE id IN ({placeholders})
|
|
""",
|
|
chunk_ids
|
|
).fetchall()
|
|
|
|
return [
|
|
{
|
|
"id": row["id"],
|
|
"file_path": row["file_path"],
|
|
"content": row["content"],
|
|
"metadata": row["metadata"],
|
|
"source_db": row["source_db"]
|
|
}
|
|
for row in rows
|
|
]
|
|
except sqlite3.Error as e:
|
|
logger.error("Failed to get chunks by IDs: %s", e)
|
|
return []
|
|
|
|
def remove_chunk(self, chunk_id: int) -> int:
|
|
"""Remove all postings for a chunk.
|
|
|
|
Args:
|
|
chunk_id: Chunk ID to remove.
|
|
|
|
Returns:
|
|
Number of deleted postings.
|
|
"""
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
cursor = conn.execute(
|
|
"DELETE FROM splade_posting_list WHERE chunk_id = ?",
|
|
(chunk_id,)
|
|
)
|
|
conn.commit()
|
|
deleted = cursor.rowcount
|
|
logger.debug("Removed %d postings for chunk_id=%d", deleted, chunk_id)
|
|
return deleted
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"Failed to remove chunk_id={chunk_id}: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="remove_chunk"
|
|
) from e
|
|
|
|
def search(
|
|
self,
|
|
query_sparse: Dict[int, float],
|
|
limit: int = 50,
|
|
min_score: float = 0.0,
|
|
max_query_terms: int = 64
|
|
) -> List[Tuple[int, float]]:
|
|
"""Search for similar chunks using dot-product scoring.
|
|
|
|
Implements efficient sparse dot-product via SQL JOIN:
|
|
score(q, d) = sum(q[t] * d[t]) for all tokens t
|
|
|
|
Args:
|
|
query_sparse: Query sparse vector as {token_id: weight}.
|
|
limit: Maximum number of results.
|
|
min_score: Minimum score threshold.
|
|
max_query_terms: Maximum query terms to use (default: 64).
|
|
Pruning to top-K terms reduces search time with minimal impact on quality.
|
|
Set to 0 or negative to disable pruning (use all terms).
|
|
|
|
Returns:
|
|
List of (chunk_id, score) tuples, ordered by score descending.
|
|
"""
|
|
if not query_sparse:
|
|
logger.warning("Empty query sparse vector")
|
|
return []
|
|
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
# Build VALUES clause for query terms
|
|
# Each term: (token_id, weight)
|
|
query_terms = [
|
|
(token_id, weight)
|
|
for token_id, weight in query_sparse.items()
|
|
if weight > 0
|
|
]
|
|
|
|
if not query_terms:
|
|
logger.warning("No non-zero query terms")
|
|
return []
|
|
|
|
# Query pruning: keep only top-K terms by weight
|
|
# max_query_terms <= 0 means no limit (use all terms)
|
|
if max_query_terms > 0 and len(query_terms) > max_query_terms:
|
|
query_terms = sorted(query_terms, key=lambda x: x[1], reverse=True)[:max_query_terms]
|
|
logger.debug(
|
|
"Query pruned from %d to %d terms",
|
|
len(query_sparse),
|
|
len(query_terms)
|
|
)
|
|
|
|
# Create CTE for query terms using parameterized VALUES
|
|
# Build placeholders and params to prevent SQL injection
|
|
params = []
|
|
placeholders = []
|
|
for token_id, weight in query_terms:
|
|
placeholders.append("(?, ?)")
|
|
params.extend([token_id, weight])
|
|
|
|
values_placeholders = ", ".join(placeholders)
|
|
|
|
sql = f"""
|
|
WITH query_terms(token_id, weight) AS (
|
|
VALUES {values_placeholders}
|
|
)
|
|
SELECT
|
|
p.chunk_id,
|
|
SUM(p.weight * q.weight) as score
|
|
FROM splade_posting_list p
|
|
INNER JOIN query_terms q ON p.token_id = q.token_id
|
|
GROUP BY p.chunk_id
|
|
HAVING score >= ?
|
|
ORDER BY score DESC
|
|
LIMIT ?
|
|
"""
|
|
|
|
# Append min_score and limit to params
|
|
params.extend([min_score, limit])
|
|
rows = conn.execute(sql, params).fetchall()
|
|
|
|
results = [(row["chunk_id"], float(row["score"])) for row in rows]
|
|
logger.debug(
|
|
"SPLADE search: %d query terms, %d results",
|
|
len(query_terms),
|
|
len(results)
|
|
)
|
|
return results
|
|
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"SPLADE search failed: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="search"
|
|
) from e
|
|
|
|
def get_metadata(self) -> Optional[Dict]:
|
|
"""Get SPLADE model metadata.
|
|
|
|
Returns:
|
|
Dictionary with model_name, vocab_size, onnx_path, created_at,
|
|
or None if not set.
|
|
"""
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT model_name, vocab_size, onnx_path, created_at
|
|
FROM splade_metadata
|
|
WHERE id = 1
|
|
"""
|
|
).fetchone()
|
|
|
|
if row:
|
|
return {
|
|
"model_name": row["model_name"],
|
|
"vocab_size": row["vocab_size"],
|
|
"onnx_path": row["onnx_path"],
|
|
"created_at": row["created_at"]
|
|
}
|
|
return None
|
|
except sqlite3.Error as e:
|
|
logger.error("Failed to get metadata: %s", e)
|
|
return None
|
|
|
|
def set_metadata(
|
|
self,
|
|
model_name: str,
|
|
vocab_size: int,
|
|
onnx_path: Optional[str] = None
|
|
) -> None:
|
|
"""Set SPLADE model metadata.
|
|
|
|
Args:
|
|
model_name: SPLADE model name.
|
|
vocab_size: Vocabulary size (typically ~30k for BERT vocab).
|
|
onnx_path: Optional path to ONNX model file.
|
|
"""
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
current_time = time.time()
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO splade_metadata
|
|
(id, model_name, vocab_size, onnx_path, created_at)
|
|
VALUES (1, ?, ?, ?, ?)
|
|
""",
|
|
(model_name, vocab_size, onnx_path, current_time)
|
|
)
|
|
conn.commit()
|
|
logger.info(
|
|
"Set SPLADE metadata: model=%s, vocab_size=%d",
|
|
model_name,
|
|
vocab_size
|
|
)
|
|
except sqlite3.Error as e:
|
|
raise StorageError(
|
|
f"Failed to set metadata: {e}",
|
|
db_path=str(self.db_path),
|
|
operation="set_metadata"
|
|
) from e
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get index statistics.
|
|
|
|
Returns:
|
|
Dictionary with total_postings, unique_tokens, unique_chunks.
|
|
"""
|
|
with self._lock:
|
|
conn = self._get_connection()
|
|
try:
|
|
row = conn.execute("""
|
|
SELECT
|
|
COUNT(*) as total_postings,
|
|
COUNT(DISTINCT token_id) as unique_tokens,
|
|
COUNT(DISTINCT chunk_id) as unique_chunks
|
|
FROM splade_posting_list
|
|
""").fetchone()
|
|
|
|
return {
|
|
"total_postings": row["total_postings"],
|
|
"unique_tokens": row["unique_tokens"],
|
|
"unique_chunks": row["unique_chunks"]
|
|
}
|
|
except sqlite3.Error as e:
|
|
logger.error("Failed to get stats: %s", e)
|
|
return {
|
|
"total_postings": 0,
|
|
"unique_tokens": 0,
|
|
"unique_chunks": 0
|
|
}
|