refactor: 移除图索引功能,修复内存泄露,优化嵌入生成

主要更改:

1. 移除图索引功能 (graph indexing)
   - 删除 graph_analyzer.py 及相关迁移文件
   - 移除 CLI 的 graph 命令和 --enrich 标志
   - 清理 chain_search.py 中的图查询方法 (370行)
   - 删除相关测试文件

2. 修复嵌入生成内存问题
   - 重构 generate_embeddings.py 使用流式批处理
   - 改用 embedding_manager 的内存安全实现
   - 文件从 548 行精简到 259 行 (52.7% 减少)

3. 修复内存泄露
   - chain_search.py: quick_search 使用 with 语句管理 ChainSearchEngine
   - embedding_manager.py: 使用 with 语句管理 VectorStore
   - vector_store.py: 添加暴力搜索内存警告

4. 代码清理
   - 移除 Symbol 模型的 token_count 和 symbol_type 字段
   - 清理相关测试用例

测试: 760 passed, 7 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
catlog22
2025-12-21 16:22:03 +08:00
parent 15d5890861
commit 3e9a309079
19 changed files with 165 additions and 3909 deletions

View File

@@ -1,15 +1,9 @@
#!/usr/bin/env python3
"""Generate vector embeddings for existing CodexLens indexes.
This script processes all files in a CodexLens index database and generates
semantic vector embeddings for code chunks. The embeddings are stored in the
same SQLite database in the 'semantic_chunks' table.
Performance optimizations:
- Parallel file processing using ProcessPoolExecutor
- Batch embedding generation for efficient GPU/CPU utilization
- Batch database writes to minimize I/O overhead
- HNSW index auto-generation for fast similarity search
This script is a CLI wrapper around the memory-efficient streaming implementation
in codexlens.cli.embedding_manager. It uses batch processing to keep memory usage
under 2GB regardless of project size.
Requirements:
pip install codexlens[semantic]
@@ -20,27 +14,21 @@ Usage:
# Generate embeddings for a single index
python generate_embeddings.py /path/to/_index.db
# Generate embeddings with parallel processing
python generate_embeddings.py /path/to/_index.db --workers 4
# Use specific embedding model and batch size
python generate_embeddings.py /path/to/_index.db --model code --batch-size 256
# Use specific embedding model
python generate_embeddings.py /path/to/_index.db --model code
# Generate embeddings for all indexes in a directory
python generate_embeddings.py --scan ~/.codexlens/indexes
# Force regeneration
python generate_embeddings.py /path/to/_index.db --force
"""
import argparse
import logging
import multiprocessing
import os
import sqlite3
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List
# Configure logging
logging.basicConfig(
@@ -50,100 +38,32 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
@dataclass
class FileData:
"""Data for a single file to process."""
full_path: str
content: str
language: str
@dataclass
class ChunkData:
"""Processed chunk data ready for embedding."""
file_path: str
content: str
metadata: dict
# Import the memory-efficient implementation
try:
from codexlens.cli.embedding_manager import (
generate_embeddings,
generate_embeddings_recursive,
)
from codexlens.semantic import SEMANTIC_AVAILABLE
except ImportError as exc:
logger.error(f"Failed to import codexlens: {exc}")
logger.error("Make sure codexlens is installed: pip install codexlens")
SEMANTIC_AVAILABLE = False
def check_dependencies():
"""Check if semantic search dependencies are available."""
try:
from codexlens.semantic import SEMANTIC_AVAILABLE
if not SEMANTIC_AVAILABLE:
logger.error("Semantic search dependencies not available")
logger.error("Install with: pip install codexlens[semantic]")
logger.error("Or: pip install fastembed numpy hnswlib")
return False
return True
except ImportError as exc:
logger.error(f"Failed to import codexlens: {exc}")
logger.error("Make sure codexlens is installed: pip install codexlens")
if not SEMANTIC_AVAILABLE:
logger.error("Semantic search dependencies not available")
logger.error("Install with: pip install codexlens[semantic]")
logger.error("Or: pip install fastembed numpy hnswlib")
return False
return True
def count_files(index_db_path: Path) -> int:
"""Count total files in index."""
try:
with sqlite3.connect(index_db_path) as conn:
cursor = conn.execute("SELECT COUNT(*) FROM files")
return cursor.fetchone()[0]
except Exception as exc:
logger.error(f"Failed to count files: {exc}")
return 0
def check_existing_chunks(index_db_path: Path) -> int:
"""Check if semantic chunks already exist."""
try:
with sqlite3.connect(index_db_path) as conn:
# Check if table exists
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_chunks'"
)
if not cursor.fetchone():
return 0
# Count existing chunks
cursor = conn.execute("SELECT COUNT(*) FROM semantic_chunks")
return cursor.fetchone()[0]
except Exception:
return 0
def process_file_worker(args: Tuple[str, str, str, int]) -> List[ChunkData]:
"""Worker function to process a single file (runs in separate process).
Args:
args: Tuple of (file_path, content, language, chunk_size)
Returns:
List of ChunkData objects
"""
file_path, content, language, chunk_size = args
try:
from codexlens.semantic.chunker import Chunker, ChunkConfig
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
chunks = chunker.chunk_sliding_window(
content,
file_path=file_path,
language=language
)
return [
ChunkData(
file_path=file_path,
content=chunk.content,
metadata=chunk.metadata or {}
)
for chunk in chunks
]
except Exception as exc:
logger.debug(f"Error processing {file_path}: {exc}")
return []
def progress_callback(message: str):
"""Callback function for progress updates."""
logger.info(message)
def generate_embeddings_for_index(
@@ -151,259 +71,63 @@ def generate_embeddings_for_index(
model_profile: str = "code",
force: bool = False,
chunk_size: int = 2000,
workers: int = 0,
batch_size: int = 256,
**kwargs # Ignore unused parameters (workers, batch_size) for backward compatibility
) -> dict:
"""Generate embeddings for all files in an index.
"""Generate embeddings for an index using memory-efficient streaming.
Performance optimizations:
- Parallel file processing (chunking)
- Batch embedding generation
- Batch database writes
- HNSW index auto-generation
This function wraps the streaming implementation from embedding_manager
to maintain CLI compatibility while using the memory-optimized approach.
Args:
index_db_path: Path to _index.db file
model_profile: Model profile to use (fast, code, multilingual, balanced)
force: If True, regenerate even if embeddings exist
chunk_size: Maximum chunk size in characters
workers: Number of parallel workers (0 = auto-detect CPU count)
batch_size: Batch size for embedding generation
**kwargs: Additional parameters (ignored for compatibility)
Returns:
Dictionary with generation statistics
"""
logger.info(f"Processing index: {index_db_path}")
# Check existing chunks
existing_chunks = check_existing_chunks(index_db_path)
if existing_chunks > 0 and not force:
logger.warning(f"Index already has {existing_chunks} chunks")
logger.warning("Use --force to regenerate")
return {
"success": False,
"error": "Embeddings already exist",
"existing_chunks": existing_chunks,
}
# Call the memory-efficient streaming implementation
result = generate_embeddings(
index_path=index_db_path,
model_profile=model_profile,
force=force,
chunk_size=chunk_size,
progress_callback=progress_callback,
)
if force and existing_chunks > 0:
logger.info(f"Force mode: clearing {existing_chunks} existing chunks")
try:
with sqlite3.connect(index_db_path) as conn:
conn.execute("DELETE FROM semantic_chunks")
conn.commit()
# Also remove HNSW index file
hnsw_path = index_db_path.parent / "_vectors.hnsw"
if hnsw_path.exists():
hnsw_path.unlink()
logger.info("Removed existing HNSW index")
except Exception as exc:
logger.error(f"Failed to clear existing data: {exc}")
if not result["success"]:
if "error" in result:
logger.error(result["error"])
return result
# Import dependencies
try:
from codexlens.semantic.embedder import Embedder
from codexlens.semantic.vector_store import VectorStore
from codexlens.entities import SemanticChunk
except ImportError as exc:
return {
"success": False,
"error": f"Import failed: {exc}",
}
# Initialize components
try:
embedder = Embedder(profile=model_profile)
vector_store = VectorStore(index_db_path)
logger.info(f"Using model: {embedder.model_name}")
logger.info(f"Embedding dimension: {embedder.embedding_dim}")
except Exception as exc:
return {
"success": False,
"error": f"Failed to initialize components: {exc}",
}
# Read files from index
try:
with sqlite3.connect(index_db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute("SELECT full_path, content, language FROM files")
files = [
FileData(
full_path=row["full_path"],
content=row["content"],
language=row["language"] or "python"
)
for row in cursor.fetchall()
]
except Exception as exc:
return {
"success": False,
"error": f"Failed to read files: {exc}",
}
logger.info(f"Found {len(files)} files to process")
if len(files) == 0:
return {
"success": False,
"error": "No files found in index",
}
# Determine worker count
if workers <= 0:
workers = min(multiprocessing.cpu_count(), len(files), 8)
logger.info(f"Using {workers} worker(s) for parallel processing")
logger.info(f"Batch size for embeddings: {batch_size}")
start_time = time.time()
# Phase 1: Parallel chunking
logger.info("Phase 1: Chunking files...")
chunk_start = time.time()
all_chunks: List[ChunkData] = []
failed_files = []
# Prepare work items
work_items = [
(f.full_path, f.content, f.language, chunk_size)
for f in files
]
if workers == 1:
# Single-threaded for debugging
for i, item in enumerate(work_items, 1):
try:
chunks = process_file_worker(item)
all_chunks.extend(chunks)
if i % 100 == 0:
logger.info(f"Chunked {i}/{len(files)} files ({len(all_chunks)} chunks)")
except Exception as exc:
failed_files.append((item[0], str(exc)))
else:
# Parallel processing
with ProcessPoolExecutor(max_workers=workers) as executor:
futures = {
executor.submit(process_file_worker, item): item[0]
for item in work_items
}
completed = 0
for future in as_completed(futures):
file_path = futures[future]
completed += 1
try:
chunks = future.result()
all_chunks.extend(chunks)
if completed % 100 == 0:
logger.info(
f"Chunked {completed}/{len(files)} files "
f"({len(all_chunks)} chunks)"
)
except Exception as exc:
failed_files.append((file_path, str(exc)))
chunk_time = time.time() - chunk_start
logger.info(f"Chunking completed in {chunk_time:.1f}s: {len(all_chunks)} chunks")
if not all_chunks:
return {
"success": False,
"error": "No chunks created from files",
"files_processed": len(files) - len(failed_files),
"files_failed": len(failed_files),
}
# Phase 2: Batch embedding generation
logger.info("Phase 2: Generating embeddings...")
embed_start = time.time()
# Extract all content for batch embedding
all_contents = [c.content for c in all_chunks]
# Generate embeddings in batches
all_embeddings = []
for i in range(0, len(all_contents), batch_size):
batch_contents = all_contents[i:i + batch_size]
batch_embeddings = embedder.embed(batch_contents)
all_embeddings.extend(batch_embeddings)
progress = min(i + batch_size, len(all_contents))
if progress % (batch_size * 4) == 0 or progress == len(all_contents):
logger.info(f"Generated embeddings: {progress}/{len(all_contents)}")
embed_time = time.time() - embed_start
logger.info(f"Embedding completed in {embed_time:.1f}s")
# Phase 3: Batch database write
logger.info("Phase 3: Storing chunks...")
store_start = time.time()
# Create SemanticChunk objects with embeddings
semantic_chunks_with_paths = []
for chunk_data, embedding in zip(all_chunks, all_embeddings):
semantic_chunk = SemanticChunk(
content=chunk_data.content,
metadata=chunk_data.metadata,
)
semantic_chunk.embedding = embedding
semantic_chunks_with_paths.append((semantic_chunk, chunk_data.file_path))
# Batch write (handles both SQLite and HNSW)
write_batch_size = 1000
total_stored = 0
for i in range(0, len(semantic_chunks_with_paths), write_batch_size):
batch = semantic_chunks_with_paths[i:i + write_batch_size]
vector_store.add_chunks_batch(batch)
total_stored += len(batch)
if total_stored % 5000 == 0 or total_stored == len(semantic_chunks_with_paths):
logger.info(f"Stored: {total_stored}/{len(semantic_chunks_with_paths)} chunks")
store_time = time.time() - store_start
logger.info(f"Storage completed in {store_time:.1f}s")
elapsed_time = time.time() - start_time
# Generate summary
# Extract result data and log summary
data = result["result"]
logger.info("=" * 60)
logger.info(f"Completed in {elapsed_time:.1f}s")
logger.info(f" Chunking: {chunk_time:.1f}s")
logger.info(f" Embedding: {embed_time:.1f}s")
logger.info(f" Storage: {store_time:.1f}s")
logger.info(f"Total chunks created: {len(all_chunks)}")
logger.info(f"Files processed: {len(files) - len(failed_files)}/{len(files)}")
if vector_store.ann_available:
logger.info(f"HNSW index vectors: {vector_store.ann_count}")
if failed_files:
logger.warning(f"Failed files: {len(failed_files)}")
for file_path, error in failed_files[:5]: # Show first 5 failures
logger.warning(f" {file_path}: {error}")
logger.info(f"Completed in {data['elapsed_time']:.1f}s")
logger.info(f"Total chunks created: {data['chunks_created']}")
logger.info(f"Files processed: {data['files_processed']}")
if data['files_failed'] > 0:
logger.warning(f"Failed files: {data['files_failed']}")
if data.get('failed_files'):
for file_path, error in data['failed_files']:
logger.warning(f" {file_path}: {error}")
return {
"success": True,
"chunks_created": len(all_chunks),
"files_processed": len(files) - len(failed_files),
"files_failed": len(failed_files),
"elapsed_time": elapsed_time,
"chunk_time": chunk_time,
"embed_time": embed_time,
"store_time": store_time,
"ann_vectors": vector_store.ann_count if vector_store.ann_available else 0,
"chunks_created": data["chunks_created"],
"files_processed": data["files_processed"],
"files_failed": data["files_failed"],
"elapsed_time": data["elapsed_time"],
}
def find_index_databases(scan_dir: Path) -> List[Path]:
"""Find all _index.db files in directory tree."""
logger.info(f"Scanning for indexes in: {scan_dir}")
index_files = list(scan_dir.rglob("_index.db"))
logger.info(f"Found {len(index_files)} index databases")
return index_files
def main():
parser = argparse.ArgumentParser(
description="Generate vector embeddings for CodexLens indexes",
description="Generate vector embeddings for CodexLens indexes (memory-efficient streaming)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
)
@@ -439,14 +163,14 @@ def main():
"--workers",
type=int,
default=0,
help="Number of parallel workers for chunking (default: auto-detect CPU count)"
help="(Deprecated) Kept for backward compatibility, ignored"
)
parser.add_argument(
"--batch-size",
type=int,
default=256,
help="Batch size for embedding generation (default: 256)"
help="(Deprecated) Kept for backward compatibility, ignored"
)
parser.add_argument(
@@ -481,43 +205,33 @@ def main():
# Determine if scanning or single file
if args.scan or index_path.is_dir():
# Scan mode
# Scan mode - use recursive implementation
if index_path.is_file():
logger.error("--scan requires a directory path")
sys.exit(1)
index_files = find_index_databases(index_path)
if not index_files:
logger.error(f"No index databases found in: {index_path}")
result = generate_embeddings_recursive(
index_root=index_path,
model_profile=args.model,
force=args.force,
chunk_size=args.chunk_size,
progress_callback=progress_callback,
)
if not result["success"]:
logger.error(f"Failed: {result.get('error', 'Unknown error')}")
sys.exit(1)
# Process each index
total_chunks = 0
successful = 0
for idx, index_file in enumerate(index_files, 1):
logger.info(f"\n{'='*60}")
logger.info(f"Processing index {idx}/{len(index_files)}")
logger.info(f"{'='*60}")
result = generate_embeddings_for_index(
index_file,
model_profile=args.model,
force=args.force,
chunk_size=args.chunk_size,
workers=args.workers,
batch_size=args.batch_size,
)
if result["success"]:
total_chunks += result["chunks_created"]
successful += 1
# Final summary
# Log summary
data = result["result"]
logger.info(f"\n{'='*60}")
logger.info("BATCH PROCESSING COMPLETE")
logger.info(f"{'='*60}")
logger.info(f"Indexes processed: {successful}/{len(index_files)}")
logger.info(f"Total chunks created: {total_chunks}")
logger.info(f"Indexes processed: {data['indexes_successful']}/{data['indexes_processed']}")
logger.info(f"Total chunks created: {data['total_chunks_created']}")
logger.info(f"Total files processed: {data['total_files_processed']}")
if data['total_files_failed'] > 0:
logger.warning(f"Total files failed: {data['total_files_failed']}")
else:
# Single index mode
@@ -530,8 +244,6 @@ def main():
model_profile=args.model,
force=args.force,
chunk_size=args.chunk_size,
workers=args.workers,
batch_size=args.batch_size,
)
if not result["success"]:

View File

@@ -268,7 +268,6 @@ def search(
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."),
enrich: bool = typer.Option(False, "--enrich", help="Enrich results with code graph relationships (calls, imports)."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
@@ -423,30 +422,10 @@ def search(
for r in result.results
]
# Enrich results with relationship data if requested
enriched = False
if enrich:
try:
from codexlens.search.enrichment import RelationshipEnricher
# Find index path for the search path
project_record = registry.find_by_source_path(str(search_path))
if project_record:
index_path = Path(project_record["index_root"]) / "_index.db"
if index_path.exists():
with RelationshipEnricher(index_path) as enricher:
results_list = enricher.enrich(results_list, limit=limit)
enriched = True
except Exception as e:
# Enrichment failure should not break search
if verbose:
console.print(f"[yellow]Warning: Enrichment failed: {e}[/yellow]")
payload = {
"query": query,
"mode": actual_mode,
"count": len(results_list),
"enriched": enriched,
"results": results_list,
"stats": {
"dirs_searched": result.stats.dirs_searched,
@@ -458,8 +437,7 @@ def search(
print_json(success=True, result=payload)
else:
render_search_results(result.results, verbose=verbose)
enrich_status = " | [green]Enriched[/green]" if enriched else ""
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms{enrich_status}[/dim]")
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms[/dim]")
except SearchError as exc:
if json_mode:
@@ -1376,103 +1354,6 @@ def clean(
raise typer.Exit(code=1)
@app.command()
def graph(
query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"),
symbol: str = typer.Argument(..., help="Symbol name to query"),
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."),
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
"""Query semantic graph for code relationships.
Supported query types:
- callers: Find all functions/methods that call the given symbol
- callees: Find all functions/methods called by the given symbol
- inheritance: Find inheritance relationships for the given class
Examples:
codex-lens graph callers my_function
codex-lens graph callees MyClass.method --path src/
codex-lens graph inheritance BaseClass
"""
_configure_logging(verbose)
search_path = path.expanduser().resolve()
# Validate query type
valid_types = ["callers", "callees", "inheritance"]
if query_type not in valid_types:
if json_mode:
print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}")
else:
console.print(f"[red]Invalid query type:[/red] {query_type}")
console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]")
raise typer.Exit(code=1)
registry: RegistryStore | None = None
try:
registry = RegistryStore()
registry.initialize()
mapper = PathMapper()
engine = ChainSearchEngine(registry, mapper)
options = SearchOptions(depth=depth, total_limit=limit)
# Execute graph query based on type
if query_type == "callers":
results = engine.search_callers(symbol, search_path, options=options)
result_type = "callers"
elif query_type == "callees":
results = engine.search_callees(symbol, search_path, options=options)
result_type = "callees"
else: # inheritance
results = engine.search_inheritance(symbol, search_path, options=options)
result_type = "inheritance"
payload = {
"query_type": query_type,
"symbol": symbol,
"count": len(results),
"relationships": results
}
if json_mode:
print_json(success=True, result=payload)
else:
from .output import render_graph_results
render_graph_results(results, query_type=query_type, symbol=symbol)
except SearchError as exc:
if json_mode:
print_json(success=False, error=f"Graph search error: {exc}")
else:
console.print(f"[red]Graph query failed (search):[/red] {exc}")
raise typer.Exit(code=1)
except StorageError as exc:
if json_mode:
print_json(success=False, error=f"Storage error: {exc}")
else:
console.print(f"[red]Graph query failed (storage):[/red] {exc}")
raise typer.Exit(code=1)
except CodexLensError as exc:
if json_mode:
print_json(success=False, error=str(exc))
else:
console.print(f"[red]Graph query failed:[/red] {exc}")
raise typer.Exit(code=1)
except Exception as exc:
if json_mode:
print_json(success=False, error=f"Unexpected error: {exc}")
else:
console.print(f"[red]Graph query failed (unexpected):[/red] {exc}")
raise typer.Exit(code=1)
finally:
if registry is not None:
registry.close()
@app.command("semantic-list")
def semantic_list(
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),

View File

@@ -194,7 +194,6 @@ def generate_embeddings(
try:
# Use cached embedder (singleton) for performance
embedder = get_embedder(profile=model_profile)
vector_store = VectorStore(index_path)
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
if progress_callback:
@@ -217,85 +216,86 @@ def generate_embeddings(
EMBEDDING_BATCH_SIZE = 8 # jina-embeddings-v2-base-code needs small batches
try:
with sqlite3.connect(index_path) as conn:
conn.row_factory = sqlite3.Row
path_column = _get_path_column(conn)
with VectorStore(index_path) as vector_store:
with sqlite3.connect(index_path) as conn:
conn.row_factory = sqlite3.Row
path_column = _get_path_column(conn)
# Get total file count for progress reporting
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
if total_files == 0:
return {"success": False, "error": "No files found in index"}
# Get total file count for progress reporting
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
if total_files == 0:
return {"success": False, "error": "No files found in index"}
if progress_callback:
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
batch_number = 0
while True:
# Fetch a batch of files (streaming, not fetchall)
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
if not file_batch:
break
batch_number += 1
batch_chunks_with_paths = []
files_in_batch_with_chunks = set()
# Step 1: Chunking for the current file batch
for file_row in file_batch:
file_path = file_row[path_column]
content = file_row["content"]
language = file_row["language"] or "python"
try:
chunks = chunker.chunk_sliding_window(
content,
file_path=file_path,
language=language
)
if chunks:
for chunk in chunks:
batch_chunks_with_paths.append((chunk, file_path))
files_in_batch_with_chunks.add(file_path)
except Exception as e:
logger.error(f"Failed to chunk {file_path}: {e}")
failed_files.append((file_path, str(e)))
if not batch_chunks_with_paths:
continue
batch_chunk_count = len(batch_chunks_with_paths)
if progress_callback:
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
# Step 2: Generate embeddings for this batch
batch_embeddings = []
try:
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
embeddings = embedder.embed(batch_contents)
batch_embeddings.extend(embeddings)
except Exception as e:
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
continue
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
batch_number = 0
# Step 3: Assign embeddings to chunks
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
chunk.embedding = embedding
while True:
# Fetch a batch of files (streaming, not fetchall)
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
if not file_batch:
break
# Step 4: Store this batch to database immediately (releases memory)
try:
vector_store.add_chunks_batch(batch_chunks_with_paths)
total_chunks_created += batch_chunk_count
total_files_processed += len(files_in_batch_with_chunks)
except Exception as e:
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
batch_number += 1
batch_chunks_with_paths = []
files_in_batch_with_chunks = set()
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
# Step 1: Chunking for the current file batch
for file_row in file_batch:
file_path = file_row[path_column]
content = file_row["content"]
language = file_row["language"] or "python"
try:
chunks = chunker.chunk_sliding_window(
content,
file_path=file_path,
language=language
)
if chunks:
for chunk in chunks:
batch_chunks_with_paths.append((chunk, file_path))
files_in_batch_with_chunks.add(file_path)
except Exception as e:
logger.error(f"Failed to chunk {file_path}: {e}")
failed_files.append((file_path, str(e)))
if not batch_chunks_with_paths:
continue
batch_chunk_count = len(batch_chunks_with_paths)
if progress_callback:
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
# Step 2: Generate embeddings for this batch
batch_embeddings = []
try:
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
embeddings = embedder.embed(batch_contents)
batch_embeddings.extend(embeddings)
except Exception as e:
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
continue
# Step 3: Assign embeddings to chunks
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
chunk.embedding = embedding
# Step 4: Store this batch to database immediately (releases memory)
try:
vector_store.add_chunks_batch(batch_chunks_with_paths)
total_chunks_created += batch_chunk_count
total_files_processed += len(files_in_batch_with_chunks)
except Exception as e:
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
except Exception as e:
return {"success": False, "error": f"Failed to read or process files: {str(e)}"}

View File

@@ -122,68 +122,3 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) ->
console.print(header)
render_symbols(list(symbols), title="Discovered Symbols")
def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None:
"""Render semantic graph query results.
Args:
results: List of relationship dicts
query_type: Type of query (callers, callees, inheritance)
symbol: Symbol name that was queried
"""
if not results:
console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}")
return
title_map = {
"callers": f"Callers of '{symbol}' ({len(results)} found)",
"callees": f"Callees of '{symbol}' ({len(results)} found)",
"inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)"
}
table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})"))
if query_type == "callers":
table.add_column("Caller", style="green")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
table.add_column("Type", style="dim")
for rel in results:
table.add_row(
rel.get("source_symbol", "-"),
rel.get("source_file", "-"),
str(rel.get("source_line", "-")),
rel.get("relationship_type", "-")
)
elif query_type == "callees":
table.add_column("Target", style="green")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
table.add_column("Type", style="dim")
for rel in results:
table.add_row(
rel.get("target_symbol", "-"),
rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"),
str(rel.get("source_line", "-")),
rel.get("relationship_type", "-")
)
else: # inheritance
table.add_column("Derived Class", style="green")
table.add_column("Base Class", style="magenta")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
for rel in results:
table.add_row(
rel.get("source_symbol", "-"),
rel.get("target_symbol", "-"),
rel.get("source_file", "-"),
str(rel.get("source_line", "-"))
)
console.print(table)

View File

@@ -14,8 +14,6 @@ class Symbol(BaseModel):
kind: str = Field(..., min_length=1)
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
token_count: Optional[int] = Field(default=None, description="Token count for symbol content")
symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering")
@field_validator("range")
@classmethod
@@ -29,13 +27,6 @@ class Symbol(BaseModel):
raise ValueError("end_line must be >= start_line")
return value
@field_validator("token_count")
@classmethod
def validate_token_count(cls, value: Optional[int]) -> Optional[int]:
if value is not None and value < 0:
raise ValueError("token_count must be >= 0")
return value
class SemanticChunk(BaseModel):
"""A semantically meaningful chunk of content, optionally embedded."""

View File

@@ -302,108 +302,6 @@ class ChainSearchEngine:
index_paths, name, kind, options.total_limit
)
def search_callers(self, target_symbol: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find all callers of a given symbol across directory hierarchy.
Args:
target_symbol: Name of the symbol to find callers for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with caller information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> callers = engine.search_callers("my_function", Path("D:/project"))
>>> for caller in callers:
... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_callers_parallel(
index_paths, target_symbol, options.total_limit
)
def search_callees(self, source_symbol: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find all callees (what a symbol calls) across directory hierarchy.
Args:
source_symbol: Name of the symbol to find callees for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with callee information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> callees = engine.search_callees("MyClass.method", Path("D:/project"))
>>> for callee in callees:
... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_callees_parallel(
index_paths, source_symbol, options.total_limit
)
def search_inheritance(self, class_name: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find inheritance relationships for a class across directory hierarchy.
Args:
class_name: Name of the class to find inheritance for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with inheritance information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project"))
>>> for rel in inheritance:
... print(f"{rel['source_symbol']} extends {rel['target_symbol']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_inheritance_parallel(
index_paths, class_name, options.total_limit
)
# === Internal Methods ===
def _find_start_index(self, source_path: Path) -> Optional[Path]:
@@ -711,273 +609,6 @@ class ChainSearchEngine:
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
return []
def _search_callers_parallel(self, index_paths: List[Path],
target_symbol: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for callers across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
target_symbol: Target symbol name
limit: Total result limit
Returns:
Deduplicated list of caller relationships
"""
all_callers = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_callers_single,
idx_path,
target_symbol
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
callers = future.result()
all_callers.extend(callers)
except Exception as exc:
self.logger.error(f"Caller search failed: {exc}")
# Deduplicate by (source_file, source_line)
seen = set()
unique_callers = []
for caller in all_callers:
key = (caller.get("source_file"), caller.get("source_line"))
if key not in seen:
seen.add(key)
unique_callers.append(caller)
# Sort by source file and line
unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0)))
return unique_callers[:limit]
def _search_callers_single(self, index_path: Path,
target_symbol: str) -> List[Dict[str, Any]]:
"""Search for callers in a single index.
Args:
index_path: Path to _index.db file
target_symbol: Target symbol name
Returns:
List of caller relationship dicts (empty on error)
"""
try:
with SQLiteStore(index_path) as store:
return store.query_relationships_by_target(target_symbol)
except Exception as exc:
self.logger.debug(f"Caller search error in {index_path}: {exc}")
return []
def _search_callees_parallel(self, index_paths: List[Path],
source_symbol: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for callees across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
source_symbol: Source symbol name
limit: Total result limit
Returns:
Deduplicated list of callee relationships
"""
all_callees = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_callees_single,
idx_path,
source_symbol
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
callees = future.result()
all_callees.extend(callees)
except Exception as exc:
self.logger.error(f"Callee search failed: {exc}")
# Deduplicate by (target_symbol, source_line)
seen = set()
unique_callees = []
for callee in all_callees:
key = (callee.get("target_symbol"), callee.get("source_line"))
if key not in seen:
seen.add(key)
unique_callees.append(callee)
# Sort by source line
unique_callees.sort(key=lambda c: c.get("source_line", 0))
return unique_callees[:limit]
def _search_callees_single(self, index_path: Path,
source_symbol: str) -> List[Dict[str, Any]]:
"""Search for callees in a single index.
Args:
index_path: Path to _index.db file
source_symbol: Source symbol name
Returns:
List of callee relationship dicts (empty on error)
"""
try:
with SQLiteStore(index_path) as store:
# Single JOIN query to get all callees (fixes N+1 query problem)
# Uses public execute_query API instead of _get_connection bypass
rows = store.execute_query(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name AS target_symbol,
r.relationship_type,
r.source_line,
f.full_path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE s.name = ? AND r.relationship_type = 'call'
ORDER BY f.full_path, r.source_line
LIMIT 100
""",
(source_symbol,)
)
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_symbol"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
except Exception as exc:
self.logger.debug(f"Callee search error in {index_path}: {exc}")
return []
def _search_inheritance_parallel(self, index_paths: List[Path],
class_name: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for inheritance relationships across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
class_name: Class name to search for
limit: Total result limit
Returns:
Deduplicated list of inheritance relationships
"""
all_inheritance = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_inheritance_single,
idx_path,
class_name
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
inheritance = future.result()
all_inheritance.extend(inheritance)
except Exception as exc:
self.logger.error(f"Inheritance search failed: {exc}")
# Deduplicate by (source_symbol, target_symbol)
seen = set()
unique_inheritance = []
for rel in all_inheritance:
key = (rel.get("source_symbol"), rel.get("target_symbol"))
if key not in seen:
seen.add(key)
unique_inheritance.append(rel)
# Sort by source file
unique_inheritance.sort(key=lambda r: r.get("source_file", ""))
return unique_inheritance[:limit]
def _search_inheritance_single(self, index_path: Path,
class_name: str) -> List[Dict[str, Any]]:
"""Search for inheritance relationships in a single index.
Args:
index_path: Path to _index.db file
class_name: Class name to search for
Returns:
List of inheritance relationship dicts (empty on error)
"""
try:
with SQLiteStore(index_path) as store:
# Use UNION to find relationships where class is either:
# 1. The base class (target) - find derived classes
# 2. The derived class (source) - find parent classes
# Uses public execute_query API instead of _get_connection bypass
rows = store.execute_query(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.full_path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE r.target_qualified_name = ? AND r.relationship_type = 'inherits'
UNION
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.full_path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE s.name = ? AND r.relationship_type = 'inherits'
LIMIT 100
""",
(class_name, class_name)
)
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
except Exception as exc:
self.logger.debug(f"Inheritance search error in {index_path}: {exc}")
return []
# === Convenience Functions ===
@@ -1007,10 +638,9 @@ def quick_search(query: str,
mapper = PathMapper()
engine = ChainSearchEngine(registry, mapper)
options = SearchOptions(depth=depth)
result = engine.search(query, source_path, options)
with ChainSearchEngine(registry, mapper) as engine:
options = SearchOptions(depth=depth)
result = engine.search(query, source_path, options)
registry.close()

View File

@@ -1,542 +0,0 @@
"""Graph analyzer for extracting code relationships using tree-sitter.
Provides AST-based analysis to identify function calls, method invocations,
and class inheritance relationships within source files.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
try:
from tree_sitter import Node as TreeSitterNode
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterNode = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import CodeRelationship, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class GraphAnalyzer:
"""Analyzer for extracting semantic relationships from code using AST traversal."""
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
"""Initialize graph analyzer for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
parser: Optional TreeSitterSymbolParser instance for dependency injection.
If None, creates a new parser instance (backward compatibility).
"""
self.language_id = language_id
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
def is_available(self) -> bool:
"""Check if graph analyzer is available.
Returns:
True if tree-sitter parser is initialized and ready
"""
return self._parser.is_available()
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
"""Analyze source code and extract relationships.
Args:
text: Source code text
file_path: File path for relationship context
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def analyze_with_symbols(
self, text: str, file_path: Path, symbols: List[Symbol]
) -> List[CodeRelationship]:
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
Args:
text: Source code text
file_path: File path for relationship context
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
# Convert Symbol objects to internal symbol format
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
# Extract relationships using provided symbols
relationships = self._extract_relationships_with_symbols(
source_bytes, root, str(file_path.resolve()), defined_symbols
)
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def _convert_symbols_to_dict(
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
) -> List[dict]:
"""Convert Symbol objects to internal dict format for relationship extraction.
Args:
source_bytes: Source code as bytes
root: Root AST node
symbols: Pre-parsed Symbol objects
Returns:
List of symbol info dicts with name, node, and type
"""
symbol_dicts = []
symbol_names = {s.name for s in symbols}
# Find AST nodes corresponding to symbols
for node in self._iter_nodes(root):
node_type = node.type
# Check if this node matches any of our symbols
if node_type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "method"
})
elif node_type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
return symbol_dicts
def _extract_relationships_with_symbols(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
) -> List[CodeRelationship]:
"""Extract relationships from AST using pre-parsed symbols.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
defined_symbols: Pre-parsed symbol dicts
Returns:
List of extracted relationships
"""
relationships: List[CodeRelationship] = []
# Determine call node type based on language
if self.language_id == "python":
call_node_type = "call"
extract_target = self._extract_call_target
elif self.language_id in {"javascript", "typescript"}:
call_node_type = "call_expression"
extract_target = self._extract_js_call_target
else:
return []
# Find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == call_node_type:
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = extract_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of extracted relationships
"""
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, file_path)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, file_path)
else:
return []
def _extract_python_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract Python relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of Python relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols with their scopes
defined_symbols = self._collect_python_symbols(source_bytes, root)
# Second pass: find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == "call":
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = self._extract_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_js_ts_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract JavaScript/TypeScript relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of JS/TS relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
# Second pass: find call expressions
for node in self._iter_nodes(root):
if node.type == "call_expression":
# Extract caller context
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
source_symbol = "<module>"
# Extract callee
target_symbol = self._extract_js_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=line_number,
)
)
return relationships
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all Python function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
return symbols
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all JS/TS function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "method"
})
elif node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
return symbols
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
"""Find the enclosing function/method/class for a node.
Returns fully qualified name (e.g., "MyClass.my_method") by traversing up
the AST tree and collecting parent class/function names.
Args:
node: AST node to find enclosure for
symbols: List of defined symbols
Returns:
Fully qualified name of enclosing symbol, or None if at module level
"""
# Walk up the tree to find all enclosing symbols
enclosing_names = []
parent = node.parent
while parent is not None:
for symbol in symbols:
if symbol["node"] == parent:
# Prepend to maintain order (innermost to outermost)
enclosing_names.insert(0, symbol["name"])
break
parent = parent.parent
# Return fully qualified name or None if at module level
if enclosing_names:
return ".".join(enclosing_names)
return None
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a Python call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers (e.g., "foo()")
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle attribute access (e.g., "obj.method()")
if function_node.type == "attribute":
attr_node = function_node.child_by_field_name("attribute")
if attr_node:
return self._node_text(source_bytes, attr_node)
return None
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a JS/TS call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle member expressions (e.g., "obj.method()")
if function_node.type == "member_expression":
property_node = function_node.child_by_field_name("property")
if property_node:
return self._node_text(source_bytes, property_node)
return None
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")

View File

@@ -602,6 +602,12 @@ class VectorStore:
Returns:
List of SearchResult ordered by similarity (highest first)
"""
logger.warning(
"Using brute-force vector search (hnswlib not available). "
"This may cause high memory usage for large indexes. "
"Install hnswlib for better performance: pip install hnswlib"
)
with self._cache_lock:
# Refresh cache if needed
if self._embedding_matrix is None:

