feat: Enhance embedding generation and search capabilities

- Added pre-calculation of estimated chunk count for HNSW capacity in `generate_dense_embeddings_centralized` to optimize indexing performance.
- Implemented binary vector generation with memory-mapped storage for efficient cascade search, including metadata saving.
- Introduced SPLADE sparse index generation with improved handling and metadata storage.
- Updated `ChainSearchEngine` to prefer centralized binary searcher for improved performance and added fallback to legacy binary index.
- Deprecated `BinaryANNIndex` in favor of `BinarySearcher` for better memory management and performance.
- Enhanced `SpladeEncoder` with warmup functionality to reduce latency spikes during first-time inference.
- Improved `SpladeIndex` with cache size adjustments for better query performance.
- Added methods for managing binary vectors in `VectorMetadataStore`, including batch insertion and retrieval.
- Created a new `BinarySearcher` class for efficient binary vector search using Hamming distance, supporting both memory-mapped and database loading modes.
This commit is contained in:
catlog22
2026-01-02 23:57:55 +08:00
parent 96b44e1482
commit 54fd94547c
12 changed files with 945 additions and 167 deletions

View File

@@ -59,6 +59,8 @@ class SpladeIndex:
conn.execute("PRAGMA foreign_keys=ON")
# Limit mmap to 1GB to avoid OOM on smaller systems
conn.execute("PRAGMA mmap_size=1073741824")
# Increase cache size for better query performance (20MB = -20000 pages)
conn.execute("PRAGMA cache_size=-20000")
self._local.conn = conn
return conn
@@ -385,25 +387,29 @@ class SpladeIndex:
self,
query_sparse: Dict[int, float],
limit: int = 50,
min_score: float = 0.0
min_score: float = 0.0,
max_query_terms: int = 64
) -> List[Tuple[int, float]]:
"""Search for similar chunks using dot-product scoring.
Implements efficient sparse dot-product via SQL JOIN:
score(q, d) = sum(q[t] * d[t]) for all tokens t
Args:
query_sparse: Query sparse vector as {token_id: weight}.
limit: Maximum number of results.
min_score: Minimum score threshold.
max_query_terms: Maximum query terms to use (default: 64).
Pruning to top-K terms reduces search time with minimal impact on quality.
Set to 0 or negative to disable pruning (use all terms).
Returns:
List of (chunk_id, score) tuples, ordered by score descending.
"""
if not query_sparse:
logger.warning("Empty query sparse vector")
return []
with self._lock:
conn = self._get_connection()
try:
@@ -414,10 +420,20 @@ class SpladeIndex:
for token_id, weight in query_sparse.items()
if weight > 0
]
if not query_terms:
logger.warning("No non-zero query terms")
return []
# Query pruning: keep only top-K terms by weight
# max_query_terms <= 0 means no limit (use all terms)
if max_query_terms > 0 and len(query_terms) > max_query_terms:
query_terms = sorted(query_terms, key=lambda x: x[1], reverse=True)[:max_query_terms]
logger.debug(
"Query pruned from %d to %d terms",
len(query_sparse),
len(query_terms)
)
# Create CTE for query terms using parameterized VALUES
# Build placeholders and params to prevent SQL injection

View File

@@ -96,6 +96,13 @@ class VectorMetadataStore:
'CREATE INDEX IF NOT EXISTS idx_chunk_category '
'ON chunk_metadata(category)'
)
# Binary vectors table for cascade search
conn.execute('''
CREATE TABLE IF NOT EXISTS binary_vectors (
chunk_id INTEGER PRIMARY KEY,
vector BLOB NOT NULL
)
''')
conn.commit()
logger.debug("VectorMetadataStore schema created/verified")
except sqlite3.Error as e:
@@ -329,3 +336,80 @@ class VectorMetadataStore:
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
# ============= Binary Vector Methods for Cascade Search =============
def add_binary_vectors(
self, chunk_ids: List[int], binary_vectors: List[bytes]
) -> None:
"""Batch insert binary vectors for cascade search.
Args:
chunk_ids: List of chunk IDs.
binary_vectors: List of packed binary vectors (as bytes).
"""
if not chunk_ids or len(chunk_ids) != len(binary_vectors):
return
with self._lock:
conn = self._get_connection()
try:
data = list(zip(chunk_ids, binary_vectors))
conn.executemany(
"INSERT OR REPLACE INTO binary_vectors (chunk_id, vector) VALUES (?, ?)",
data
)
conn.commit()
logger.debug("Added %d binary vectors", len(chunk_ids))
except sqlite3.Error as e:
raise StorageError(
f"Failed to add binary vectors: {e}",
db_path=str(self.db_path),
operation="add_binary_vectors"
) from e
def get_all_binary_vectors(self) -> List[tuple]:
"""Get all binary vectors for cascade search.
Returns:
List of (chunk_id, vector_bytes) tuples.
"""
conn = self._get_connection()
try:
rows = conn.execute(
"SELECT chunk_id, vector FROM binary_vectors"
).fetchall()
return [(row[0], row[1]) for row in rows]
except sqlite3.Error as e:
logger.error("Failed to get binary vectors: %s", e)
return []
def get_binary_vector_count(self) -> int:
"""Get total number of binary vectors.
Returns:
Binary vector count.
"""
conn = self._get_connection()
try:
row = conn.execute(
"SELECT COUNT(*) FROM binary_vectors"
).fetchone()
return row[0] if row else 0
except sqlite3.Error:
return 0
def clear_binary_vectors(self) -> None:
"""Clear all binary vectors."""
with self._lock:
conn = self._get_connection()
try:
conn.execute("DELETE FROM binary_vectors")
conn.commit()
logger.info("Cleared all binary vectors")
except sqlite3.Error as e:
raise StorageError(
f"Failed to clear binary vectors: {e}",
db_path=str(self.db_path),
operation="clear_binary_vectors"
) from e