mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-11 02:33:51 +08:00
feat: Enhance BinaryANNIndex with vectorized search and performance benchmarking
This commit is contained in:
@@ -487,6 +487,11 @@ class BinaryANNIndex:
|
||||
self._vectors: dict[int, bytes] = {}
|
||||
self._id_list: list[int] = [] # Ordered list for efficient iteration
|
||||
|
||||
# Cached numpy array for vectorized search (invalidated on add/remove)
|
||||
self._vectors_matrix: Optional[np.ndarray] = None
|
||||
self._ids_array: Optional[np.ndarray] = None
|
||||
self._cache_valid: bool = False
|
||||
|
||||
logger.info(
|
||||
f"Initialized BinaryANNIndex with dim={dim}, packed_dim={self.packed_dim}"
|
||||
)
|
||||
@@ -524,6 +529,9 @@ class BinaryANNIndex:
|
||||
self._id_list.append(vec_id)
|
||||
self._vectors[vec_id] = vec
|
||||
|
||||
# Invalidate cache on modification
|
||||
self._cache_valid = False
|
||||
|
||||
logger.debug(
|
||||
f"Added {len(ids)} binary vectors to index (total: {len(self._vectors)})"
|
||||
)
|
||||
@@ -599,6 +607,8 @@ class BinaryANNIndex:
|
||||
# Rebuild ID list efficiently - O(N) once instead of O(N) per removal
|
||||
if removed_count > 0:
|
||||
self._id_list = [id_ for id_ in self._id_list if id_ not in ids_to_remove]
|
||||
# Invalidate cache on modification
|
||||
self._cache_valid = False
|
||||
|
||||
logger.debug(f"Removed {removed_count}/{len(ids)} vectors from index")
|
||||
|
||||
@@ -610,11 +620,42 @@ class BinaryANNIndex:
|
||||
f"Failed to remove vectors from Binary ANN index: {e}"
|
||||
)
|
||||
|
||||
def _build_cache(self) -> None:
|
||||
"""Build numpy array cache from vectors dict for vectorized search.
|
||||
|
||||
Pre-computes a contiguous numpy array from all vectors for efficient
|
||||
batch distance computation. Called lazily on first search after modification.
|
||||
"""
|
||||
if self._cache_valid:
|
||||
return
|
||||
|
||||
n_vectors = len(self._id_list)
|
||||
if n_vectors == 0:
|
||||
self._vectors_matrix = None
|
||||
self._ids_array = None
|
||||
self._cache_valid = True
|
||||
return
|
||||
|
||||
# Build contiguous numpy array of all packed vectors
|
||||
# Shape: (n_vectors, packed_dim) with uint8 dtype
|
||||
self._vectors_matrix = np.empty((n_vectors, self.packed_dim), dtype=np.uint8)
|
||||
self._ids_array = np.array(self._id_list, dtype=np.int64)
|
||||
|
||||
for i, vec_id in enumerate(self._id_list):
|
||||
vec_bytes = self._vectors[vec_id]
|
||||
self._vectors_matrix[i] = np.frombuffer(vec_bytes, dtype=np.uint8)
|
||||
|
||||
self._cache_valid = True
|
||||
logger.debug(f"Built vectorized cache for {n_vectors} binary vectors")
|
||||
|
||||
def search(
|
||||
self, query: bytes, top_k: int = 10
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
"""Search for nearest neighbors using Hamming distance.
|
||||
|
||||
Uses vectorized batch computation for O(N) search with SIMD acceleration.
|
||||
Pre-computes and caches numpy arrays for efficient repeated queries.
|
||||
|
||||
Args:
|
||||
query: Packed binary query vector (size: packed_dim bytes)
|
||||
top_k: Number of nearest neighbors to return
|
||||
@@ -638,27 +679,48 @@ class BinaryANNIndex:
|
||||
if len(self._vectors) == 0:
|
||||
return [], []
|
||||
|
||||
# Compute Hamming distances to all vectors
|
||||
# Build cache if needed (lazy initialization)
|
||||
self._build_cache()
|
||||
|
||||
if self._vectors_matrix is None or self._ids_array is None:
|
||||
return [], []
|
||||
|
||||
# Vectorized Hamming distance computation
|
||||
# 1. Convert query to numpy array
|
||||
query_arr = np.frombuffer(query, dtype=np.uint8)
|
||||
distances = []
|
||||
|
||||
for vec_id in self._id_list:
|
||||
vec = self._vectors[vec_id]
|
||||
vec_arr = np.frombuffer(vec, dtype=np.uint8)
|
||||
# XOR and popcount for Hamming distance
|
||||
xor = np.bitwise_xor(query_arr, vec_arr)
|
||||
dist = int(np.unpackbits(xor).sum())
|
||||
distances.append((vec_id, dist))
|
||||
# 2. Broadcast XOR: (1, packed_dim) XOR (n_vectors, packed_dim)
|
||||
# Result shape: (n_vectors, packed_dim)
|
||||
xor_result = np.bitwise_xor(query_arr, self._vectors_matrix)
|
||||
|
||||
# Sort by distance (ascending)
|
||||
distances.sort(key=lambda x: x[1])
|
||||
# 3. Vectorized popcount using lookup table for efficiency
|
||||
# np.unpackbits is slow for large arrays, use popcount LUT instead
|
||||
popcount_lut = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
|
||||
bit_counts = popcount_lut[xor_result]
|
||||
|
||||
# Return top-k
|
||||
top_results = distances[:top_k]
|
||||
ids = [r[0] for r in top_results]
|
||||
dists = [r[1] for r in top_results]
|
||||
# 4. Sum across packed bytes to get Hamming distance per vector
|
||||
distances = bit_counts.sum(axis=1)
|
||||
|
||||
return ids, dists
|
||||
# 5. Get top-k using argpartition (O(N) instead of O(N log N) for full sort)
|
||||
n_vectors = len(distances)
|
||||
k = min(top_k, n_vectors)
|
||||
|
||||
if k == n_vectors:
|
||||
# No partitioning needed, just sort all
|
||||
sorted_indices = np.argsort(distances)
|
||||
else:
|
||||
# Use argpartition for O(N) partial sort
|
||||
partition_indices = np.argpartition(distances, k)[:k]
|
||||
# Sort only the top-k
|
||||
top_k_distances = distances[partition_indices]
|
||||
sorted_order = np.argsort(top_k_distances)
|
||||
sorted_indices = partition_indices[sorted_order]
|
||||
|
||||
# 6. Return results
|
||||
result_ids = self._ids_array[sorted_indices].tolist()
|
||||
result_dists = distances[sorted_indices].tolist()
|
||||
|
||||
return result_ids, result_dists
|
||||
|
||||
except Exception as e:
|
||||
raise StorageError(f"Failed to search Binary ANN index: {e}")
|
||||
@@ -797,6 +859,7 @@ class BinaryANNIndex:
|
||||
# Clear existing data
|
||||
self._vectors.clear()
|
||||
self._id_list.clear()
|
||||
self._cache_valid = False
|
||||
|
||||
# Read vectors
|
||||
for _ in range(num_vectors):
|
||||
@@ -853,6 +916,9 @@ class BinaryANNIndex:
|
||||
with self._lock:
|
||||
self._vectors.clear()
|
||||
self._id_list.clear()
|
||||
self._vectors_matrix = None
|
||||
self._ids_array = None
|
||||
self._cache_valid = False
|
||||
logger.debug("Cleared binary index")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user