View File

@@ -17,7 +17,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from codexlens.entities import CodeRelationship, SearchResult, Symbol
from codexlens.entities import SearchResult, Symbol
from codexlens.errors import StorageError
@@ -237,116 +237,6 @@ class DirIndexStore:
conn.rollback()
raise StorageError(f"Failed to add file {name}: {exc}") from exc
def add_relationships(
self,
file_path: str | Path,
relationships: List[CodeRelationship],
) -> int:
"""Store code relationships for a file.
Args:
file_path: Path to the source file
relationships: List of CodeRelationship objects to store
Returns:
Number of relationships stored
Raises:
StorageError: If database operations fail
"""
if not relationships:
return 0
with self._lock:
conn = self._get_connection()
file_path_str = str(Path(file_path).resolve())
try:
# Get file_id
row = conn.execute(
"SELECT id FROM files WHERE full_path=?", (file_path_str,)
).fetchone()
if not row:
return 0
file_id = int(row["id"])
# Delete existing relationships for symbols in this file
conn.execute(
"""
DELETE FROM code_relationships
WHERE source_symbol_id IN (
SELECT id FROM symbols WHERE file_id=?
)
""",
(file_id,),
)
# Insert new relationships
relationship_rows = []
skipped_relationships = []
for rel in relationships:
# Extract simple name from fully qualified name (e.g., "MyClass.my_method" -> "my_method")
# This handles cases where GraphAnalyzer generates qualified names but symbols table stores simple names
source_symbol_simple = rel.source_symbol.split(".")[-1] if "." in rel.source_symbol else rel.source_symbol
# Find symbol_id by name and file
symbol_row = conn.execute(
"""
SELECT id FROM symbols
WHERE file_id=? AND name=? AND start_line<=? AND end_line>=?
LIMIT 1
""",
(file_id, source_symbol_simple, rel.source_line, rel.source_line),
).fetchone()
if not symbol_row:
# Try matching by simple name only
symbol_row = conn.execute(
"SELECT id FROM symbols WHERE file_id=? AND name=? LIMIT 1",
(file_id, source_symbol_simple),
).fetchone()
if symbol_row:
relationship_rows.append((
int(symbol_row["id"]),
rel.target_symbol,
rel.relationship_type,
rel.source_line,
rel.target_file,
))
else:
# Log warning when symbol lookup fails
skipped_relationships.append(rel.source_symbol)
# Log skipped relationships for debugging
if skipped_relationships:
self.logger.warning(
"Failed to find source symbol IDs for %d relationships in %s: %s",
len(skipped_relationships),
file_path_str,
", ".join(set(skipped_relationships))
)
if relationship_rows:
conn.executemany(
"""
INSERT INTO code_relationships(
source_symbol_id, target_qualified_name, relationship_type,
source_line, target_file
)
VALUES(?, ?, ?, ?, ?)
""",
relationship_rows,
)
conn.commit()
return len(relationship_rows)
except sqlite3.DatabaseError as exc:
conn.rollback()
raise StorageError(f"Failed to add relationships: {exc}") from exc
def add_files_batch(
self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]]
) -> int:

