mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
fix(vector-store): protect bulk insert mode transitions with lock
Ensure begin_bulk_insert() and end_bulk_insert() are fully lock-protected to prevent TOCTOU race conditions. Solution-ID: SOL-1735392000003 Issue-ID: ISS-1766921318981-12 Task-ID: T2
This commit is contained in:
@@ -614,9 +614,10 @@ class VectorStore:
|
|||||||
for batch in batches:
|
for batch in batches:
|
||||||
store.add_chunks_batch(batch)
|
store.add_chunks_batch(batch)
|
||||||
"""
|
"""
|
||||||
self._bulk_insert_mode = True
|
with self._ann_write_lock:
|
||||||
self._bulk_insert_ids.clear()
|
self._bulk_insert_mode = True
|
||||||
self._bulk_insert_embeddings.clear()
|
self._bulk_insert_ids.clear()
|
||||||
|
self._bulk_insert_embeddings.clear()
|
||||||
logger.debug("Entered bulk insert mode")
|
logger.debug("Entered bulk insert mode")
|
||||||
|
|
||||||
def end_bulk_insert(self) -> None:
|
def end_bulk_insert(self) -> None:
|
||||||
@@ -625,30 +626,32 @@ class VectorStore:
|
|||||||
This method should be called after all bulk inserts are complete to
|
This method should be called after all bulk inserts are complete to
|
||||||
update the ANN index in a single batch operation.
|
update the ANN index in a single batch operation.
|
||||||
"""
|
"""
|
||||||
if not self._bulk_insert_mode:
|
with self._ann_write_lock:
|
||||||
logger.warning("end_bulk_insert called but not in bulk insert mode")
|
if not self._bulk_insert_mode:
|
||||||
return
|
logger.warning("end_bulk_insert called but not in bulk insert mode")
|
||||||
|
return
|
||||||
|
|
||||||
self._bulk_insert_mode = False
|
self._bulk_insert_mode = False
|
||||||
|
bulk_ids = list(self._bulk_insert_ids)
|
||||||
|
bulk_embeddings = list(self._bulk_insert_embeddings)
|
||||||
|
self._bulk_insert_ids.clear()
|
||||||
|
self._bulk_insert_embeddings.clear()
|
||||||
|
|
||||||
# Update ANN index with all accumulated data
|
# Update ANN index with accumulated data.
|
||||||
if self._bulk_insert_ids and self._bulk_insert_embeddings:
|
if bulk_ids and bulk_embeddings:
|
||||||
if self._ensure_ann_index(len(self._bulk_insert_embeddings[0])):
|
if self._ensure_ann_index(len(bulk_embeddings[0])):
|
||||||
with self._ann_write_lock:
|
with self._ann_write_lock:
|
||||||
try:
|
try:
|
||||||
embeddings_matrix = np.vstack(self._bulk_insert_embeddings)
|
embeddings_matrix = np.vstack(bulk_embeddings)
|
||||||
self._ann_index.add_vectors(self._bulk_insert_ids, embeddings_matrix)
|
self._ann_index.add_vectors(bulk_ids, embeddings_matrix)
|
||||||
self._ann_index.save()
|
self._ann_index.save()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Bulk insert complete: added %d vectors to ANN index",
|
"Bulk insert complete: added %d vectors to ANN index",
|
||||||
len(self._bulk_insert_ids)
|
len(bulk_ids),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to update ANN index after bulk insert: %s", e)
|
logger.error("Failed to update ANN index after bulk insert: %s", e)
|
||||||
|
|
||||||
# Clear accumulated data
|
|
||||||
self._bulk_insert_ids.clear()
|
|
||||||
self._bulk_insert_embeddings.clear()
|
|
||||||
logger.debug("Exited bulk insert mode")
|
logger.debug("Exited bulk insert mode")
|
||||||
|
|
||||||
class BulkInsertContext:
|
class BulkInsertContext:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -65,3 +66,66 @@ def test_concurrent_bulk_insert(monkeypatch: pytest.MonkeyPatch, temp_db: Path)
|
|||||||
assert len(store._bulk_insert_embeddings) == 50
|
assert len(store._bulk_insert_embeddings) == 50
|
||||||
assert store.count_chunks() == 50
|
assert store.count_chunks() == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_bulk_insert_mode_transitions(monkeypatch: pytest.MonkeyPatch, temp_db: Path) -> None:
|
||||||
|
"""begin/end bulk insert should be thread-safe with concurrent add operations."""
|
||||||
|
store = VectorStore(temp_db)
|
||||||
|
|
||||||
|
class DummyAnn:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.total_added = 0
|
||||||
|
self.save_calls = 0
|
||||||
|
|
||||||
|
def add_vectors(self, ids, embeddings) -> None:
|
||||||
|
self.total_added += len(ids)
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
self.save_calls += 1
|
||||||
|
|
||||||
|
dummy_ann = DummyAnn()
|
||||||
|
store._ann_index = dummy_ann
|
||||||
|
monkeypatch.setattr(store, "_ensure_ann_index", lambda dim: True)
|
||||||
|
|
||||||
|
errors: list[Exception] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
def adder(worker_id: int) -> None:
|
||||||
|
try:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
chunk = SemanticChunk(content=f"chunk {worker_id}", metadata={})
|
||||||
|
chunk.embedding = np.random.randn(8).astype(np.float32)
|
||||||
|
store.add_chunks_batch([(chunk, f"file_{worker_id}.py")], auto_save_ann=False)
|
||||||
|
except Exception as exc:
|
||||||
|
with lock:
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
def toggler() -> None:
|
||||||
|
try:
|
||||||
|
for _ in range(5):
|
||||||
|
store.begin_bulk_insert()
|
||||||
|
time.sleep(0.05)
|
||||||
|
store.end_bulk_insert()
|
||||||
|
time.sleep(0.05)
|
||||||
|
except Exception as exc:
|
||||||
|
with lock:
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=adder, args=(i,)) for i in range(3)]
|
||||||
|
toggle_thread = threading.Thread(target=toggler)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
toggle_thread.start()
|
||||||
|
|
||||||
|
toggle_thread.join(timeout=10)
|
||||||
|
stop_event.set()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=10)
|
||||||
|
|
||||||
|
assert not errors
|
||||||
|
assert toggle_thread.is_alive() is False
|
||||||
|
assert store._bulk_insert_mode is False
|
||||||
|
assert store._bulk_insert_ids == []
|
||||||
|
assert store._bulk_insert_embeddings == []
|
||||||
|
assert dummy_ann.total_added == store.count_chunks()
|
||||||
|
|||||||
Reference in New Issue
Block a user