mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-13 02:41:50 +08:00
fix(embedder): add lock protection for cache read operations
Protect fast path cache read in get_embedder() to prevent KeyError during concurrent access and cache clearing operations. Solution-ID: SOL-1735392000001 Issue-ID: ISS-1766921318981-2 Task-ID: T1
This commit is contained in:
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Global embedder cache for singleton pattern
|
# Global embedder cache for singleton pattern
|
||||||
_embedder_cache: Dict[str, "Embedder"] = {}
|
_embedder_cache: Dict[str, "Embedder"] = {}
|
||||||
_cache_lock = threading.Lock()
|
_cache_lock = threading.RLock()
|
||||||
|
|
||||||
|
|
||||||
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
|
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
|
||||||
@@ -43,15 +43,12 @@ def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
|
|||||||
# Cache key includes GPU preference to support mixed configurations
|
# Cache key includes GPU preference to support mixed configurations
|
||||||
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
|
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
|
||||||
|
|
||||||
# Fast path: check cache without lock
|
# All cache access is protected by _cache_lock to avoid races with
|
||||||
if cache_key in _embedder_cache:
|
# clear_embedder_cache() during concurrent access.
|
||||||
return _embedder_cache[cache_key]
|
|
||||||
|
|
||||||
# Slow path: acquire lock for initialization
|
|
||||||
with _cache_lock:
|
with _cache_lock:
|
||||||
# Double-check after acquiring lock
|
embedder = _embedder_cache.get(cache_key)
|
||||||
if cache_key in _embedder_cache:
|
if embedder is not None:
|
||||||
return _embedder_cache[cache_key]
|
return embedder
|
||||||
|
|
||||||
# Create new embedder and cache it
|
# Create new embedder and cache it
|
||||||
embedder = Embedder(profile=profile, use_gpu=use_gpu)
|
embedder = Embedder(profile=profile, use_gpu=use_gpu)
|
||||||
|
|||||||
85
codex-lens/tests/test_embedder.py
Normal file
85
codex-lens/tests/test_embedder.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Tests for embedder cache concurrency."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import codexlens.semantic.embedder as embedder_module
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_embedder_for_unit_tests(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Make get_embedder() tests deterministic and fast (no model downloads)."""
|
||||||
|
|
||||||
|
monkeypatch.setattr(embedder_module, "SEMANTIC_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(embedder_module, "get_optimal_providers", lambda *args, **kwargs: [])
|
||||||
|
monkeypatch.setattr(embedder_module, "is_gpu_available", lambda: False)
|
||||||
|
monkeypatch.setattr(embedder_module.Embedder, "_load_model", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedder_instances_are_cached_and_reused(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
_patch_embedder_for_unit_tests(monkeypatch)
|
||||||
|
embedder_module.clear_embedder_cache()
|
||||||
|
|
||||||
|
first = embedder_module.get_embedder(profile="code", use_gpu=False)
|
||||||
|
second = embedder_module.get_embedder(profile="code", use_gpu=False)
|
||||||
|
|
||||||
|
assert first is second
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_cache_access(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
_patch_embedder_for_unit_tests(monkeypatch)
|
||||||
|
embedder_module.clear_embedder_cache()
|
||||||
|
|
||||||
|
profiles = ["fast", "code", "balanced", "multilingual"]
|
||||||
|
for profile in profiles:
|
||||||
|
embedder_module.get_embedder(profile=profile, use_gpu=False)
|
||||||
|
|
||||||
|
errors: list[BaseException] = []
|
||||||
|
errors_lock = threading.Lock()
|
||||||
|
|
||||||
|
def record_error(err: BaseException) -> None:
|
||||||
|
with errors_lock:
|
||||||
|
errors.append(err)
|
||||||
|
|
||||||
|
worker_count = 20
|
||||||
|
start_barrier = threading.Barrier(worker_count + 1)
|
||||||
|
stop_at = time.monotonic() + 1.0
|
||||||
|
|
||||||
|
def clear_worker() -> None:
|
||||||
|
try:
|
||||||
|
start_barrier.wait()
|
||||||
|
while time.monotonic() < stop_at:
|
||||||
|
embedder_module.clear_embedder_cache()
|
||||||
|
time.sleep(0)
|
||||||
|
except BaseException as err:
|
||||||
|
record_error(err)
|
||||||
|
|
||||||
|
def access_worker(profile: str) -> None:
|
||||||
|
try:
|
||||||
|
start_barrier.wait()
|
||||||
|
while time.monotonic() < stop_at:
|
||||||
|
embedder_module.get_embedder(profile=profile, use_gpu=False)
|
||||||
|
except BaseException as err:
|
||||||
|
record_error(err)
|
||||||
|
|
||||||
|
threads: list[threading.Thread] = [
|
||||||
|
threading.Thread(target=clear_worker, name="clear-embedder-cache"),
|
||||||
|
]
|
||||||
|
for idx in range(worker_count):
|
||||||
|
threads.append(
|
||||||
|
threading.Thread(
|
||||||
|
target=access_worker,
|
||||||
|
name=f"get-embedder-{idx}",
|
||||||
|
args=(profiles[idx % len(profiles)],),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join(timeout=10)
|
||||||
|
|
||||||
|
assert not errors, f"Unexpected errors during concurrent access: {errors!r}"
|
||||||
Reference in New Issue
Block a user