View File

@@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Set
from codexlens.config import Config
from codexlens.parsers.factory import ParserFactory
from codexlens.semantic.graph_analyzer import GraphAnalyzer
from codexlens.storage.dir_index import DirIndexStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import ProjectInfo, RegistryStore
@@ -525,16 +524,6 @@ class IndexTreeBuilder:
symbols=indexed_file.symbols,
)
# Extract and store code relationships for graph visualization
if language_id in {"python", "javascript", "typescript"}:
graph_analyzer = GraphAnalyzer(language_id)
if graph_analyzer.is_available():
relationships = graph_analyzer.analyze_with_symbols(
text, file_path, indexed_file.symbols
)
if relationships:
store.add_relationships(file_path, relationships)
files_count += 1
symbols_count += len(indexed_file.symbols)
@@ -742,16 +731,6 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
symbols=indexed_file.symbols,
)
# Extract and store code relationships for graph visualization
if language_id in {"python", "javascript", "typescript"}:
graph_analyzer = GraphAnalyzer(language_id)
if graph_analyzer.is_available():
relationships = graph_analyzer.analyze_with_symbols(
text, item, indexed_file.symbols
)
if relationships:
store.add_relationships(item, relationships)
files_count += 1
symbols_count += len(indexed_file.symbols)

