feat: Enhance BinaryANNIndex with vectorized search and performance benchmarking

This commit is contained in:
catlog22
2026-01-02 11:49:54 +08:00
parent da68ba0b82
commit 9129c981a4
4 changed files with 479 additions and 140 deletions

View File

@@ -608,31 +608,43 @@ class ChainSearchEngine:
for index_path, chunk_ids in candidates_by_index.items():
try:
store = SQLiteStore(index_path)
dense_embeddings = store.get_dense_embeddings(chunk_ids)
chunks_data = store.get_chunks_by_ids(chunk_ids)
# Read directly from semantic_chunks table (where cascade-index stores data)
import sqlite3
conn = sqlite3.connect(str(index_path))
conn.row_factory = sqlite3.Row
# Create lookup for chunk content
chunk_content: Dict[int, Dict[str, Any]] = {
c["id"]: c for c in chunks_data
}
placeholders = ",".join("?" * len(chunk_ids))
rows = conn.execute(
f"SELECT id, file_path, content, embedding_dense FROM semantic_chunks WHERE id IN ({placeholders})",
chunk_ids
).fetchall()
conn.close()
for chunk_id in chunk_ids:
dense_bytes = dense_embeddings.get(chunk_id)
chunk_info = chunk_content.get(chunk_id)
# Batch processing: collect all valid embeddings first
valid_rows = []
dense_vectors = []
for row in rows:
dense_bytes = row["embedding_dense"]
if dense_bytes is not None:
valid_rows.append(row)
dense_vectors.append(np.frombuffer(dense_bytes, dtype=np.float32))
if dense_bytes is None or chunk_info is None:
continue
if not dense_vectors:
continue
# Compute cosine similarity
dense_vec = np.frombuffer(dense_bytes, dtype=np.float32)
score = self._compute_cosine_similarity(query_dense, dense_vec)
# Stack into matrix for batch computation
doc_matrix = np.vstack(dense_vectors)
# Create search result
excerpt = chunk_info.get("content", "")[:500]
# Batch compute cosine similarities
scores = self._compute_cosine_similarity_batch(query_dense, doc_matrix)
# Create search results
for i, row in enumerate(valid_rows):
score = float(scores[i])
excerpt = (row["content"] or "")[:500]
result = SearchResult(
path=chunk_info.get("file_path", ""),
score=float(score),
path=row["file_path"] or "",
score=score,
excerpt=excerpt,
)
scored_results.append((score, result))
@@ -783,6 +795,58 @@ class ChainSearchEngine:
return float(dot_product / (norm_q * norm_d))
def _compute_cosine_similarity_batch(
self,
query_vec: "np.ndarray",
doc_matrix: "np.ndarray",
) -> "np.ndarray":
"""Compute cosine similarity between query and multiple document vectors.
Uses vectorized matrix operations for efficient batch computation.
Args:
query_vec: Query embedding vector of shape (dim,)
doc_matrix: Document embeddings matrix of shape (n_docs, dim)
Returns:
Array of cosine similarity scores of shape (n_docs,)
"""
if not NUMPY_AVAILABLE:
return np.zeros(doc_matrix.shape[0])
# Ensure query is 1D
if query_vec.ndim > 1:
query_vec = query_vec.flatten()
# Handle dimension mismatch by truncating to smaller dimension
min_dim = min(len(query_vec), doc_matrix.shape[1])
q = query_vec[:min_dim]
docs = doc_matrix[:, :min_dim]
# Compute query norm once
norm_q = np.linalg.norm(q)
if norm_q == 0:
return np.zeros(docs.shape[0])
# Normalize query
q_normalized = q / norm_q
# Compute document norms (vectorized)
doc_norms = np.linalg.norm(docs, axis=1)
# Avoid division by zero
nonzero_mask = doc_norms > 0
scores = np.zeros(docs.shape[0], dtype=np.float32)
if np.any(nonzero_mask):
# Normalize documents with non-zero norms
docs_normalized = docs[nonzero_mask] / doc_norms[nonzero_mask, np.newaxis]
# Batch dot product: (n_docs, dim) @ (dim,) = (n_docs,)
scores[nonzero_mask] = docs_normalized @ q_normalized
return scores
def _build_results_from_candidates(
self,
candidates: List[Tuple[int, int, Path]],

View File

@@ -487,6 +487,11 @@ class BinaryANNIndex:
self._vectors: dict[int, bytes] = {}
self._id_list: list[int] = [] # Ordered list for efficient iteration
# Cached numpy array for vectorized search (invalidated on add/remove)
self._vectors_matrix: Optional[np.ndarray] = None
self._ids_array: Optional[np.ndarray] = None
self._cache_valid: bool = False
logger.info(
f"Initialized BinaryANNIndex with dim={dim}, packed_dim={self.packed_dim}"
)
@@ -524,6 +529,9 @@ class BinaryANNIndex:
self._id_list.append(vec_id)
self._vectors[vec_id] = vec
# Invalidate cache on modification
self._cache_valid = False
logger.debug(
f"Added {len(ids)} binary vectors to index (total: {len(self._vectors)})"
)
@@ -599,6 +607,8 @@ class BinaryANNIndex:
# Rebuild ID list efficiently - O(N) once instead of O(N) per removal
if removed_count > 0:
self._id_list = [id_ for id_ in self._id_list if id_ not in ids_to_remove]
# Invalidate cache on modification
self._cache_valid = False
logger.debug(f"Removed {removed_count}/{len(ids)} vectors from index")
@@ -610,11 +620,42 @@ class BinaryANNIndex:
f"Failed to remove vectors from Binary ANN index: {e}"
)
def _build_cache(self) -> None:
"""Build numpy array cache from vectors dict for vectorized search.
Pre-computes a contiguous numpy array from all vectors for efficient
batch distance computation. Called lazily on first search after modification.
"""
if self._cache_valid:
return
n_vectors = len(self._id_list)
if n_vectors == 0:
self._vectors_matrix = None
self._ids_array = None
self._cache_valid = True
return
# Build contiguous numpy array of all packed vectors
# Shape: (n_vectors, packed_dim) with uint8 dtype
self._vectors_matrix = np.empty((n_vectors, self.packed_dim), dtype=np.uint8)
self._ids_array = np.array(self._id_list, dtype=np.int64)
for i, vec_id in enumerate(self._id_list):
vec_bytes = self._vectors[vec_id]
self._vectors_matrix[i] = np.frombuffer(vec_bytes, dtype=np.uint8)
self._cache_valid = True
logger.debug(f"Built vectorized cache for {n_vectors} binary vectors")
def search(
self, query: bytes, top_k: int = 10
) -> Tuple[List[int], List[int]]:
"""Search for nearest neighbors using Hamming distance.
Uses vectorized batch computation for O(N) search with SIMD acceleration.
Pre-computes and caches numpy arrays for efficient repeated queries.
Args:
query: Packed binary query vector (size: packed_dim bytes)
top_k: Number of nearest neighbors to return
@@ -638,27 +679,48 @@ class BinaryANNIndex:
if len(self._vectors) == 0:
return [], []
# Compute Hamming distances to all vectors
# Build cache if needed (lazy initialization)
self._build_cache()
if self._vectors_matrix is None or self._ids_array is None:
return [], []
# Vectorized Hamming distance computation
# 1. Convert query to numpy array
query_arr = np.frombuffer(query, dtype=np.uint8)
distances = []
for vec_id in self._id_list:
vec = self._vectors[vec_id]
vec_arr = np.frombuffer(vec, dtype=np.uint8)
# XOR and popcount for Hamming distance
xor = np.bitwise_xor(query_arr, vec_arr)
dist = int(np.unpackbits(xor).sum())
distances.append((vec_id, dist))
# 2. Broadcast XOR: (1, packed_dim) XOR (n_vectors, packed_dim)
# Result shape: (n_vectors, packed_dim)
xor_result = np.bitwise_xor(query_arr, self._vectors_matrix)
# Sort by distance (ascending)
distances.sort(key=lambda x: x[1])
# 3. Vectorized popcount using lookup table for efficiency
# np.unpackbits is slow for large arrays, use popcount LUT instead
popcount_lut = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
bit_counts = popcount_lut[xor_result]
# Return top-k
top_results = distances[:top_k]
ids = [r[0] for r in top_results]
dists = [r[1] for r in top_results]
# 4. Sum across packed bytes to get Hamming distance per vector
distances = bit_counts.sum(axis=1)
return ids, dists
# 5. Get top-k using argpartition (O(N) instead of O(N log N) for full sort)
n_vectors = len(distances)
k = min(top_k, n_vectors)
if k == n_vectors:
# No partitioning needed, just sort all
sorted_indices = np.argsort(distances)
else:
# Use argpartition for O(N) partial sort
partition_indices = np.argpartition(distances, k)[:k]
# Sort only the top-k
top_k_distances = distances[partition_indices]
sorted_order = np.argsort(top_k_distances)
sorted_indices = partition_indices[sorted_order]
# 6. Return results
result_ids = self._ids_array[sorted_indices].tolist()
result_dists = distances[sorted_indices].tolist()
return result_ids, result_dists
except Exception as e:
raise StorageError(f"Failed to search Binary ANN index: {e}")
@@ -797,6 +859,7 @@ class BinaryANNIndex:
# Clear existing data
self._vectors.clear()
self._id_list.clear()
self._cache_valid = False
# Read vectors
for _ in range(num_vectors):
@@ -853,6 +916,9 @@ class BinaryANNIndex:
with self._lock:
self._vectors.clear()
self._id_list.clear()
self._vectors_matrix = None
self._ids_array = None
self._cache_valid = False
logger.debug("Cleared binary index")