Add comprehensive tests for semantic chunking and search functionality

- Implemented tests for the ChunkConfig and Chunker classes, covering default and custom configurations.
- Added tests for symbol-based chunking, including single and multiple symbols, handling of empty symbols, and preservation of line numbers.
- Developed tests for sliding window chunking, ensuring correct chunking behavior with various content sizes and configurations.
- Created integration tests for semantic search, validating embedding generation, vector storage, and search accuracy across a complex codebase.
- Included performance tests for embedding generation and search operations.
- Established tests for chunking strategies, comparing symbol-based and sliding window approaches.
- Enhanced test coverage for edge cases, including handling of unicode characters and out-of-bounds symbol ranges.
This commit is contained in:
catlog22
2025-12-12 19:55:35 +08:00
parent c42f91a7fe
commit 4faa5f1c95
27 changed files with 4812 additions and 129 deletions

BIN
codex-lens/.coverage Normal file

Binary file not shown.

View File

@@ -24,9 +24,9 @@ dependencies = [
]
[project.optional-dependencies]
# Semantic search using fastembed (ONNX-based, lightweight ~200MB)
semantic = [
"numpy>=1.24",
"sentence-transformers>=2.2",
"fastembed>=0.2",
]

View File

@@ -67,7 +67,14 @@ class SearchResult(BaseModel):
path: str = Field(..., min_length=1)
score: float = Field(..., ge=0.0)
excerpt: Optional[str] = None
content: Optional[str] = Field(default=None, description="Full content of matched code block")
symbol: Optional[Symbol] = None
chunk: Optional[SemanticChunk] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
# Additional context for complete code blocks
start_line: Optional[int] = Field(default=None, description="Start line of code block (1-based)")
end_line: Optional[int] = Field(default=None, description="End line of code block (1-based)")
symbol_name: Optional[str] = Field(default=None, description="Name of matched symbol/function/class")
symbol_kind: Optional[str] = Field(default=None, description="Kind of symbol (function/class/method)")

View File

@@ -1,28 +1,32 @@
"""Optional semantic search module for CodexLens.
Install with: pip install codexlens[semantic]
Uses fastembed (ONNX-based, lightweight ~200MB)
"""
from __future__ import annotations
SEMANTIC_AVAILABLE = False
SEMANTIC_BACKEND: str | None = None
_import_error: str | None = None
try:
import numpy as np
def _detect_backend() -> tuple[bool, str | None, str | None]:
"""Detect if fastembed is available."""
try:
import numpy as np
except ImportError as e:
return False, None, f"numpy not available: {e}"
try:
from fastembed import TextEmbedding
SEMANTIC_BACKEND = "fastembed"
return True, "fastembed", None
except ImportError:
try:
from sentence_transformers import SentenceTransformer
SEMANTIC_BACKEND = "sentence-transformers"
except ImportError:
raise ImportError("Neither fastembed nor sentence-transformers available")
SEMANTIC_AVAILABLE = True
except ImportError as e:
_import_error = str(e)
SEMANTIC_BACKEND = None
pass
return False, None, "fastembed not available. Install with: pip install codexlens[semantic]"
# Initialize on module load
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, _import_error = _detect_backend()
def check_semantic_available() -> tuple[bool, str | None]:
"""Check if semantic search dependencies are available."""

View File

@@ -0,0 +1,274 @@
"""Smart code extraction for complete code blocks."""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import SearchResult, Symbol
def extract_complete_code_block(
result: SearchResult,
source_file_path: Optional[str] = None,
context_lines: int = 0,
) -> str:
"""Extract complete code block from a search result.
Args:
result: SearchResult from semantic search.
source_file_path: Optional path to source file for re-reading.
context_lines: Additional lines of context to include above/below.
Returns:
Complete code block as string.
"""
# If we have full content stored, use it
if result.content:
if context_lines == 0:
return result.content
# Need to add context, read from file
# Try to read from source file
file_path = source_file_path or result.path
if not file_path or not Path(file_path).exists():
# Fall back to excerpt
return result.excerpt or ""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
# Get line range
start_line = result.start_line or 1
end_line = result.end_line or len(lines)
# Add context
start_idx = max(0, start_line - 1 - context_lines)
end_idx = min(len(lines), end_line + context_lines)
return "\n".join(lines[start_idx:end_idx])
except Exception:
return result.excerpt or result.content or ""
def extract_symbol_with_context(
file_path: str,
symbol: Symbol,
include_docstring: bool = True,
include_decorators: bool = True,
) -> str:
"""Extract a symbol (function/class) with its docstring and decorators.
Args:
file_path: Path to source file.
symbol: Symbol to extract.
include_docstring: Include docstring if present.
include_decorators: Include decorators/annotations above symbol.
Returns:
Complete symbol code with context.
"""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
start_line, end_line = symbol.range
start_idx = start_line - 1
end_idx = end_line
# Look for decorators above the symbol
if include_decorators and start_idx > 0:
decorator_start = start_idx
# Search backwards for decorators
i = start_idx - 1
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
line = lines[i].strip()
if line.startswith("@"):
decorator_start = i
i -= 1
elif line == "" or line.startswith("#"):
# Skip empty lines and comments, continue looking
i -= 1
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
# JavaScript/Java style comments
decorator_start = i
i -= 1
else:
# Found non-decorator, non-comment line, stop
break
start_idx = decorator_start
return "\n".join(lines[start_idx:end_idx])
except Exception:
return ""
def format_search_result_code(
result: SearchResult,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
highlight_match: bool = False,
) -> str:
"""Format search result code for display.
Args:
result: SearchResult to format.
max_lines: Maximum lines to show (None for all).
show_line_numbers: Include line numbers in output.
highlight_match: Add markers for matched region.
Returns:
Formatted code string.
"""
content = result.content or result.excerpt or ""
if not content:
return ""
lines = content.splitlines()
# Truncate if needed
truncated = False
if max_lines and len(lines) > max_lines:
lines = lines[:max_lines]
truncated = True
# Format with line numbers
if show_line_numbers:
start = result.start_line or 1
formatted_lines = []
for i, line in enumerate(lines):
line_num = start + i
formatted_lines.append(f"{line_num:4d} | {line}")
output = "\n".join(formatted_lines)
else:
output = "\n".join(lines)
if truncated:
output += "\n... (truncated)"
return output
def get_code_block_summary(result: SearchResult) -> str:
"""Get a concise summary of a code block.
Args:
result: SearchResult to summarize.
Returns:
Summary string like "function hello_world (lines 10-25)"
"""
parts = []
if result.symbol_kind:
parts.append(result.symbol_kind)
if result.symbol_name:
parts.append(f"`{result.symbol_name}`")
elif result.excerpt:
# Extract first meaningful identifier
first_line = result.excerpt.split("\n")[0][:50]
parts.append(f'"{first_line}..."')
if result.start_line and result.end_line:
if result.start_line == result.end_line:
parts.append(f"(line {result.start_line})")
else:
parts.append(f"(lines {result.start_line}-{result.end_line})")
if result.path:
file_name = Path(result.path).name
parts.append(f"in {file_name}")
return " ".join(parts) if parts else "unknown code block"
class CodeBlockResult:
"""Enhanced search result with complete code block."""
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
self.result = result
self.source_path = source_path or result.path
self._full_code: Optional[str] = None
@property
def score(self) -> float:
return self.result.score
@property
def path(self) -> str:
return self.result.path
@property
def file_name(self) -> str:
return Path(self.result.path).name
@property
def symbol_name(self) -> Optional[str]:
return self.result.symbol_name
@property
def symbol_kind(self) -> Optional[str]:
return self.result.symbol_kind
@property
def line_range(self) -> Tuple[int, int]:
return (
self.result.start_line or 1,
self.result.end_line or 1
)
@property
def full_code(self) -> str:
"""Get full code block content."""
if self._full_code is None:
self._full_code = extract_complete_code_block(self.result, self.source_path)
return self._full_code
@property
def excerpt(self) -> str:
"""Get short excerpt."""
return self.result.excerpt or ""
@property
def summary(self) -> str:
"""Get code block summary."""
return get_code_block_summary(self.result)
def format(
self,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
) -> str:
"""Format code for display."""
# Use full code if available
display_result = SearchResult(
path=self.result.path,
score=self.result.score,
content=self.full_code,
start_line=self.result.start_line,
end_line=self.result.end_line,
)
return format_search_result_code(
display_result,
max_lines=max_lines,
show_line_numbers=show_line_numbers
)
def __repr__(self) -> str:
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
def enhance_search_results(
results: List[SearchResult],
) -> List[CodeBlockResult]:
"""Enhance search results with complete code block access.
Args:
results: List of SearchResult from semantic search.
Returns:
List of CodeBlockResult with full code access.
"""
return [CodeBlockResult(r) for r in results]

View File

@@ -1,17 +1,14 @@
"""Embedder for semantic code search."""
"""Embedder for semantic code search using fastembed."""
from __future__ import annotations
from typing import Iterable, List
from . import SEMANTIC_AVAILABLE, SEMANTIC_BACKEND
if SEMANTIC_AVAILABLE:
import numpy as np
from . import SEMANTIC_AVAILABLE
class Embedder:
"""Generate embeddings for code chunks using fastembed or sentence-transformers."""
"""Generate embeddings for code chunks using fastembed (ONNX-based)."""
MODEL_NAME = "BAAI/bge-small-en-v1.5"
EMBEDDING_DIM = 384
@@ -25,19 +22,14 @@ class Embedder:
self.model_name = model_name or self.MODEL_NAME
self._model = None
self._backend = SEMANTIC_BACKEND
def _load_model(self) -> None:
"""Lazy load the embedding model."""
if self._model is not None:
return
if self._backend == "fastembed":
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name=self.model_name)
else:
from sentence_transformers import SentenceTransformer
self._model = SentenceTransformer(self.model_name)
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name=self.model_name)
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
"""Generate embeddings for one or more texts.
@@ -55,12 +47,8 @@ class Embedder:
else:
texts = list(texts)
if self._backend == "fastembed":
embeddings = list(self._model.embed(texts))
return [emb.tolist() for emb in embeddings]
else:
embeddings = self._model.encode(texts)
return embeddings.tolist()
embeddings = list(self._model.embed(texts))
return [emb.tolist() for emb in embeddings]
def embed_single(self, text: str) -> List[float]:
"""Generate embedding for a single text."""

View File

@@ -119,6 +119,7 @@ class VectorStore:
query_embedding: List[float],
top_k: int = 10,
min_score: float = 0.0,
return_full_content: bool = True,
) -> List[SearchResult]:
"""Find chunks most similar to query embedding.
@@ -126,6 +127,7 @@ class VectorStore:
query_embedding: Query vector.
top_k: Maximum results to return.
min_score: Minimum similarity score (0-1).
return_full_content: If True, return full code block content.
Returns:
List of SearchResult ordered by similarity (highest first).
@@ -144,14 +146,39 @@ class VectorStore:
if score >= min_score:
metadata = json.loads(metadata_json) if metadata_json else {}
# Build excerpt
# Build excerpt (short preview)
excerpt = content[:200] + "..." if len(content) > 200 else content
# Extract symbol information from metadata
symbol_name = metadata.get("symbol_name")
symbol_kind = metadata.get("symbol_kind")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
# Build Symbol object if we have symbol info
symbol = None
if symbol_name and symbol_kind and start_line and end_line:
try:
from codexlens.entities import Symbol
symbol = Symbol(
name=symbol_name,
kind=symbol_kind,
range=(start_line, end_line)
)
except Exception:
pass
results.append((score, SearchResult(
path=file_path,
score=score,
excerpt=excerpt,
symbol=None,
content=content if return_full_content else None,
symbol=symbol,
metadata=metadata,
start_line=start_line,
end_line=end_line,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
)))
# Sort by score descending

View File

