feat: Enhance embedding generation and search capabilities

- Added pre-calculation of estimated chunk count for HNSW capacity in `generate_dense_embeddings_centralized` to optimize indexing performance.
- Implemented binary vector generation with memory-mapped storage for efficient cascade search, including metadata saving.
- Introduced SPLADE sparse index generation with improved handling and metadata storage.
- Updated `ChainSearchEngine` to prefer centralized binary searcher for improved performance and added fallback to legacy binary index.
- Deprecated `BinaryANNIndex` in favor of `BinarySearcher` for better memory management and performance.
- Enhanced `SpladeEncoder` with warmup functionality to reduce latency spikes during first-time inference.
- Improved `SpladeIndex` with cache size adjustments for better query performance.
- Added methods for managing binary vectors in `VectorMetadataStore`, including batch insertion and retrieval.
- Created a new `BinarySearcher` class for efficient binary vector search using Hamming distance, supporting both memory-mapped and database loading modes.
This commit is contained in:
catlog22
2026-01-02 23:57:55 +08:00
parent 96b44e1482
commit 54fd94547c
12 changed files with 945 additions and 167 deletions

View File

@@ -1170,6 +1170,22 @@ def generate_dense_embeddings_centralized(
if progress_callback:
progress_callback(f"Found {len(index_files)} index databases for centralized embedding")
# Pre-calculate estimated chunk count for HNSW capacity
# This avoids expensive resize operations during indexing
estimated_total_files = 0
for index_path in index_files:
try:
with sqlite3.connect(index_path) as conn:
cursor = conn.execute("SELECT COUNT(*) FROM files")
estimated_total_files += cursor.fetchone()[0]
except Exception:
pass
# Heuristic: ~15 chunks per file on average
estimated_chunks = max(100000, estimated_total_files * 15)
if progress_callback:
progress_callback(f"Estimated {estimated_total_files} files, ~{estimated_chunks} chunks")
# Check for existing centralized index
central_hnsw_path = index_root / VECTORS_HNSW_NAME
if central_hnsw_path.exists() and not force:
@@ -1217,11 +1233,12 @@ def generate_dense_embeddings_centralized(
"error": f"Failed to initialize components: {str(e)}",
}
# Create centralized ANN index
# Create centralized ANN index with pre-calculated capacity
# Using estimated_chunks avoids expensive resize operations during indexing
central_ann_index = ANNIndex.create_central(
index_root=index_root,
dim=embedder.embedding_dim,
initial_capacity=100000, # Larger capacity for centralized index
initial_capacity=estimated_chunks,
auto_save=False,
)
@@ -1360,6 +1377,148 @@ def generate_dense_embeddings_centralized(
logger.warning("Failed to store vector metadata: %s", e)
# Non-fatal: continue without centralized metadata
# --- Binary Vector Generation for Cascade Search (Memory-Mapped) ---
binary_success = False
binary_count = 0
try:
from codexlens.config import Config, BINARY_VECTORS_MMAP_NAME
config = Config.load()
if getattr(config, 'enable_binary_cascade', True) and all_embeddings:
import numpy as np
if progress_callback:
progress_callback(f"Generating binary vectors for {len(all_embeddings)} chunks...")
# Binarize dense vectors: sign(x) -> 1 if x > 0, 0 otherwise
# Pack into bytes for efficient storage and Hamming distance computation
embeddings_matrix = np.vstack(all_embeddings)
binary_matrix = (embeddings_matrix > 0).astype(np.uint8)
# Pack bits into bytes (8 bits per byte) - vectorized for all rows
packed_matrix = np.packbits(binary_matrix, axis=1)
binary_count = len(packed_matrix)
# Save as memory-mapped file for efficient loading
binary_mmap_path = index_root / BINARY_VECTORS_MMAP_NAME
mmap_array = np.memmap(
str(binary_mmap_path),
dtype=np.uint8,
mode='w+',
shape=packed_matrix.shape
)
mmap_array[:] = packed_matrix
mmap_array.flush()
del mmap_array # Close the memmap
# Save metadata (shape and chunk_ids) to sidecar JSON
import json
meta_path = binary_mmap_path.with_suffix('.meta.json')
with open(meta_path, 'w') as f:
json.dump({
'shape': list(packed_matrix.shape),
'chunk_ids': all_chunk_ids,
'embedding_dim': embeddings_matrix.shape[1],
}, f)
# Also store in DB for backward compatibility
from codexlens.storage.vector_meta_store import VectorMetadataStore
binary_packed_bytes = [row.tobytes() for row in packed_matrix]
with VectorMetadataStore(vectors_meta_path) as meta_store:
meta_store.add_binary_vectors(all_chunk_ids, binary_packed_bytes)
binary_success = True
if progress_callback:
progress_callback(f"Generated {binary_count} binary vectors ({embeddings_matrix.shape[1]} dims -> {packed_matrix.shape[1]} bytes, mmap: {binary_mmap_path.name})")
except Exception as e:
logger.warning("Binary vector generation failed: %s", e)
# Non-fatal: continue without binary vectors
# --- SPLADE Sparse Index Generation (Centralized) ---
splade_success = False
splade_chunks_count = 0
try:
from codexlens.config import Config
config = Config.load()
if config.enable_splade and chunk_id_to_info:
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
from codexlens.storage.splade_index import SpladeIndex
import json
ok, err = check_splade_available()
if ok:
if progress_callback:
progress_callback(f"Generating SPLADE sparse vectors for {len(chunk_id_to_info)} chunks...")
# Initialize SPLADE encoder and index
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
splade_db_path = index_root / SPLADE_DB_NAME
splade_index = SpladeIndex(splade_db_path)
splade_index.create_tables()
# Batch encode for efficiency
SPLADE_BATCH_SIZE = 32
all_postings = []
all_chunk_metadata = []
# Create batches from chunk_id_to_info
chunk_items = list(chunk_id_to_info.items())
for i in range(0, len(chunk_items), SPLADE_BATCH_SIZE):
batch_items = chunk_items[i:i + SPLADE_BATCH_SIZE]
chunk_ids = [item[0] for item in batch_items]
chunk_contents = [item[1]["content"] for item in batch_items]
# Generate sparse vectors
sparse_vecs = splade_encoder.encode_batch(chunk_contents, batch_size=SPLADE_BATCH_SIZE)
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
all_postings.append((cid, sparse_vec))
if progress_callback and (i + SPLADE_BATCH_SIZE) % 100 == 0:
progress_callback(f"SPLADE encoding: {min(i + SPLADE_BATCH_SIZE, len(chunk_items))}/{len(chunk_items)}")
# Batch insert all postings
if all_postings:
splade_index.add_postings_batch(all_postings)
# CRITICAL FIX: Populate splade_chunks table
for cid, info in chunk_id_to_info.items():
metadata_str = json.dumps(info.get("metadata", {})) if info.get("metadata") else None
all_chunk_metadata.append((
cid,
info["file_path"],
info["content"],
metadata_str,
info.get("source_index_db")
))
if all_chunk_metadata:
splade_index.add_chunks_metadata_batch(all_chunk_metadata)
splade_chunks_count = len(all_chunk_metadata)
# Set metadata
splade_index.set_metadata(
model_name=splade_encoder.model_name,
vocab_size=splade_encoder.vocab_size
)
splade_index.close()
splade_success = True
if progress_callback:
progress_callback(f"SPLADE index created: {len(all_postings)} postings, {splade_chunks_count} chunks")
else:
if progress_callback:
progress_callback(f"SPLADE not available, skipping sparse index: {err}")
except Exception as e:
logger.warning("SPLADE encoding failed: %s", e)
if progress_callback:
progress_callback(f"SPLADE encoding failed: {e}")
elapsed_time = time.time() - start_time
# Cleanup
@@ -1380,6 +1539,10 @@ def generate_dense_embeddings_centralized(
"model_name": embedder.model_name,
"central_index_path": str(central_hnsw_path),
"failed_files": failed_files[:5],
"splade_success": splade_success,
"splade_chunks": splade_chunks_count,
"binary_success": binary_success,
"binary_count": binary_count,
},
}