View File

@@ -1,57 +0,0 @@
"""
Migration 003: Add code relationships storage.
This migration introduces the `code_relationships` table to store semantic
relationships between code symbols (function calls, inheritance, imports).
This enables graph-based code navigation and dependency analysis.
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to add code relationships table.
- Creates `code_relationships` table with foreign key to symbols
- Creates indexes for efficient relationship queries
- Supports lazy expansion with target_symbol being qualified names
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Creating 'code_relationships' table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT,
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
)
"""
)
log.info("Creating indexes for code_relationships...")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
)
log.info("Finished creating code_relationships table and indexes.")

View File

@@ -10,7 +10,7 @@ from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol
from codexlens.entities import IndexedFile, SearchResult, Symbol
from codexlens.errors import StorageError
@@ -420,167 +420,6 @@ class SQLiteStore:
}
def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None:
"""Store code relationships for a file.
Args:
file_path: Path to the file containing the relationships
relationships: List of CodeRelationship objects to store
"""
if not relationships:
return
with self._lock:
conn = self._get_connection()
resolved_path = str(Path(file_path).resolve())
# Get file_id
row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone()
if not row:
raise StorageError(f"File not found in index: {file_path}")
file_id = int(row["id"])
# Delete existing relationships for symbols in this file
conn.execute(
"""
DELETE FROM code_relationships
WHERE source_symbol_id IN (
SELECT id FROM symbols WHERE file_id=?
)
""",
(file_id,)
)
# Insert new relationships
relationship_rows = []
for rel in relationships:
# Find source symbol ID
symbol_row = conn.execute(
"""
SELECT id FROM symbols
WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ?
ORDER BY (end_line - start_line) ASC
LIMIT 1
""",
(file_id, rel.source_symbol, rel.source_line, rel.source_line)
).fetchone()
if symbol_row:
source_symbol_id = int(symbol_row["id"])
relationship_rows.append((
source_symbol_id,
rel.target_symbol,
rel.relationship_type,
rel.source_line,
rel.target_file
))
if relationship_rows:
conn.executemany(
"""
INSERT INTO code_relationships(
source_symbol_id, target_qualified_name, relationship_type,
source_line, target_file
)
VALUES(?, ?, ?, ?, ?)
""",
relationship_rows
)
conn.commit()
def query_relationships_by_target(
self, target_name: str, *, limit: int = 100
) -> List[Dict[str, Any]]:
"""Query relationships by target symbol name (find all callers).
Args:
target_name: Name of the target symbol
limit: Maximum number of results
Returns:
List of dicts containing relationship info with file paths and line numbers
"""
with self._lock:
conn = self._get_connection()
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.full_path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE r.target_qualified_name = ?
ORDER BY f.full_path, r.source_line
LIMIT ?
""",
(target_name, limit)
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
def query_relationships_by_source(
self, source_symbol: str, source_file: str | Path, *, limit: int = 100
) -> List[Dict[str, Any]]:
"""Query relationships by source symbol (find what a symbol calls).
Args:
source_symbol: Name of the source symbol
source_file: File path containing the source symbol
limit: Maximum number of results
Returns:
List of dicts containing relationship info
"""
with self._lock:
conn = self._get_connection()
resolved_path = str(Path(source_file).resolve())
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE s.name = ? AND f.path = ?
ORDER BY r.source_line
LIMIT ?
""",
(source_symbol, resolved_path, limit)
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
def _connect(self) -> sqlite3.Connection:
"""Legacy method for backward compatibility."""
return self._get_connection()

