diff --git a/codex-lens/tests/test_storage_concurrency.py b/codex-lens/tests/test_storage_concurrency.py index 7338c6de..6f2608a5 100644 --- a/codex-lens/tests/test_storage_concurrency.py +++ b/codex-lens/tests/test_storage_concurrency.py @@ -67,6 +67,16 @@ def dir_index_store(tmp_path): store.close() +@pytest.fixture() +def writable_store(tmp_path): + """Create a fresh SQLiteStore for concurrent write tests.""" + db_path = tmp_path / "writes.db" + store = SQLiteStore(db_path) + store.initialize() + yield store + store.close() + + class TestConcurrentReads: """Concurrent read tests for storage managers.""" @@ -345,3 +355,248 @@ class TestConcurrentReads: assert not errors assert results == [10] * 10 + + +class TestConcurrentWrites: + """Concurrent write tests for SQLiteStore.""" + + def test_concurrent_inserts_commit_all_rows(self, writable_store): + """Concurrent inserts from 10 threads should commit all rows.""" + thread_count = 10 + files_per_thread = 10 + errors = [] + lock = threading.Lock() + + def worker(thread_index: int): + try: + for i in range(files_per_thread): + path = f"/write/thread_{thread_index}/file_{i}.py" + indexed_file = IndexedFile( + path=path, + language="python", + symbols=[Symbol(name=f"sym_{thread_index}_{i}", kind="function", range=(1, 1))], + ) + content = f"# write_token_{thread_index}_{i}\nprint({i})\n" + writable_store.add_file(indexed_file, content) + except Exception as exc: + with lock: + errors.append(exc) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(thread_count)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + stats = writable_store.stats() + assert stats["files"] == thread_count * files_per_thread + assert stats["symbols"] == thread_count * files_per_thread + + def test_concurrent_updates_same_file_serializes(self, writable_store): + """Concurrent updates to the same file should serialize and not lose writes.""" + target_path = "/write/shared.py" + base = IndexedFile( + path=target_path, + language="python", + symbols=[Symbol(name="base", kind="function", range=(1, 1))], + ) + writable_store.add_file(base, "print('base')\n") + + update_contents = [] + errors = [] + lock = threading.Lock() + + def worker(version: int): + try: + content = f"print('v{version}')\n" + indexed_file = IndexedFile( + path=target_path, + language="python", + symbols=[Symbol(name=f"v{version}", kind="function", range=(1, 1))], + ) + writable_store.add_file(indexed_file, content) + with lock: + update_contents.append(content) + except Exception as exc: + with lock: + errors.append(exc) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + + resolved = str(Path(target_path).resolve()) + rows = writable_store.execute_query("SELECT content FROM files WHERE path=?", (resolved,)) + assert len(rows) == 1 + assert rows[0]["content"] in set(update_contents) + + def test_wal_mode_is_active_for_thread_connections(self, writable_store): + """PRAGMA journal_mode should be WAL for all thread-local connections.""" + modes = [] + errors = [] + lock = threading.Lock() + + def worker(): + try: + conn = writable_store._get_connection() + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + with lock: + modes.append(str(mode).lower()) + except Exception as exc: + with lock: + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert modes + assert all(mode == "wal" for mode in modes) + + def test_transaction_isolation_reader_sees_committed_state(self, writable_store): + """Readers should not see uncommitted writer updates and should not block.""" + target_path = "/write/isolation.py" + indexed_file = IndexedFile(path=target_path, language="python", symbols=[]) + writable_store.add_file(indexed_file, "print('original')\n") + resolved = str(Path(target_path).resolve()) + + writer_started = threading.Event() + reader_done = threading.Event() + errors = [] + lock = threading.Lock() + observed = {"reader": None} + updated_content = "print('updated')\n" + + def writer(): + try: + conn = writable_store._get_connection() + conn.execute("BEGIN IMMEDIATE") + conn.execute( + "UPDATE files SET content=? WHERE path=?", + (updated_content, resolved), + ) + writer_started.set() + reader_done.wait(timeout=5) + conn.commit() + except Exception as exc: + with lock: + errors.append(exc) + + def reader(): + try: + writer_started.wait(timeout=5) + conn = writable_store._get_connection() + row = conn.execute("SELECT content FROM files WHERE path=?", (resolved,)).fetchone() + observed["reader"] = row[0] if row else None + reader_done.set() + except Exception as exc: + with lock: + errors.append(exc) + + threads = [threading.Thread(target=writer), threading.Thread(target=reader)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + assert observed["reader"] == "print('original')\n" + + rows = writable_store.execute_query("SELECT content FROM files WHERE path=?", (resolved,)) + assert rows[0]["content"] == updated_content + + def test_batch_insert_performance_and_counts(self, writable_store): + """Batch inserts across threads should not lose rows.""" + thread_count = 10 + files_per_thread = 100 + errors = [] + lock = threading.Lock() + + def worker(thread_index: int): + try: + files = [] + for i in range(files_per_thread): + path = f"/write/batch_{thread_index}/file_{i}.py" + indexed_file = IndexedFile( + path=path, + language="python", + symbols=[ + Symbol(name=f"sym_{thread_index}_{i}", kind="function", range=(1, 1)) + ], + ) + content = f"# batch_token_{thread_index}_{i}\nprint({i})\n" + files.append((indexed_file, content)) + + writable_store.add_files(files) + except Exception as exc: + with lock: + errors.append(exc) + + start = time.time() + threads = [threading.Thread(target=worker, args=(i,)) for i in range(thread_count)] + for t in threads: + t.start() + for t in threads: + t.join() + duration = max(time.time() - start, 1e-6) + + assert not errors + stats = writable_store.stats() + assert stats["files"] == thread_count * files_per_thread + assert stats["symbols"] == thread_count * files_per_thread + assert (thread_count * files_per_thread) / duration > 0 + + def test_mixed_read_write_operations_no_errors(self, writable_store): + """Mixed reader and writer threads should complete without exceptions.""" + writer_threads = 5 + reader_threads = 10 + writes_per_writer = 20 + reads_per_reader = 50 + + errors = [] + lock = threading.Lock() + target_paths = [ + f"/write/mixed_{w}/file_{i}.py" + for w in range(writer_threads) + for i in range(writes_per_writer) + ] + + def writer(worker_index: int): + try: + for i in range(writes_per_writer): + path = f"/write/mixed_{worker_index}/file_{i}.py" + indexed_file = IndexedFile(path=path, language="python", symbols=[]) + writable_store.add_file(indexed_file, f"# mixed\nprint({i})\n") + except Exception as exc: + with lock: + errors.append(exc) + + def reader(worker_index: int): + try: + for i in range(reads_per_reader): + path = target_paths[(worker_index + i) % len(target_paths)] + writable_store.file_exists(path) + except Exception as exc: + with lock: + errors.append(exc) + + threads = [ + *[threading.Thread(target=writer, args=(i,)) for i in range(writer_threads)], + *[threading.Thread(target=reader, args=(i,)) for i in range(reader_threads)], + ] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + stats = writable_store.stats() + assert stats["files"] == writer_threads * writes_per_writer