@@ -0,0 +1,280 @@
"""Tests for CodexLens CLI output functions."""
import json
from dataclasses import dataclass
from io import StringIO
from pathlib import Path
from unittest.mock import patch
import pytest
from rich.console import Console
from codexlens.cli.output import (
_to_jsonable,
print_json,
render_file_inspect,
render_search_results,
render_status,
render_symbols,
)
from codexlens.entities import SearchResult, Symbol
class TestToJsonable:
"""Tests for _to_jsonable helper function."""
def test_none_value(self):
"""Test converting None."""
assert _to_jsonable(None) is None
def test_primitive_values(self):
"""Test converting primitive values."""
assert _to_jsonable("string") == "string"
assert _to_jsonable(42) == 42
assert _to_jsonable(3.14) == 3.14
assert _to_jsonable(True) is True
def test_path_conversion(self):
"""Test converting Path to string."""
path = Path("/test/file.py")
result = _to_jsonable(path)
assert result == str(path)
def test_dict_conversion(self):
"""Test converting dict with nested values."""
data = {"key": "value", "path": Path("/test.py"), "nested": {"a": 1}}
result = _to_jsonable(data)
assert result["key"] == "value"
# Path conversion uses str(), which may differ by OS
assert result["path"] == str(Path("/test.py"))
assert result["nested"]["a"] == 1
def test_list_conversion(self):
"""Test converting list with various items."""
data = ["string", 42, Path("/test.py")]
result = _to_jsonable(data)
assert result == ["string", 42, str(Path("/test.py"))]
def test_tuple_conversion(self):
"""Test converting tuple."""
data = ("a", "b", Path("/test.py"))
result = _to_jsonable(data)
assert result == ["a", "b", str(Path("/test.py"))]
def test_set_conversion(self):
"""Test converting set."""
data = {1, 2, 3}
result = _to_jsonable(data)
assert set(result) == {1, 2, 3}
def test_pydantic_model_conversion(self):
"""Test converting Pydantic model."""
symbol = Symbol(name="test", kind="function", range=(1, 5))
result = _to_jsonable(symbol)
assert result["name"] == "test"
assert result["kind"] == "function"
assert result["range"] == (1, 5)
def test_dataclass_conversion(self):
"""Test converting dataclass."""
@dataclass
class TestData:
name: str
value: int
data = TestData(name="test", value=42)
result = _to_jsonable(data)
assert result["name"] == "test"
assert result["value"] == 42
class TestPrintJson:
"""Tests for print_json function."""
def test_print_success_json(self, capsys):
"""Test printing success JSON."""
with patch("codexlens.cli.output.console") as mock_console:
captured_output = []
mock_console.print_json = lambda x: captured_output.append(x)
print_json(success=True, result={"key": "value"})
output = json.loads(captured_output[0])
assert output["success"] is True
assert output["result"]["key"] == "value"
def test_print_error_json(self, capsys):
"""Test printing error JSON."""
with patch("codexlens.cli.output.console") as mock_console:
captured_output = []
mock_console.print_json = lambda x: captured_output.append(x)
print_json(success=False, error="Something went wrong")
output = json.loads(captured_output[0])
assert output["success"] is False
assert output["error"] == "Something went wrong"
def test_print_error_default_message(self, capsys):
"""Test printing error with default message."""
with patch("codexlens.cli.output.console") as mock_console:
captured_output = []
mock_console.print_json = lambda x: captured_output.append(x)
print_json(success=False)
output = json.loads(captured_output[0])
assert output["error"] == "Unknown error"
class TestRenderSearchResults:
"""Tests for render_search_results function."""
def test_render_empty_results(self):
"""Test rendering empty results."""
with patch("codexlens.cli.output.console") as mock_console:
render_search_results([])
mock_console.print.assert_called_once()
def test_render_results_with_data(self):
"""Test rendering results with data."""
results = [
SearchResult(path="/test/a.py", score=0.95, excerpt="test excerpt"),
SearchResult(path="/test/b.py", score=0.85, excerpt="another excerpt"),
]
with patch("codexlens.cli.output.console") as mock_console:
render_search_results(results)
mock_console.print.assert_called_once()
def test_render_results_custom_title(self):
"""Test rendering results with custom title."""
results = [SearchResult(path="/test.py", score=0.5)]
with patch("codexlens.cli.output.console") as mock_console:
render_search_results(results, title="Custom Title")
mock_console.print.assert_called_once()
class TestRenderSymbols:
"""Tests for render_symbols function."""
def test_render_empty_symbols(self):
"""Test rendering empty symbols list."""
with patch("codexlens.cli.output.console") as mock_console:
render_symbols([])
mock_console.print.assert_called_once()
def test_render_symbols_with_data(self):
"""Test rendering symbols with data."""
symbols = [
Symbol(name="MyClass", kind="class", range=(1, 10)),
Symbol(name="my_func", kind="function", range=(12, 20)),
]
with patch("codexlens.cli.output.console") as mock_console:
render_symbols(symbols)
mock_console.print.assert_called_once()
def test_render_symbols_custom_title(self):
"""Test rendering symbols with custom title."""
symbols = [Symbol(name="test", kind="function", range=(1, 1))]
with patch("codexlens.cli.output.console") as mock_console:
render_symbols(symbols, title="Functions Found")
mock_console.print.assert_called_once()
class TestRenderStatus:
"""Tests for render_status function."""
def test_render_basic_stats(self):
"""Test rendering basic stats."""
stats = {"files": 100, "symbols": 500}
with patch("codexlens.cli.output.console") as mock_console:
render_status(stats)
mock_console.print.assert_called_once()
def test_render_stats_with_nested_dict(self):
"""Test rendering stats with nested dict."""
stats = {
"files": 100,
"languages": {"python": 50, "javascript": 30, "go": 20},
}
with patch("codexlens.cli.output.console") as mock_console:
render_status(stats)
mock_console.print.assert_called_once()
def test_render_stats_with_list(self):
"""Test rendering stats with list value."""
stats = {
"files": 100,
"recent_files": ["/a.py", "/b.py", "/c.py"],
}
with patch("codexlens.cli.output.console") as mock_console:
render_status(stats)
mock_console.print.assert_called_once()
class TestRenderFileInspect:
"""Tests for render_file_inspect function."""
def test_render_file_with_symbols(self):
"""Test rendering file inspection with symbols."""
symbols = [
Symbol(name="hello", kind="function", range=(1, 5)),
Symbol(name="MyClass", kind="class", range=(7, 20)),
]
with patch("codexlens.cli.output.console") as mock_console:
render_file_inspect("/test/file.py", "python", symbols)
# Should be called twice: once for header, once for symbols table
assert mock_console.print.call_count == 2
def test_render_file_without_symbols(self):
"""Test rendering file inspection without symbols."""
with patch("codexlens.cli.output.console") as mock_console:
render_file_inspect("/test/file.py", "python", [])
assert mock_console.print.call_count == 2
class TestJsonOutputIntegration:
"""Integration tests for JSON output."""
def test_search_result_to_json(self):
"""Test converting SearchResult to JSON."""
result = SearchResult(
path="/test.py",
score=0.95,
excerpt="test code here",
metadata={"line": 10},
)
jsonable = _to_jsonable(result)
# Verify it can be JSON serialized
json_str = json.dumps(jsonable)
parsed = json.loads(json_str)
assert parsed["path"] == "/test.py"
assert parsed["score"] == 0.95
assert parsed["excerpt"] == "test code here"
def test_nested_results_to_json(self):
"""Test converting nested structure to JSON."""
data = {
"query": "test",
"results": [
SearchResult(path="/a.py", score=0.9),
SearchResult(path="/b.py", score=0.8),
],
}
jsonable = _to_jsonable(data)
json_str = json.dumps(jsonable)
parsed = json.loads(json_str)
assert parsed["query"] == "test"
assert len(parsed["results"]) == 2

View File

@@ -0,0 +1,342 @@
"""Tests for code extractor functionality."""
import tempfile
from pathlib import Path
import pytest
from codexlens.entities import SearchResult, Symbol
from codexlens.semantic.code_extractor import (
CodeBlockResult,
extract_complete_code_block,
extract_symbol_with_context,
format_search_result_code,
get_code_block_summary,
enhance_search_results,
)
class TestExtractCompleteCodeBlock:
"""Test extract_complete_code_block function."""
def test_returns_stored_content(self):
"""Test returns content when available in result."""
result = SearchResult(
path="/test.py",
score=0.9,
content="def hello():\n return 'world'",
start_line=1,
end_line=2,
)
code = extract_complete_code_block(result)
assert code == "def hello():\n return 'world'"
def test_reads_from_file_when_no_content(self, tmp_path):
"""Test reads from file when content not in result."""
test_file = tmp_path / "test.py"
test_file.write_text("""# Header comment
def hello():
'''Docstring'''
return 'world'
def goodbye():
pass
""")
result = SearchResult(
path=str(test_file),
score=0.9,
excerpt="def hello():",
start_line=2,
end_line=4,
)
code = extract_complete_code_block(result)
assert "def hello():" in code
assert "return 'world'" in code
def test_adds_context_lines(self, tmp_path):
"""Test adding context lines."""
test_file = tmp_path / "test.py"
test_file.write_text("""# Line 1
# Line 2
def hello():
return 'world'
# Line 5
# Line 6
""")
result = SearchResult(
path=str(test_file),
score=0.9,
start_line=3,
end_line=4,
)
code = extract_complete_code_block(result, context_lines=1)
assert "# Line 2" in code
assert "# Line 5" in code
class TestExtractSymbolWithContext:
"""Test extract_symbol_with_context function."""
def test_extracts_with_decorators(self, tmp_path):
"""Test extracting symbol with decorators."""
test_file = tmp_path / "test.py"
# Line 1: @decorator
# Line 2: @another_decorator
# Line 3: def hello():
# Line 4: return 'world'
test_file.write_text("@decorator\n@another_decorator\ndef hello():\n return 'world'\n")
symbol = Symbol(name="hello", kind="function", range=(3, 4))
code = extract_symbol_with_context(str(test_file), symbol)
assert "@decorator" in code
assert "@another_decorator" in code
assert "def hello():" in code
class TestFormatSearchResultCode:
"""Test format_search_result_code function."""
def test_format_with_line_numbers(self):
"""Test formatting with line numbers."""
result = SearchResult(
path="/test.py",
score=0.9,
content="def hello():\n return 'world'",
start_line=10,
end_line=11,
)
formatted = format_search_result_code(result, show_line_numbers=True)
assert " 10 |" in formatted
assert " 11 |" in formatted
def test_format_truncation(self):
"""Test max_lines truncation."""
result = SearchResult(
path="/test.py",
score=0.9,
content="line1\nline2\nline3\nline4\nline5",
start_line=1,
end_line=5,
)
formatted = format_search_result_code(result, max_lines=2)
assert "(truncated)" in formatted
def test_format_without_line_numbers(self):
"""Test formatting without line numbers."""
result = SearchResult(
path="/test.py",
score=0.9,
content="def hello():\n pass",
start_line=1,
end_line=2,
)
formatted = format_search_result_code(result, show_line_numbers=False)
assert "def hello():" in formatted
assert " | " not in formatted
class TestGetCodeBlockSummary:
"""Test get_code_block_summary function."""
def test_summary_with_symbol(self):
"""Test summary with symbol info."""
result = SearchResult(
path="/test.py",
score=0.9,
symbol_name="hello",
symbol_kind="function",
start_line=10,
end_line=20,
)
summary = get_code_block_summary(result)
assert "function" in summary
assert "hello" in summary
assert "10-20" in summary
assert "test.py" in summary
def test_summary_single_line(self):
"""Test summary for single line."""
result = SearchResult(
path="/test.py",
score=0.9,
start_line=5,
end_line=5,
)
summary = get_code_block_summary(result)
assert "line 5" in summary
class TestCodeBlockResult:
"""Test CodeBlockResult class."""
def test_properties(self):
"""Test CodeBlockResult properties."""
result = SearchResult(
path="/path/to/test.py",
score=0.85,
content="def hello(): pass",
symbol_name="hello",
symbol_kind="function",
start_line=1,
end_line=1,
)
block = CodeBlockResult(result)
assert block.score == 0.85
assert block.path == "/path/to/test.py"
assert block.file_name == "test.py"
assert block.symbol_name == "hello"
assert block.symbol_kind == "function"
assert block.line_range == (1, 1)
assert block.full_code == "def hello(): pass"
def test_summary(self):
"""Test CodeBlockResult summary."""
result = SearchResult(
path="/test.py",
score=0.9,
symbol_name="Calculator",
symbol_kind="class",
start_line=10,
end_line=50,
)
block = CodeBlockResult(result)
summary = block.summary
assert "class" in summary
assert "Calculator" in summary
def test_format(self):
"""Test CodeBlockResult format."""
result = SearchResult(
path="/test.py",
score=0.9,
content="def hello():\n return 42",
start_line=1,
end_line=2,
)
block = CodeBlockResult(result)
formatted = block.format(show_line_numbers=True)
assert " 1 |" in formatted
assert "def hello():" in formatted
class TestEnhanceSearchResults:
"""Test enhance_search_results function."""
def test_enhances_results(self):
"""Test enhancing search results."""
results = [
SearchResult(path="/a.py", score=0.9, content="def a(): pass"),
SearchResult(path="/b.py", score=0.8, content="def b(): pass"),
]
enhanced = enhance_search_results(results)
assert len(enhanced) == 2
assert all(isinstance(r, CodeBlockResult) for r in enhanced)
assert enhanced[0].score == 0.9
assert enhanced[1].score == 0.8
class TestIntegration:
"""Integration tests for code extraction."""
def test_full_workflow(self, tmp_path):
"""Test complete code extraction workflow."""
# Create test file
test_file = tmp_path / "calculator.py"
test_file.write_text('''"""Calculator module."""
@staticmethod
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: First number
b: Second number
Returns:
Sum of a and b
"""
return a + b
class Calculator:
"""A simple calculator."""
def __init__(self):
self.result = 0
def compute(self, operation: str, value: int) -> int:
"""Perform computation."""
if operation == "add":
self.result += value
elif operation == "sub":
self.result -= value
return self.result
''')
# Simulate search result for 'add' function
result = SearchResult(
path=str(test_file),
score=0.92,
content='''@staticmethod
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: First number
b: Second number
Returns:
Sum of a and b
"""
return a + b''',
symbol_name="add",
symbol_kind="function",
start_line=3,
end_line=14,
)
block = CodeBlockResult(result)
# Test properties
assert block.symbol_name == "add"
assert block.symbol_kind == "function"
assert block.line_range == (3, 14)
# Test full code
assert "@staticmethod" in block.full_code
assert "def add(" in block.full_code
assert "return a + b" in block.full_code
# Test summary
summary = block.summary
assert "function" in summary
assert "add" in summary
# Test format
formatted = block.format(show_line_numbers=True)
assert " 3 |" in formatted or "3 |" in formatted
print("\n--- Full Code Block ---")
print(block.full_code)
print("\n--- Formatted Output ---")
print(formatted)
print("\n--- Summary ---")
print(summary)