View File

@@ -1,644 +0,0 @@
"""Unit tests for ChainSearchEngine.
Tests the graph query methods (search_callers, search_callees, search_inheritance)
with mocked SQLiteStore dependency to test logic in isolation.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, MagicMock, patch, call
from concurrent.futures import ThreadPoolExecutor
from codexlens.search.chain_search import (
ChainSearchEngine,
SearchOptions,
SearchStats,
ChainSearchResult,
)
from codexlens.entities import SearchResult, Symbol
from codexlens.storage.registry import RegistryStore, DirMapping
from codexlens.storage.path_mapper import PathMapper
@pytest.fixture
def mock_registry():
"""Create a mock RegistryStore."""
registry = Mock(spec=RegistryStore)
return registry
@pytest.fixture
def mock_mapper():
"""Create a mock PathMapper."""
mapper = Mock(spec=PathMapper)
return mapper
@pytest.fixture
def search_engine(mock_registry, mock_mapper):
"""Create a ChainSearchEngine with mocked dependencies."""
return ChainSearchEngine(mock_registry, mock_mapper, max_workers=2)
@pytest.fixture
def sample_index_path():
"""Sample index database path."""
return Path("/test/project/_index.db")
class TestChainSearchEngineCallers:
"""Tests for search_callers method."""
def test_search_callers_returns_relationships(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callers returns caller relationships."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
# Mock finding the start index
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
# Mock collect_index_paths to return single index
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
# Mock the parallel search to return caller data
expected_callers = [
{
"source_symbol": "caller_function",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 42,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callers_parallel', return_value=expected_callers):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "caller_function"
assert result[0]["target_symbol"] == "my_function"
assert result[0]["relationship_type"] == "calls"
assert result[0]["source_line"] == 42
def test_search_callers_empty_results(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callers handles no results gracefully."""
# Setup
source_path = Path("/test/project")
target_symbol = "nonexistent_function"
# Mock finding the start index
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
# Mock collect_index_paths
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
# Mock empty results
with patch.object(search_engine, '_search_callers_parallel', return_value=[]):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert
assert result == []
def test_search_callers_no_index_found(self, search_engine, mock_registry):
"""Test that search_callers returns empty list when no index found."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
# Mock no index found
mock_registry.find_nearest_index.return_value = None
with patch.object(search_engine, '_find_start_index', return_value=None):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert
assert result == []
def test_search_callers_uses_options(self, search_engine, mock_registry, mock_mapper, sample_index_path):
"""Test that search_callers respects SearchOptions."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
options = SearchOptions(depth=1, total_limit=50)
# Configure mapper to return a path that exists
mock_mapper.source_to_index_db.return_value = sample_index_path
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]) as mock_collect:
with patch.object(search_engine, '_search_callers_parallel', return_value=[]) as mock_search:
# Patch Path.exists to return True so the exact match is found
with patch.object(Path, 'exists', return_value=True):
# Execute
search_engine.search_callers(target_symbol, source_path, options)
# Assert that depth was passed to collect_index_paths
mock_collect.assert_called_once_with(sample_index_path, 1)
# Assert that total_limit was passed to parallel search
mock_search.assert_called_once_with([sample_index_path], target_symbol, 50)
class TestChainSearchEngineCallees:
"""Tests for search_callees method."""
def test_search_callees_returns_relationships(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callees returns callee relationships."""
# Setup
source_path = Path("/test/project")
source_symbol = "caller_function"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
expected_callees = [
{
"source_symbol": "caller_function",
"target_symbol": "callee_function",
"relationship_type": "calls",
"source_line": 15,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
# Execute
result = search_engine.search_callees(source_symbol, source_path)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "caller_function"
assert result[0]["target_symbol"] == "callee_function"
assert result[0]["source_line"] == 15
def test_search_callees_filters_by_file(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callees correctly handles file-specific queries."""
# Setup
source_path = Path("/test/project")
source_symbol = "MyClass.method"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
# Multiple callees from same source symbol
expected_callees = [
{
"source_symbol": "MyClass.method",
"target_symbol": "helper_a",
"relationship_type": "calls",
"source_line": 10,
"source_file": "/test/project/module.py",
"target_file": "/test/project/utils.py",
},
{
"source_symbol": "MyClass.method",
"target_symbol": "helper_b",
"relationship_type": "calls",
"source_line": 20,
"source_file": "/test/project/module.py",
"target_file": "/test/project/utils.py",
}
]
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
# Execute
result = search_engine.search_callees(source_symbol, source_path)
# Assert
assert len(result) == 2
assert result[0]["target_symbol"] == "helper_a"
assert result[1]["target_symbol"] == "helper_b"
def test_search_callees_empty_results(self, search_engine, mock_registry, sample_index_path):
"""Test that search_callees handles no callees gracefully."""
source_path = Path("/test/project")
source_symbol = "leaf_function"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
with patch.object(search_engine, '_search_callees_parallel', return_value=[]):
# Execute
result = search_engine.search_callees(source_symbol, source_path)
# Assert
assert result == []
class TestChainSearchEngineInheritance:
"""Tests for search_inheritance method."""
def test_search_inheritance_returns_inherits_relationships(self, search_engine, mock_registry, sample_index_path):
"""Test that search_inheritance returns inheritance relationships."""
# Setup
source_path = Path("/test/project")
class_name = "BaseClass"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
expected_inheritance = [
{
"source_symbol": "DerivedClass",
"target_symbol": "BaseClass",
"relationship_type": "inherits",
"source_line": 5,
"source_file": "/test/project/derived.py",
"target_file": "/test/project/base.py",
}
]
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
# Execute
result = search_engine.search_inheritance(class_name, source_path)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "DerivedClass"
assert result[0]["target_symbol"] == "BaseClass"
assert result[0]["relationship_type"] == "inherits"
def test_search_inheritance_multiple_subclasses(self, search_engine, mock_registry, sample_index_path):
"""Test inheritance search with multiple derived classes."""
source_path = Path("/test/project")
class_name = "BaseClass"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
expected_inheritance = [
{
"source_symbol": "DerivedClassA",
"target_symbol": "BaseClass",
"relationship_type": "inherits",
"source_line": 5,
"source_file": "/test/project/derived_a.py",
"target_file": "/test/project/base.py",
},
{
"source_symbol": "DerivedClassB",
"target_symbol": "BaseClass",
"relationship_type": "inherits",
"source_line": 10,
"source_file": "/test/project/derived_b.py",
"target_file": "/test/project/base.py",
}
]
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
# Execute
result = search_engine.search_inheritance(class_name, source_path)
# Assert
assert len(result) == 2
assert result[0]["source_symbol"] == "DerivedClassA"
assert result[1]["source_symbol"] == "DerivedClassB"
def test_search_inheritance_empty_results(self, search_engine, mock_registry, sample_index_path):
"""Test inheritance search with no subclasses found."""
source_path = Path("/test/project")
class_name = "FinalClass"
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=sample_index_path,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
with patch.object(search_engine, '_search_inheritance_parallel', return_value=[]):
# Execute
result = search_engine.search_inheritance(class_name, source_path)
# Assert
assert result == []
class TestChainSearchEngineParallelSearch:
"""Tests for parallel search aggregation."""
def test_parallel_search_aggregates_results(self, search_engine, mock_registry, sample_index_path):
"""Test that parallel search aggregates results from multiple indexes."""
# Setup
source_path = Path("/test/project")
target_symbol = "my_function"
index_path_1 = Path("/test/project/_index.db")
index_path_2 = Path("/test/project/subdir/_index.db")
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=index_path_1,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
# Mock parallel search results from multiple indexes
callers_from_multiple = [
{
"source_symbol": "caller_in_root",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 10,
"source_file": "/test/project/root.py",
"target_file": "/test/project/lib.py",
},
{
"source_symbol": "caller_in_subdir",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 20,
"source_file": "/test/project/subdir/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callers_parallel', return_value=callers_from_multiple):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert results from both indexes are included
assert len(result) == 2
assert any(r["source_file"] == "/test/project/root.py" for r in result)
assert any(r["source_file"] == "/test/project/subdir/module.py" for r in result)
def test_parallel_search_deduplicates_results(self, search_engine, mock_registry, sample_index_path):
"""Test that parallel search deduplicates results by (source_file, source_line)."""
# Note: This test verifies the behavior of _search_callers_parallel deduplication
source_path = Path("/test/project")
target_symbol = "my_function"
index_path_1 = Path("/test/project/_index.db")
index_path_2 = Path("/test/project/_index.db") # Same index (simulates duplicate)
mock_registry.find_nearest_index.return_value = DirMapping(
id=1,
project_id=1,
source_path=source_path,
index_path=index_path_1,
depth=0,
files_count=10,
last_updated=0.0
)
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
# Mock duplicate results from same location
duplicate_callers = [
{
"source_symbol": "caller_function",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 42,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
},
{
"source_symbol": "caller_function",
"target_symbol": "my_function",
"relationship_type": "calls",
"source_line": 42,
"source_file": "/test/project/module.py",
"target_file": "/test/project/lib.py",
}
]
with patch.object(search_engine, '_search_callers_parallel', return_value=duplicate_callers):
# Execute
result = search_engine.search_callers(target_symbol, source_path)
# Assert: even with duplicates in input, output may contain both
# (actual deduplication happens in _search_callers_parallel)
assert len(result) >= 1
class TestChainSearchEngineContextManager:
"""Tests for context manager functionality."""
def test_context_manager_closes_executor(self, mock_registry, mock_mapper):
"""Test that context manager properly closes executor."""
with ChainSearchEngine(mock_registry, mock_mapper) as engine:
# Force executor creation
engine._get_executor()
assert engine._executor is not None
# Executor should be closed after exiting context
assert engine._executor is None
def test_close_method_shuts_down_executor(self, search_engine):
"""Test that close() method shuts down executor."""
# Create executor
search_engine._get_executor()
assert search_engine._executor is not None
# Close
search_engine.close()
assert search_engine._executor is None
class TestSearchCallersSingle:
"""Tests for _search_callers_single internal method."""
def test_search_callers_single_queries_store(self, search_engine, sample_index_path):
"""Test that _search_callers_single queries SQLiteStore correctly."""
target_symbol = "my_function"
# Mock SQLiteStore
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
mock_store_instance = MockStore.return_value.__enter__.return_value
mock_store_instance.query_relationships_by_target.return_value = [
{
"source_symbol": "caller",
"target_symbol": target_symbol,
"relationship_type": "calls",
"source_line": 10,
"source_file": "/test/file.py",
"target_file": "/test/lib.py",
}
]
# Execute
result = search_engine._search_callers_single(sample_index_path, target_symbol)
# Assert
assert len(result) == 1
assert result[0]["source_symbol"] == "caller"
mock_store_instance.query_relationships_by_target.assert_called_once_with(target_symbol)
def test_search_callers_single_handles_errors(self, search_engine, sample_index_path):
"""Test that _search_callers_single returns empty list on error."""
target_symbol = "my_function"
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
MockStore.return_value.__enter__.side_effect = Exception("Database error")
# Execute
result = search_engine._search_callers_single(sample_index_path, target_symbol)
# Assert - should return empty list, not raise exception
assert result == []
class TestSearchCalleesSingle:
"""Tests for _search_callees_single internal method."""
def test_search_callees_single_queries_database(self, search_engine, sample_index_path):
"""Test that _search_callees_single queries SQLiteStore correctly."""
source_symbol = "caller_function"
# Mock SQLiteStore
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
mock_store_instance = MagicMock()
MockStore.return_value.__enter__.return_value = mock_store_instance
# Mock execute_query to return relationship data (using new public API)
mock_store_instance.execute_query.return_value = [
{
"source_symbol": source_symbol,
"target_symbol": "callee_function",
"relationship_type": "call",
"source_line": 15,
"source_file": "/test/module.py",
"target_file": "/test/lib.py",
}
]
# Execute
result = search_engine._search_callees_single(sample_index_path, source_symbol)
# Assert - verify execute_query was called (public API)
assert mock_store_instance.execute_query.called
assert len(result) == 1
assert result[0]["source_symbol"] == source_symbol
assert result[0]["target_symbol"] == "callee_function"
def test_search_callees_single_handles_errors(self, search_engine, sample_index_path):
"""Test that _search_callees_single returns empty list on error."""
source_symbol = "caller_function"
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
MockStore.return_value.__enter__.side_effect = Exception("DB error")
# Execute
result = search_engine._search_callees_single(sample_index_path, source_symbol)
# Assert - should return empty list, not raise exception
assert result == []
class TestSearchInheritanceSingle:
"""Tests for _search_inheritance_single internal method."""
def test_search_inheritance_single_queries_database(self, search_engine, sample_index_path):
"""Test that _search_inheritance_single queries SQLiteStore correctly."""
class_name = "BaseClass"
# Mock SQLiteStore
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
mock_store_instance = MagicMock()
MockStore.return_value.__enter__.return_value = mock_store_instance
# Mock execute_query to return relationship data (using new public API)
mock_store_instance.execute_query.return_value = [
{
"source_symbol": "DerivedClass",
"target_qualified_name": "BaseClass",
"relationship_type": "inherits",
"source_line": 5,
"source_file": "/test/derived.py",
"target_file": "/test/base.py",
}
]
# Execute
result = search_engine._search_inheritance_single(sample_index_path, class_name)
# Assert
assert mock_store_instance.execute_query.called
assert len(result) == 1
assert result[0]["source_symbol"] == "DerivedClass"
assert result[0]["relationship_type"] == "inherits"
# Verify execute_query was called with 'inherits' filter
call_args = mock_store_instance.execute_query.call_args
sql_query = call_args[0][0]
assert "relationship_type = 'inherits'" in sql_query
def test_search_inheritance_single_handles_errors(self, search_engine, sample_index_path):
"""Test that _search_inheritance_single returns empty list on error."""
class_name = "BaseClass"
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
MockStore.return_value.__enter__.side_effect = Exception("DB error")
# Execute
result = search_engine._search_inheritance_single(sample_index_path, class_name)
# Assert - should return empty list, not raise exception
assert result == []

View File

@@ -1,122 +0,0 @@
"""Tests for CLI search command with --enrich flag."""
import json
import pytest
from typer.testing import CliRunner
from codexlens.cli.commands import app
class TestCLISearchEnrich:
"""Test CLI search command with --enrich flag integration."""
@pytest.fixture
def runner(self):
"""Create CLI test runner."""
return CliRunner()
def test_search_with_enrich_flag_help(self, runner):
"""Test --enrich flag is documented in help."""
result = runner.invoke(app, ["search", "--help"])
assert result.exit_code == 0
assert "--enrich" in result.output
assert "relationships" in result.output.lower() or "graph" in result.output.lower()
def test_search_with_enrich_flag_accepted(self, runner):
"""Test --enrich flag is accepted by the CLI."""
result = runner.invoke(app, ["search", "test", "--enrich"])
# Should not show 'unknown option' error
assert "No such option" not in result.output
assert "error: unrecognized" not in result.output.lower()
def test_search_without_enrich_flag(self, runner):
"""Test search without --enrich flag has no relationships."""
result = runner.invoke(app, ["search", "test", "--json"])
# Even without an index, JSON should be attempted
if result.exit_code == 0:
try:
data = json.loads(result.output)
# If we get results, they should not have enriched=true
if data.get("success") and "result" in data:
assert data["result"].get("enriched", False) is False
except json.JSONDecodeError:
pass # Not JSON output, that's fine for error cases
def test_search_enrich_json_output_structure(self, runner):
"""Test JSON output structure includes enriched flag."""
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
# If we get valid JSON output, check structure
if result.exit_code == 0:
try:
data = json.loads(result.output)
if data.get("success") and "result" in data:
# enriched field should exist
assert "enriched" in data["result"]
except json.JSONDecodeError:
pass # Not JSON output
def test_search_enrich_with_mode(self, runner):
"""Test --enrich works with different search modes."""
modes = ["exact", "fuzzy", "hybrid"]
for mode in modes:
result = runner.invoke(
app, ["search", "test", "--mode", mode, "--enrich"]
)
# Should not show validation errors
assert "Invalid" not in result.output
class TestEnrichFlagBehavior:
"""Test behavioral aspects of --enrich flag."""
@pytest.fixture
def runner(self):
"""Create CLI test runner."""
return CliRunner()
def test_enrich_failure_does_not_break_search(self, runner):
"""Test that enrichment failure doesn't prevent search from returning results."""
# Even without proper index, search should not crash due to enrich
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
# Should not have unhandled exception
assert "Traceback" not in result.output
def test_enrich_flag_with_files_only(self, runner):
"""Test --enrich is accepted with --files-only mode."""
result = runner.invoke(app, ["search", "test", "--enrich", "--files-only"])
# Should not show option conflict error
assert "conflict" not in result.output.lower()
def test_enrich_flag_with_limit(self, runner):
"""Test --enrich works with --limit parameter."""
result = runner.invoke(app, ["search", "test", "--enrich", "--limit", "5"])
# Should not show validation error
assert "Invalid" not in result.output
class TestEnrichOutputFormat:
"""Test output format with --enrich flag."""
@pytest.fixture
def runner(self):
"""Create CLI test runner."""
return CliRunner()
def test_enrich_verbose_shows_status(self, runner):
"""Test verbose mode shows enrichment status."""
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
# Verbose mode may show enrichment info or warnings
# Just ensure it doesn't crash
assert result.exit_code in [0, 1] # 0 = success, 1 = no index
def test_json_output_has_enriched_field(self, runner):
"""Test JSON output always has enriched field when --enrich used."""
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
if result.exit_code == 0:
try:
data = json.loads(result.output)
if data.get("success"):
result_data = data.get("result", {})
assert "enriched" in result_data
assert isinstance(result_data["enriched"], bool)
except json.JSONDecodeError:
pass

View File

@@ -204,8 +204,6 @@ class TestEntitySerialization:
"kind": "function",
"range": (1, 10),
"file": None,
"token_count": None,
"symbol_type": None,
}
def test_indexed_file_model_dump(self):

View File

@@ -1,436 +0,0 @@
"""Tests for GraphAnalyzer - code relationship extraction."""
from pathlib import Path
import pytest
from codexlens.semantic.graph_analyzer import GraphAnalyzer
TREE_SITTER_PYTHON_AVAILABLE = True
try:
import tree_sitter_python # type: ignore[import-not-found] # noqa: F401
except Exception:
TREE_SITTER_PYTHON_AVAILABLE = False
TREE_SITTER_JS_AVAILABLE = True
try:
import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401
except Exception:
TREE_SITTER_JS_AVAILABLE = False
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
class TestPythonGraphAnalyzer:
"""Tests for Python relationship extraction."""
def test_simple_function_call(self):
"""Test extraction of simple function call."""
code = """def helper():
pass
def main():
helper()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find main -> helper call
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "main"
assert rel.target_symbol == "helper"
assert rel.relationship_type == "call"
assert rel.source_line == 5
def test_multiple_calls_in_function(self):
"""Test extraction of multiple calls from same function."""
code = """def foo():
pass
def bar():
pass
def main():
foo()
bar()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find main -> foo and main -> bar
assert len(relationships) == 2
targets = {rel.target_symbol for rel in relationships}
assert targets == {"foo", "bar"}
assert all(rel.source_symbol == "main" for rel in relationships)
def test_nested_function_calls(self):
"""Test extraction of calls from nested functions."""
code = """def inner_helper():
pass
def outer():
def inner():
inner_helper()
inner()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find outer.inner -> inner_helper and outer -> inner (with fully qualified names)
assert len(relationships) == 2
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
assert ("outer.inner", "inner_helper") in call_pairs
assert ("outer", "inner") in call_pairs
def test_method_call_in_class(self):
"""Test extraction of method calls within class."""
code = """class Calculator:
def add(self, a, b):
return a + b
def compute(self, x, y):
result = self.add(x, y)
return result
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find Calculator.compute -> add (with fully qualified source)
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "Calculator.compute"
assert rel.target_symbol == "add"
def test_module_level_call(self):
"""Test extraction of module-level function calls."""
code = """def setup():
pass
setup()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find <module> -> setup
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "<module>"
assert rel.target_symbol == "setup"
def test_async_function_call(self):
"""Test extraction of calls involving async functions."""
code = """async def fetch_data():
pass
async def process():
await fetch_data()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Should find process -> fetch_data
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "process"
assert rel.target_symbol == "fetch_data"
def test_complex_python_file(self):
"""Test extraction from realistic Python file with multiple patterns."""
code = """class DataProcessor:
def __init__(self):
self.data = []
def load(self, filename):
self.data = read_file(filename)
def process(self):
self.validate()
self.transform()
def validate(self):
pass
def transform(self):
pass
def read_file(filename):
pass
def main():
processor = DataProcessor()
processor.load("data.txt")
processor.process()
main()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Extract call pairs
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
# Expected relationships (with fully qualified source symbols for methods)
expected = {
("DataProcessor.load", "read_file"),
("DataProcessor.process", "validate"),
("DataProcessor.process", "transform"),
("main", "DataProcessor"),
("main", "load"),
("main", "process"),
("<module>", "main"),
}
# Should find all expected relationships
assert call_pairs >= expected
def test_empty_file(self):
"""Test handling of empty file."""
code = ""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
assert len(relationships) == 0
def test_file_with_no_calls(self):
"""Test handling of file with definitions but no calls."""
code = """def func1():
pass
def func2():
pass
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
assert len(relationships) == 0
@pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed")
class TestJavaScriptGraphAnalyzer:
"""Tests for JavaScript relationship extraction."""
def test_simple_function_call(self):
"""Test extraction of simple JavaScript function call."""
code = """function helper() {}
function main() {
helper();
}
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Should find main -> helper call
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "main"
assert rel.target_symbol == "helper"
assert rel.relationship_type == "call"
def test_arrow_function_call(self):
"""Test extraction of calls from arrow functions."""
code = """const helper = () => {};
const main = () => {
helper();
};
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Should find main -> helper call
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "main"
assert rel.target_symbol == "helper"
def test_class_method_call(self):
"""Test extraction of method calls in JavaScript class."""
code = """class Calculator {
add(a, b) {
return a + b;
}
compute(x, y) {
return this.add(x, y);
}
}
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Should find Calculator.compute -> add (with fully qualified source)
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "Calculator.compute"
assert rel.target_symbol == "add"
def test_complex_javascript_file(self):
"""Test extraction from realistic JavaScript file."""
code = """function readFile(filename) {
return "";
}
class DataProcessor {
constructor() {
this.data = [];
}
load(filename) {
this.data = readFile(filename);
}
process() {
this.validate();
this.transform();
}
validate() {}
transform() {}
}
function main() {
const processor = new DataProcessor();
processor.load("data.txt");
processor.process();
}
main();
"""
analyzer = GraphAnalyzer("javascript")
relationships = analyzer.analyze_file(code, Path("test.js"))
# Extract call pairs
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
# Expected relationships (with fully qualified source symbols for methods)
# Note: constructor calls like "new DataProcessor()" are not tracked
expected = {
("DataProcessor.load", "readFile"),
("DataProcessor.process", "validate"),
("DataProcessor.process", "transform"),
("main", "load"),
("main", "process"),
("<module>", "main"),
}
# Should find all expected relationships
assert call_pairs >= expected
class TestGraphAnalyzerEdgeCases:
"""Edge case tests for GraphAnalyzer."""
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_unavailable_language(self):
"""Test handling of unsupported language."""
code = "some code"
analyzer = GraphAnalyzer("rust")
relationships = analyzer.analyze_file(code, Path("test.rs"))
assert len(relationships) == 0
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_malformed_python_code(self):
"""Test handling of malformed Python code."""
code = "def broken(\n pass"
analyzer = GraphAnalyzer("python")
# Should not crash
relationships = analyzer.analyze_file(code, Path("test.py"))
assert isinstance(relationships, list)
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_file_path_in_relationship(self):
"""Test that file path is correctly set in relationships."""
code = """def foo():
pass
def bar():
foo()
"""
test_path = Path("test.py")
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, test_path)
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_file == str(test_path.resolve())
assert rel.target_file is None # Intra-file
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_performance_large_file(self):
"""Test performance on larger file (1000 lines)."""
import time
# Generate file with many functions and calls
lines = []
for i in range(100):
lines.append(f"def func_{i}():")
if i > 0:
lines.append(f" func_{i-1}()")
else:
lines.append(" pass")
code = "\n".join(lines)
analyzer = GraphAnalyzer("python")
start_time = time.time()
relationships = analyzer.analyze_file(code, Path("test.py"))
elapsed_ms = (time.time() - start_time) * 1000
# Should complete in under 500ms
assert elapsed_ms < 500
# Should find 99 calls (func_1 -> func_0, func_2 -> func_1, ...)
assert len(relationships) == 99
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
def test_call_accuracy_rate(self):
"""Test >95% accuracy on known call graph."""
code = """def a(): pass
def b(): pass
def c(): pass
def d(): pass
def e(): pass
def test1():
a()
b()
def test2():
c()
d()
def test3():
e()
def main():
test1()
test2()
test3()
"""
analyzer = GraphAnalyzer("python")
relationships = analyzer.analyze_file(code, Path("test.py"))
# Expected calls: test1->a, test1->b, test2->c, test2->d, test3->e, main->test1, main->test2, main->test3
expected_calls = {
("test1", "a"),
("test1", "b"),
("test2", "c"),
("test2", "d"),
("test3", "e"),
("main", "test1"),
("main", "test2"),
("main", "test3"),
}
found_calls = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
# Calculate accuracy
correct = len(expected_calls & found_calls)
total = len(expected_calls)
accuracy = (correct / total) * 100 if total > 0 else 0
# Should have >95% accuracy
assert accuracy >= 95.0
assert correct == total # Should be 100% for this simple case

View File

@@ -1,392 +0,0 @@
"""End-to-end tests for graph search CLI commands."""
import tempfile
from pathlib import Path
from typer.testing import CliRunner
import pytest
from codexlens.cli.commands import app
from codexlens.storage.sqlite_store import SQLiteStore
from codexlens.storage.registry import RegistryStore
from codexlens.storage.path_mapper import PathMapper
from codexlens.entities import IndexedFile, Symbol, CodeRelationship
runner = CliRunner()
@pytest.fixture
def temp_project():
"""Create a temporary project with indexed code and relationships."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir) / "test_project"
project_root.mkdir()
# Create test Python files
(project_root / "main.py").write_text("""
def main():
result = calculate(5, 3)
print(result)
def calculate(a, b):
return add(a, b)
def add(x, y):
return x + y
""")
(project_root / "utils.py").write_text("""
class BaseClass:
def method(self):
pass
class DerivedClass(BaseClass):
def method(self):
super().method()
helper()
def helper():
return True
""")
# Create a custom index directory for graph testing
# Skip the standard init to avoid schema conflicts
mapper = PathMapper()
index_root = mapper.source_to_index_dir(project_root)
index_root.mkdir(parents=True, exist_ok=True)
test_db = index_root / "_index.db"
# Register project manually
registry = RegistryStore()
registry.initialize()
project_info = registry.register_project(
source_root=project_root,
index_root=index_root
)
registry.register_dir(
project_id=project_info.id,
source_path=project_root,
index_path=test_db,
depth=0,
files_count=2
)
# Initialize the store with proper SQLiteStore schema and add files
with SQLiteStore(test_db) as store:
# Read and add files to the store
main_content = (project_root / "main.py").read_text()
utils_content = (project_root / "utils.py").read_text()
main_indexed = IndexedFile(
path=str(project_root / "main.py"),
language="python",
symbols=[
Symbol(name="main", kind="function", range=(2, 4)),
Symbol(name="calculate", kind="function", range=(6, 7)),
Symbol(name="add", kind="function", range=(9, 10))
]
)
utils_indexed = IndexedFile(
path=str(project_root / "utils.py"),
language="python",
symbols=[
Symbol(name="BaseClass", kind="class", range=(2, 4)),
Symbol(name="DerivedClass", kind="class", range=(6, 9)),
Symbol(name="helper", kind="function", range=(11, 12))
]
)
store.add_file(main_indexed, main_content)
store.add_file(utils_indexed, utils_content)
with SQLiteStore(test_db) as store:
# Add relationships for main.py
main_file = project_root / "main.py"
relationships_main = [
CodeRelationship(
source_symbol="main",
target_symbol="calculate",
relationship_type="call",
source_file=str(main_file),
source_line=3,
target_file=str(main_file)
),
CodeRelationship(
source_symbol="calculate",
target_symbol="add",
relationship_type="call",
source_file=str(main_file),
source_line=7,
target_file=str(main_file)
),
]
store.add_relationships(main_file, relationships_main)
# Add relationships for utils.py
utils_file = project_root / "utils.py"
relationships_utils = [
CodeRelationship(
source_symbol="DerivedClass",
target_symbol="BaseClass",
relationship_type="inherits",
source_file=str(utils_file),
source_line=6, # DerivedClass is defined on line 6
target_file=str(utils_file)
),
CodeRelationship(
source_symbol="DerivedClass.method",
target_symbol="helper",
relationship_type="call",
source_file=str(utils_file),
source_line=8,
target_file=str(utils_file)
),
]
store.add_relationships(utils_file, relationships_utils)
registry.close()
yield project_root
class TestGraphCallers:
"""Test callers query type."""
def test_find_callers_basic(self, temp_project):
"""Test finding functions that call a given function."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "calculate" in result.stdout
assert "Callers of 'add'" in result.stdout
def test_find_callers_json_mode(self, temp_project):
"""Test callers query with JSON output."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--json"
])
assert result.exit_code == 0
assert "success" in result.stdout
assert "relationships" in result.stdout
def test_find_callers_no_results(self, temp_project):
"""Test callers query when no callers exist."""
result = runner.invoke(app, [
"graph",
"callers",
"nonexistent_function",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "No callers found" in result.stdout or "0 found" in result.stdout
class TestGraphCallees:
"""Test callees query type."""
def test_find_callees_basic(self, temp_project):
"""Test finding functions called by a given function."""
result = runner.invoke(app, [
"graph",
"callees",
"main",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "calculate" in result.stdout
assert "Callees of 'main'" in result.stdout
def test_find_callees_chain(self, temp_project):
"""Test finding callees in a call chain."""
result = runner.invoke(app, [
"graph",
"callees",
"calculate",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "add" in result.stdout
def test_find_callees_json_mode(self, temp_project):
"""Test callees query with JSON output."""
result = runner.invoke(app, [
"graph",
"callees",
"main",
"--path", str(temp_project),
"--json"
])
assert result.exit_code == 0
assert "success" in result.stdout
class TestGraphInheritance:
"""Test inheritance query type."""
def test_find_inheritance_basic(self, temp_project):
"""Test finding inheritance relationships."""
result = runner.invoke(app, [
"graph",
"inheritance",
"BaseClass",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "DerivedClass" in result.stdout
assert "Inheritance relationships" in result.stdout
def test_find_inheritance_derived(self, temp_project):
"""Test finding inheritance from derived class perspective."""
result = runner.invoke(app, [
"graph",
"inheritance",
"DerivedClass",
"--path", str(temp_project)
])
assert result.exit_code == 0
assert "BaseClass" in result.stdout
def test_find_inheritance_json_mode(self, temp_project):
"""Test inheritance query with JSON output."""
result = runner.invoke(app, [
"graph",
"inheritance",
"BaseClass",
"--path", str(temp_project),
"--json"
])
assert result.exit_code == 0
assert "success" in result.stdout
class TestGraphValidation:
"""Test query validation and error handling."""
def test_invalid_query_type(self, temp_project):
"""Test error handling for invalid query type."""
result = runner.invoke(app, [
"graph",
"invalid_type",
"symbol",
"--path", str(temp_project)
])
assert result.exit_code == 1
assert "Invalid query type" in result.stdout
def test_invalid_path(self):
"""Test error handling for non-existent path."""
result = runner.invoke(app, [
"graph",
"callers",
"symbol",
"--path", "/nonexistent/path"
])
# Should handle gracefully (may exit with error or return empty results)
assert result.exit_code in [0, 1]
class TestGraphPerformance:
"""Test graph query performance requirements."""
def test_query_response_time(self, temp_project):
"""Verify graph queries complete in under 1 second."""
import time
start = time.time()
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project)
])
elapsed = time.time() - start
assert result.exit_code == 0
assert elapsed < 1.0, f"Query took {elapsed:.2f}s, expected <1s"
def test_multiple_query_types(self, temp_project):
"""Test all three query types complete successfully."""
import time
queries = [
("callers", "add"),
("callees", "main"),
("inheritance", "BaseClass")
]
total_start = time.time()
for query_type, symbol in queries:
result = runner.invoke(app, [
"graph",
query_type,
symbol,
"--path", str(temp_project)
])
assert result.exit_code == 0
total_elapsed = time.time() - total_start
assert total_elapsed < 3.0, f"All queries took {total_elapsed:.2f}s, expected <3s"
class TestGraphOptions:
"""Test graph command options."""
def test_limit_option(self, temp_project):
"""Test limit option works correctly."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--limit", "1"
])
assert result.exit_code == 0
def test_depth_option(self, temp_project):
"""Test depth option works correctly."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--depth", "0"
])
assert result.exit_code == 0
def test_verbose_option(self, temp_project):
"""Test verbose option works correctly."""
result = runner.invoke(app, [
"graph",
"callers",
"add",
"--path", str(temp_project),
"--verbose"
])
assert result.exit_code == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,355 +0,0 @@
"""Tests for code relationship storage."""
import sqlite3
import tempfile
from pathlib import Path
import pytest
from codexlens.entities import CodeRelationship, IndexedFile, Symbol
from codexlens.storage.migration_manager import MigrationManager
from codexlens.storage.sqlite_store import SQLiteStore
@pytest.fixture
def temp_db():
"""Create a temporary database for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
yield db_path
@pytest.fixture
def store(temp_db):
"""Create a SQLiteStore with migrations applied."""
store = SQLiteStore(temp_db)
store.initialize()
# Manually apply migration_003 (code_relationships table)
conn = store._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT,
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
)
"""
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
)
conn.commit()
yield store
# Cleanup
store.close()
def test_relationship_table_created(store):
"""Test that the code_relationships table is created by migration."""
conn = store._get_connection()
cursor = conn.cursor()
# Check table exists
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='code_relationships'"
)
result = cursor.fetchone()
assert result is not None, "code_relationships table should exist"
# Check indexes exist
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='code_relationships'"
)
indexes = [row[0] for row in cursor.fetchall()]
assert "idx_relationships_source" in indexes
assert "idx_relationships_target" in indexes
assert "idx_relationships_type" in indexes
def test_add_relationships(store):
"""Test storing code relationships."""
# First add a file with symbols
indexed_file = IndexedFile(
path=str(Path(__file__).parent / "sample.py"),
language="python",
symbols=[
Symbol(name="foo", kind="function", range=(1, 5)),
Symbol(name="bar", kind="function", range=(7, 10)),
]
)
content = """def foo():
bar()
baz()
def bar():
print("hello")
"""
store.add_file(indexed_file, content)
# Add relationships
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=indexed_file.path,
target_file=None,
source_line=2
),
CodeRelationship(
source_symbol="foo",
target_symbol="baz",
relationship_type="call",
source_file=indexed_file.path,
target_file=None,
source_line=3
),
]
store.add_relationships(indexed_file.path, relationships)
# Verify relationships were stored
conn = store._get_connection()
count = conn.execute("SELECT COUNT(*) FROM code_relationships").fetchone()[0]
assert count == 2, "Should have stored 2 relationships"
def test_query_relationships_by_target(store):
"""Test querying relationships by target symbol (find callers)."""
# Setup: Add file and relationships
file_path = str(Path(__file__).parent / "sample.py")
# Content: Line 1-2: foo(), Line 4-5: bar(), Line 7-8: main()
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[
Symbol(name="foo", kind="function", range=(1, 2)),
Symbol(name="bar", kind="function", range=(4, 5)),
Symbol(name="main", kind="function", range=(7, 8)),
]
)
content = "def foo():\n bar()\n\ndef bar():\n pass\n\ndef main():\n bar()\n"
store.add_file(indexed_file, content)
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2 # Call inside foo (line 2)
),
CodeRelationship(
source_symbol="main",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=8 # Call inside main (line 8)
),
]
store.add_relationships(file_path, relationships)
# Query: Find all callers of "bar"
callers = store.query_relationships_by_target("bar")
assert len(callers) == 2, "Should find 2 callers of bar"
assert any(r["source_symbol"] == "foo" for r in callers)
assert any(r["source_symbol"] == "main" for r in callers)
assert all(r["target_symbol"] == "bar" for r in callers)
assert all(r["relationship_type"] == "call" for r in callers)
def test_query_relationships_by_source(store):
"""Test querying relationships by source symbol (find callees)."""
# Setup
file_path = str(Path(__file__).parent / "sample.py")
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[
Symbol(name="foo", kind="function", range=(1, 6)),
]
)
content = "def foo():\n bar()\n baz()\n qux()\n"
store.add_file(indexed_file, content)
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
),
CodeRelationship(
source_symbol="foo",
target_symbol="baz",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=3
),
CodeRelationship(
source_symbol="foo",
target_symbol="qux",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=4
),
]
store.add_relationships(file_path, relationships)
# Query: Find all functions called by foo
callees = store.query_relationships_by_source("foo", file_path)
assert len(callees) == 3, "Should find 3 functions called by foo"
targets = {r["target_symbol"] for r in callees}
assert targets == {"bar", "baz", "qux"}
assert all(r["source_symbol"] == "foo" for r in callees)
def test_query_performance(store):
"""Test that relationship queries execute within performance threshold."""
import time
# Setup: Create a file with many relationships
file_path = str(Path(__file__).parent / "large_file.py")
symbols = [Symbol(name=f"func_{i}", kind="function", range=(i*10+1, i*10+5)) for i in range(100)]
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=symbols
)
content = "\n".join([f"def func_{i}():\n pass\n" for i in range(100)])
store.add_file(indexed_file, content)
# Create many relationships
relationships = []
for i in range(100):
for j in range(10):
relationships.append(
CodeRelationship(
source_symbol=f"func_{i}",
target_symbol=f"target_{j}",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=i*10 + 1
)
)
store.add_relationships(file_path, relationships)
# Query and measure time
start = time.time()
results = store.query_relationships_by_target("target_5")
elapsed_ms = (time.time() - start) * 1000
assert len(results) == 100, "Should find 100 callers"
assert elapsed_ms < 50, f"Query took {elapsed_ms:.1f}ms, should be <50ms"
def test_stats_includes_relationships(store):
"""Test that stats() includes relationship count."""
# Add a file with relationships
file_path = str(Path(__file__).parent / "sample.py")
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[Symbol(name="foo", kind="function", range=(1, 5))]
)
store.add_file(indexed_file, "def foo():\n bar()\n")
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
)
]
store.add_relationships(file_path, relationships)
# Check stats
stats = store.stats()
assert "relationships" in stats
assert stats["relationships"] == 1
assert stats["files"] == 1
assert stats["symbols"] == 1
def test_update_relationships_on_file_reindex(store):
"""Test that relationships are updated when file is re-indexed."""
file_path = str(Path(__file__).parent / "sample.py")
# Initial index
indexed_file = IndexedFile(
path=file_path,
language="python",
symbols=[Symbol(name="foo", kind="function", range=(1, 3))]
)
store.add_file(indexed_file, "def foo():\n bar()\n")
relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="bar",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
)
]
store.add_relationships(file_path, relationships)
# Re-index with different relationships
new_relationships = [
CodeRelationship(
source_symbol="foo",
target_symbol="baz",
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=2
)
]
store.add_relationships(file_path, new_relationships)
# Verify old relationships are replaced
all_rels = store.query_relationships_by_source("foo", file_path)
assert len(all_rels) == 1
assert all_rels[0]["target_symbol"] == "baz"

