From 13960104371019b87781f1ea68b5de90946298c9 Mon Sep 17 00:00:00 2001 From: catlog22 Date: Mon, 29 Dec 2025 12:33:23 +0800 Subject: [PATCH] fix(embedder): add lock protection for cache read operations Protect fast path cache read in get_embedder() to prevent KeyError during concurrent access and cache clearing operations. Solution-ID: SOL-1735392000001 Issue-ID: ISS-1766921318981-2 Task-ID: T1 --- codex-lens/src/codexlens/semantic/embedder.py | 15 ++-- codex-lens/tests/test_embedder.py | 85 +++++++++++++++++++ 2 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 codex-lens/tests/test_embedder.py diff --git a/codex-lens/src/codexlens/semantic/embedder.py b/codex-lens/src/codexlens/semantic/embedder.py index bb1bc856..e2d21717 100644 --- a/codex-lens/src/codexlens/semantic/embedder.py +++ b/codex-lens/src/codexlens/semantic/embedder.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) # Global embedder cache for singleton pattern _embedder_cache: Dict[str, "Embedder"] = {} -_cache_lock = threading.Lock() +_cache_lock = threading.RLock() def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder": @@ -43,15 +43,12 @@ def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder": # Cache key includes GPU preference to support mixed configurations cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}" - # Fast path: check cache without lock - if cache_key in _embedder_cache: - return _embedder_cache[cache_key] - - # Slow path: acquire lock for initialization + # All cache access is protected by _cache_lock to avoid races with + # clear_embedder_cache() during concurrent access. with _cache_lock: - # Double-check after acquiring lock - if cache_key in _embedder_cache: - return _embedder_cache[cache_key] + embedder = _embedder_cache.get(cache_key) + if embedder is not None: + return embedder # Create new embedder and cache it embedder = Embedder(profile=profile, use_gpu=use_gpu) diff --git a/codex-lens/tests/test_embedder.py b/codex-lens/tests/test_embedder.py new file mode 100644 index 00000000..3d6850a1 --- /dev/null +++ b/codex-lens/tests/test_embedder.py @@ -0,0 +1,85 @@ +"""Tests for embedder cache concurrency.""" + +from __future__ import annotations + +import threading +import time + +import pytest + +import codexlens.semantic.embedder as embedder_module + + +def _patch_embedder_for_unit_tests(monkeypatch: pytest.MonkeyPatch) -> None: + """Make get_embedder() tests deterministic and fast (no model downloads).""" + + monkeypatch.setattr(embedder_module, "SEMANTIC_AVAILABLE", True) + monkeypatch.setattr(embedder_module, "get_optimal_providers", lambda *args, **kwargs: []) + monkeypatch.setattr(embedder_module, "is_gpu_available", lambda: False) + monkeypatch.setattr(embedder_module.Embedder, "_load_model", lambda self: None) + + +def test_embedder_instances_are_cached_and_reused(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_embedder_for_unit_tests(monkeypatch) + embedder_module.clear_embedder_cache() + + first = embedder_module.get_embedder(profile="code", use_gpu=False) + second = embedder_module.get_embedder(profile="code", use_gpu=False) + + assert first is second + + +def test_concurrent_cache_access(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_embedder_for_unit_tests(monkeypatch) + embedder_module.clear_embedder_cache() + + profiles = ["fast", "code", "balanced", "multilingual"] + for profile in profiles: + embedder_module.get_embedder(profile=profile, use_gpu=False) + + errors: list[BaseException] = [] + errors_lock = threading.Lock() + + def record_error(err: BaseException) -> None: + with errors_lock: + errors.append(err) + + worker_count = 20 + start_barrier = threading.Barrier(worker_count + 1) + stop_at = time.monotonic() + 1.0 + + def clear_worker() -> None: + try: + start_barrier.wait() + while time.monotonic() < stop_at: + embedder_module.clear_embedder_cache() + time.sleep(0) + except BaseException as err: + record_error(err) + + def access_worker(profile: str) -> None: + try: + start_barrier.wait() + while time.monotonic() < stop_at: + embedder_module.get_embedder(profile=profile, use_gpu=False) + except BaseException as err: + record_error(err) + + threads: list[threading.Thread] = [ + threading.Thread(target=clear_worker, name="clear-embedder-cache"), + ] + for idx in range(worker_count): + threads.append( + threading.Thread( + target=access_worker, + name=f"get-embedder-{idx}", + args=(profiles[idx % len(profiles)],), + ) + ) + + for thread in threads: + thread.start() + for thread in threads: + thread.join(timeout=10) + + assert not errors, f"Unexpected errors during concurrent access: {errors!r}"