View File

@@ -0,0 +1,350 @@
"""Tests for CodexLens configuration system."""
import os
import tempfile
from pathlib import Path
import pytest
from codexlens.config import (
WORKSPACE_DIR_NAME,
Config,
WorkspaceConfig,
_default_global_dir,
find_workspace_root,
)
from codexlens.errors import ConfigError
class TestDefaultGlobalDir:
"""Tests for _default_global_dir function."""
def test_default_location(self):
"""Test default location is ~/.codexlens."""
# Clear any environment override
env_backup = os.environ.get("CODEXLENS_DATA_DIR")
if "CODEXLENS_DATA_DIR" in os.environ:
del os.environ["CODEXLENS_DATA_DIR"]
try:
result = _default_global_dir()
assert result == (Path.home() / ".codexlens").resolve()
finally:
if env_backup is not None:
os.environ["CODEXLENS_DATA_DIR"] = env_backup
def test_env_override(self):
"""Test CODEXLENS_DATA_DIR environment variable override."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
result = _default_global_dir()
assert result == Path(tmpdir).resolve()
finally:
del os.environ["CODEXLENS_DATA_DIR"]
class TestFindWorkspaceRoot:
"""Tests for find_workspace_root function."""
def test_finds_workspace_in_current_dir(self):
"""Test finding workspace when .codexlens is in current directory."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
(base / WORKSPACE_DIR_NAME).mkdir()
result = find_workspace_root(base)
assert result == base.resolve()
def test_finds_workspace_in_parent_dir(self):
"""Test finding workspace in parent directory."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
(base / WORKSPACE_DIR_NAME).mkdir()
subdir = base / "src" / "components"
subdir.mkdir(parents=True)
result = find_workspace_root(subdir)
assert result == base.resolve()
def test_returns_none_when_not_found(self):
"""Test returns None when no workspace found in isolated directory."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a deep nested directory to avoid finding user's home .codexlens
isolated = Path(tmpdir) / "a" / "b" / "c"
isolated.mkdir(parents=True)
result = find_workspace_root(isolated)
# May find user's .codexlens if it exists in parent dirs
# So we just check it doesn't find one in our temp directory
if result is not None:
assert WORKSPACE_DIR_NAME not in str(isolated)
def test_does_not_find_file_as_workspace(self):
"""Test that a file named .codexlens is not recognized as workspace."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
# Create isolated subdirectory
subdir = base / "project"
subdir.mkdir()
(subdir / WORKSPACE_DIR_NAME).write_text("not a directory")
result = find_workspace_root(subdir)
# Should not find the file as workspace
if result is not None:
assert result != subdir
class TestConfig:
"""Tests for Config class."""
def test_default_config(self):
"""Test creating config with defaults."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
assert config.data_dir == Path(tmpdir).resolve()
assert config.venv_path == Path(tmpdir).resolve() / "venv"
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_creates_data_dir(self):
"""Test that data_dir is created on init."""
with tempfile.TemporaryDirectory() as tmpdir:
data_dir = Path(tmpdir) / "new_dir"
config = Config(data_dir=data_dir)
assert data_dir.exists()
def test_supported_languages(self):
"""Test default supported languages."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
assert "python" in config.supported_languages
assert "javascript" in config.supported_languages
assert "typescript" in config.supported_languages
assert "java" in config.supported_languages
assert "go" in config.supported_languages
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_cache_dir_property(self):
"""Test cache_dir property."""
with tempfile.TemporaryDirectory() as tmpdir:
config = Config(data_dir=Path(tmpdir))
assert config.cache_dir == Path(tmpdir).resolve() / "cache"
def test_index_dir_property(self):
"""Test index_dir property."""
with tempfile.TemporaryDirectory() as tmpdir:
config = Config(data_dir=Path(tmpdir))
assert config.index_dir == Path(tmpdir).resolve() / "index"
def test_db_path_property(self):
"""Test db_path property."""
with tempfile.TemporaryDirectory() as tmpdir:
config = Config(data_dir=Path(tmpdir))
assert config.db_path == Path(tmpdir).resolve() / "index" / "codexlens.db"
def test_ensure_runtime_dirs(self):
"""Test ensure_runtime_dirs creates directories."""
with tempfile.TemporaryDirectory() as tmpdir:
config = Config(data_dir=Path(tmpdir))
config.ensure_runtime_dirs()
assert config.cache_dir.exists()
assert config.index_dir.exists()
def test_language_for_path_python(self):
"""Test language detection for Python files."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
assert config.language_for_path("test.py") == "python"
assert config.language_for_path("/path/to/file.py") == "python"
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_language_for_path_javascript(self):
"""Test language detection for JavaScript files."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
assert config.language_for_path("test.js") == "javascript"
assert config.language_for_path("component.jsx") == "javascript"
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_language_for_path_typescript(self):
"""Test language detection for TypeScript files."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
assert config.language_for_path("test.ts") == "typescript"
assert config.language_for_path("component.tsx") == "typescript"
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_language_for_path_unknown(self):
"""Test language detection for unknown files."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
assert config.language_for_path("test.xyz") is None
assert config.language_for_path("README.md") is None
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_language_for_path_case_insensitive(self):
"""Test language detection is case insensitive."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
assert config.language_for_path("TEST.PY") == "python"
assert config.language_for_path("File.Js") == "javascript"
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_rules_for_language(self):
"""Test getting parsing rules for a language."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
rules = config.rules_for_language("python")
assert "max_chunk_chars" in rules
assert "max_chunk_lines" in rules
assert "overlap_lines" in rules
finally:
del os.environ["CODEXLENS_DATA_DIR"]
class TestWorkspaceConfig:
"""Tests for WorkspaceConfig class."""
def test_create_workspace_config(self):
"""Test creating a workspace config."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
assert workspace.workspace_root == Path(tmpdir).resolve()
def test_codexlens_dir_property(self):
"""Test codexlens_dir property."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
assert workspace.codexlens_dir == Path(tmpdir).resolve() / WORKSPACE_DIR_NAME
def test_db_path_property(self):
"""Test db_path property."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
expected = Path(tmpdir).resolve() / WORKSPACE_DIR_NAME / "index.db"
assert workspace.db_path == expected
def test_cache_dir_property(self):
"""Test cache_dir property."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
expected = Path(tmpdir).resolve() / WORKSPACE_DIR_NAME / "cache"
assert workspace.cache_dir == expected
def test_initialize_creates_directories(self):
"""Test initialize creates .codexlens directory structure."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
workspace.initialize()
assert workspace.codexlens_dir.exists()
assert workspace.cache_dir.exists()
assert (workspace.codexlens_dir / ".gitignore").exists()
def test_initialize_creates_gitignore(self):
"""Test initialize creates .gitignore with correct content."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
workspace.initialize()
gitignore = workspace.codexlens_dir / ".gitignore"
content = gitignore.read_text()
assert "cache/" in content
def test_exists_false_when_not_initialized(self):
"""Test exists returns False when not initialized."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
assert not workspace.exists()
def test_exists_true_when_initialized_with_db(self):
"""Test exists returns True when initialized with db."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig(workspace_root=Path(tmpdir))
workspace.initialize()
# Create the db file to simulate full initialization
workspace.db_path.write_text("")
assert workspace.exists()
def test_from_path_finds_workspace(self):
"""Test from_path finds existing workspace."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
(base / WORKSPACE_DIR_NAME).mkdir()
workspace = WorkspaceConfig.from_path(base)
assert workspace is not None
assert workspace.workspace_root == base.resolve()
def test_from_path_returns_none_when_not_found(self):
"""Test from_path returns None when no workspace found in isolated directory."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create isolated directory structure to avoid user's .codexlens
isolated = Path(tmpdir) / "a" / "b" / "c"
isolated.mkdir(parents=True)
workspace = WorkspaceConfig.from_path(isolated)
# May find user's .codexlens if it exists
if workspace is not None:
assert WORKSPACE_DIR_NAME not in str(isolated)
def test_create_at_initializes_workspace(self):
"""Test create_at creates and initializes workspace."""
with tempfile.TemporaryDirectory() as tmpdir:
workspace = WorkspaceConfig.create_at(Path(tmpdir))
assert workspace.codexlens_dir.exists()
assert workspace.cache_dir.exists()
class TestConfigEdgeCases:
"""Edge case tests for configuration."""
def test_config_with_path_object(self):
"""Test Config accepts Path objects."""
with tempfile.TemporaryDirectory() as tmpdir:
config = Config(data_dir=Path(tmpdir))
assert isinstance(config.data_dir, Path)
def test_config_expands_user_path(self):
"""Test Config expands ~ in paths."""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
# Just verify it doesn't crash and returns a resolved path
assert config.data_dir.is_absolute()
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_workspace_config_from_subdir(self):
"""Test WorkspaceConfig.from_path works from subdirectory."""
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir)
(base / WORKSPACE_DIR_NAME).mkdir()
deep_subdir = base / "a" / "b" / "c" / "d"
deep_subdir.mkdir(parents=True)
workspace = WorkspaceConfig.from_path(deep_subdir)
assert workspace is not None
assert workspace.workspace_root == base.resolve()

View File

@@ -0,0 +1,222 @@
"""Tests for CodexLens entity models."""
import pytest
from pydantic import ValidationError
from codexlens.entities import IndexedFile, SearchResult, SemanticChunk, Symbol
class TestSymbol:
"""Tests for Symbol entity."""
def test_create_valid_symbol(self):
"""Test creating a valid symbol."""
symbol = Symbol(name="hello", kind="function", range=(1, 10))
assert symbol.name == "hello"
assert symbol.kind == "function"
assert symbol.range == (1, 10)
def test_symbol_range_validation(self):
"""Test that range values must be valid."""
# Range must have start >= 1
with pytest.raises(ValidationError):
Symbol(name="test", kind="function", range=(0, 5))
# Range must have end >= start
with pytest.raises(ValidationError):
Symbol(name="test", kind="function", range=(5, 3))
# Both values must be >= 1
with pytest.raises(ValidationError):
Symbol(name="test", kind="function", range=(-1, 5))
def test_symbol_name_required(self):
"""Test that name is required and non-empty."""
with pytest.raises(ValidationError):
Symbol(name="", kind="function", range=(1, 1))
def test_symbol_kind_required(self):
"""Test that kind is required and non-empty."""
with pytest.raises(ValidationError):
Symbol(name="test", kind="", range=(1, 1))
def test_symbol_equal_range(self):
"""Test symbol with equal start and end line."""
symbol = Symbol(name="one_liner", kind="function", range=(5, 5))
assert symbol.range == (5, 5)
class TestSemanticChunk:
"""Tests for SemanticChunk entity."""
def test_create_chunk_without_embedding(self):
"""Test creating a chunk without embedding."""
chunk = SemanticChunk(content="def hello(): pass")
assert chunk.content == "def hello(): pass"
assert chunk.embedding is None
assert chunk.metadata == {}
def test_create_chunk_with_embedding(self):
"""Test creating a chunk with embedding."""
embedding = [0.1, 0.2, 0.3, 0.4]
chunk = SemanticChunk(content="some code", embedding=embedding)
assert chunk.embedding == embedding
def test_chunk_with_metadata(self):
"""Test creating a chunk with metadata."""
metadata = {"file": "test.py", "language": "python", "line": 10}
chunk = SemanticChunk(content="code", metadata=metadata)
assert chunk.metadata == metadata
def test_chunk_content_required(self):
"""Test that content is required and non-empty."""
with pytest.raises(ValidationError):
SemanticChunk(content="")
def test_chunk_embedding_validation(self):
"""Test that embedding cannot be empty list when provided."""
with pytest.raises(ValidationError):
SemanticChunk(content="code", embedding=[])
def test_chunk_embedding_with_floats(self):
"""Test embedding with various float values."""
embedding = [0.0, 1.0, -0.5, 0.123456789]
chunk = SemanticChunk(content="code", embedding=embedding)
assert chunk.embedding == embedding
class TestIndexedFile:
"""Tests for IndexedFile entity."""
def test_create_empty_indexed_file(self):
"""Test creating an indexed file with no symbols or chunks."""
indexed = IndexedFile(path="/test/file.py", language="python")
assert indexed.path == "/test/file.py"
assert indexed.language == "python"
assert indexed.symbols == []
assert indexed.chunks == []
def test_create_indexed_file_with_symbols(self):
"""Test creating an indexed file with symbols."""
symbols = [
Symbol(name="MyClass", kind="class", range=(1, 10)),
Symbol(name="my_func", kind="function", range=(12, 20)),
]
indexed = IndexedFile(
path="/test/file.py",
language="python",
symbols=symbols,
)
assert len(indexed.symbols) == 2
assert indexed.symbols[0].name == "MyClass"
def test_create_indexed_file_with_chunks(self):
"""Test creating an indexed file with chunks."""
chunks = [
SemanticChunk(content="chunk 1", metadata={"line": 1}),
SemanticChunk(content="chunk 2", metadata={"line": 10}),
]
indexed = IndexedFile(
path="/test/file.py",
language="python",
chunks=chunks,
)
assert len(indexed.chunks) == 2
def test_indexed_file_path_strip(self):
"""Test that path is stripped of whitespace."""
indexed = IndexedFile(path=" /test/file.py ", language="python")
assert indexed.path == "/test/file.py"
def test_indexed_file_language_strip(self):
"""Test that language is stripped of whitespace."""
indexed = IndexedFile(path="/test/file.py", language=" python ")
assert indexed.language == "python"
def test_indexed_file_path_required(self):
"""Test that path is required and non-blank."""
with pytest.raises(ValidationError):
IndexedFile(path="", language="python")
with pytest.raises(ValidationError):
IndexedFile(path=" ", language="python")
def test_indexed_file_language_required(self):
"""Test that language is required and non-blank."""
with pytest.raises(ValidationError):
IndexedFile(path="/test/file.py", language="")
class TestSearchResult:
"""Tests for SearchResult entity."""
def test_create_minimal_search_result(self):
"""Test creating a minimal search result."""
result = SearchResult(path="/test/file.py", score=0.95)
assert result.path == "/test/file.py"
assert result.score == 0.95
assert result.excerpt is None
assert result.symbol is None
assert result.chunk is None
assert result.metadata == {}
def test_create_full_search_result(self):
"""Test creating a search result with all fields."""
symbol = Symbol(name="test", kind="function", range=(1, 5))
chunk = SemanticChunk(content="test code")
result = SearchResult(
path="/test/file.py",
score=0.88,
excerpt="...matching code...",
symbol=symbol,
chunk=chunk,
metadata={"match_type": "fts"},
)
assert result.excerpt == "...matching code..."
assert result.symbol.name == "test"
assert result.chunk.content == "test code"
def test_search_result_score_validation(self):
"""Test that score must be >= 0."""
with pytest.raises(ValidationError):
SearchResult(path="/test/file.py", score=-0.1)
def test_search_result_zero_score(self):
"""Test that zero score is valid."""
result = SearchResult(path="/test/file.py", score=0.0)
assert result.score == 0.0
def test_search_result_path_required(self):
"""Test that path is required and non-empty."""
with pytest.raises(ValidationError):
SearchResult(path="", score=0.5)
class TestEntitySerialization:
"""Tests for entity serialization."""
def test_symbol_model_dump(self):
"""Test Symbol serialization."""
symbol = Symbol(name="test", kind="function", range=(1, 10))
data = symbol.model_dump()
assert data == {"name": "test", "kind": "function", "range": (1, 10)}
def test_indexed_file_model_dump(self):
"""Test IndexedFile serialization."""
indexed = IndexedFile(
path="/test.py",
language="python",
symbols=[Symbol(name="foo", kind="function", range=(1, 1))],
)
data = indexed.model_dump()
assert data["path"] == "/test.py"
assert data["language"] == "python"
assert len(data["symbols"]) == 1
def test_search_result_model_dump(self):
"""Test SearchResult serialization."""
result = SearchResult(path="/test.py", score=0.5, excerpt="test")
data = result.model_dump()
assert data["path"] == "/test.py"
assert data["score"] == 0.5
assert data["excerpt"] == "test"

View File

@@ -0,0 +1,165 @@
"""Tests for CodexLens error classes."""
import pytest
from codexlens.errors import (
CodexLensError,
ConfigError,
ParseError,
SearchError,
StorageError,
)
class TestErrorHierarchy:
"""Tests for error class hierarchy."""
def test_codexlens_error_is_exception(self):
"""Test that CodexLensError is an Exception."""
assert issubclass(CodexLensError, Exception)
def test_config_error_inherits_from_base(self):
"""Test ConfigError inherits from CodexLensError."""
assert issubclass(ConfigError, CodexLensError)
def test_parse_error_inherits_from_base(self):
"""Test ParseError inherits from CodexLensError."""
assert issubclass(ParseError, CodexLensError)
def test_storage_error_inherits_from_base(self):
"""Test StorageError inherits from CodexLensError."""
assert issubclass(StorageError, CodexLensError)
def test_search_error_inherits_from_base(self):
"""Test SearchError inherits from CodexLensError."""
assert issubclass(SearchError, CodexLensError)
class TestErrorMessages:
"""Tests for error message handling."""
def test_codexlens_error_with_message(self):
"""Test creating CodexLensError with message."""
error = CodexLensError("Something went wrong")
assert str(error) == "Something went wrong"
def test_config_error_with_message(self):
"""Test creating ConfigError with message."""
error = ConfigError("Invalid configuration")
assert str(error) == "Invalid configuration"
def test_parse_error_with_message(self):
"""Test creating ParseError with message."""
error = ParseError("Failed to parse file.py")
assert str(error) == "Failed to parse file.py"
def test_storage_error_with_message(self):
"""Test creating StorageError with message."""
error = StorageError("Database connection failed")
assert str(error) == "Database connection failed"
def test_search_error_with_message(self):
"""Test creating SearchError with message."""
error = SearchError("FTS query syntax error")
assert str(error) == "FTS query syntax error"
class TestErrorRaising:
"""Tests for raising and catching errors."""
def test_catch_specific_error(self):
"""Test catching specific error type."""
with pytest.raises(ConfigError):
raise ConfigError("test")
def test_catch_base_error(self):
"""Test catching base error type catches all subtypes."""
with pytest.raises(CodexLensError):
raise ConfigError("test")
with pytest.raises(CodexLensError):
raise ParseError("test")
with pytest.raises(CodexLensError):
raise StorageError("test")
with pytest.raises(CodexLensError):
raise SearchError("test")
def test_error_not_caught_as_wrong_type(self):
"""Test that errors aren't caught as wrong type."""
with pytest.raises(ConfigError):
try:
raise ConfigError("config issue")
except ParseError:
pass # This should not catch ConfigError
class TestErrorChaining:
"""Tests for error chaining."""
def test_error_with_cause(self):
"""Test error chaining with __cause__."""
original = ValueError("original error")
try:
raise StorageError("storage failed") from original
except StorageError as e:
assert e.__cause__ is original
def test_nested_error_handling(self):
"""Test nested error handling pattern."""
def inner_function():
raise ValueError("inner error")
def outer_function():
try:
inner_function()
except ValueError as e:
raise ParseError("outer error") from e
with pytest.raises(ParseError) as exc_info:
outer_function()
assert exc_info.value.__cause__ is not None
assert isinstance(exc_info.value.__cause__, ValueError)
class TestErrorUsagePatterns:
"""Tests for common error usage patterns."""
def test_error_in_context_manager(self):
"""Test error handling in context manager."""
class FakeStore:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False # Don't suppress exceptions
def query(self):
raise StorageError("query failed")
with pytest.raises(StorageError):
with FakeStore() as store:
store.query()
def test_error_comparison(self):
"""Test error instance comparison."""
error1 = ConfigError("test")
error2 = ConfigError("test")
# Different instances, even with same message
assert error1 is not error2
# But same string representation
assert str(error1) == str(error2)
def test_empty_error_message(self):
"""Test error with empty message."""
error = CodexLensError("")
assert str(error) == ""
def test_error_with_format_args(self):
"""Test error with formatted message."""
path = "/test/file.py"
error = ParseError(f"Failed to parse {path}: syntax error on line 10")
assert "/test/file.py" in str(error)
assert "line 10" in str(error)

