fix(vector-store): protect bulk insert mode transitions with lock

Ensure begin_bulk_insert() and end_bulk_insert() are fully
lock-protected to prevent TOCTOU race conditions.

Solution-ID: SOL-1735392000003
Issue-ID: ISS-1766921318981-12
Task-ID: T2
This commit is contained in:
catlog22
2025-12-29 19:20:02 +08:00
parent d8be23fa83
commit 5914b1c5fc
2 changed files with 83 additions and 16 deletions

View File

@@ -1,5 +1,6 @@
import tempfile
import threading
import time
from pathlib import Path
import numpy as np
@@ -65,3 +66,66 @@ def test_concurrent_bulk_insert(monkeypatch: pytest.MonkeyPatch, temp_db: Path)
assert len(store._bulk_insert_embeddings) == 50
assert store.count_chunks() == 50
def test_bulk_insert_mode_transitions(monkeypatch: pytest.MonkeyPatch, temp_db: Path) -> None:
"""begin/end bulk insert should be thread-safe with concurrent add operations."""
store = VectorStore(temp_db)
class DummyAnn:
def __init__(self) -> None:
self.total_added = 0
self.save_calls = 0
def add_vectors(self, ids, embeddings) -> None:
self.total_added += len(ids)
def save(self) -> None:
self.save_calls += 1
dummy_ann = DummyAnn()
store._ann_index = dummy_ann
monkeypatch.setattr(store, "_ensure_ann_index", lambda dim: True)
errors: list[Exception] = []
lock = threading.Lock()
stop_event = threading.Event()
def adder(worker_id: int) -> None:
try:
while not stop_event.is_set():
chunk = SemanticChunk(content=f"chunk {worker_id}", metadata={})
chunk.embedding = np.random.randn(8).astype(np.float32)
store.add_chunks_batch([(chunk, f"file_{worker_id}.py")], auto_save_ann=False)
except Exception as exc:
with lock:
errors.append(exc)
def toggler() -> None:
try:
for _ in range(5):
store.begin_bulk_insert()
time.sleep(0.05)
store.end_bulk_insert()
time.sleep(0.05)
except Exception as exc:
with lock:
errors.append(exc)
threads = [threading.Thread(target=adder, args=(i,)) for i in range(3)]
toggle_thread = threading.Thread(target=toggler)
for t in threads:
t.start()
toggle_thread.start()
toggle_thread.join(timeout=10)
stop_event.set()
for t in threads:
t.join(timeout=10)
assert not errors
assert toggle_thread.is_alive() is False
assert store._bulk_insert_mode is False
assert store._bulk_insert_ids == []
assert store._bulk_insert_embeddings == []
assert dummy_ann.total_added == store.count_chunks()