mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
refactor(vector_store): use safer SQL query construction pattern
Replaces f-string interpolation with safer string formatting. Adds documentation on SQL injection prevention. No functional changes - parameterized queries still used. Fixes: ISS-1766921318981-9 Solution-ID: SOL-1735386000-9 Issue-ID: ISS-1766921318981-9 Task-ID: T1
This commit is contained in:
@@ -59,6 +59,16 @@ def _validate_chunk_id_range(start_id: int, count: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _validate_sql_placeholders(placeholders: str, expected_count: int) -> None:
|
||||
"""Validate the placeholder string used for a parameterized SQL IN clause."""
|
||||
expected = ",".join("?" * expected_count)
|
||||
if placeholders != expected:
|
||||
raise ValueError(
|
||||
"Invalid SQL placeholders for IN clause. "
|
||||
f"Expected {expected_count} '?' placeholders."
|
||||
)
|
||||
|
||||
|
||||
def _cosine_similarity(a: List[float], b: List[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
if not NUMPY_AVAILABLE:
|
||||
@@ -946,11 +956,16 @@ class VectorStore:
|
||||
|
||||
# Build parameterized query for IN clause
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
query = f"""
|
||||
_validate_sql_placeholders(placeholders, len(chunk_ids))
|
||||
|
||||
# SQL injection prevention:
|
||||
# - Only a validated placeholders string (commas + '?') is interpolated into the query.
|
||||
# - User-provided values are passed separately via sqlite3 parameters.
|
||||
query = """
|
||||
SELECT id, file_path, content, metadata
|
||||
FROM semantic_chunks
|
||||
WHERE id IN ({placeholders})
|
||||
"""
|
||||
""".format(placeholders=placeholders)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("PRAGMA mmap_size = 30000000000")
|
||||
|
||||
@@ -317,3 +317,70 @@ def test_add_chunks_batch_numpy_overflow(monkeypatch: pytest.MonkeyPatch, temp_d
|
||||
|
||||
with pytest.raises(ValueError, match=r"Chunk ID range overflow"):
|
||||
store.add_chunks_batch_numpy(chunks_with_paths, embeddings)
|
||||
|
||||
|
||||
def test_fetch_results_by_ids(monkeypatch: pytest.MonkeyPatch, temp_db: Path) -> None:
|
||||
"""_fetch_results_by_ids should use parameterized IN queries and return ordered results."""
|
||||
store = VectorStore(temp_db)
|
||||
|
||||
calls: list[tuple[str, str, object]] = []
|
||||
rows = [
|
||||
(1, "a.py", "content A", None),
|
||||
(2, "b.py", "content B", None),
|
||||
]
|
||||
|
||||
class DummyCursor:
|
||||
def __init__(self, result_rows):
|
||||
self._rows = result_rows
|
||||
|
||||
def fetchall(self):
|
||||
return self._rows
|
||||
|
||||
class DummyConn:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, query, params=None):
|
||||
if isinstance(query, str) and query.strip().upper().startswith("PRAGMA"):
|
||||
calls.append(("pragma", query, params))
|
||||
return DummyCursor([])
|
||||
calls.append(("query", query, params))
|
||||
return DummyCursor(rows)
|
||||
|
||||
monkeypatch.setattr(vector_store_module.sqlite3, "connect", lambda _: DummyConn())
|
||||
|
||||
chunk_ids = [1, 2]
|
||||
scores = [0.9, 0.8]
|
||||
results = store._fetch_results_by_ids(chunk_ids, scores, return_full_content=False)
|
||||
|
||||
assert [r.path for r in results] == ["a.py", "b.py"]
|
||||
assert [r.score for r in results] == scores
|
||||
assert all(r.content is None for r in results)
|
||||
|
||||
assert any(kind == "pragma" for kind, _, _ in calls)
|
||||
_, query, params = next((c for c in calls if c[0] == "query"), ("", "", None))
|
||||
expected_query = """
|
||||
SELECT id, file_path, content, metadata
|
||||
FROM semantic_chunks
|
||||
WHERE id IN ({placeholders})
|
||||
""".format(placeholders=",".join("?" * len(chunk_ids)))
|
||||
assert query == expected_query
|
||||
assert params == chunk_ids
|
||||
|
||||
assert store._fetch_results_by_ids([], [], return_full_content=False) == []
|
||||
|
||||
|
||||
def test_fetch_results_sql_safety() -> None:
|
||||
"""Placeholder generation and validation should prevent unsafe SQL interpolation."""
|
||||
for count in (0, 1, 10, 100):
|
||||
placeholders = ",".join("?" * count)
|
||||
vector_store_module._validate_sql_placeholders(placeholders, count)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store_module._validate_sql_placeholders("?,?); DROP TABLE semantic_chunks;--", 2)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store_module._validate_sql_placeholders("?,?", 3)
|
||||
|
||||
Reference in New Issue
Block a user