View File

@@ -0,0 +1,224 @@
"""Tests for CodexLens file cache."""
import tempfile
from pathlib import Path
import pytest
from codexlens.storage.file_cache import FileCache
class TestFileCache:
"""Tests for FileCache class."""
def test_create_cache(self):
"""Test creating a FileCache instance."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
assert cache.cache_path == Path(tmpdir)
def test_store_and_load_mtime(self):
"""Test storing and loading mtime."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
file_path = Path("/test/file.py")
mtime = 1234567890.123
cache.store_mtime(file_path, mtime)
loaded = cache.load_mtime(file_path)
assert loaded == mtime
def test_load_nonexistent_mtime(self):
"""Test loading mtime for uncached file returns None."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
file_path = Path("/nonexistent/file.py")
loaded = cache.load_mtime(file_path)
assert loaded is None
def test_update_mtime(self):
"""Test updating existing mtime."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
file_path = Path("/test/file.py")
cache.store_mtime(file_path, 1000.0)
cache.store_mtime(file_path, 2000.0)
loaded = cache.load_mtime(file_path)
assert loaded == 2000.0
def test_multiple_files(self):
"""Test caching multiple files."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
files = {
Path("/test/a.py"): 1000.0,
Path("/test/b.py"): 2000.0,
Path("/test/c.py"): 3000.0,
}
for path, mtime in files.items():
cache.store_mtime(path, mtime)
for path, expected_mtime in files.items():
loaded = cache.load_mtime(path)
assert loaded == expected_mtime
class TestFileCacheKeyGeneration:
"""Tests for cache key generation."""
def test_key_for_simple_path(self):
"""Test key generation for simple path."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
key = cache._key_for(Path("test.py"))
assert key.endswith(".mtime")
def test_key_for_path_with_slashes(self):
"""Test key generation for path with slashes."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
key = cache._key_for(Path("/path/to/file.py"))
assert "/" not in key
assert key.endswith(".mtime")
def test_key_for_windows_path(self):
"""Test key generation for Windows-style path."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
key = cache._key_for(Path("C:\\Users\\test\\file.py"))
assert "\\" not in key
assert ":" not in key
assert key.endswith(".mtime")
def test_different_paths_different_keys(self):
"""Test that different paths produce different keys."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
key1 = cache._key_for(Path("/test/a.py"))
key2 = cache._key_for(Path("/test/b.py"))
assert key1 != key2
class TestFileCacheDirectoryCreation:
"""Tests for cache directory creation."""
def test_creates_cache_directory(self):
"""Test that cache directory is created when storing."""
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = Path(tmpdir) / "new_cache_dir"
cache = FileCache(cache_path=cache_path)
assert not cache_path.exists()
cache.store_mtime(Path("/test.py"), 1000.0)
assert cache_path.exists()
def test_nested_cache_directory(self):
"""Test creating nested cache directory."""
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = Path(tmpdir) / "a" / "b" / "c" / "cache"
cache = FileCache(cache_path=cache_path)
cache.store_mtime(Path("/test.py"), 1000.0)
assert cache_path.exists()
class TestFileCacheEdgeCases:
"""Edge case tests for FileCache."""
def test_mtime_precision(self):
"""Test that mtime precision is preserved."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
precise_mtime = 1234567890.123456789
cache.store_mtime(Path("/test.py"), precise_mtime)
loaded = cache.load_mtime(Path("/test.py"))
# Should preserve reasonable precision
assert abs(loaded - precise_mtime) < 0.0001
def test_zero_mtime(self):
"""Test storing zero mtime."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
cache.store_mtime(Path("/test.py"), 0.0)
loaded = cache.load_mtime(Path("/test.py"))
assert loaded == 0.0
def test_negative_mtime(self):
"""Test storing negative mtime (edge case)."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
cache.store_mtime(Path("/test.py"), -1000.0)
loaded = cache.load_mtime(Path("/test.py"))
assert loaded == -1000.0
def test_large_mtime(self):
"""Test storing large mtime value."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
large_mtime = 9999999999.999
cache.store_mtime(Path("/test.py"), large_mtime)
loaded = cache.load_mtime(Path("/test.py"))
assert loaded == large_mtime
def test_unicode_path(self):
"""Test path with unicode characters."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
unicode_path = Path("/测试/文件.py")
cache.store_mtime(unicode_path, 1000.0)
loaded = cache.load_mtime(unicode_path)
assert loaded == 1000.0
def test_load_corrupted_cache_file(self):
"""Test loading corrupted cache file returns None."""
with tempfile.TemporaryDirectory() as tmpdir:
cache = FileCache(cache_path=Path(tmpdir))
file_path = Path("/test.py")
# Create a corrupted cache file
cache.store_mtime(file_path, 1000.0)
key = cache._key_for(file_path)
(Path(tmpdir) / key).write_text("not a number")
# Should return None for corrupted data
loaded = cache.load_mtime(file_path)
assert loaded is None
class TestFileCachePersistence:
"""Tests for cache persistence across instances."""
def test_cache_persists_across_instances(self):
"""Test that cache data persists when creating new instance."""
with tempfile.TemporaryDirectory() as tmpdir:
cache_path = Path(tmpdir)
# Store with first instance
cache1 = FileCache(cache_path=cache_path)
cache1.store_mtime(Path("/test.py"), 1234.0)
# Load with second instance
cache2 = FileCache(cache_path=cache_path)
loaded = cache2.load_mtime(Path("/test.py"))
assert loaded == 1234.0

