mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
- Implemented BinaryEmbeddingBackend for fast coarse filtering using 256-dimensional binary vectors. - Developed DenseEmbeddingBackend for high-precision dense vectors (2048 dimensions) for reranking. - Created CascadeEmbeddingBackend to combine binary and dense embeddings for two-stage retrieval. - Introduced utility functions for embedding conversion and distance computation. chore: Migration 010 - Add multi-vector storage support - Added 'chunks' table to support multi-vector embeddings for cascade retrieval. - Included new columns: embedding_binary (256-dim) and embedding_dense (2048-dim) for efficient storage. - Implemented upgrade and downgrade functions to manage schema changes and data migration.
445 lines
15 KiB
Python
445 lines
15 KiB
Python
"""Tests for SQLiteStore connection pool behavior."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from codexlens.entities import IndexedFile
|
|
from codexlens.storage.sqlite_store import SQLiteStore
|
|
|
|
|
|
def test_periodic_cleanup(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
"""Periodic timer should proactively clean up stale thread connections."""
|
|
monkeypatch.setattr(SQLiteStore, "CLEANUP_INTERVAL", 0.2)
|
|
|
|
store = SQLiteStore(tmp_path / "periodic_cleanup.db")
|
|
store.initialize()
|
|
|
|
cleanup_called = threading.Event()
|
|
original_cleanup = store._cleanup_stale_connections
|
|
|
|
def wrapped_cleanup() -> None:
|
|
cleanup_called.set()
|
|
original_cleanup()
|
|
|
|
monkeypatch.setattr(store, "_cleanup_stale_connections", wrapped_cleanup)
|
|
|
|
created: list[int] = []
|
|
lock = threading.Lock()
|
|
main_tid = threading.get_ident()
|
|
|
|
def worker() -> None:
|
|
store._get_connection()
|
|
with lock:
|
|
created.append(threading.get_ident())
|
|
|
|
try:
|
|
threads = [threading.Thread(target=worker) for _ in range(10)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Ensure we created thread-local connections without reaching MAX_POOL_SIZE.
|
|
assert len(store._pool) >= 2
|
|
assert all(tid in store._pool for tid in created)
|
|
|
|
# Wait for periodic cleanup to run and prune dead thread connections.
|
|
assert cleanup_called.wait(timeout=3)
|
|
deadline = time.time() + 3
|
|
while time.time() < deadline and any(tid in store._pool for tid in created):
|
|
time.sleep(0.05)
|
|
|
|
assert all(tid not in store._pool for tid in created)
|
|
assert set(store._pool.keys()).issubset({main_tid})
|
|
finally:
|
|
store.close()
|
|
|
|
|
|
def test_cleanup_robustness(
|
|
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, tmp_path: Path
|
|
) -> None:
|
|
"""Cleanup should handle dead threads, idle timeouts, and invalid connections."""
|
|
monkeypatch.setattr(SQLiteStore, "CLEANUP_INTERVAL", 0)
|
|
caplog.set_level(logging.DEBUG, logger="codexlens.storage.sqlite_store")
|
|
|
|
store = SQLiteStore(tmp_path / "cleanup_robustness.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
# Invalid connection: active thread but pooled connection is already closed.
|
|
conn = store._get_connection()
|
|
conn.close()
|
|
with store._pool_lock:
|
|
store._pool[threading.get_ident()] = (conn, time.time())
|
|
store._cleanup_stale_connections()
|
|
|
|
assert "invalid_connection" in caplog.text
|
|
assert threading.get_ident() not in store._pool
|
|
|
|
# Ensure next access recreates a working connection after cleanup.
|
|
fresh_conn = store._get_connection()
|
|
assert fresh_conn is not conn
|
|
|
|
# Idle timeout cleanup should be logged distinctly.
|
|
with store._pool_lock:
|
|
store._pool[threading.get_ident()] = (fresh_conn, time.time() - store.IDLE_TIMEOUT - 1)
|
|
store._cleanup_stale_connections()
|
|
|
|
assert "idle_timeout" in caplog.text
|
|
assert threading.get_ident() not in store._pool
|
|
|
|
# Dead thread cleanup should be logged distinctly.
|
|
created: list[int] = []
|
|
|
|
def worker() -> None:
|
|
store._get_connection()
|
|
created.append(threading.get_ident())
|
|
|
|
t = threading.Thread(target=worker)
|
|
t.start()
|
|
t.join()
|
|
|
|
dead_tid = created[0]
|
|
assert dead_tid in store._pool
|
|
with store._pool_lock:
|
|
store._cleanup_stale_connections()
|
|
|
|
assert "dead_thread" in caplog.text
|
|
assert dead_tid not in store._pool
|
|
finally:
|
|
store.close()
|
|
|
|
|
|
def test_add_files_rollback_preserves_original_exception(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""add_files should re-raise the transaction error when rollback succeeds."""
|
|
monkeypatch.setattr(SQLiteStore, "CLEANUP_INTERVAL", 0)
|
|
store = SQLiteStore(tmp_path / "add_files_ok.db")
|
|
store.initialize()
|
|
|
|
real_conn = store._get_connection()
|
|
|
|
class FailingConnection:
|
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
|
self._conn = conn
|
|
self.rollback_calls = 0
|
|
|
|
def execute(self, sql: str, params: tuple = ()):
|
|
if "INSERT INTO files" in sql:
|
|
raise sqlite3.OperationalError("boom")
|
|
return self._conn.execute(sql, params)
|
|
|
|
def executemany(self, sql: str, seq):
|
|
return self._conn.executemany(sql, seq)
|
|
|
|
def commit(self) -> None:
|
|
self._conn.commit()
|
|
|
|
def rollback(self) -> None:
|
|
self.rollback_calls += 1
|
|
self._conn.rollback()
|
|
|
|
wrapped = FailingConnection(real_conn)
|
|
monkeypatch.setattr(store, "_get_connection", lambda: wrapped)
|
|
|
|
indexed_file = IndexedFile(path=str(tmp_path / "a.py"), language="python", symbols=[])
|
|
|
|
try:
|
|
with pytest.raises(sqlite3.OperationalError, match="boom"):
|
|
store.add_files([(indexed_file, "# content")])
|
|
assert wrapped.rollback_calls == 1
|
|
finally:
|
|
store.close()
|
|
|
|
|
|
def test_add_files_rollback_failure_is_chained(
|
|
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
|
) -> None:
|
|
"""Rollback failures should be logged and chained as the cause."""
|
|
monkeypatch.setattr(SQLiteStore, "CLEANUP_INTERVAL", 0)
|
|
caplog.set_level(logging.ERROR, logger="codexlens.storage.sqlite_store")
|
|
|
|
store = SQLiteStore(tmp_path / "add_files_rollback_fail.db")
|
|
store.initialize()
|
|
real_conn = store._get_connection()
|
|
|
|
class FailingRollbackConnection:
|
|
def __init__(self, conn: sqlite3.Connection) -> None:
|
|
self._conn = conn
|
|
|
|
def execute(self, sql: str, params: tuple = ()):
|
|
if "INSERT INTO files" in sql:
|
|
raise sqlite3.OperationalError("boom")
|
|
return self._conn.execute(sql, params)
|
|
|
|
def executemany(self, sql: str, seq):
|
|
return self._conn.executemany(sql, seq)
|
|
|
|
def commit(self) -> None:
|
|
self._conn.commit()
|
|
|
|
def rollback(self) -> None:
|
|
raise sqlite3.OperationalError("rollback boom")
|
|
|
|
monkeypatch.setattr(store, "_get_connection", lambda: FailingRollbackConnection(real_conn))
|
|
indexed_file = IndexedFile(path=str(tmp_path / "b.py"), language="python", symbols=[])
|
|
|
|
try:
|
|
with pytest.raises(sqlite3.OperationalError) as exc:
|
|
store.add_files([(indexed_file, "# content")])
|
|
|
|
assert exc.value.__cause__ is not None
|
|
assert isinstance(exc.value.__cause__, sqlite3.OperationalError)
|
|
assert "rollback boom" in str(exc.value.__cause__)
|
|
assert "Rollback failed after add_files() error" in caplog.text
|
|
assert "boom" in caplog.text
|
|
finally:
|
|
store.close()
|
|
|
|
|
|
class TestMultiVectorChunks:
|
|
"""Tests for multi-vector chunk storage operations."""
|
|
|
|
def test_add_chunks_basic(self, tmp_path: Path) -> None:
|
|
"""Basic chunk insertion without embeddings."""
|
|
store = SQLiteStore(tmp_path / "chunks_basic.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
chunks_data = [
|
|
{"content": "def hello(): pass", "metadata": {"type": "function"}},
|
|
{"content": "class World: pass", "metadata": {"type": "class"}},
|
|
]
|
|
|
|
ids = store.add_chunks("test.py", chunks_data)
|
|
|
|
assert len(ids) == 2
|
|
assert ids == [1, 2]
|
|
assert store.count_chunks() == 2
|
|
finally:
|
|
store.close()
|
|
|
|
def test_add_chunks_with_binary_embeddings(self, tmp_path: Path) -> None:
|
|
"""Chunk insertion with binary embeddings for coarse ranking."""
|
|
store = SQLiteStore(tmp_path / "chunks_binary.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
chunks_data = [
|
|
{"content": "content1"},
|
|
{"content": "content2"},
|
|
]
|
|
# 256-bit binary = 32 bytes
|
|
binary_embs = [b"\x00" * 32, b"\xff" * 32]
|
|
|
|
ids = store.add_chunks(
|
|
"test.py", chunks_data, embedding_binary=binary_embs
|
|
)
|
|
|
|
assert len(ids) == 2
|
|
|
|
retrieved = store.get_binary_embeddings(ids)
|
|
assert len(retrieved) == 2
|
|
assert retrieved[ids[0]] == b"\x00" * 32
|
|
assert retrieved[ids[1]] == b"\xff" * 32
|
|
finally:
|
|
store.close()
|
|
|
|
def test_add_chunks_with_dense_embeddings(self, tmp_path: Path) -> None:
|
|
"""Chunk insertion with dense embeddings for fine ranking."""
|
|
store = SQLiteStore(tmp_path / "chunks_dense.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
chunks_data = [{"content": "content1"}, {"content": "content2"}]
|
|
# 2048 floats = 8192 bytes
|
|
dense_embs = [b"\x00" * 8192, b"\xff" * 8192]
|
|
|
|
ids = store.add_chunks(
|
|
"test.py", chunks_data, embedding_dense=dense_embs
|
|
)
|
|
|
|
assert len(ids) == 2
|
|
|
|
retrieved = store.get_dense_embeddings(ids)
|
|
assert len(retrieved) == 2
|
|
assert retrieved[ids[0]] == b"\x00" * 8192
|
|
assert retrieved[ids[1]] == b"\xff" * 8192
|
|
finally:
|
|
store.close()
|
|
|
|
def test_add_chunks_with_all_embeddings(self, tmp_path: Path) -> None:
|
|
"""Chunk insertion with all embedding types."""
|
|
store = SQLiteStore(tmp_path / "chunks_all.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
chunks_data = [{"content": "full test"}]
|
|
embedding = [[0.1, 0.2, 0.3]]
|
|
binary_embs = [b"\xab" * 32]
|
|
dense_embs = [b"\xcd" * 8192]
|
|
|
|
ids = store.add_chunks(
|
|
"test.py",
|
|
chunks_data,
|
|
embedding=embedding,
|
|
embedding_binary=binary_embs,
|
|
embedding_dense=dense_embs,
|
|
)
|
|
|
|
assert len(ids) == 1
|
|
|
|
binary = store.get_binary_embeddings(ids)
|
|
dense = store.get_dense_embeddings(ids)
|
|
|
|
assert binary[ids[0]] == b"\xab" * 32
|
|
assert dense[ids[0]] == b"\xcd" * 8192
|
|
finally:
|
|
store.close()
|
|
|
|
def test_add_chunks_length_mismatch_raises(self, tmp_path: Path) -> None:
|
|
"""Mismatched embedding length should raise ValueError."""
|
|
store = SQLiteStore(tmp_path / "chunks_mismatch.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
chunks_data = [{"content": "a"}, {"content": "b"}]
|
|
|
|
with pytest.raises(ValueError, match="embedding_binary length"):
|
|
store.add_chunks(
|
|
"test.py", chunks_data, embedding_binary=[b"\x00" * 32]
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="embedding_dense length"):
|
|
store.add_chunks(
|
|
"test.py", chunks_data, embedding_dense=[b"\x00" * 8192]
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="embedding length"):
|
|
store.add_chunks(
|
|
"test.py", chunks_data, embedding=[[0.1]]
|
|
)
|
|
finally:
|
|
store.close()
|
|
|
|
def test_get_chunks_by_ids(self, tmp_path: Path) -> None:
|
|
"""Retrieve chunk data by IDs."""
|
|
store = SQLiteStore(tmp_path / "chunks_get.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
chunks_data = [
|
|
{"content": "def foo(): pass", "metadata": {"line": 1}},
|
|
{"content": "def bar(): pass", "metadata": {"line": 5}},
|
|
]
|
|
|
|
ids = store.add_chunks("test.py", chunks_data)
|
|
retrieved = store.get_chunks_by_ids(ids)
|
|
|
|
assert len(retrieved) == 2
|
|
assert retrieved[0]["content"] == "def foo(): pass"
|
|
assert retrieved[0]["metadata"]["line"] == 1
|
|
assert retrieved[1]["content"] == "def bar(): pass"
|
|
assert retrieved[1]["file_path"] == "test.py"
|
|
finally:
|
|
store.close()
|
|
|
|
def test_delete_chunks_by_file(self, tmp_path: Path) -> None:
|
|
"""Delete all chunks for a file."""
|
|
store = SQLiteStore(tmp_path / "chunks_delete.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
store.add_chunks("a.py", [{"content": "a1"}, {"content": "a2"}])
|
|
store.add_chunks("b.py", [{"content": "b1"}])
|
|
|
|
assert store.count_chunks() == 3
|
|
|
|
deleted = store.delete_chunks_by_file("a.py")
|
|
assert deleted == 2
|
|
assert store.count_chunks() == 1
|
|
|
|
deleted = store.delete_chunks_by_file("nonexistent.py")
|
|
assert deleted == 0
|
|
finally:
|
|
store.close()
|
|
|
|
def test_get_embeddings_empty_list(self, tmp_path: Path) -> None:
|
|
"""Empty chunk ID list returns empty dict."""
|
|
store = SQLiteStore(tmp_path / "chunks_empty.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
assert store.get_binary_embeddings([]) == {}
|
|
assert store.get_dense_embeddings([]) == {}
|
|
assert store.get_chunks_by_ids([]) == []
|
|
finally:
|
|
store.close()
|
|
|
|
def test_add_chunks_empty_list(self, tmp_path: Path) -> None:
|
|
"""Empty chunks list returns empty IDs."""
|
|
store = SQLiteStore(tmp_path / "chunks_empty_add.db")
|
|
store.initialize()
|
|
|
|
try:
|
|
ids = store.add_chunks("test.py", [])
|
|
assert ids == []
|
|
assert store.count_chunks() == 0
|
|
finally:
|
|
store.close()
|
|
|
|
def test_chunks_table_migration(self, tmp_path: Path) -> None:
|
|
"""Existing chunks table gets new columns via migration."""
|
|
db_path = tmp_path / "chunks_migration.db"
|
|
|
|
# Create old schema without multi-vector columns
|
|
conn = sqlite3.connect(db_path)
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE chunks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
file_path TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
embedding BLOB,
|
|
metadata TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"""
|
|
)
|
|
conn.execute("CREATE INDEX idx_chunks_file_path ON chunks(file_path)")
|
|
conn.execute(
|
|
"INSERT INTO chunks (file_path, content) VALUES ('old.py', 'old content')"
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
# Open with SQLiteStore - should migrate
|
|
store = SQLiteStore(db_path)
|
|
store.initialize()
|
|
|
|
try:
|
|
# Verify new columns exist by using them
|
|
ids = store.add_chunks(
|
|
"new.py",
|
|
[{"content": "new content"}],
|
|
embedding_binary=[b"\x00" * 32],
|
|
embedding_dense=[b"\x00" * 8192],
|
|
)
|
|
|
|
assert len(ids) == 1
|
|
|
|
# Old data should still be accessible
|
|
assert store.count_chunks() == 2
|
|
|
|
# New embeddings should work
|
|
binary = store.get_binary_embeddings(ids)
|
|
assert binary[ids[0]] == b"\x00" * 32
|
|
finally:
|
|
store.close()
|