View File

@@ -188,60 +188,3 @@ class TestTokenCountPerformance:
# Precomputed should be at least 10% faster
speedup = ((computed_time - precomputed_time) / computed_time) * 100
assert speedup >= 10.0, f"Speedup {speedup:.2f}% < 10% (computed={computed_time:.4f}s, precomputed={precomputed_time:.4f}s)"
class TestSymbolEntityTokenCount:
"""Tests for Symbol entity token_count field."""
def test_symbol_with_token_count(self):
"""Test creating Symbol with token_count."""
symbol = Symbol(
name="test_func",
kind="function",
range=(1, 10),
token_count=42
)
assert symbol.token_count == 42
def test_symbol_without_token_count(self):
"""Test creating Symbol without token_count (defaults to None)."""
symbol = Symbol(
name="test_func",
kind="function",
range=(1, 10)
)
assert symbol.token_count is None
def test_symbol_with_symbol_type(self):
"""Test creating Symbol with symbol_type."""
symbol = Symbol(
name="TestClass",
kind="class",
range=(1, 20),
symbol_type="class_definition"
)
assert symbol.symbol_type == "class_definition"
def test_symbol_token_count_validation(self):
"""Test that negative token counts are rejected."""
with pytest.raises(ValueError, match="token_count must be >= 0"):
Symbol(
name="test",
kind="function",
range=(1, 2),
token_count=-1
)
def test_symbol_zero_token_count(self):
"""Test that zero token count is allowed."""
symbol = Symbol(
name="empty",
kind="function",
range=(1, 1),
token_count=0
)
assert symbol.token_count == 0