View File

@@ -1,13 +1,19 @@
"""Tests for CodexLens parsers."""
import tempfile
from pathlib import Path
import pytest
from codexlens.config import Config
from codexlens.parsers.factory import (
ParserFactory,
SimpleRegexParser,
_parse_go_symbols,
_parse_java_symbols,
_parse_js_ts_symbols,
_parse_python_symbols,
_parse_generic_symbols,
)
@@ -137,6 +143,151 @@ class TestJavaScriptParser:
assert all(name != "constructor" for name, _ in names_kinds)
class TestJavaParser:
"""Tests for Java symbol parsing."""
def test_parse_class(self):
code = "public class MyClass {\n}"
symbols = _parse_java_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "MyClass"
assert symbols[0].kind == "class"
def test_parse_class_without_public(self):
code = "class InternalClass {\n}"
symbols = _parse_java_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "InternalClass"
def test_parse_method(self):
code = "public class Test {\n public void doSomething() {}\n}"
symbols = _parse_java_symbols(code)
assert len(symbols) == 2
assert symbols[0].name == "Test"
assert symbols[0].kind == "class"
assert symbols[1].name == "doSomething"
assert symbols[1].kind == "method"
def test_parse_static_method(self):
code = "public class Test {\n public static void main(String[] args) {}\n}"
symbols = _parse_java_symbols(code)
method_names = [s.name for s in symbols if s.kind == "method"]
assert "main" in method_names
def test_parse_private_method(self):
code = "public class Test {\n private int calculate() { return 0; }\n}"
symbols = _parse_java_symbols(code)
method_names = [s.name for s in symbols if s.kind == "method"]
assert "calculate" in method_names
def test_parse_generic_return_type(self):
code = "public class Test {\n public List<String> getItems() { return null; }\n}"
symbols = _parse_java_symbols(code)
method_names = [s.name for s in symbols if s.kind == "method"]
assert "getItems" in method_names
class TestGoParser:
"""Tests for Go symbol parsing."""
def test_parse_function(self):
code = "func hello() {\n}"
symbols = _parse_go_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "hello"
assert symbols[0].kind == "function"
def test_parse_function_with_params(self):
code = "func greet(name string) string {\n return name\n}"
symbols = _parse_go_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "greet"
def test_parse_method(self):
code = "func (s *Server) Start() error {\n return nil\n}"
symbols = _parse_go_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "Start"
assert symbols[0].kind == "function"
def test_parse_struct(self):
code = "type User struct {\n Name string\n}"
symbols = _parse_go_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "User"
assert symbols[0].kind == "class"
def test_parse_interface(self):
code = "type Reader interface {\n Read(p []byte) (n int, err error)\n}"
symbols = _parse_go_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "Reader"
assert symbols[0].kind == "class"
def test_parse_multiple_symbols(self):
code = """type Config struct {
Port int
}
func NewConfig() *Config {
return &Config{}
}
func (c *Config) Validate() error {
return nil
}
"""
symbols = _parse_go_symbols(code)
names = [s.name for s in symbols]
assert "Config" in names
assert "NewConfig" in names
assert "Validate" in names
class TestGenericParser:
"""Tests for generic symbol parsing."""
def test_parse_def_keyword(self):
code = "def something():\n pass"
symbols = _parse_generic_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "something"
assert symbols[0].kind == "function"
def test_parse_function_keyword(self):
code = "function doIt() {}"
symbols = _parse_generic_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "doIt"
def test_parse_func_keyword(self):
code = "func test() {}"
symbols = _parse_generic_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "test"
def test_parse_class_keyword(self):
code = "class MyClass {}"
symbols = _parse_generic_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "MyClass"
assert symbols[0].kind == "class"
def test_parse_struct_keyword(self):
code = "struct Point { x: i32, y: i32 }"
symbols = _parse_generic_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "Point"
assert symbols[0].kind == "class"
def test_parse_interface_keyword(self):
code = "interface Drawable {}"
symbols = _parse_generic_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "Drawable"
assert symbols[0].kind == "class"
class TestParserInterface:
"""High-level interface tests."""
@@ -146,3 +297,129 @@ class TestParserInterface:
assert indexed.language == "python"
assert len(indexed.symbols) == 1
assert indexed.symbols[0].name == "hello"
def test_simple_parser_javascript(self):
parser = SimpleRegexParser("javascript")
indexed = parser.parse("function test() {}", Path("test.js"))
assert indexed.language == "javascript"
assert len(indexed.symbols) == 1
def test_simple_parser_typescript(self):
parser = SimpleRegexParser("typescript")
indexed = parser.parse("export class Service {}", Path("test.ts"))
assert indexed.language == "typescript"
assert len(indexed.symbols) == 1
def test_simple_parser_java(self):
parser = SimpleRegexParser("java")
indexed = parser.parse("public class Main {}", Path("Main.java"))
assert indexed.language == "java"
assert len(indexed.symbols) == 1
def test_simple_parser_go(self):
parser = SimpleRegexParser("go")
indexed = parser.parse("func main() {}", Path("main.go"))
assert indexed.language == "go"
assert len(indexed.symbols) == 1
def test_simple_parser_unknown_language(self):
parser = SimpleRegexParser("zig")
indexed = parser.parse("fn main() void {}", Path("main.zig"))
assert indexed.language == "zig"
# Uses generic parser
assert indexed.chunks == []
def test_indexed_file_path_resolved(self):
parser = SimpleRegexParser("python")
indexed = parser.parse("def test(): pass", Path("./test.py"))
# Path should be resolved to absolute
assert Path(indexed.path).is_absolute()
class TestParserFactory:
"""Tests for ParserFactory."""
def test_factory_creates_parser(self):
with tempfile.TemporaryDirectory() as tmpdir:
import os
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
factory = ParserFactory(config)
parser = factory.get_parser("python")
assert parser is not None
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_factory_caches_parsers(self):
with tempfile.TemporaryDirectory() as tmpdir:
import os
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
factory = ParserFactory(config)
parser1 = factory.get_parser("python")
parser2 = factory.get_parser("python")
assert parser1 is parser2
finally:
del os.environ["CODEXLENS_DATA_DIR"]
def test_factory_different_languages(self):
with tempfile.TemporaryDirectory() as tmpdir:
import os
os.environ["CODEXLENS_DATA_DIR"] = tmpdir
try:
config = Config()
factory = ParserFactory(config)
py_parser = factory.get_parser("python")
js_parser = factory.get_parser("javascript")
assert py_parser is not js_parser
finally:
del os.environ["CODEXLENS_DATA_DIR"]
class TestParserEdgeCases:
"""Edge case tests for parsers."""
def test_empty_code(self):
symbols = _parse_python_symbols("")
assert len(symbols) == 0
def test_only_comments(self):
code = "# This is a comment\n# Another comment"
symbols = _parse_python_symbols(code)
assert len(symbols) == 0
def test_nested_functions(self):
code = """def outer():
def inner():
pass
return inner
"""
symbols = _parse_python_symbols(code)
names = [s.name for s in symbols]
assert "outer" in names
assert "inner" in names
def test_unicode_function_name(self):
code = "def 你好():\n pass"
symbols = _parse_python_symbols(code)
# Regex may not support unicode function names, tree-sitter does
# So we just verify it doesn't crash
assert isinstance(symbols, list)
def test_long_file(self):
# Generate a file with many functions
lines = []
for i in range(100):
lines.append(f"def func_{i}():\n pass\n")
code = "\n".join(lines)
symbols = _parse_python_symbols(code)
assert len(symbols) == 100
def test_malformed_code(self):
# Parser should handle malformed code gracefully
code = "def broken(\n pass"
# Should not crash
symbols = _parse_python_symbols(code)
# May or may not find symbols depending on regex

View File

@@ -0,0 +1,290 @@
"""Tests for CodexLens semantic module."""
import pytest
from codexlens.entities import SemanticChunk, Symbol
from codexlens.semantic.chunker import ChunkConfig, Chunker
class TestChunkConfig:
"""Tests for ChunkConfig."""
def test_default_config(self):
"""Test default configuration values."""
config = ChunkConfig()
assert config.max_chunk_size == 1000
assert config.overlap == 100
assert config.min_chunk_size == 50
def test_custom_config(self):
"""Test custom configuration."""
config = ChunkConfig(max_chunk_size=2000, overlap=200, min_chunk_size=100)
assert config.max_chunk_size == 2000
assert config.overlap == 200
assert config.min_chunk_size == 100
class TestChunker:
"""Tests for Chunker class."""
def test_chunker_default_config(self):
"""Test chunker with default config."""
chunker = Chunker()
assert chunker.config.max_chunk_size == 1000
def test_chunker_custom_config(self):
"""Test chunker with custom config."""
config = ChunkConfig(max_chunk_size=500)
chunker = Chunker(config=config)
assert chunker.config.max_chunk_size == 500
class TestChunkBySymbol:
"""Tests for symbol-based chunking."""
def test_chunk_single_function(self):
"""Test chunking a single function."""
# Use config with smaller min_chunk_size
config = ChunkConfig(min_chunk_size=10)
chunker = Chunker(config=config)
content = "def hello():\n print('hello')\n return True\n"
symbols = [Symbol(name="hello", kind="function", range=(1, 3))]
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
assert len(chunks) == 1
assert "def hello():" in chunks[0].content
assert chunks[0].metadata["symbol_name"] == "hello"
assert chunks[0].metadata["symbol_kind"] == "function"
assert chunks[0].metadata["file"] == "test.py"
assert chunks[0].metadata["language"] == "python"
assert chunks[0].metadata["strategy"] == "symbol"
def test_chunk_multiple_symbols(self):
"""Test chunking multiple symbols."""
# Use config with smaller min_chunk_size
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
content = """def foo():
pass
def bar():
pass
class MyClass:
pass
"""
symbols = [
Symbol(name="foo", kind="function", range=(1, 2)),
Symbol(name="bar", kind="function", range=(4, 5)),
Symbol(name="MyClass", kind="class", range=(7, 8)),
]
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
assert len(chunks) == 3
names = [c.metadata["symbol_name"] for c in chunks]
assert "foo" in names
assert "bar" in names
assert "MyClass" in names
def test_chunk_skips_small_content(self):
"""Test that chunks smaller than min_chunk_size are skipped."""
config = ChunkConfig(min_chunk_size=100)
chunker = Chunker(config=config)
content = "def x():\n pass\n"
symbols = [Symbol(name="x", kind="function", range=(1, 2))]
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
assert len(chunks) == 0 # Content is too small
def test_chunk_preserves_line_numbers(self):
"""Test that chunks preserve correct line numbers."""
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
content = "# comment\ndef hello():\n pass\n"
symbols = [Symbol(name="hello", kind="function", range=(2, 3))]
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
assert len(chunks) == 1
assert chunks[0].metadata["start_line"] == 2
assert chunks[0].metadata["end_line"] == 3
def test_chunk_handles_empty_symbols(self):
"""Test chunking with empty symbols list."""
chunker = Chunker()
content = "# just a comment"
symbols = []
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
assert len(chunks) == 0
class TestChunkSlidingWindow:
"""Tests for sliding window chunking."""
def test_sliding_window_basic(self):
"""Test basic sliding window chunking."""
config = ChunkConfig(max_chunk_size=100, overlap=20, min_chunk_size=10)
chunker = Chunker(config=config)
# Create content with multiple lines
lines = [f"line {i} content here\n" for i in range(20)]
content = "".join(lines)
chunks = chunker.chunk_sliding_window(content, "test.py", "python")
assert len(chunks) > 0
for chunk in chunks:
assert chunk.metadata["strategy"] == "sliding_window"
assert chunk.metadata["file"] == "test.py"
assert chunk.metadata["language"] == "python"
def test_sliding_window_empty_content(self):
"""Test sliding window with empty content."""
chunker = Chunker()
chunks = chunker.chunk_sliding_window("", "test.py", "python")
assert len(chunks) == 0
def test_sliding_window_small_content(self):
"""Test sliding window with content smaller than chunk size."""
config = ChunkConfig(max_chunk_size=1000, min_chunk_size=10)
chunker = Chunker(config=config)
content = "small content here"
chunks = chunker.chunk_sliding_window(content, "test.py", "python")
# Small content should produce one chunk
assert len(chunks) <= 1
def test_sliding_window_chunk_indices(self):
"""Test that chunk indices are sequential."""
config = ChunkConfig(max_chunk_size=50, overlap=10, min_chunk_size=5)
chunker = Chunker(config=config)
lines = [f"line {i}\n" for i in range(50)]
content = "".join(lines)
chunks = chunker.chunk_sliding_window(content, "test.py", "python")
if len(chunks) > 1:
indices = [c.metadata["chunk_index"] for c in chunks]
assert indices == list(range(len(chunks)))
class TestChunkFile:
"""Tests for chunk_file method."""
def test_chunk_file_with_symbols(self):
"""Test chunk_file uses symbol-based chunking when symbols available."""
chunker = Chunker()
content = "def hello():\n print('world')\n return 42\n"
symbols = [Symbol(name="hello", kind="function", range=(1, 3))]
chunks = chunker.chunk_file(content, symbols, "test.py", "python")
assert all(c.metadata["strategy"] == "symbol" for c in chunks)
def test_chunk_file_without_symbols(self):
"""Test chunk_file uses sliding window when no symbols."""
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
content = "# just comments\n# more comments\n# even more\n"
chunks = chunker.chunk_file(content, [], "test.py", "python")
# Should use sliding window strategy
if len(chunks) > 0:
assert all(c.metadata["strategy"] == "sliding_window" for c in chunks)
class TestChunkMetadata:
"""Tests for chunk metadata."""
def test_symbol_chunk_metadata_complete(self):
"""Test that symbol chunks have complete metadata."""
config = ChunkConfig(min_chunk_size=10)
chunker = Chunker(config=config)
content = "class MyClass:\n def method(self):\n pass\n"
symbols = [Symbol(name="MyClass", kind="class", range=(1, 3))]
chunks = chunker.chunk_by_symbol(content, symbols, "/path/to/file.py", "python")
assert len(chunks) == 1
meta = chunks[0].metadata
assert meta["file"] == "/path/to/file.py"
assert meta["language"] == "python"
assert meta["symbol_name"] == "MyClass"
assert meta["symbol_kind"] == "class"
assert meta["start_line"] == 1
assert meta["end_line"] == 3
assert meta["strategy"] == "symbol"
def test_sliding_window_metadata_complete(self):
"""Test that sliding window chunks have complete metadata."""
config = ChunkConfig(min_chunk_size=5)
chunker = Chunker(config=config)
content = "some content here\nmore content\n"
chunks = chunker.chunk_sliding_window(content, "/path/file.js", "javascript")
if len(chunks) > 0:
meta = chunks[0].metadata
assert meta["file"] == "/path/file.js"
assert meta["language"] == "javascript"
assert "chunk_index" in meta
assert "start_line" in meta
assert "end_line" in meta
assert meta["strategy"] == "sliding_window"
class TestChunkEdgeCases:
"""Edge case tests for chunking."""
def test_chunk_with_unicode(self):
"""Test chunking content with unicode characters."""
config = ChunkConfig(min_chunk_size=10)
chunker = Chunker(config=config)
content = "def 你好():\n print('世界')\n return '🎉'\n"
symbols = [Symbol(name="你好", kind="function", range=(1, 3))]
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
assert len(chunks) == 1
assert "你好" in chunks[0].content
def test_chunk_with_windows_line_endings(self):
"""Test chunking with Windows-style line endings."""
chunker = Chunker()
content = "def hello():\r\n pass\r\n"
symbols = [Symbol(name="hello", kind="function", range=(1, 2))]
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
# Should handle without errors
assert len(chunks) <= 1
def test_chunk_range_out_of_bounds(self):
"""Test chunking when symbol range exceeds content."""
chunker = Chunker()
content = "def hello():\n pass\n"
# Symbol range goes beyond content
symbols = [Symbol(name="hello", kind="function", range=(1, 100))]
# Should not crash, just handle gracefully
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
assert len(chunks) <= 1
def test_chunk_content_returned_as_semantic_chunk(self):
"""Test that returned chunks are SemanticChunk instances."""
chunker = Chunker()
content = "def test():\n return True\n"
symbols = [Symbol(name="test", kind="function", range=(1, 2))]
chunks = chunker.chunk_by_symbol(content, symbols, "test.py", "python")
for chunk in chunks:
assert isinstance(chunk, SemanticChunk)
assert chunk.embedding is None # Not embedded yet

View File

@@ -0,0 +1,804 @@
"""Comprehensive tests for semantic search functionality.
Tests embedding generation, vector storage, and semantic similarity search
across complex codebases with various file types and content patterns.
"""
import json
import os
import shutil
import tempfile
import time
from pathlib import Path
from typing import List, Dict, Any
import pytest
from codexlens.entities import SemanticChunk, Symbol
from codexlens.semantic import SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, check_semantic_available
# Skip all tests if semantic search not available
pytestmark = pytest.mark.skipif(
not SEMANTIC_AVAILABLE,
reason="Semantic search dependencies not installed"
)
class TestEmbedderPerformance:
"""Test Embedder performance and quality."""
@pytest.fixture
def embedder(self):
"""Create embedder instance."""
from codexlens.semantic.embedder import Embedder
return Embedder()
def test_single_embedding(self, embedder):
"""Test single text embedding."""
text = "def calculate_sum(a, b): return a + b"
start = time.time()
embedding = embedder.embed_single(text)
elapsed = time.time() - start
assert len(embedding) == 384, "Embedding dimension should be 384"
assert all(isinstance(x, float) for x in embedding)
print(f"\nSingle embedding time: {elapsed*1000:.2f}ms")
def test_batch_embedding_performance(self, embedder):
"""Test batch embedding performance."""
texts = [
"def hello(): print('world')",
"class Calculator: def add(self, a, b): return a + b",
"async def fetch_data(url): return await client.get(url)",
"const processData = (data) => data.map(x => x * 2)",
"function initializeApp() { console.log('Starting...'); }",
] * 10 # 50 texts total
start = time.time()
embeddings = embedder.embed(texts)
elapsed = time.time() - start
assert len(embeddings) == len(texts)
print(f"\nBatch embedding ({len(texts)} texts): {elapsed*1000:.2f}ms")
print(f"Per-text average: {elapsed/len(texts)*1000:.2f}ms")
def test_embedding_similarity(self, embedder):
"""Test that similar code has similar embeddings."""
from codexlens.semantic.vector_store import _cosine_similarity
# Similar functions (should have high similarity)
code1 = "def add(a, b): return a + b"
code2 = "def sum_numbers(x, y): return x + y"
# Different function (should have lower similarity)
code3 = "class UserAuthentication: def login(self, user, password): pass"
emb1 = embedder.embed_single(code1)
emb2 = embedder.embed_single(code2)
emb3 = embedder.embed_single(code3)
sim_12 = _cosine_similarity(emb1, emb2)
sim_13 = _cosine_similarity(emb1, emb3)
print(f"\nSimilarity (add vs sum_numbers): {sim_12:.4f}")
print(f"Similarity (add vs login): {sim_13:.4f}")
assert sim_12 > sim_13, "Similar code should have higher similarity"
assert sim_12 > 0.6, "Similar functions should have >0.6 similarity"
class TestVectorStore:
"""Test VectorStore functionality."""
@pytest.fixture
def temp_db(self, tmp_path):
"""Create temporary database."""
return tmp_path / "semantic.db"
@pytest.fixture
def vector_store(self, temp_db):
"""Create vector store instance."""
from codexlens.semantic.vector_store import VectorStore
return VectorStore(temp_db)
@pytest.fixture
def embedder(self):
"""Create embedder instance."""
from codexlens.semantic.embedder import Embedder
return Embedder()
def test_add_and_search_chunks(self, vector_store, embedder):
"""Test adding chunks and searching."""
# Create test chunks with embeddings
chunks = [
SemanticChunk(
content="def calculate_sum(a, b): return a + b",
metadata={"symbol": "calculate_sum", "language": "python"}
),
SemanticChunk(
content="class UserManager: def create_user(self): pass",
metadata={"symbol": "UserManager", "language": "python"}
),
SemanticChunk(
content="async function fetchData(url) { return await fetch(url); }",
metadata={"symbol": "fetchData", "language": "javascript"}
),
]
# Add embeddings
for chunk in chunks:
chunk.embedding = embedder.embed_single(chunk.content)
vector_store.add_chunk(chunk, "/test/file.py")
# Search for similar code
query = "function to add two numbers together"
query_embedding = embedder.embed_single(query)
results = vector_store.search_similar(query_embedding, top_k=3)
assert len(results) > 0, "Should find results"
assert "calculate_sum" in results[0].excerpt or "sum" in results[0].excerpt.lower()
print(f"\nQuery: '{query}'")
for i, r in enumerate(results):
print(f" {i+1}. Score: {r.score:.4f} - {r.excerpt[:50]}...")
def test_min_score_filtering(self, vector_store, embedder):
"""Test minimum score filtering."""
# Add a chunk
chunk = SemanticChunk(
content="def hello_world(): print('Hello, World!')",
metadata={}
)
chunk.embedding = embedder.embed_single(chunk.content)
vector_store.add_chunk(chunk, "/test/hello.py")
# Search with unrelated query
query = "database connection pool management"
query_embedding = embedder.embed_single(query)
# Low threshold - should find result
results_low = vector_store.search_similar(query_embedding, min_score=0.0)
# High threshold - might filter out
results_high = vector_store.search_similar(query_embedding, min_score=0.8)
print(f"\nResults with min_score=0.0: {len(results_low)}")
print(f"Results with min_score=0.8: {len(results_high)}")
assert len(results_low) >= len(results_high)
class TestSemanticSearchIntegration:
"""Integration tests for semantic search on real-like codebases."""
@pytest.fixture
def complex_codebase(self, tmp_path):
"""Create a complex test codebase."""
# Python files
(tmp_path / "src").mkdir()
(tmp_path / "src" / "auth.py").write_text('''
"""Authentication module."""
class AuthenticationService:
"""Handle user authentication and authorization."""
def __init__(self, secret_key: str):
self.secret_key = secret_key
self.token_expiry = 3600
def login(self, username: str, password: str) -> dict:
"""Authenticate user and return JWT token."""
user = self._validate_credentials(username, password)
if user:
return self._generate_token(user)
raise AuthError("Invalid credentials")
def logout(self, token: str) -> bool:
"""Invalidate user session."""
return self._revoke_token(token)
def verify_token(self, token: str) -> dict:
"""Verify JWT token and return user claims."""
pass
def hash_password(password: str) -> str:
"""Hash password using bcrypt."""
import hashlib
return hashlib.sha256(password.encode()).hexdigest()
''')
(tmp_path / "src" / "database.py").write_text('''
"""Database connection and ORM."""
from typing import List, Optional
class DatabaseConnection:
"""Manage database connections with pooling."""
def __init__(self, connection_string: str, pool_size: int = 5):
self.connection_string = connection_string
self.pool_size = pool_size
self._pool = []
def connect(self) -> "Connection":
"""Get connection from pool."""
if self._pool:
return self._pool.pop()
return self._create_connection()
def release(self, conn: "Connection"):
"""Return connection to pool."""
if len(self._pool) < self.pool_size:
self._pool.append(conn)
class QueryBuilder:
"""SQL query builder with fluent interface."""
def select(self, *columns) -> "QueryBuilder":
pass
def where(self, condition: str) -> "QueryBuilder":
pass
def execute(self) -> List[dict]:
pass
''')
(tmp_path / "src" / "api.py").write_text('''
"""REST API endpoints."""
from typing import List, Dict, Any
class APIRouter:
"""Route HTTP requests to handlers."""
def __init__(self):
self.routes = {}
def get(self, path: str):
"""Register GET endpoint."""
def decorator(func):
self.routes[("GET", path)] = func
return func
return decorator
def post(self, path: str):
"""Register POST endpoint."""
def decorator(func):
self.routes[("POST", path)] = func
return func
return decorator
async def handle_request(method: str, path: str, body: Dict) -> Dict:
"""Process incoming HTTP request."""
pass
def validate_json_schema(data: Dict, schema: Dict) -> bool:
"""Validate request data against JSON schema."""
pass
''')
# JavaScript files
(tmp_path / "frontend").mkdir()
(tmp_path / "frontend" / "components.js").write_text('''
/**
* React UI Components
*/
class UserProfile extends Component {
constructor(props) {
super(props);
this.state = { user: null, loading: true };
}
async componentDidMount() {
const user = await fetchUserData(this.props.userId);
this.setState({ user, loading: false });
}
render() {
if (this.state.loading) return <Spinner />;
return <ProfileCard user={this.state.user} />;
}
}
function Button({ onClick, children, variant = "primary" }) {
return (
<button className={`btn btn-${variant}`} onClick={onClick}>
{children}
</button>
);
}
const FormInput = ({ label, value, onChange, type = "text" }) => {
return (
<div className="form-group">
<label>{label}</label>
<input type={type} value={value} onChange={onChange} />
</div>
);
};
''')
(tmp_path / "frontend" / "api.js").write_text('''
/**
* API Client for backend communication
*/
const API_BASE = "/api/v1";
async function fetchUserData(userId) {
const response = await fetch(`${API_BASE}/users/${userId}`);
if (!response.ok) throw new Error("Failed to fetch user");
return response.json();
}
async function createUser(userData) {
const response = await fetch(`${API_BASE}/users`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(userData)
});
return response.json();
}
async function updateUserProfile(userId, updates) {
const response = await fetch(`${API_BASE}/users/${userId}`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(updates)
});
return response.json();
}
class WebSocketClient {
constructor(url) {
this.url = url;
this.ws = null;
this.handlers = {};
}
connect() {
this.ws = new WebSocket(this.url);
this.ws.onmessage = (event) => this._handleMessage(event);
}
on(eventType, handler) {
this.handlers[eventType] = handler;
}
}
''')
return tmp_path
@pytest.fixture
def indexed_codebase(self, complex_codebase, tmp_path):
"""Index the complex codebase with semantic embeddings."""
from codexlens.semantic.embedder import Embedder
from codexlens.semantic.vector_store import VectorStore
from codexlens.semantic.chunker import Chunker, ChunkConfig
from codexlens.parsers.factory import ParserFactory
from codexlens.config import Config
db_path = tmp_path / "semantic.db"
vector_store = VectorStore(db_path)
embedder = Embedder()
config = Config()
factory = ParserFactory(config)
chunker = Chunker(ChunkConfig(min_chunk_size=20, max_chunk_size=500))
# Index all source files
indexed_files = []
for ext in ["*.py", "*.js"]:
for file_path in complex_codebase.rglob(ext):
content = file_path.read_text()
language = "python" if file_path.suffix == ".py" else "javascript"
# Parse symbols
parser = factory.get_parser(language)
indexed_file = parser.parse(content, file_path)
# Create chunks
chunks = chunker.chunk_file(
content,
indexed_file.symbols,
str(file_path),
language
)
# Add embeddings and store
for chunk in chunks:
chunk.embedding = embedder.embed_single(chunk.content)
vector_store.add_chunk(chunk, str(file_path))
indexed_files.append(str(file_path))
return {
"vector_store": vector_store,
"embedder": embedder,
"files": indexed_files,
"codebase_path": complex_codebase
}
def test_semantic_search_accuracy(self, indexed_codebase):
"""Test semantic search accuracy on complex queries."""
vector_store = indexed_codebase["vector_store"]
embedder = indexed_codebase["embedder"]
test_queries = [
{
"query": "user authentication login function",
"expected_contains": ["login", "auth", "credential"],
"expected_not_contains": ["database", "button"]
},
{
"query": "database connection pooling",
"expected_contains": ["connect", "pool", "database"],
"expected_not_contains": ["login", "button"]
},
{
"query": "React component for user profile",
"expected_contains": ["UserProfile", "component", "render"],
"expected_not_contains": ["database", "auth"]
},
{
"query": "HTTP API endpoint handler",
"expected_contains": ["API", "request", "handle"],
"expected_not_contains": ["UserProfile", "button"]
},
{
"query": "form input UI element",
"expected_contains": ["input", "form", "label"],
"expected_not_contains": ["database", "auth"]
}
]
print("\n" + "="*60)
print("SEMANTIC SEARCH ACCURACY TEST")
print("="*60)
for test in test_queries:
query = test["query"]
query_embedding = embedder.embed_single(query)
results = vector_store.search_similar(query_embedding, top_k=5, min_score=0.3)
print(f"\nQuery: '{query}'")
print("-" * 40)
# Check results
all_excerpts = " ".join([r.excerpt.lower() for r in results])
found_expected = []
for expected in test["expected_contains"]:
if expected.lower() in all_excerpts:
found_expected.append(expected)
found_unexpected = []
for unexpected in test["expected_not_contains"]:
if unexpected.lower() in all_excerpts:
found_unexpected.append(unexpected)
for i, r in enumerate(results[:3]):
print(f" {i+1}. Score: {r.score:.4f}")
print(f" File: {Path(r.path).name}")
print(f" Excerpt: {r.excerpt[:80]}...")
print(f"\n [OK] Found expected: {found_expected}")
if found_unexpected:
print(f" [WARN] Found unexpected: {found_unexpected}")
def test_search_performance(self, indexed_codebase):
"""Test search performance with various parameters."""
vector_store = indexed_codebase["vector_store"]
embedder = indexed_codebase["embedder"]
query = "function to handle user data"
query_embedding = embedder.embed_single(query)
print("\n" + "="*60)
print("SEARCH PERFORMANCE TEST")
print("="*60)
# Test different top_k values
for top_k in [5, 10, 20, 50]:
start = time.time()
results = vector_store.search_similar(query_embedding, top_k=top_k)
elapsed = time.time() - start
print(f"top_k={top_k}: {elapsed*1000:.2f}ms ({len(results)} results)")
# Test different min_score values
print("\nMin score filtering:")
for min_score in [0.0, 0.3, 0.5, 0.7]:
start = time.time()
results = vector_store.search_similar(query_embedding, top_k=50, min_score=min_score)
elapsed = time.time() - start
print(f"min_score={min_score}: {elapsed*1000:.2f}ms ({len(results)} results)")
class TestChunkerOptimization:
"""Test chunker parameters for optimal semantic search."""
@pytest.fixture
def sample_code(self):
"""Long Python file for chunking tests."""
return '''
"""Large module with multiple classes and functions."""
import os
import sys
from typing import List, Dict, Any, Optional
# Constants
MAX_RETRIES = 3
DEFAULT_TIMEOUT = 30
class ConfigManager:
"""Manage application configuration."""
def __init__(self, config_path: str):
self.config_path = config_path
self._config: Dict[str, Any] = {}
def load(self) -> Dict[str, Any]:
"""Load configuration from file."""
with open(self.config_path) as f:
self._config = json.load(f)
return self._config
def get(self, key: str, default: Any = None) -> Any:
"""Get configuration value."""
return self._config.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set configuration value."""
self._config[key] = value
class DataProcessor:
"""Process and transform data."""
def __init__(self, source: str):
self.source = source
self.data: List[Dict] = []
def load_data(self) -> List[Dict]:
"""Load data from source."""
# Implementation here
pass
def transform(self, transformers: List[callable]) -> List[Dict]:
"""Apply transformations to data."""
result = self.data
for transformer in transformers:
result = [transformer(item) for item in result]
return result
def filter(self, predicate: callable) -> List[Dict]:
"""Filter data by predicate."""
return [item for item in self.data if predicate(item)]
def aggregate(self, key: str, aggregator: callable) -> Dict:
"""Aggregate data by key."""
groups: Dict[str, List] = {}
for item in self.data:
k = item.get(key)
if k not in groups:
groups[k] = []
groups[k].append(item)
return {k: aggregator(v) for k, v in groups.items()}
def validate_input(data: Dict, schema: Dict) -> bool:
"""Validate input data against schema."""
for field, rules in schema.items():
if rules.get("required") and field not in data:
return False
if field in data:
value = data[field]
if "type" in rules and not isinstance(value, rules["type"]):
return False
return True
def format_output(data: Any, format_type: str = "json") -> str:
"""Format output data."""
if format_type == "json":
return json.dumps(data, indent=2)
elif format_type == "csv":
# CSV formatting
pass
return str(data)
async def fetch_remote_data(url: str, timeout: int = DEFAULT_TIMEOUT) -> Dict:
"""Fetch data from remote URL."""
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=timeout) as response:
return await response.json()
class CacheManager:
"""Manage caching with TTL support."""
def __init__(self, default_ttl: int = 300):
self.default_ttl = default_ttl
self._cache: Dict[str, tuple] = {}
def get(self, key: str) -> Optional[Any]:
"""Get cached value if not expired."""
if key in self._cache:
value, expiry = self._cache[key]
if time.time() < expiry:
return value
del self._cache[key]
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""Set cached value with TTL."""
expiry = time.time() + (ttl or self.default_ttl)
self._cache[key] = (value, expiry)
def invalidate(self, pattern: str) -> int:
"""Invalidate cache entries matching pattern."""
keys_to_delete = [k for k in self._cache if pattern in k]
for k in keys_to_delete:
del self._cache[k]
return len(keys_to_delete)
'''
def test_chunk_size_comparison(self, sample_code):
"""Compare different chunk sizes for search quality."""
from codexlens.semantic.chunker import Chunker, ChunkConfig
from codexlens.semantic.embedder import Embedder
from codexlens.semantic.vector_store import _cosine_similarity
from codexlens.parsers.factory import ParserFactory
from codexlens.config import Config
config = Config()
factory = ParserFactory(config)
parser = factory.get_parser("python")
indexed_file = parser.parse(sample_code, Path("/test.py"))
embedder = Embedder()
print("\n" + "="*60)
print("CHUNK SIZE OPTIMIZATION TEST")
print("="*60)
# Test different chunk configurations
configs = [
ChunkConfig(min_chunk_size=20, max_chunk_size=200, overlap=20),
ChunkConfig(min_chunk_size=50, max_chunk_size=500, overlap=50),
ChunkConfig(min_chunk_size=100, max_chunk_size=1000, overlap=100),
]
test_query = "cache management with TTL expiration"
query_embedding = embedder.embed_single(test_query)
for cfg in configs:
chunker = Chunker(cfg)
chunks = chunker.chunk_file(
sample_code,
indexed_file.symbols,
"/test.py",
"python"
)
print(f"\nConfig: min={cfg.min_chunk_size}, max={cfg.max_chunk_size}, overlap={cfg.overlap}")
print(f" Chunks generated: {len(chunks)}")
if chunks:
# Find best matching chunk
best_score = 0
best_chunk = None
for chunk in chunks:
chunk.embedding = embedder.embed_single(chunk.content)
score = _cosine_similarity(query_embedding, chunk.embedding)
if score > best_score:
best_score = score
best_chunk = chunk
if best_chunk:
print(f" Best match score: {best_score:.4f}")
print(f" Best chunk preview: {best_chunk.content[:100]}...")
def test_symbol_vs_sliding_window(self, sample_code):
"""Compare symbol-based vs sliding window chunking."""
from codexlens.semantic.chunker import Chunker, ChunkConfig
from codexlens.parsers.factory import ParserFactory
from codexlens.config import Config
config = Config()
factory = ParserFactory(config)
parser = factory.get_parser("python")
indexed_file = parser.parse(sample_code, Path("/test.py"))
chunker = Chunker(ChunkConfig(min_chunk_size=20))
print("\n" + "="*60)
print("CHUNKING STRATEGY COMPARISON")
print("="*60)
# Symbol-based chunking
symbol_chunks = chunker.chunk_by_symbol(
sample_code,
indexed_file.symbols,
"/test.py",
"python"
)
# Sliding window chunking
window_chunks = chunker.chunk_sliding_window(
sample_code,
"/test.py",
"python"
)
print(f"\nSymbol-based chunks: {len(symbol_chunks)}")
for i, chunk in enumerate(symbol_chunks[:5]):
symbol_name = chunk.metadata.get("symbol_name", "unknown")
print(f" {i+1}. {symbol_name}: {len(chunk.content)} chars")
print(f"\nSliding window chunks: {len(window_chunks)}")
for i, chunk in enumerate(window_chunks[:5]):
lines = f"{chunk.metadata.get('start_line', '?')}-{chunk.metadata.get('end_line', '?')}"
print(f" {i+1}. Lines {lines}: {len(chunk.content)} chars")
class TestRealWorldScenarios:
"""Test real-world semantic search scenarios."""
@pytest.fixture
def embedder(self):
from codexlens.semantic.embedder import Embedder
return Embedder()
def test_natural_language_queries(self, embedder):
"""Test various natural language query patterns."""
from codexlens.semantic.vector_store import _cosine_similarity
code_samples = {
"auth": "def authenticate_user(username, password): verify credentials and create session",
"db": "class DatabasePool: manage connection pooling for efficient database access",
"api": "async def handle_http_request(req): process incoming REST API calls",
"ui": "function Button({ onClick }) { return <button onClick={onClick}>Click</button> }",
"cache": "class LRUCache: implements least recently used caching strategy with TTL",
}
# Generate embeddings for code
code_embeddings = {k: embedder.embed_single(v) for k, v in code_samples.items()}
# Test queries
queries = [
("How do I log in a user?", "auth"),
("Database connection management", "db"),
("REST endpoint handler", "api"),
("Button component React", "ui"),
("Caching with expiration", "cache"),
]
print("\n" + "="*60)
print("NATURAL LANGUAGE QUERY TEST")
print("="*60)
correct = 0
for query, expected_best in queries:
query_embedding = embedder.embed_single(query)
scores = {k: _cosine_similarity(query_embedding, v)
for k, v in code_embeddings.items()}
best_match = max(scores.items(), key=lambda x: x[1])
is_correct = best_match[0] == expected_best
correct += is_correct
status = "[OK]" if is_correct else "[FAIL]"
print(f"\n{status} Query: '{query}'")
print(f" Expected: {expected_best}, Got: {best_match[0]} (score: {best_match[1]:.4f})")
accuracy = correct / len(queries) * 100
print(f"\n\nAccuracy: {accuracy:.1f}% ({correct}/{len(queries)})")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -1,12 +1,14 @@
"""Tests for CodexLens storage."""
import sqlite3
import threading
import pytest
import tempfile
from pathlib import Path
from codexlens.storage.sqlite_store import SQLiteStore
from codexlens.entities import IndexedFile, Symbol
from codexlens.errors import StorageError
@pytest.fixture
@@ -20,6 +22,13 @@ def temp_db():
store.close()
@pytest.fixture
def temp_db_path():
"""Create a temporary directory and return db path."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir) / "test.db"
class TestSQLiteStore:
"""Tests for SQLiteStore."""
@@ -158,3 +167,368 @@ class TestSQLiteStore:
assert "content='files'" in row["sql"] or "content=files" in row["sql"]
finally:
store.close()
class TestSQLiteStoreAddFiles:
"""Tests for add_files batch operation."""
def test_add_files_batch(self, temp_db):
"""Test adding multiple files in a batch."""
files_data = [
(IndexedFile(
path="/test/a.py",
language="python",
symbols=[Symbol(name="func_a", kind="function", range=(1, 1))],
), "def func_a(): pass"),
(IndexedFile(
path="/test/b.py",
language="python",
symbols=[Symbol(name="func_b", kind="function", range=(1, 1))],
), "def func_b(): pass"),
(IndexedFile(
path="/test/c.py",
language="python",
symbols=[Symbol(name="func_c", kind="function", range=(1, 1))],
), "def func_c(): pass"),
]
temp_db.add_files(files_data)
stats = temp_db.stats()
assert stats["files"] == 3
assert stats["symbols"] == 3
def test_add_files_empty_list(self, temp_db):
"""Test adding empty list of files."""
temp_db.add_files([])
stats = temp_db.stats()
assert stats["files"] == 0
class TestSQLiteStoreSearch:
"""Tests for search operations."""
def test_search_fts_with_limit(self, temp_db):
"""Test FTS search with limit."""
for i in range(10):
indexed_file = IndexedFile(
path=f"/test/file{i}.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, f"def test{i}(): pass")
results = temp_db.search_fts("test", limit=3)
assert len(results) <= 3
def test_search_fts_with_offset(self, temp_db):
"""Test FTS search with offset."""
for i in range(10):
indexed_file = IndexedFile(
path=f"/test/file{i}.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, f"searchterm content {i}")
results_page1 = temp_db.search_fts("searchterm", limit=3, offset=0)
results_page2 = temp_db.search_fts("searchterm", limit=3, offset=3)
# Pages should be different
paths1 = {r.path for r in results_page1}
paths2 = {r.path for r in results_page2}
assert paths1.isdisjoint(paths2)
def test_search_fts_no_results(self, temp_db):
"""Test FTS search with no results."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, "def hello(): pass")
results = temp_db.search_fts("nonexistent")
assert len(results) == 0
def test_search_symbols_by_kind(self, temp_db):
"""Test symbol search filtered by kind."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[
Symbol(name="MyClass", kind="class", range=(1, 5)),
Symbol(name="my_func", kind="function", range=(7, 10)),
Symbol(name="my_method", kind="method", range=(2, 4)),
],
)
temp_db.add_file(indexed_file, "class MyClass:\n def my_method(): pass\ndef my_func(): pass")
# Search for functions only
results = temp_db.search_symbols("my", kind="function")
assert len(results) == 1
assert results[0].name == "my_func"
def test_search_symbols_with_limit(self, temp_db):
"""Test symbol search with limit."""
# Range starts from 1, not 0
symbols = [Symbol(name=f"func{i}", kind="function", range=(i+1, i+1)) for i in range(20)]
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=symbols,
)
temp_db.add_file(indexed_file, "# lots of functions")
results = temp_db.search_symbols("func", limit=5)
assert len(results) == 5
def test_search_files_only(self, temp_db):
"""Test search_files_only returns only paths."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, "def hello(): pass")
results = temp_db.search_files_only("hello")
assert len(results) == 1
assert isinstance(results[0], str)
class TestSQLiteStoreFileOperations:
"""Tests for file operations."""
def test_file_exists_true(self, temp_db):
"""Test file_exists returns True for existing file."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, "content")
assert temp_db.file_exists("/test/file.py")
def test_file_exists_false(self, temp_db):
"""Test file_exists returns False for non-existing file."""
assert not temp_db.file_exists("/nonexistent/file.py")
def test_remove_nonexistent_file(self, temp_db):
"""Test removing non-existent file returns False."""
result = temp_db.remove_file("/nonexistent/file.py")
assert result is False
def test_get_file_mtime(self, temp_db):
"""Test getting file mtime."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, "content")
# Note: mtime is only set if the file actually exists on disk
mtime = temp_db.get_file_mtime("/test/file.py")
# May be None if file doesn't exist on disk
assert mtime is None or isinstance(mtime, float)
def test_get_file_mtime_nonexistent(self, temp_db):
"""Test getting mtime for non-indexed file."""
mtime = temp_db.get_file_mtime("/nonexistent/file.py")
assert mtime is None
def test_update_existing_file(self, temp_db):
"""Test updating an existing file."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[Symbol(name="old_func", kind="function", range=(1, 1))],
)
temp_db.add_file(indexed_file, "def old_func(): pass")
# Update with new content and symbols
updated_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[Symbol(name="new_func", kind="function", range=(1, 1))],
)
temp_db.add_file(updated_file, "def new_func(): pass")
stats = temp_db.stats()
assert stats["files"] == 1 # Still one file
assert stats["symbols"] == 1 # Old symbols replaced
symbols = temp_db.search_symbols("new_func")
assert len(symbols) == 1
class TestSQLiteStoreStats:
"""Tests for stats operation."""
def test_stats_empty_db(self, temp_db):
"""Test stats on empty database."""
stats = temp_db.stats()
assert stats["files"] == 0
assert stats["symbols"] == 0
assert stats["languages"] == {}
def test_stats_with_data(self, temp_db):
"""Test stats with data."""
files = [
(IndexedFile(path="/test/a.py", language="python", symbols=[
Symbol(name="func1", kind="function", range=(1, 1)),
Symbol(name="func2", kind="function", range=(2, 2)),
]), "content"),
(IndexedFile(path="/test/b.js", language="javascript", symbols=[
Symbol(name="func3", kind="function", range=(1, 1)),
]), "content"),
]
temp_db.add_files(files)
stats = temp_db.stats()
assert stats["files"] == 2
assert stats["symbols"] == 3
assert stats["languages"]["python"] == 1
assert stats["languages"]["javascript"] == 1
assert "db_path" in stats
class TestSQLiteStoreContextManager:
"""Tests for context manager usage."""
def test_context_manager(self, temp_db_path):
"""Test using SQLiteStore as context manager."""
with SQLiteStore(temp_db_path) as store:
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[],
)
store.add_file(indexed_file, "content")
stats = store.stats()
assert stats["files"] == 1
class TestSQLiteStoreThreadSafety:
"""Tests for thread safety."""
def test_multiple_threads_read(self, temp_db):
"""Test reading from multiple threads."""
# Add some data first
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[Symbol(name="test", kind="function", range=(1, 1))],
)
temp_db.add_file(indexed_file, "def test(): pass")
results = []
errors = []
def read_data():
try:
stats = temp_db.stats()
results.append(stats)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=read_data) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
assert len(results) == 5
for stats in results:
assert stats["files"] == 1
class TestSQLiteStoreEdgeCases:
"""Edge case tests for SQLiteStore."""
def test_special_characters_in_path(self, temp_db):
"""Test file path with special characters."""
indexed_file = IndexedFile(
path="/test/file with spaces.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, "content")
assert temp_db.file_exists("/test/file with spaces.py")
def test_unicode_content(self, temp_db):
"""Test file with unicode content."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[Symbol(name="你好", kind="function", range=(1, 1))],
)
temp_db.add_file(indexed_file, "def 你好(): print('世界')")
symbols = temp_db.search_symbols("你好")
assert len(symbols) == 1
def test_very_long_content(self, temp_db):
"""Test file with very long content."""
long_content = "x = 1\n" * 10000
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, long_content)
stats = temp_db.stats()
assert stats["files"] == 1
def test_file_with_no_symbols(self, temp_db):
"""Test file with no symbols."""
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[],
)
temp_db.add_file(indexed_file, "# just a comment")
stats = temp_db.stats()
assert stats["files"] == 1
assert stats["symbols"] == 0
def test_file_with_many_symbols(self, temp_db):
"""Test file with many symbols."""
# Range starts from 1, not 0
symbols = [Symbol(name=f"func_{i}", kind="function", range=(i+1, i+1)) for i in range(100)]
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=symbols,
)
temp_db.add_file(indexed_file, "# lots of functions")
stats = temp_db.stats()
assert stats["symbols"] == 100
def test_close_and_reopen(self, temp_db_path):
"""Test closing and reopening database."""
# First session
store1 = SQLiteStore(temp_db_path)
store1.initialize()
indexed_file = IndexedFile(
path="/test/file.py",
language="python",
symbols=[Symbol(name="test", kind="function", range=(1, 1))],
)
store1.add_file(indexed_file, "def test(): pass")
store1.close()
# Second session
store2 = SQLiteStore(temp_db_path)
store2.initialize()
stats = store2.stats()
assert stats["files"] == 1
assert stats["symbols"] == 1
store2.close()