mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
refactor: 移除图索引功能,修复内存泄露,优化嵌入生成
主要更改: 1. 移除图索引功能 (graph indexing) - 删除 graph_analyzer.py 及相关迁移文件 - 移除 CLI 的 graph 命令和 --enrich 标志 - 清理 chain_search.py 中的图查询方法 (370行) - 删除相关测试文件 2. 修复嵌入生成内存问题 - 重构 generate_embeddings.py 使用流式批处理 - 改用 embedding_manager 的内存安全实现 - 文件从 548 行精简到 259 行 (52.7% 减少) 3. 修复内存泄露 - chain_search.py: quick_search 使用 with 语句管理 ChainSearchEngine - embedding_manager.py: 使用 with 语句管理 VectorStore - vector_store.py: 添加暴力搜索内存警告 4. 代码清理 - 移除 Symbol 模型的 token_count 和 symbol_type 字段 - 清理相关测试用例 测试: 760 passed, 7 skipped 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,15 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate vector embeddings for existing CodexLens indexes.
|
||||
|
||||
This script processes all files in a CodexLens index database and generates
|
||||
semantic vector embeddings for code chunks. The embeddings are stored in the
|
||||
same SQLite database in the 'semantic_chunks' table.
|
||||
|
||||
Performance optimizations:
|
||||
- Parallel file processing using ProcessPoolExecutor
|
||||
- Batch embedding generation for efficient GPU/CPU utilization
|
||||
- Batch database writes to minimize I/O overhead
|
||||
- HNSW index auto-generation for fast similarity search
|
||||
This script is a CLI wrapper around the memory-efficient streaming implementation
|
||||
in codexlens.cli.embedding_manager. It uses batch processing to keep memory usage
|
||||
under 2GB regardless of project size.
|
||||
|
||||
Requirements:
|
||||
pip install codexlens[semantic]
|
||||
@@ -20,27 +14,21 @@ Usage:
|
||||
# Generate embeddings for a single index
|
||||
python generate_embeddings.py /path/to/_index.db
|
||||
|
||||
# Generate embeddings with parallel processing
|
||||
python generate_embeddings.py /path/to/_index.db --workers 4
|
||||
|
||||
# Use specific embedding model and batch size
|
||||
python generate_embeddings.py /path/to/_index.db --model code --batch-size 256
|
||||
# Use specific embedding model
|
||||
python generate_embeddings.py /path/to/_index.db --model code
|
||||
|
||||
# Generate embeddings for all indexes in a directory
|
||||
python generate_embeddings.py --scan ~/.codexlens/indexes
|
||||
|
||||
# Force regeneration
|
||||
python generate_embeddings.py /path/to/_index.db --force
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -50,100 +38,32 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileData:
|
||||
"""Data for a single file to process."""
|
||||
full_path: str
|
||||
content: str
|
||||
language: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkData:
|
||||
"""Processed chunk data ready for embedding."""
|
||||
file_path: str
|
||||
content: str
|
||||
metadata: dict
|
||||
# Import the memory-efficient implementation
|
||||
try:
|
||||
from codexlens.cli.embedding_manager import (
|
||||
generate_embeddings,
|
||||
generate_embeddings_recursive,
|
||||
)
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||
except ImportError as exc:
|
||||
logger.error(f"Failed to import codexlens: {exc}")
|
||||
logger.error("Make sure codexlens is installed: pip install codexlens")
|
||||
SEMANTIC_AVAILABLE = False
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""Check if semantic search dependencies are available."""
|
||||
try:
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
logger.error("Semantic search dependencies not available")
|
||||
logger.error("Install with: pip install codexlens[semantic]")
|
||||
logger.error("Or: pip install fastembed numpy hnswlib")
|
||||
return False
|
||||
return True
|
||||
except ImportError as exc:
|
||||
logger.error(f"Failed to import codexlens: {exc}")
|
||||
logger.error("Make sure codexlens is installed: pip install codexlens")
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
logger.error("Semantic search dependencies not available")
|
||||
logger.error("Install with: pip install codexlens[semantic]")
|
||||
logger.error("Or: pip install fastembed numpy hnswlib")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def count_files(index_db_path: Path) -> int:
|
||||
"""Count total files in index."""
|
||||
try:
|
||||
with sqlite3.connect(index_db_path) as conn:
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM files")
|
||||
return cursor.fetchone()[0]
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to count files: {exc}")
|
||||
return 0
|
||||
|
||||
|
||||
def check_existing_chunks(index_db_path: Path) -> int:
|
||||
"""Check if semantic chunks already exist."""
|
||||
try:
|
||||
with sqlite3.connect(index_db_path) as conn:
|
||||
# Check if table exists
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_chunks'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
return 0
|
||||
|
||||
# Count existing chunks
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM semantic_chunks")
|
||||
return cursor.fetchone()[0]
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def process_file_worker(args: Tuple[str, str, str, int]) -> List[ChunkData]:
|
||||
"""Worker function to process a single file (runs in separate process).
|
||||
|
||||
Args:
|
||||
args: Tuple of (file_path, content, language, chunk_size)
|
||||
|
||||
Returns:
|
||||
List of ChunkData objects
|
||||
"""
|
||||
file_path, content, language, chunk_size = args
|
||||
|
||||
try:
|
||||
from codexlens.semantic.chunker import Chunker, ChunkConfig
|
||||
|
||||
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
|
||||
chunks = chunker.chunk_sliding_window(
|
||||
content,
|
||||
file_path=file_path,
|
||||
language=language
|
||||
)
|
||||
|
||||
return [
|
||||
ChunkData(
|
||||
file_path=file_path,
|
||||
content=chunk.content,
|
||||
metadata=chunk.metadata or {}
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error processing {file_path}: {exc}")
|
||||
return []
|
||||
def progress_callback(message: str):
|
||||
"""Callback function for progress updates."""
|
||||
logger.info(message)
|
||||
|
||||
|
||||
def generate_embeddings_for_index(
|
||||
@@ -151,259 +71,63 @@ def generate_embeddings_for_index(
|
||||
model_profile: str = "code",
|
||||
force: bool = False,
|
||||
chunk_size: int = 2000,
|
||||
workers: int = 0,
|
||||
batch_size: int = 256,
|
||||
**kwargs # Ignore unused parameters (workers, batch_size) for backward compatibility
|
||||
) -> dict:
|
||||
"""Generate embeddings for all files in an index.
|
||||
"""Generate embeddings for an index using memory-efficient streaming.
|
||||
|
||||
Performance optimizations:
|
||||
- Parallel file processing (chunking)
|
||||
- Batch embedding generation
|
||||
- Batch database writes
|
||||
- HNSW index auto-generation
|
||||
This function wraps the streaming implementation from embedding_manager
|
||||
to maintain CLI compatibility while using the memory-optimized approach.
|
||||
|
||||
Args:
|
||||
index_db_path: Path to _index.db file
|
||||
model_profile: Model profile to use (fast, code, multilingual, balanced)
|
||||
force: If True, regenerate even if embeddings exist
|
||||
chunk_size: Maximum chunk size in characters
|
||||
workers: Number of parallel workers (0 = auto-detect CPU count)
|
||||
batch_size: Batch size for embedding generation
|
||||
**kwargs: Additional parameters (ignored for compatibility)
|
||||
|
||||
Returns:
|
||||
Dictionary with generation statistics
|
||||
"""
|
||||
logger.info(f"Processing index: {index_db_path}")
|
||||
|
||||
# Check existing chunks
|
||||
existing_chunks = check_existing_chunks(index_db_path)
|
||||
if existing_chunks > 0 and not force:
|
||||
logger.warning(f"Index already has {existing_chunks} chunks")
|
||||
logger.warning("Use --force to regenerate")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Embeddings already exist",
|
||||
"existing_chunks": existing_chunks,
|
||||
}
|
||||
# Call the memory-efficient streaming implementation
|
||||
result = generate_embeddings(
|
||||
index_path=index_db_path,
|
||||
model_profile=model_profile,
|
||||
force=force,
|
||||
chunk_size=chunk_size,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
if force and existing_chunks > 0:
|
||||
logger.info(f"Force mode: clearing {existing_chunks} existing chunks")
|
||||
try:
|
||||
with sqlite3.connect(index_db_path) as conn:
|
||||
conn.execute("DELETE FROM semantic_chunks")
|
||||
conn.commit()
|
||||
# Also remove HNSW index file
|
||||
hnsw_path = index_db_path.parent / "_vectors.hnsw"
|
||||
if hnsw_path.exists():
|
||||
hnsw_path.unlink()
|
||||
logger.info("Removed existing HNSW index")
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to clear existing data: {exc}")
|
||||
if not result["success"]:
|
||||
if "error" in result:
|
||||
logger.error(result["error"])
|
||||
return result
|
||||
|
||||
# Import dependencies
|
||||
try:
|
||||
from codexlens.semantic.embedder import Embedder
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
from codexlens.entities import SemanticChunk
|
||||
except ImportError as exc:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Import failed: {exc}",
|
||||
}
|
||||
|
||||
# Initialize components
|
||||
try:
|
||||
embedder = Embedder(profile=model_profile)
|
||||
vector_store = VectorStore(index_db_path)
|
||||
|
||||
logger.info(f"Using model: {embedder.model_name}")
|
||||
logger.info(f"Embedding dimension: {embedder.embedding_dim}")
|
||||
except Exception as exc:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to initialize components: {exc}",
|
||||
}
|
||||
|
||||
# Read files from index
|
||||
try:
|
||||
with sqlite3.connect(index_db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute("SELECT full_path, content, language FROM files")
|
||||
files = [
|
||||
FileData(
|
||||
full_path=row["full_path"],
|
||||
content=row["content"],
|
||||
language=row["language"] or "python"
|
||||
)
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
except Exception as exc:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to read files: {exc}",
|
||||
}
|
||||
|
||||
logger.info(f"Found {len(files)} files to process")
|
||||
if len(files) == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No files found in index",
|
||||
}
|
||||
|
||||
# Determine worker count
|
||||
if workers <= 0:
|
||||
workers = min(multiprocessing.cpu_count(), len(files), 8)
|
||||
logger.info(f"Using {workers} worker(s) for parallel processing")
|
||||
logger.info(f"Batch size for embeddings: {batch_size}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Phase 1: Parallel chunking
|
||||
logger.info("Phase 1: Chunking files...")
|
||||
chunk_start = time.time()
|
||||
|
||||
all_chunks: List[ChunkData] = []
|
||||
failed_files = []
|
||||
|
||||
# Prepare work items
|
||||
work_items = [
|
||||
(f.full_path, f.content, f.language, chunk_size)
|
||||
for f in files
|
||||
]
|
||||
|
||||
if workers == 1:
|
||||
# Single-threaded for debugging
|
||||
for i, item in enumerate(work_items, 1):
|
||||
try:
|
||||
chunks = process_file_worker(item)
|
||||
all_chunks.extend(chunks)
|
||||
if i % 100 == 0:
|
||||
logger.info(f"Chunked {i}/{len(files)} files ({len(all_chunks)} chunks)")
|
||||
except Exception as exc:
|
||||
failed_files.append((item[0], str(exc)))
|
||||
else:
|
||||
# Parallel processing
|
||||
with ProcessPoolExecutor(max_workers=workers) as executor:
|
||||
futures = {
|
||||
executor.submit(process_file_worker, item): item[0]
|
||||
for item in work_items
|
||||
}
|
||||
|
||||
completed = 0
|
||||
for future in as_completed(futures):
|
||||
file_path = futures[future]
|
||||
completed += 1
|
||||
try:
|
||||
chunks = future.result()
|
||||
all_chunks.extend(chunks)
|
||||
if completed % 100 == 0:
|
||||
logger.info(
|
||||
f"Chunked {completed}/{len(files)} files "
|
||||
f"({len(all_chunks)} chunks)"
|
||||
)
|
||||
except Exception as exc:
|
||||
failed_files.append((file_path, str(exc)))
|
||||
|
||||
chunk_time = time.time() - chunk_start
|
||||
logger.info(f"Chunking completed in {chunk_time:.1f}s: {len(all_chunks)} chunks")
|
||||
|
||||
if not all_chunks:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No chunks created from files",
|
||||
"files_processed": len(files) - len(failed_files),
|
||||
"files_failed": len(failed_files),
|
||||
}
|
||||
|
||||
# Phase 2: Batch embedding generation
|
||||
logger.info("Phase 2: Generating embeddings...")
|
||||
embed_start = time.time()
|
||||
|
||||
# Extract all content for batch embedding
|
||||
all_contents = [c.content for c in all_chunks]
|
||||
|
||||
# Generate embeddings in batches
|
||||
all_embeddings = []
|
||||
for i in range(0, len(all_contents), batch_size):
|
||||
batch_contents = all_contents[i:i + batch_size]
|
||||
batch_embeddings = embedder.embed(batch_contents)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
progress = min(i + batch_size, len(all_contents))
|
||||
if progress % (batch_size * 4) == 0 or progress == len(all_contents):
|
||||
logger.info(f"Generated embeddings: {progress}/{len(all_contents)}")
|
||||
|
||||
embed_time = time.time() - embed_start
|
||||
logger.info(f"Embedding completed in {embed_time:.1f}s")
|
||||
|
||||
# Phase 3: Batch database write
|
||||
logger.info("Phase 3: Storing chunks...")
|
||||
store_start = time.time()
|
||||
|
||||
# Create SemanticChunk objects with embeddings
|
||||
semantic_chunks_with_paths = []
|
||||
for chunk_data, embedding in zip(all_chunks, all_embeddings):
|
||||
semantic_chunk = SemanticChunk(
|
||||
content=chunk_data.content,
|
||||
metadata=chunk_data.metadata,
|
||||
)
|
||||
semantic_chunk.embedding = embedding
|
||||
semantic_chunks_with_paths.append((semantic_chunk, chunk_data.file_path))
|
||||
|
||||
# Batch write (handles both SQLite and HNSW)
|
||||
write_batch_size = 1000
|
||||
total_stored = 0
|
||||
for i in range(0, len(semantic_chunks_with_paths), write_batch_size):
|
||||
batch = semantic_chunks_with_paths[i:i + write_batch_size]
|
||||
vector_store.add_chunks_batch(batch)
|
||||
total_stored += len(batch)
|
||||
if total_stored % 5000 == 0 or total_stored == len(semantic_chunks_with_paths):
|
||||
logger.info(f"Stored: {total_stored}/{len(semantic_chunks_with_paths)} chunks")
|
||||
|
||||
store_time = time.time() - store_start
|
||||
logger.info(f"Storage completed in {store_time:.1f}s")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Generate summary
|
||||
# Extract result data and log summary
|
||||
data = result["result"]
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Completed in {elapsed_time:.1f}s")
|
||||
logger.info(f" Chunking: {chunk_time:.1f}s")
|
||||
logger.info(f" Embedding: {embed_time:.1f}s")
|
||||
logger.info(f" Storage: {store_time:.1f}s")
|
||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||
logger.info(f"Files processed: {len(files) - len(failed_files)}/{len(files)}")
|
||||
if vector_store.ann_available:
|
||||
logger.info(f"HNSW index vectors: {vector_store.ann_count}")
|
||||
if failed_files:
|
||||
logger.warning(f"Failed files: {len(failed_files)}")
|
||||
for file_path, error in failed_files[:5]: # Show first 5 failures
|
||||
logger.warning(f" {file_path}: {error}")
|
||||
logger.info(f"Completed in {data['elapsed_time']:.1f}s")
|
||||
logger.info(f"Total chunks created: {data['chunks_created']}")
|
||||
logger.info(f"Files processed: {data['files_processed']}")
|
||||
if data['files_failed'] > 0:
|
||||
logger.warning(f"Failed files: {data['files_failed']}")
|
||||
if data.get('failed_files'):
|
||||
for file_path, error in data['failed_files']:
|
||||
logger.warning(f" {file_path}: {error}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"chunks_created": len(all_chunks),
|
||||
"files_processed": len(files) - len(failed_files),
|
||||
"files_failed": len(failed_files),
|
||||
"elapsed_time": elapsed_time,
|
||||
"chunk_time": chunk_time,
|
||||
"embed_time": embed_time,
|
||||
"store_time": store_time,
|
||||
"ann_vectors": vector_store.ann_count if vector_store.ann_available else 0,
|
||||
"chunks_created": data["chunks_created"],
|
||||
"files_processed": data["files_processed"],
|
||||
"files_failed": data["files_failed"],
|
||||
"elapsed_time": data["elapsed_time"],
|
||||
}
|
||||
|
||||
|
||||
def find_index_databases(scan_dir: Path) -> List[Path]:
|
||||
"""Find all _index.db files in directory tree."""
|
||||
logger.info(f"Scanning for indexes in: {scan_dir}")
|
||||
index_files = list(scan_dir.rglob("_index.db"))
|
||||
logger.info(f"Found {len(index_files)} index databases")
|
||||
return index_files
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate vector embeddings for CodexLens indexes",
|
||||
description="Generate vector embeddings for CodexLens indexes (memory-efficient streaming)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__
|
||||
)
|
||||
@@ -439,14 +163,14 @@ def main():
|
||||
"--workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of parallel workers for chunking (default: auto-detect CPU count)"
|
||||
help="(Deprecated) Kept for backward compatibility, ignored"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Batch size for embedding generation (default: 256)"
|
||||
help="(Deprecated) Kept for backward compatibility, ignored"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -481,43 +205,33 @@ def main():
|
||||
|
||||
# Determine if scanning or single file
|
||||
if args.scan or index_path.is_dir():
|
||||
# Scan mode
|
||||
# Scan mode - use recursive implementation
|
||||
if index_path.is_file():
|
||||
logger.error("--scan requires a directory path")
|
||||
sys.exit(1)
|
||||
|
||||
index_files = find_index_databases(index_path)
|
||||
if not index_files:
|
||||
logger.error(f"No index databases found in: {index_path}")
|
||||
result = generate_embeddings_recursive(
|
||||
index_root=index_path,
|
||||
model_profile=args.model,
|
||||
force=args.force,
|
||||
chunk_size=args.chunk_size,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
logger.error(f"Failed: {result.get('error', 'Unknown error')}")
|
||||
sys.exit(1)
|
||||
|
||||
# Process each index
|
||||
total_chunks = 0
|
||||
successful = 0
|
||||
for idx, index_file in enumerate(index_files, 1):
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Processing index {idx}/{len(index_files)}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
result = generate_embeddings_for_index(
|
||||
index_file,
|
||||
model_profile=args.model,
|
||||
force=args.force,
|
||||
chunk_size=args.chunk_size,
|
||||
workers=args.workers,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
total_chunks += result["chunks_created"]
|
||||
successful += 1
|
||||
|
||||
# Final summary
|
||||
# Log summary
|
||||
data = result["result"]
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info("BATCH PROCESSING COMPLETE")
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Indexes processed: {successful}/{len(index_files)}")
|
||||
logger.info(f"Total chunks created: {total_chunks}")
|
||||
logger.info(f"Indexes processed: {data['indexes_successful']}/{data['indexes_processed']}")
|
||||
logger.info(f"Total chunks created: {data['total_chunks_created']}")
|
||||
logger.info(f"Total files processed: {data['total_files_processed']}")
|
||||
if data['total_files_failed'] > 0:
|
||||
logger.warning(f"Total files failed: {data['total_files_failed']}")
|
||||
|
||||
else:
|
||||
# Single index mode
|
||||
@@ -530,8 +244,6 @@ def main():
|
||||
model_profile=args.model,
|
||||
force=args.force,
|
||||
chunk_size=args.chunk_size,
|
||||
workers=args.workers,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
|
||||
@@ -268,7 +268,6 @@ def search(
|
||||
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
|
||||
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
|
||||
weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."),
|
||||
enrich: bool = typer.Option(False, "--enrich", help="Enrich results with code graph relationships (calls, imports)."),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
@@ -423,30 +422,10 @@ def search(
|
||||
for r in result.results
|
||||
]
|
||||
|
||||
# Enrich results with relationship data if requested
|
||||
enriched = False
|
||||
if enrich:
|
||||
try:
|
||||
from codexlens.search.enrichment import RelationshipEnricher
|
||||
|
||||
# Find index path for the search path
|
||||
project_record = registry.find_by_source_path(str(search_path))
|
||||
if project_record:
|
||||
index_path = Path(project_record["index_root"]) / "_index.db"
|
||||
if index_path.exists():
|
||||
with RelationshipEnricher(index_path) as enricher:
|
||||
results_list = enricher.enrich(results_list, limit=limit)
|
||||
enriched = True
|
||||
except Exception as e:
|
||||
# Enrichment failure should not break search
|
||||
if verbose:
|
||||
console.print(f"[yellow]Warning: Enrichment failed: {e}[/yellow]")
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"mode": actual_mode,
|
||||
"count": len(results_list),
|
||||
"enriched": enriched,
|
||||
"results": results_list,
|
||||
"stats": {
|
||||
"dirs_searched": result.stats.dirs_searched,
|
||||
@@ -458,8 +437,7 @@ def search(
|
||||
print_json(success=True, result=payload)
|
||||
else:
|
||||
render_search_results(result.results, verbose=verbose)
|
||||
enrich_status = " | [green]Enriched[/green]" if enriched else ""
|
||||
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms{enrich_status}[/dim]")
|
||||
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms[/dim]")
|
||||
|
||||
except SearchError as exc:
|
||||
if json_mode:
|
||||
@@ -1376,103 +1354,6 @@ def clean(
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def graph(
|
||||
query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"),
|
||||
symbol: str = typer.Argument(..., help="Symbol name to query"),
|
||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
|
||||
limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."),
|
||||
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
"""Query semantic graph for code relationships.
|
||||
|
||||
Supported query types:
|
||||
- callers: Find all functions/methods that call the given symbol
|
||||
- callees: Find all functions/methods called by the given symbol
|
||||
- inheritance: Find inheritance relationships for the given class
|
||||
|
||||
Examples:
|
||||
codex-lens graph callers my_function
|
||||
codex-lens graph callees MyClass.method --path src/
|
||||
codex-lens graph inheritance BaseClass
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
search_path = path.expanduser().resolve()
|
||||
|
||||
# Validate query type
|
||||
valid_types = ["callers", "callees", "inheritance"]
|
||||
if query_type not in valid_types:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}")
|
||||
else:
|
||||
console.print(f"[red]Invalid query type:[/red] {query_type}")
|
||||
console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
registry: RegistryStore | None = None
|
||||
try:
|
||||
registry = RegistryStore()
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper)
|
||||
options = SearchOptions(depth=depth, total_limit=limit)
|
||||
|
||||
# Execute graph query based on type
|
||||
if query_type == "callers":
|
||||
results = engine.search_callers(symbol, search_path, options=options)
|
||||
result_type = "callers"
|
||||
elif query_type == "callees":
|
||||
results = engine.search_callees(symbol, search_path, options=options)
|
||||
result_type = "callees"
|
||||
else: # inheritance
|
||||
results = engine.search_inheritance(symbol, search_path, options=options)
|
||||
result_type = "inheritance"
|
||||
|
||||
payload = {
|
||||
"query_type": query_type,
|
||||
"symbol": symbol,
|
||||
"count": len(results),
|
||||
"relationships": results
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
print_json(success=True, result=payload)
|
||||
else:
|
||||
from .output import render_graph_results
|
||||
render_graph_results(results, query_type=query_type, symbol=symbol)
|
||||
|
||||
except SearchError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Graph search error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (search):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except StorageError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Storage error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (storage):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except CodexLensError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=str(exc))
|
||||
else:
|
||||
console.print(f"[red]Graph query failed:[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Unexpected error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (unexpected):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
finally:
|
||||
if registry is not None:
|
||||
registry.close()
|
||||
|
||||
|
||||
@app.command("semantic-list")
|
||||
def semantic_list(
|
||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),
|
||||
|
||||
@@ -194,7 +194,6 @@ def generate_embeddings(
|
||||
try:
|
||||
# Use cached embedder (singleton) for performance
|
||||
embedder = get_embedder(profile=model_profile)
|
||||
vector_store = VectorStore(index_path)
|
||||
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
|
||||
|
||||
if progress_callback:
|
||||
@@ -217,85 +216,86 @@ def generate_embeddings(
|
||||
EMBEDDING_BATCH_SIZE = 8 # jina-embeddings-v2-base-code needs small batches
|
||||
|
||||
try:
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
path_column = _get_path_column(conn)
|
||||
with VectorStore(index_path) as vector_store:
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
path_column = _get_path_column(conn)
|
||||
|
||||
# Get total file count for progress reporting
|
||||
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
|
||||
if total_files == 0:
|
||||
return {"success": False, "error": "No files found in index"}
|
||||
# Get total file count for progress reporting
|
||||
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
|
||||
if total_files == 0:
|
||||
return {"success": False, "error": "No files found in index"}
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
|
||||
|
||||
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
|
||||
batch_number = 0
|
||||
|
||||
while True:
|
||||
# Fetch a batch of files (streaming, not fetchall)
|
||||
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
|
||||
if not file_batch:
|
||||
break
|
||||
|
||||
batch_number += 1
|
||||
batch_chunks_with_paths = []
|
||||
files_in_batch_with_chunks = set()
|
||||
|
||||
# Step 1: Chunking for the current file batch
|
||||
for file_row in file_batch:
|
||||
file_path = file_row[path_column]
|
||||
content = file_row["content"]
|
||||
language = file_row["language"] or "python"
|
||||
|
||||
try:
|
||||
chunks = chunker.chunk_sliding_window(
|
||||
content,
|
||||
file_path=file_path,
|
||||
language=language
|
||||
)
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
batch_chunks_with_paths.append((chunk, file_path))
|
||||
files_in_batch_with_chunks.add(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to chunk {file_path}: {e}")
|
||||
failed_files.append((file_path, str(e)))
|
||||
|
||||
if not batch_chunks_with_paths:
|
||||
continue
|
||||
|
||||
batch_chunk_count = len(batch_chunks_with_paths)
|
||||
if progress_callback:
|
||||
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
|
||||
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
|
||||
|
||||
# Step 2: Generate embeddings for this batch
|
||||
batch_embeddings = []
|
||||
try:
|
||||
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
|
||||
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
|
||||
embeddings = embedder.embed(batch_contents)
|
||||
batch_embeddings.extend(embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
continue
|
||||
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
|
||||
batch_number = 0
|
||||
|
||||
# Step 3: Assign embeddings to chunks
|
||||
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
|
||||
chunk.embedding = embedding
|
||||
while True:
|
||||
# Fetch a batch of files (streaming, not fetchall)
|
||||
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
|
||||
if not file_batch:
|
||||
break
|
||||
|
||||
# Step 4: Store this batch to database immediately (releases memory)
|
||||
try:
|
||||
vector_store.add_chunks_batch(batch_chunks_with_paths)
|
||||
total_chunks_created += batch_chunk_count
|
||||
total_files_processed += len(files_in_batch_with_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
batch_number += 1
|
||||
batch_chunks_with_paths = []
|
||||
files_in_batch_with_chunks = set()
|
||||
|
||||
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
|
||||
# Step 1: Chunking for the current file batch
|
||||
for file_row in file_batch:
|
||||
file_path = file_row[path_column]
|
||||
content = file_row["content"]
|
||||
language = file_row["language"] or "python"
|
||||
|
||||
try:
|
||||
chunks = chunker.chunk_sliding_window(
|
||||
content,
|
||||
file_path=file_path,
|
||||
language=language
|
||||
)
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
batch_chunks_with_paths.append((chunk, file_path))
|
||||
files_in_batch_with_chunks.add(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to chunk {file_path}: {e}")
|
||||
failed_files.append((file_path, str(e)))
|
||||
|
||||
if not batch_chunks_with_paths:
|
||||
continue
|
||||
|
||||
batch_chunk_count = len(batch_chunks_with_paths)
|
||||
if progress_callback:
|
||||
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
|
||||
|
||||
# Step 2: Generate embeddings for this batch
|
||||
batch_embeddings = []
|
||||
try:
|
||||
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
|
||||
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
|
||||
embeddings = embedder.embed(batch_contents)
|
||||
batch_embeddings.extend(embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
continue
|
||||
|
||||
# Step 3: Assign embeddings to chunks
|
||||
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
|
||||
chunk.embedding = embedding
|
||||
|
||||
# Step 4: Store this batch to database immediately (releases memory)
|
||||
try:
|
||||
vector_store.add_chunks_batch(batch_chunks_with_paths)
|
||||
total_chunks_created += batch_chunk_count
|
||||
total_files_processed += len(files_in_batch_with_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
|
||||
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Failed to read or process files: {str(e)}"}
|
||||
|
||||
@@ -122,68 +122,3 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) ->
|
||||
console.print(header)
|
||||
render_symbols(list(symbols), title="Discovered Symbols")
|
||||
|
||||
|
||||
def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None:
|
||||
"""Render semantic graph query results.
|
||||
|
||||
Args:
|
||||
results: List of relationship dicts
|
||||
query_type: Type of query (callers, callees, inheritance)
|
||||
symbol: Symbol name that was queried
|
||||
"""
|
||||
if not results:
|
||||
console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}")
|
||||
return
|
||||
|
||||
title_map = {
|
||||
"callers": f"Callers of '{symbol}' ({len(results)} found)",
|
||||
"callees": f"Callees of '{symbol}' ({len(results)} found)",
|
||||
"inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)"
|
||||
}
|
||||
|
||||
table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})"))
|
||||
|
||||
if query_type == "callers":
|
||||
table.add_column("Caller", style="green")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
table.add_column("Type", style="dim")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("source_symbol", "-"),
|
||||
rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-")),
|
||||
rel.get("relationship_type", "-")
|
||||
)
|
||||
|
||||
elif query_type == "callees":
|
||||
table.add_column("Target", style="green")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
table.add_column("Type", style="dim")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("target_symbol", "-"),
|
||||
rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-")),
|
||||
rel.get("relationship_type", "-")
|
||||
)
|
||||
|
||||
else: # inheritance
|
||||
table.add_column("Derived Class", style="green")
|
||||
table.add_column("Base Class", style="magenta")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("source_symbol", "-"),
|
||||
rel.get("target_symbol", "-"),
|
||||
rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-"))
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ class Symbol(BaseModel):
|
||||
kind: str = Field(..., min_length=1)
|
||||
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
||||
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
|
||||
token_count: Optional[int] = Field(default=None, description="Token count for symbol content")
|
||||
symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering")
|
||||
|
||||
@field_validator("range")
|
||||
@classmethod
|
||||
@@ -29,13 +27,6 @@ class Symbol(BaseModel):
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
return value
|
||||
|
||||
@field_validator("token_count")
|
||||
@classmethod
|
||||
def validate_token_count(cls, value: Optional[int]) -> Optional[int]:
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("token_count must be >= 0")
|
||||
return value
|
||||
|
||||
|
||||
class SemanticChunk(BaseModel):
|
||||
"""A semantically meaningful chunk of content, optionally embedded."""
|
||||
|
||||
@@ -302,108 +302,6 @@ class ChainSearchEngine:
|
||||
index_paths, name, kind, options.total_limit
|
||||
)
|
||||
|
||||
def search_callers(self, target_symbol: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find all callers of a given symbol across directory hierarchy.
|
||||
|
||||
Args:
|
||||
target_symbol: Name of the symbol to find callers for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with caller information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> callers = engine.search_callers("my_function", Path("D:/project"))
|
||||
>>> for caller in callers:
|
||||
... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_callers_parallel(
|
||||
index_paths, target_symbol, options.total_limit
|
||||
)
|
||||
|
||||
def search_callees(self, source_symbol: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find all callees (what a symbol calls) across directory hierarchy.
|
||||
|
||||
Args:
|
||||
source_symbol: Name of the symbol to find callees for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with callee information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> callees = engine.search_callees("MyClass.method", Path("D:/project"))
|
||||
>>> for callee in callees:
|
||||
... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_callees_parallel(
|
||||
index_paths, source_symbol, options.total_limit
|
||||
)
|
||||
|
||||
def search_inheritance(self, class_name: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find inheritance relationships for a class across directory hierarchy.
|
||||
|
||||
Args:
|
||||
class_name: Name of the class to find inheritance for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with inheritance information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project"))
|
||||
>>> for rel in inheritance:
|
||||
... print(f"{rel['source_symbol']} extends {rel['target_symbol']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_inheritance_parallel(
|
||||
index_paths, class_name, options.total_limit
|
||||
)
|
||||
|
||||
# === Internal Methods ===
|
||||
|
||||
def _find_start_index(self, source_path: Path) -> Optional[Path]:
|
||||
@@ -711,273 +609,6 @@ class ChainSearchEngine:
|
||||
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_callers_parallel(self, index_paths: List[Path],
|
||||
target_symbol: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for callers across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
target_symbol: Target symbol name
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of caller relationships
|
||||
"""
|
||||
all_callers = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_callers_single,
|
||||
idx_path,
|
||||
target_symbol
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
callers = future.result()
|
||||
all_callers.extend(callers)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Caller search failed: {exc}")
|
||||
|
||||
# Deduplicate by (source_file, source_line)
|
||||
seen = set()
|
||||
unique_callers = []
|
||||
for caller in all_callers:
|
||||
key = (caller.get("source_file"), caller.get("source_line"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_callers.append(caller)
|
||||
|
||||
# Sort by source file and line
|
||||
unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0)))
|
||||
|
||||
return unique_callers[:limit]
|
||||
|
||||
def _search_callers_single(self, index_path: Path,
|
||||
target_symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Search for callers in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
target_symbol: Target symbol name
|
||||
|
||||
Returns:
|
||||
List of caller relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
return store.query_relationships_by_target(target_symbol)
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Caller search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_callees_parallel(self, index_paths: List[Path],
|
||||
source_symbol: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for callees across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
source_symbol: Source symbol name
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of callee relationships
|
||||
"""
|
||||
all_callees = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_callees_single,
|
||||
idx_path,
|
||||
source_symbol
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
callees = future.result()
|
||||
all_callees.extend(callees)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Callee search failed: {exc}")
|
||||
|
||||
# Deduplicate by (target_symbol, source_line)
|
||||
seen = set()
|
||||
unique_callees = []
|
||||
for callee in all_callees:
|
||||
key = (callee.get("target_symbol"), callee.get("source_line"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_callees.append(callee)
|
||||
|
||||
# Sort by source line
|
||||
unique_callees.sort(key=lambda c: c.get("source_line", 0))
|
||||
|
||||
return unique_callees[:limit]
|
||||
|
||||
def _search_callees_single(self, index_path: Path,
|
||||
source_symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Search for callees in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
source_symbol: Source symbol name
|
||||
|
||||
Returns:
|
||||
List of callee relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
# Single JOIN query to get all callees (fixes N+1 query problem)
|
||||
# Uses public execute_query API instead of _get_connection bypass
|
||||
rows = store.execute_query(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name AS target_symbol,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ? AND r.relationship_type = 'call'
|
||||
ORDER BY f.full_path, r.source_line
|
||||
LIMIT 100
|
||||
""",
|
||||
(source_symbol,)
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_symbol"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Callee search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_inheritance_parallel(self, index_paths: List[Path],
|
||||
class_name: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for inheritance relationships across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
class_name: Class name to search for
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of inheritance relationships
|
||||
"""
|
||||
all_inheritance = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_inheritance_single,
|
||||
idx_path,
|
||||
class_name
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
inheritance = future.result()
|
||||
all_inheritance.extend(inheritance)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Inheritance search failed: {exc}")
|
||||
|
||||
# Deduplicate by (source_symbol, target_symbol)
|
||||
seen = set()
|
||||
unique_inheritance = []
|
||||
for rel in all_inheritance:
|
||||
key = (rel.get("source_symbol"), rel.get("target_symbol"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_inheritance.append(rel)
|
||||
|
||||
# Sort by source file
|
||||
unique_inheritance.sort(key=lambda r: r.get("source_file", ""))
|
||||
|
||||
return unique_inheritance[:limit]
|
||||
|
||||
def _search_inheritance_single(self, index_path: Path,
|
||||
class_name: str) -> List[Dict[str, Any]]:
|
||||
"""Search for inheritance relationships in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
class_name: Class name to search for
|
||||
|
||||
Returns:
|
||||
List of inheritance relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
# Use UNION to find relationships where class is either:
|
||||
# 1. The base class (target) - find derived classes
|
||||
# 2. The derived class (source) - find parent classes
|
||||
# Uses public execute_query API instead of _get_connection bypass
|
||||
rows = store.execute_query(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE r.target_qualified_name = ? AND r.relationship_type = 'inherits'
|
||||
UNION
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ? AND r.relationship_type = 'inherits'
|
||||
LIMIT 100
|
||||
""",
|
||||
(class_name, class_name)
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Inheritance search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
|
||||
# === Convenience Functions ===
|
||||
|
||||
@@ -1007,10 +638,9 @@ def quick_search(query: str,
|
||||
|
||||
mapper = PathMapper()
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper)
|
||||
options = SearchOptions(depth=depth)
|
||||
|
||||
result = engine.search(query, source_path, options)
|
||||
with ChainSearchEngine(registry, mapper) as engine:
|
||||
options = SearchOptions(depth=depth)
|
||||
result = engine.search(query, source_path, options)
|
||||
|
||||
registry.close()
|
||||
|
||||
|
||||
@@ -1,542 +0,0 @@
|
||||
"""Graph analyzer for extracting code relationships using tree-sitter.
|
||||
|
||||
Provides AST-based analysis to identify function calls, method invocations,
|
||||
and class inheritance relationships within source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
TREE_SITTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import CodeRelationship, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
class GraphAnalyzer:
|
||||
"""Analyzer for extracting semantic relationships from code using AST traversal."""
|
||||
|
||||
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
|
||||
"""Initialize graph analyzer for a language.
|
||||
|
||||
Args:
|
||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||
parser: Optional TreeSitterSymbolParser instance for dependency injection.
|
||||
If None, creates a new parser instance (backward compatibility).
|
||||
"""
|
||||
self.language_id = language_id
|
||||
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if graph analyzer is available.
|
||||
|
||||
Returns:
|
||||
True if tree-sitter parser is initialized and ready
|
||||
"""
|
||||
return self._parser.is_available()
|
||||
|
||||
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
|
||||
"""Analyze source code and extract relationships.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def analyze_with_symbols(
|
||||
self, text: str, file_path: Path, symbols: List[Symbol]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
# Convert Symbol objects to internal symbol format
|
||||
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
|
||||
|
||||
# Extract relationships using provided symbols
|
||||
relationships = self._extract_relationships_with_symbols(
|
||||
source_bytes, root, str(file_path.resolve()), defined_symbols
|
||||
)
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def _convert_symbols_to_dict(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
|
||||
) -> List[dict]:
|
||||
"""Convert Symbol objects to internal dict format for relationship extraction.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
symbols: Pre-parsed Symbol objects
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbol_dicts = []
|
||||
symbol_names = {s.name for s in symbols}
|
||||
|
||||
# Find AST nodes corresponding to symbols
|
||||
for node in self._iter_nodes(root):
|
||||
node_type = node.type
|
||||
|
||||
# Check if this node matches any of our symbols
|
||||
if node_type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node_type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
|
||||
return symbol_dicts
|
||||
|
||||
def _extract_relationships_with_symbols(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST using pre-parsed symbols.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
defined_symbols: Pre-parsed symbol dicts
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# Determine call node type based on language
|
||||
if self.language_id == "python":
|
||||
call_node_type = "call"
|
||||
extract_target = self._extract_call_target
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
call_node_type = "call_expression"
|
||||
extract_target = self._extract_js_call_target
|
||||
else:
|
||||
return []
|
||||
|
||||
# Find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == call_node_type:
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = extract_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_relationships(source_bytes, root, file_path)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_relationships(source_bytes, root, file_path)
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_python_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract Python relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of Python relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols with their scopes
|
||||
defined_symbols = self._collect_python_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call":
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = self._extract_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_js_ts_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract JavaScript/TypeScript relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of JS/TS relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols
|
||||
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call_expression":
|
||||
# Extract caller context
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee
|
||||
target_symbol = self._extract_js_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all Python function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all JS/TS function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
|
||||
"""Find the enclosing function/method/class for a node.
|
||||
|
||||
Returns fully qualified name (e.g., "MyClass.my_method") by traversing up
|
||||
the AST tree and collecting parent class/function names.
|
||||
|
||||
Args:
|
||||
node: AST node to find enclosure for
|
||||
symbols: List of defined symbols
|
||||
|
||||
Returns:
|
||||
Fully qualified name of enclosing symbol, or None if at module level
|
||||
"""
|
||||
# Walk up the tree to find all enclosing symbols
|
||||
enclosing_names = []
|
||||
parent = node.parent
|
||||
|
||||
while parent is not None:
|
||||
for symbol in symbols:
|
||||
if symbol["node"] == parent:
|
||||
# Prepend to maintain order (innermost to outermost)
|
||||
enclosing_names.insert(0, symbol["name"])
|
||||
break
|
||||
parent = parent.parent
|
||||
|
||||
# Return fully qualified name or None if at module level
|
||||
if enclosing_names:
|
||||
return ".".join(enclosing_names)
|
||||
return None
|
||||
|
||||
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a Python call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers (e.g., "foo()")
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle attribute access (e.g., "obj.method()")
|
||||
if function_node.type == "attribute":
|
||||
attr_node = function_node.child_by_field_name("attribute")
|
||||
if attr_node:
|
||||
return self._node_text(source_bytes, attr_node)
|
||||
|
||||
return None
|
||||
|
||||
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a JS/TS call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle member expressions (e.g., "obj.method()")
|
||||
if function_node.type == "member_expression":
|
||||
property_node = function_node.child_by_field_name("property")
|
||||
if property_node:
|
||||
return self._node_text(source_bytes, property_node)
|
||||
|
||||
return None
|
||||
|
||||
def _iter_nodes(self, root: TreeSitterNode):
|
||||
"""Iterate over all nodes in AST.
|
||||
|
||||
Args:
|
||||
root: Root node to start iteration
|
||||
|
||||
Yields:
|
||||
AST nodes in depth-first order
|
||||
"""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
"""Extract text for a node.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
Text content of node
|
||||
"""
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
@@ -602,6 +602,12 @@ class VectorStore:
|
||||
Returns:
|
||||
List of SearchResult ordered by similarity (highest first)
|
||||
"""
|
||||
logger.warning(
|
||||
"Using brute-force vector search (hnswlib not available). "
|
||||
"This may cause high memory usage for large indexes. "
|
||||
"Install hnswlib for better performance: pip install hnswlib"
|
||||
)
|
||||
|
||||
with self._cache_lock:
|
||||
# Refresh cache if needed
|
||||
if self._embedding_matrix is None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import CodeRelationship, SearchResult, Symbol
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
|
||||
@@ -237,116 +237,6 @@ class DirIndexStore:
|
||||
conn.rollback()
|
||||
raise StorageError(f"Failed to add file {name}: {exc}") from exc
|
||||
|
||||
def add_relationships(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
relationships: List[CodeRelationship],
|
||||
) -> int:
|
||||
"""Store code relationships for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file
|
||||
relationships: List of CodeRelationship objects to store
|
||||
|
||||
Returns:
|
||||
Number of relationships stored
|
||||
|
||||
Raises:
|
||||
StorageError: If database operations fail
|
||||
"""
|
||||
if not relationships:
|
||||
return 0
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
try:
|
||||
# Get file_id
|
||||
row = conn.execute(
|
||||
"SELECT id FROM files WHERE full_path=?", (file_path_str,)
|
||||
).fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
|
||||
file_id = int(row["id"])
|
||||
|
||||
# Delete existing relationships for symbols in this file
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM code_relationships
|
||||
WHERE source_symbol_id IN (
|
||||
SELECT id FROM symbols WHERE file_id=?
|
||||
)
|
||||
""",
|
||||
(file_id,),
|
||||
)
|
||||
|
||||
# Insert new relationships
|
||||
relationship_rows = []
|
||||
skipped_relationships = []
|
||||
for rel in relationships:
|
||||
# Extract simple name from fully qualified name (e.g., "MyClass.my_method" -> "my_method")
|
||||
# This handles cases where GraphAnalyzer generates qualified names but symbols table stores simple names
|
||||
source_symbol_simple = rel.source_symbol.split(".")[-1] if "." in rel.source_symbol else rel.source_symbol
|
||||
|
||||
# Find symbol_id by name and file
|
||||
symbol_row = conn.execute(
|
||||
"""
|
||||
SELECT id FROM symbols
|
||||
WHERE file_id=? AND name=? AND start_line<=? AND end_line>=?
|
||||
LIMIT 1
|
||||
""",
|
||||
(file_id, source_symbol_simple, rel.source_line, rel.source_line),
|
||||
).fetchone()
|
||||
|
||||
if not symbol_row:
|
||||
# Try matching by simple name only
|
||||
symbol_row = conn.execute(
|
||||
"SELECT id FROM symbols WHERE file_id=? AND name=? LIMIT 1",
|
||||
(file_id, source_symbol_simple),
|
||||
).fetchone()
|
||||
|
||||
if symbol_row:
|
||||
relationship_rows.append((
|
||||
int(symbol_row["id"]),
|
||||
rel.target_symbol,
|
||||
rel.relationship_type,
|
||||
rel.source_line,
|
||||
rel.target_file,
|
||||
))
|
||||
else:
|
||||
# Log warning when symbol lookup fails
|
||||
skipped_relationships.append(rel.source_symbol)
|
||||
|
||||
# Log skipped relationships for debugging
|
||||
if skipped_relationships:
|
||||
self.logger.warning(
|
||||
"Failed to find source symbol IDs for %d relationships in %s: %s",
|
||||
len(skipped_relationships),
|
||||
file_path_str,
|
||||
", ".join(set(skipped_relationships))
|
||||
)
|
||||
|
||||
if relationship_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO code_relationships(
|
||||
source_symbol_id, target_qualified_name, relationship_type,
|
||||
source_line, target_file
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
""",
|
||||
relationship_rows,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return len(relationship_rows)
|
||||
|
||||
except sqlite3.DatabaseError as exc:
|
||||
conn.rollback()
|
||||
raise StorageError(f"Failed to add relationships: {exc}") from exc
|
||||
|
||||
def add_files_batch(
|
||||
self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]]
|
||||
) -> int:
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Set
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.parsers.factory import ParserFactory
|
||||
from codexlens.semantic.graph_analyzer import GraphAnalyzer
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import ProjectInfo, RegistryStore
|
||||
@@ -525,16 +524,6 @@ class IndexTreeBuilder:
|
||||
symbols=indexed_file.symbols,
|
||||
)
|
||||
|
||||
# Extract and store code relationships for graph visualization
|
||||
if language_id in {"python", "javascript", "typescript"}:
|
||||
graph_analyzer = GraphAnalyzer(language_id)
|
||||
if graph_analyzer.is_available():
|
||||
relationships = graph_analyzer.analyze_with_symbols(
|
||||
text, file_path, indexed_file.symbols
|
||||
)
|
||||
if relationships:
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
files_count += 1
|
||||
symbols_count += len(indexed_file.symbols)
|
||||
|
||||
@@ -742,16 +731,6 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
|
||||
symbols=indexed_file.symbols,
|
||||
)
|
||||
|
||||
# Extract and store code relationships for graph visualization
|
||||
if language_id in {"python", "javascript", "typescript"}:
|
||||
graph_analyzer = GraphAnalyzer(language_id)
|
||||
if graph_analyzer.is_available():
|
||||
relationships = graph_analyzer.analyze_with_symbols(
|
||||
text, item, indexed_file.symbols
|
||||
)
|
||||
if relationships:
|
||||
store.add_relationships(item, relationships)
|
||||
|
||||
files_count += 1
|
||||
symbols_count += len(indexed_file.symbols)
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
Migration 003: Add code relationships storage.
|
||||
|
||||
This migration introduces the `code_relationships` table to store semantic
|
||||
relationships between code symbols (function calls, inheritance, imports).
|
||||
This enables graph-based code navigation and dependency analysis.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""
|
||||
Applies the migration to add code relationships table.
|
||||
|
||||
- Creates `code_relationships` table with foreign key to symbols
|
||||
- Creates indexes for efficient relationship queries
|
||||
- Supports lazy expansion with target_symbol being qualified names
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating 'code_relationships' table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT,
|
||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for code_relationships...")
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
|
||||
)
|
||||
|
||||
log.info("Finished creating code_relationships table and indexes.")
|
||||
@@ -10,7 +10,7 @@ from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol
|
||||
from codexlens.entities import IndexedFile, SearchResult, Symbol
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
|
||||
@@ -420,167 +420,6 @@ class SQLiteStore:
|
||||
}
|
||||
|
||||
|
||||
def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None:
|
||||
"""Store code relationships for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file containing the relationships
|
||||
relationships: List of CodeRelationship objects to store
|
||||
"""
|
||||
if not relationships:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
resolved_path = str(Path(file_path).resolve())
|
||||
|
||||
# Get file_id
|
||||
row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone()
|
||||
if not row:
|
||||
raise StorageError(f"File not found in index: {file_path}")
|
||||
file_id = int(row["id"])
|
||||
|
||||
# Delete existing relationships for symbols in this file
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM code_relationships
|
||||
WHERE source_symbol_id IN (
|
||||
SELECT id FROM symbols WHERE file_id=?
|
||||
)
|
||||
""",
|
||||
(file_id,)
|
||||
)
|
||||
|
||||
# Insert new relationships
|
||||
relationship_rows = []
|
||||
for rel in relationships:
|
||||
# Find source symbol ID
|
||||
symbol_row = conn.execute(
|
||||
"""
|
||||
SELECT id FROM symbols
|
||||
WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ?
|
||||
ORDER BY (end_line - start_line) ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(file_id, rel.source_symbol, rel.source_line, rel.source_line)
|
||||
).fetchone()
|
||||
|
||||
if symbol_row:
|
||||
source_symbol_id = int(symbol_row["id"])
|
||||
relationship_rows.append((
|
||||
source_symbol_id,
|
||||
rel.target_symbol,
|
||||
rel.relationship_type,
|
||||
rel.source_line,
|
||||
rel.target_file
|
||||
))
|
||||
|
||||
if relationship_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO code_relationships(
|
||||
source_symbol_id, target_qualified_name, relationship_type,
|
||||
source_line, target_file
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
""",
|
||||
relationship_rows
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def query_relationships_by_target(
|
||||
self, target_name: str, *, limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query relationships by target symbol name (find all callers).
|
||||
|
||||
Args:
|
||||
target_name: Name of the target symbol
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of dicts containing relationship info with file paths and line numbers
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE r.target_qualified_name = ?
|
||||
ORDER BY f.full_path, r.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(target_name, limit)
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def query_relationships_by_source(
|
||||
self, source_symbol: str, source_file: str | Path, *, limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query relationships by source symbol (find what a symbol calls).
|
||||
|
||||
Args:
|
||||
source_symbol: Name of the source symbol
|
||||
source_file: File path containing the source symbol
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of dicts containing relationship info
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
resolved_path = str(Path(source_file).resolve())
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ? AND f.path = ?
|
||||
ORDER BY r.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(source_symbol, resolved_path, limit)
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
"""Legacy method for backward compatibility."""
|
||||
return self._get_connection()
|
||||
|
||||
@@ -1,644 +0,0 @@
|
||||
"""Unit tests for ChainSearchEngine.
|
||||
|
||||
Tests the graph query methods (search_callers, search_callees, search_inheritance)
|
||||
with mocked SQLiteStore dependency to test logic in isolation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock, patch, call
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from codexlens.search.chain_search import (
|
||||
ChainSearchEngine,
|
||||
SearchOptions,
|
||||
SearchStats,
|
||||
ChainSearchResult,
|
||||
)
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
from codexlens.storage.registry import RegistryStore, DirMapping
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_registry():
|
||||
"""Create a mock RegistryStore."""
|
||||
registry = Mock(spec=RegistryStore)
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mapper():
|
||||
"""Create a mock PathMapper."""
|
||||
mapper = Mock(spec=PathMapper)
|
||||
return mapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine(mock_registry, mock_mapper):
|
||||
"""Create a ChainSearchEngine with mocked dependencies."""
|
||||
return ChainSearchEngine(mock_registry, mock_mapper, max_workers=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_index_path():
|
||||
"""Sample index database path."""
|
||||
return Path("/test/project/_index.db")
|
||||
|
||||
|
||||
class TestChainSearchEngineCallers:
|
||||
"""Tests for search_callers method."""
|
||||
|
||||
def test_search_callers_returns_relationships(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callers returns caller relationships."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
# Mock finding the start index
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
# Mock collect_index_paths to return single index
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
# Mock the parallel search to return caller data
|
||||
expected_callers = [
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 42,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=expected_callers):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "caller_function"
|
||||
assert result[0]["target_symbol"] == "my_function"
|
||||
assert result[0]["relationship_type"] == "calls"
|
||||
assert result[0]["source_line"] == 42
|
||||
|
||||
def test_search_callers_empty_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callers handles no results gracefully."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "nonexistent_function"
|
||||
|
||||
# Mock finding the start index
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
# Mock collect_index_paths
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
# Mock empty results
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=[]):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_search_callers_no_index_found(self, search_engine, mock_registry):
|
||||
"""Test that search_callers returns empty list when no index found."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
# Mock no index found
|
||||
mock_registry.find_nearest_index.return_value = None
|
||||
|
||||
with patch.object(search_engine, '_find_start_index', return_value=None):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_search_callers_uses_options(self, search_engine, mock_registry, mock_mapper, sample_index_path):
|
||||
"""Test that search_callers respects SearchOptions."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
options = SearchOptions(depth=1, total_limit=50)
|
||||
|
||||
# Configure mapper to return a path that exists
|
||||
mock_mapper.source_to_index_db.return_value = sample_index_path
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]) as mock_collect:
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=[]) as mock_search:
|
||||
# Patch Path.exists to return True so the exact match is found
|
||||
with patch.object(Path, 'exists', return_value=True):
|
||||
# Execute
|
||||
search_engine.search_callers(target_symbol, source_path, options)
|
||||
|
||||
# Assert that depth was passed to collect_index_paths
|
||||
mock_collect.assert_called_once_with(sample_index_path, 1)
|
||||
# Assert that total_limit was passed to parallel search
|
||||
mock_search.assert_called_once_with([sample_index_path], target_symbol, 50)
|
||||
|
||||
|
||||
class TestChainSearchEngineCallees:
|
||||
"""Tests for search_callees method."""
|
||||
|
||||
def test_search_callees_returns_relationships(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callees returns callee relationships."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
source_symbol = "caller_function"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
expected_callees = [
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "callee_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 15,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
|
||||
# Execute
|
||||
result = search_engine.search_callees(source_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "caller_function"
|
||||
assert result[0]["target_symbol"] == "callee_function"
|
||||
assert result[0]["source_line"] == 15
|
||||
|
||||
def test_search_callees_filters_by_file(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callees correctly handles file-specific queries."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
source_symbol = "MyClass.method"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
# Multiple callees from same source symbol
|
||||
expected_callees = [
|
||||
{
|
||||
"source_symbol": "MyClass.method",
|
||||
"target_symbol": "helper_a",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/utils.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "MyClass.method",
|
||||
"target_symbol": "helper_b",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 20,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/utils.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
|
||||
# Execute
|
||||
result = search_engine.search_callees(source_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0]["target_symbol"] == "helper_a"
|
||||
assert result[1]["target_symbol"] == "helper_b"
|
||||
|
||||
def test_search_callees_empty_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_callees handles no callees gracefully."""
|
||||
source_path = Path("/test/project")
|
||||
source_symbol = "leaf_function"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
with patch.object(search_engine, '_search_callees_parallel', return_value=[]):
|
||||
# Execute
|
||||
result = search_engine.search_callees(source_symbol, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestChainSearchEngineInheritance:
|
||||
"""Tests for search_inheritance method."""
|
||||
|
||||
def test_search_inheritance_returns_inherits_relationships(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that search_inheritance returns inheritance relationships."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
class_name = "BaseClass"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
expected_inheritance = [
|
||||
{
|
||||
"source_symbol": "DerivedClass",
|
||||
"target_symbol": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 5,
|
||||
"source_file": "/test/project/derived.py",
|
||||
"target_file": "/test/project/base.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
|
||||
# Execute
|
||||
result = search_engine.search_inheritance(class_name, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "DerivedClass"
|
||||
assert result[0]["target_symbol"] == "BaseClass"
|
||||
assert result[0]["relationship_type"] == "inherits"
|
||||
|
||||
def test_search_inheritance_multiple_subclasses(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test inheritance search with multiple derived classes."""
|
||||
source_path = Path("/test/project")
|
||||
class_name = "BaseClass"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
expected_inheritance = [
|
||||
{
|
||||
"source_symbol": "DerivedClassA",
|
||||
"target_symbol": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 5,
|
||||
"source_file": "/test/project/derived_a.py",
|
||||
"target_file": "/test/project/base.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "DerivedClassB",
|
||||
"target_symbol": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/project/derived_b.py",
|
||||
"target_file": "/test/project/base.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
|
||||
# Execute
|
||||
result = search_engine.search_inheritance(class_name, source_path)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0]["source_symbol"] == "DerivedClassA"
|
||||
assert result[1]["source_symbol"] == "DerivedClassB"
|
||||
|
||||
def test_search_inheritance_empty_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test inheritance search with no subclasses found."""
|
||||
source_path = Path("/test/project")
|
||||
class_name = "FinalClass"
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=sample_index_path,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=[]):
|
||||
# Execute
|
||||
result = search_engine.search_inheritance(class_name, source_path)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestChainSearchEngineParallelSearch:
|
||||
"""Tests for parallel search aggregation."""
|
||||
|
||||
def test_parallel_search_aggregates_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that parallel search aggregates results from multiple indexes."""
|
||||
# Setup
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
index_path_1 = Path("/test/project/_index.db")
|
||||
index_path_2 = Path("/test/project/subdir/_index.db")
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=index_path_1,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
|
||||
# Mock parallel search results from multiple indexes
|
||||
callers_from_multiple = [
|
||||
{
|
||||
"source_symbol": "caller_in_root",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/project/root.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "caller_in_subdir",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 20,
|
||||
"source_file": "/test/project/subdir/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=callers_from_multiple):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert results from both indexes are included
|
||||
assert len(result) == 2
|
||||
assert any(r["source_file"] == "/test/project/root.py" for r in result)
|
||||
assert any(r["source_file"] == "/test/project/subdir/module.py" for r in result)
|
||||
|
||||
def test_parallel_search_deduplicates_results(self, search_engine, mock_registry, sample_index_path):
|
||||
"""Test that parallel search deduplicates results by (source_file, source_line)."""
|
||||
# Note: This test verifies the behavior of _search_callers_parallel deduplication
|
||||
source_path = Path("/test/project")
|
||||
target_symbol = "my_function"
|
||||
|
||||
index_path_1 = Path("/test/project/_index.db")
|
||||
index_path_2 = Path("/test/project/_index.db") # Same index (simulates duplicate)
|
||||
|
||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
||||
id=1,
|
||||
project_id=1,
|
||||
source_path=source_path,
|
||||
index_path=index_path_1,
|
||||
depth=0,
|
||||
files_count=10,
|
||||
last_updated=0.0
|
||||
)
|
||||
|
||||
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
|
||||
# Mock duplicate results from same location
|
||||
duplicate_callers = [
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 42,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
},
|
||||
{
|
||||
"source_symbol": "caller_function",
|
||||
"target_symbol": "my_function",
|
||||
"relationship_type": "calls",
|
||||
"source_line": 42,
|
||||
"source_file": "/test/project/module.py",
|
||||
"target_file": "/test/project/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(search_engine, '_search_callers_parallel', return_value=duplicate_callers):
|
||||
# Execute
|
||||
result = search_engine.search_callers(target_symbol, source_path)
|
||||
|
||||
# Assert: even with duplicates in input, output may contain both
|
||||
# (actual deduplication happens in _search_callers_parallel)
|
||||
assert len(result) >= 1
|
||||
|
||||
|
||||
class TestChainSearchEngineContextManager:
|
||||
"""Tests for context manager functionality."""
|
||||
|
||||
def test_context_manager_closes_executor(self, mock_registry, mock_mapper):
|
||||
"""Test that context manager properly closes executor."""
|
||||
with ChainSearchEngine(mock_registry, mock_mapper) as engine:
|
||||
# Force executor creation
|
||||
engine._get_executor()
|
||||
assert engine._executor is not None
|
||||
|
||||
# Executor should be closed after exiting context
|
||||
assert engine._executor is None
|
||||
|
||||
def test_close_method_shuts_down_executor(self, search_engine):
|
||||
"""Test that close() method shuts down executor."""
|
||||
# Create executor
|
||||
search_engine._get_executor()
|
||||
assert search_engine._executor is not None
|
||||
|
||||
# Close
|
||||
search_engine.close()
|
||||
assert search_engine._executor is None
|
||||
|
||||
|
||||
class TestSearchCallersSingle:
|
||||
"""Tests for _search_callers_single internal method."""
|
||||
|
||||
def test_search_callers_single_queries_store(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callers_single queries SQLiteStore correctly."""
|
||||
target_symbol = "my_function"
|
||||
|
||||
# Mock SQLiteStore
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
mock_store_instance = MockStore.return_value.__enter__.return_value
|
||||
mock_store_instance.query_relationships_by_target.return_value = [
|
||||
{
|
||||
"source_symbol": "caller",
|
||||
"target_symbol": target_symbol,
|
||||
"relationship_type": "calls",
|
||||
"source_line": 10,
|
||||
"source_file": "/test/file.py",
|
||||
"target_file": "/test/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callers_single(sample_index_path, target_symbol)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "caller"
|
||||
mock_store_instance.query_relationships_by_target.assert_called_once_with(target_symbol)
|
||||
|
||||
def test_search_callers_single_handles_errors(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callers_single returns empty list on error."""
|
||||
target_symbol = "my_function"
|
||||
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
MockStore.return_value.__enter__.side_effect = Exception("Database error")
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callers_single(sample_index_path, target_symbol)
|
||||
|
||||
# Assert - should return empty list, not raise exception
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestSearchCalleesSingle:
|
||||
"""Tests for _search_callees_single internal method."""
|
||||
|
||||
def test_search_callees_single_queries_database(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callees_single queries SQLiteStore correctly."""
|
||||
source_symbol = "caller_function"
|
||||
|
||||
# Mock SQLiteStore
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
mock_store_instance = MagicMock()
|
||||
MockStore.return_value.__enter__.return_value = mock_store_instance
|
||||
|
||||
# Mock execute_query to return relationship data (using new public API)
|
||||
mock_store_instance.execute_query.return_value = [
|
||||
{
|
||||
"source_symbol": source_symbol,
|
||||
"target_symbol": "callee_function",
|
||||
"relationship_type": "call",
|
||||
"source_line": 15,
|
||||
"source_file": "/test/module.py",
|
||||
"target_file": "/test/lib.py",
|
||||
}
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callees_single(sample_index_path, source_symbol)
|
||||
|
||||
# Assert - verify execute_query was called (public API)
|
||||
assert mock_store_instance.execute_query.called
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == source_symbol
|
||||
assert result[0]["target_symbol"] == "callee_function"
|
||||
|
||||
def test_search_callees_single_handles_errors(self, search_engine, sample_index_path):
|
||||
"""Test that _search_callees_single returns empty list on error."""
|
||||
source_symbol = "caller_function"
|
||||
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
MockStore.return_value.__enter__.side_effect = Exception("DB error")
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_callees_single(sample_index_path, source_symbol)
|
||||
|
||||
# Assert - should return empty list, not raise exception
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestSearchInheritanceSingle:
|
||||
"""Tests for _search_inheritance_single internal method."""
|
||||
|
||||
def test_search_inheritance_single_queries_database(self, search_engine, sample_index_path):
|
||||
"""Test that _search_inheritance_single queries SQLiteStore correctly."""
|
||||
class_name = "BaseClass"
|
||||
|
||||
# Mock SQLiteStore
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
mock_store_instance = MagicMock()
|
||||
MockStore.return_value.__enter__.return_value = mock_store_instance
|
||||
|
||||
# Mock execute_query to return relationship data (using new public API)
|
||||
mock_store_instance.execute_query.return_value = [
|
||||
{
|
||||
"source_symbol": "DerivedClass",
|
||||
"target_qualified_name": "BaseClass",
|
||||
"relationship_type": "inherits",
|
||||
"source_line": 5,
|
||||
"source_file": "/test/derived.py",
|
||||
"target_file": "/test/base.py",
|
||||
}
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_inheritance_single(sample_index_path, class_name)
|
||||
|
||||
# Assert
|
||||
assert mock_store_instance.execute_query.called
|
||||
assert len(result) == 1
|
||||
assert result[0]["source_symbol"] == "DerivedClass"
|
||||
assert result[0]["relationship_type"] == "inherits"
|
||||
|
||||
# Verify execute_query was called with 'inherits' filter
|
||||
call_args = mock_store_instance.execute_query.call_args
|
||||
sql_query = call_args[0][0]
|
||||
assert "relationship_type = 'inherits'" in sql_query
|
||||
|
||||
def test_search_inheritance_single_handles_errors(self, search_engine, sample_index_path):
|
||||
"""Test that _search_inheritance_single returns empty list on error."""
|
||||
class_name = "BaseClass"
|
||||
|
||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
||||
MockStore.return_value.__enter__.side_effect = Exception("DB error")
|
||||
|
||||
# Execute
|
||||
result = search_engine._search_inheritance_single(sample_index_path, class_name)
|
||||
|
||||
# Assert - should return empty list, not raise exception
|
||||
assert result == []
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Tests for CLI search command with --enrich flag."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
from codexlens.cli.commands import app
|
||||
|
||||
|
||||
class TestCLISearchEnrich:
|
||||
"""Test CLI search command with --enrich flag integration."""
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create CLI test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_search_with_enrich_flag_help(self, runner):
|
||||
"""Test --enrich flag is documented in help."""
|
||||
result = runner.invoke(app, ["search", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "--enrich" in result.output
|
||||
assert "relationships" in result.output.lower() or "graph" in result.output.lower()
|
||||
|
||||
def test_search_with_enrich_flag_accepted(self, runner):
|
||||
"""Test --enrich flag is accepted by the CLI."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich"])
|
||||
# Should not show 'unknown option' error
|
||||
assert "No such option" not in result.output
|
||||
assert "error: unrecognized" not in result.output.lower()
|
||||
|
||||
def test_search_without_enrich_flag(self, runner):
|
||||
"""Test search without --enrich flag has no relationships."""
|
||||
result = runner.invoke(app, ["search", "test", "--json"])
|
||||
# Even without an index, JSON should be attempted
|
||||
if result.exit_code == 0:
|
||||
try:
|
||||
data = json.loads(result.output)
|
||||
# If we get results, they should not have enriched=true
|
||||
if data.get("success") and "result" in data:
|
||||
assert data["result"].get("enriched", False) is False
|
||||
except json.JSONDecodeError:
|
||||
pass # Not JSON output, that's fine for error cases
|
||||
|
||||
def test_search_enrich_json_output_structure(self, runner):
|
||||
"""Test JSON output structure includes enriched flag."""
|
||||
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
|
||||
# If we get valid JSON output, check structure
|
||||
if result.exit_code == 0:
|
||||
try:
|
||||
data = json.loads(result.output)
|
||||
if data.get("success") and "result" in data:
|
||||
# enriched field should exist
|
||||
assert "enriched" in data["result"]
|
||||
except json.JSONDecodeError:
|
||||
pass # Not JSON output
|
||||
|
||||
def test_search_enrich_with_mode(self, runner):
|
||||
"""Test --enrich works with different search modes."""
|
||||
modes = ["exact", "fuzzy", "hybrid"]
|
||||
for mode in modes:
|
||||
result = runner.invoke(
|
||||
app, ["search", "test", "--mode", mode, "--enrich"]
|
||||
)
|
||||
# Should not show validation errors
|
||||
assert "Invalid" not in result.output
|
||||
|
||||
|
||||
class TestEnrichFlagBehavior:
|
||||
"""Test behavioral aspects of --enrich flag."""
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create CLI test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_enrich_failure_does_not_break_search(self, runner):
|
||||
"""Test that enrichment failure doesn't prevent search from returning results."""
|
||||
# Even without proper index, search should not crash due to enrich
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
|
||||
# Should not have unhandled exception
|
||||
assert "Traceback" not in result.output
|
||||
|
||||
def test_enrich_flag_with_files_only(self, runner):
|
||||
"""Test --enrich is accepted with --files-only mode."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--files-only"])
|
||||
# Should not show option conflict error
|
||||
assert "conflict" not in result.output.lower()
|
||||
|
||||
def test_enrich_flag_with_limit(self, runner):
|
||||
"""Test --enrich works with --limit parameter."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--limit", "5"])
|
||||
# Should not show validation error
|
||||
assert "Invalid" not in result.output
|
||||
|
||||
|
||||
class TestEnrichOutputFormat:
|
||||
"""Test output format with --enrich flag."""
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create CLI test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_enrich_verbose_shows_status(self, runner):
|
||||
"""Test verbose mode shows enrichment status."""
|
||||
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
|
||||
# Verbose mode may show enrichment info or warnings
|
||||
# Just ensure it doesn't crash
|
||||
assert result.exit_code in [0, 1] # 0 = success, 1 = no index
|
||||
|
||||
def test_json_output_has_enriched_field(self, runner):
|
||||
"""Test JSON output always has enriched field when --enrich used."""
|
||||
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
|
||||
if result.exit_code == 0:
|
||||
try:
|
||||
data = json.loads(result.output)
|
||||
if data.get("success"):
|
||||
result_data = data.get("result", {})
|
||||
assert "enriched" in result_data
|
||||
assert isinstance(result_data["enriched"], bool)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
@@ -204,8 +204,6 @@ class TestEntitySerialization:
|
||||
"kind": "function",
|
||||
"range": (1, 10),
|
||||
"file": None,
|
||||
"token_count": None,
|
||||
"symbol_type": None,
|
||||
}
|
||||
|
||||
def test_indexed_file_model_dump(self):
|
||||
|
||||
@@ -1,436 +0,0 @@
|
||||
"""Tests for GraphAnalyzer - code relationship extraction."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.semantic.graph_analyzer import GraphAnalyzer
|
||||
|
||||
|
||||
TREE_SITTER_PYTHON_AVAILABLE = True
|
||||
try:
|
||||
import tree_sitter_python # type: ignore[import-not-found] # noqa: F401
|
||||
except Exception:
|
||||
TREE_SITTER_PYTHON_AVAILABLE = False
|
||||
|
||||
|
||||
TREE_SITTER_JS_AVAILABLE = True
|
||||
try:
|
||||
import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401
|
||||
except Exception:
|
||||
TREE_SITTER_JS_AVAILABLE = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
class TestPythonGraphAnalyzer:
|
||||
"""Tests for Python relationship extraction."""
|
||||
|
||||
def test_simple_function_call(self):
|
||||
"""Test extraction of simple function call."""
|
||||
code = """def helper():
|
||||
pass
|
||||
|
||||
def main():
|
||||
helper()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find main -> helper call
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "main"
|
||||
assert rel.target_symbol == "helper"
|
||||
assert rel.relationship_type == "call"
|
||||
assert rel.source_line == 5
|
||||
|
||||
def test_multiple_calls_in_function(self):
|
||||
"""Test extraction of multiple calls from same function."""
|
||||
code = """def foo():
|
||||
pass
|
||||
|
||||
def bar():
|
||||
pass
|
||||
|
||||
def main():
|
||||
foo()
|
||||
bar()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find main -> foo and main -> bar
|
||||
assert len(relationships) == 2
|
||||
targets = {rel.target_symbol for rel in relationships}
|
||||
assert targets == {"foo", "bar"}
|
||||
assert all(rel.source_symbol == "main" for rel in relationships)
|
||||
|
||||
def test_nested_function_calls(self):
|
||||
"""Test extraction of calls from nested functions."""
|
||||
code = """def inner_helper():
|
||||
pass
|
||||
|
||||
def outer():
|
||||
def inner():
|
||||
inner_helper()
|
||||
inner()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find outer.inner -> inner_helper and outer -> inner (with fully qualified names)
|
||||
assert len(relationships) == 2
|
||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
assert ("outer.inner", "inner_helper") in call_pairs
|
||||
assert ("outer", "inner") in call_pairs
|
||||
|
||||
def test_method_call_in_class(self):
|
||||
"""Test extraction of method calls within class."""
|
||||
code = """class Calculator:
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
def compute(self, x, y):
|
||||
result = self.add(x, y)
|
||||
return result
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find Calculator.compute -> add (with fully qualified source)
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "Calculator.compute"
|
||||
assert rel.target_symbol == "add"
|
||||
|
||||
def test_module_level_call(self):
|
||||
"""Test extraction of module-level function calls."""
|
||||
code = """def setup():
|
||||
pass
|
||||
|
||||
setup()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find <module> -> setup
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "<module>"
|
||||
assert rel.target_symbol == "setup"
|
||||
|
||||
def test_async_function_call(self):
|
||||
"""Test extraction of calls involving async functions."""
|
||||
code = """async def fetch_data():
|
||||
pass
|
||||
|
||||
async def process():
|
||||
await fetch_data()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Should find process -> fetch_data
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "process"
|
||||
assert rel.target_symbol == "fetch_data"
|
||||
|
||||
def test_complex_python_file(self):
|
||||
"""Test extraction from realistic Python file with multiple patterns."""
|
||||
code = """class DataProcessor:
|
||||
def __init__(self):
|
||||
self.data = []
|
||||
|
||||
def load(self, filename):
|
||||
self.data = read_file(filename)
|
||||
|
||||
def process(self):
|
||||
self.validate()
|
||||
self.transform()
|
||||
|
||||
def validate(self):
|
||||
pass
|
||||
|
||||
def transform(self):
|
||||
pass
|
||||
|
||||
def read_file(filename):
|
||||
pass
|
||||
|
||||
def main():
|
||||
processor = DataProcessor()
|
||||
processor.load("data.txt")
|
||||
processor.process()
|
||||
|
||||
main()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Extract call pairs
|
||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
|
||||
# Expected relationships (with fully qualified source symbols for methods)
|
||||
expected = {
|
||||
("DataProcessor.load", "read_file"),
|
||||
("DataProcessor.process", "validate"),
|
||||
("DataProcessor.process", "transform"),
|
||||
("main", "DataProcessor"),
|
||||
("main", "load"),
|
||||
("main", "process"),
|
||||
("<module>", "main"),
|
||||
}
|
||||
|
||||
# Should find all expected relationships
|
||||
assert call_pairs >= expected
|
||||
|
||||
def test_empty_file(self):
|
||||
"""Test handling of empty file."""
|
||||
code = ""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
assert len(relationships) == 0
|
||||
|
||||
def test_file_with_no_calls(self):
|
||||
"""Test handling of file with definitions but no calls."""
|
||||
code = """def func1():
|
||||
pass
|
||||
|
||||
def func2():
|
||||
pass
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
assert len(relationships) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed")
|
||||
class TestJavaScriptGraphAnalyzer:
|
||||
"""Tests for JavaScript relationship extraction."""
|
||||
|
||||
def test_simple_function_call(self):
|
||||
"""Test extraction of simple JavaScript function call."""
|
||||
code = """function helper() {}
|
||||
|
||||
function main() {
|
||||
helper();
|
||||
}
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Should find main -> helper call
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "main"
|
||||
assert rel.target_symbol == "helper"
|
||||
assert rel.relationship_type == "call"
|
||||
|
||||
def test_arrow_function_call(self):
|
||||
"""Test extraction of calls from arrow functions."""
|
||||
code = """const helper = () => {};
|
||||
|
||||
const main = () => {
|
||||
helper();
|
||||
};
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Should find main -> helper call
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "main"
|
||||
assert rel.target_symbol == "helper"
|
||||
|
||||
def test_class_method_call(self):
|
||||
"""Test extraction of method calls in JavaScript class."""
|
||||
code = """class Calculator {
|
||||
add(a, b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
compute(x, y) {
|
||||
return this.add(x, y);
|
||||
}
|
||||
}
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Should find Calculator.compute -> add (with fully qualified source)
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_symbol == "Calculator.compute"
|
||||
assert rel.target_symbol == "add"
|
||||
|
||||
def test_complex_javascript_file(self):
|
||||
"""Test extraction from realistic JavaScript file."""
|
||||
code = """function readFile(filename) {
|
||||
return "";
|
||||
}
|
||||
|
||||
class DataProcessor {
|
||||
constructor() {
|
||||
this.data = [];
|
||||
}
|
||||
|
||||
load(filename) {
|
||||
this.data = readFile(filename);
|
||||
}
|
||||
|
||||
process() {
|
||||
this.validate();
|
||||
this.transform();
|
||||
}
|
||||
|
||||
validate() {}
|
||||
|
||||
transform() {}
|
||||
}
|
||||
|
||||
function main() {
|
||||
const processor = new DataProcessor();
|
||||
processor.load("data.txt");
|
||||
processor.process();
|
||||
}
|
||||
|
||||
main();
|
||||
"""
|
||||
analyzer = GraphAnalyzer("javascript")
|
||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
||||
|
||||
# Extract call pairs
|
||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
|
||||
# Expected relationships (with fully qualified source symbols for methods)
|
||||
# Note: constructor calls like "new DataProcessor()" are not tracked
|
||||
expected = {
|
||||
("DataProcessor.load", "readFile"),
|
||||
("DataProcessor.process", "validate"),
|
||||
("DataProcessor.process", "transform"),
|
||||
("main", "load"),
|
||||
("main", "process"),
|
||||
("<module>", "main"),
|
||||
}
|
||||
|
||||
# Should find all expected relationships
|
||||
assert call_pairs >= expected
|
||||
|
||||
|
||||
class TestGraphAnalyzerEdgeCases:
|
||||
"""Edge case tests for GraphAnalyzer."""
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_unavailable_language(self):
|
||||
"""Test handling of unsupported language."""
|
||||
code = "some code"
|
||||
analyzer = GraphAnalyzer("rust")
|
||||
relationships = analyzer.analyze_file(code, Path("test.rs"))
|
||||
assert len(relationships) == 0
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_malformed_python_code(self):
|
||||
"""Test handling of malformed Python code."""
|
||||
code = "def broken(\n pass"
|
||||
analyzer = GraphAnalyzer("python")
|
||||
# Should not crash
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
assert isinstance(relationships, list)
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_file_path_in_relationship(self):
|
||||
"""Test that file path is correctly set in relationships."""
|
||||
code = """def foo():
|
||||
pass
|
||||
|
||||
def bar():
|
||||
foo()
|
||||
"""
|
||||
test_path = Path("test.py")
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, test_path)
|
||||
|
||||
assert len(relationships) == 1
|
||||
rel = relationships[0]
|
||||
assert rel.source_file == str(test_path.resolve())
|
||||
assert rel.target_file is None # Intra-file
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_performance_large_file(self):
|
||||
"""Test performance on larger file (1000 lines)."""
|
||||
import time
|
||||
|
||||
# Generate file with many functions and calls
|
||||
lines = []
|
||||
for i in range(100):
|
||||
lines.append(f"def func_{i}():")
|
||||
if i > 0:
|
||||
lines.append(f" func_{i-1}()")
|
||||
else:
|
||||
lines.append(" pass")
|
||||
|
||||
code = "\n".join(lines)
|
||||
|
||||
analyzer = GraphAnalyzer("python")
|
||||
start_time = time.time()
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Should complete in under 500ms
|
||||
assert elapsed_ms < 500
|
||||
|
||||
# Should find 99 calls (func_1 -> func_0, func_2 -> func_1, ...)
|
||||
assert len(relationships) == 99
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
||||
def test_call_accuracy_rate(self):
|
||||
"""Test >95% accuracy on known call graph."""
|
||||
code = """def a(): pass
|
||||
def b(): pass
|
||||
def c(): pass
|
||||
def d(): pass
|
||||
def e(): pass
|
||||
|
||||
def test1():
|
||||
a()
|
||||
b()
|
||||
|
||||
def test2():
|
||||
c()
|
||||
d()
|
||||
|
||||
def test3():
|
||||
e()
|
||||
|
||||
def main():
|
||||
test1()
|
||||
test2()
|
||||
test3()
|
||||
"""
|
||||
analyzer = GraphAnalyzer("python")
|
||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
||||
|
||||
# Expected calls: test1->a, test1->b, test2->c, test2->d, test3->e, main->test1, main->test2, main->test3
|
||||
expected_calls = {
|
||||
("test1", "a"),
|
||||
("test1", "b"),
|
||||
("test2", "c"),
|
||||
("test2", "d"),
|
||||
("test3", "e"),
|
||||
("main", "test1"),
|
||||
("main", "test2"),
|
||||
("main", "test3"),
|
||||
}
|
||||
|
||||
found_calls = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
||||
|
||||
# Calculate accuracy
|
||||
correct = len(expected_calls & found_calls)
|
||||
total = len(expected_calls)
|
||||
accuracy = (correct / total) * 100 if total > 0 else 0
|
||||
|
||||
# Should have >95% accuracy
|
||||
assert accuracy >= 95.0
|
||||
assert correct == total # Should be 100% for this simple case
|
||||
@@ -1,392 +0,0 @@
|
||||
"""End-to-end tests for graph search CLI commands."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typer.testing import CliRunner
|
||||
import pytest
|
||||
|
||||
from codexlens.cli.commands import app
|
||||
from codexlens.storage.sqlite_store import SQLiteStore
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.entities import IndexedFile, Symbol, CodeRelationship
|
||||
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project():
|
||||
"""Create a temporary project with indexed code and relationships."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir) / "test_project"
|
||||
project_root.mkdir()
|
||||
|
||||
# Create test Python files
|
||||
(project_root / "main.py").write_text("""
|
||||
def main():
|
||||
result = calculate(5, 3)
|
||||
print(result)
|
||||
|
||||
def calculate(a, b):
|
||||
return add(a, b)
|
||||
|
||||
def add(x, y):
|
||||
return x + y
|
||||
""")
|
||||
|
||||
(project_root / "utils.py").write_text("""
|
||||
class BaseClass:
|
||||
def method(self):
|
||||
pass
|
||||
|
||||
class DerivedClass(BaseClass):
|
||||
def method(self):
|
||||
super().method()
|
||||
helper()
|
||||
|
||||
def helper():
|
||||
return True
|
||||
""")
|
||||
|
||||
# Create a custom index directory for graph testing
|
||||
# Skip the standard init to avoid schema conflicts
|
||||
mapper = PathMapper()
|
||||
index_root = mapper.source_to_index_dir(project_root)
|
||||
index_root.mkdir(parents=True, exist_ok=True)
|
||||
test_db = index_root / "_index.db"
|
||||
|
||||
# Register project manually
|
||||
registry = RegistryStore()
|
||||
registry.initialize()
|
||||
project_info = registry.register_project(
|
||||
source_root=project_root,
|
||||
index_root=index_root
|
||||
)
|
||||
registry.register_dir(
|
||||
project_id=project_info.id,
|
||||
source_path=project_root,
|
||||
index_path=test_db,
|
||||
depth=0,
|
||||
files_count=2
|
||||
)
|
||||
|
||||
# Initialize the store with proper SQLiteStore schema and add files
|
||||
with SQLiteStore(test_db) as store:
|
||||
# Read and add files to the store
|
||||
main_content = (project_root / "main.py").read_text()
|
||||
utils_content = (project_root / "utils.py").read_text()
|
||||
|
||||
main_indexed = IndexedFile(
|
||||
path=str(project_root / "main.py"),
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="main", kind="function", range=(2, 4)),
|
||||
Symbol(name="calculate", kind="function", range=(6, 7)),
|
||||
Symbol(name="add", kind="function", range=(9, 10))
|
||||
]
|
||||
)
|
||||
utils_indexed = IndexedFile(
|
||||
path=str(project_root / "utils.py"),
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="BaseClass", kind="class", range=(2, 4)),
|
||||
Symbol(name="DerivedClass", kind="class", range=(6, 9)),
|
||||
Symbol(name="helper", kind="function", range=(11, 12))
|
||||
]
|
||||
)
|
||||
|
||||
store.add_file(main_indexed, main_content)
|
||||
store.add_file(utils_indexed, utils_content)
|
||||
|
||||
with SQLiteStore(test_db) as store:
|
||||
# Add relationships for main.py
|
||||
main_file = project_root / "main.py"
|
||||
relationships_main = [
|
||||
CodeRelationship(
|
||||
source_symbol="main",
|
||||
target_symbol="calculate",
|
||||
relationship_type="call",
|
||||
source_file=str(main_file),
|
||||
source_line=3,
|
||||
target_file=str(main_file)
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="calculate",
|
||||
target_symbol="add",
|
||||
relationship_type="call",
|
||||
source_file=str(main_file),
|
||||
source_line=7,
|
||||
target_file=str(main_file)
|
||||
),
|
||||
]
|
||||
store.add_relationships(main_file, relationships_main)
|
||||
|
||||
# Add relationships for utils.py
|
||||
utils_file = project_root / "utils.py"
|
||||
relationships_utils = [
|
||||
CodeRelationship(
|
||||
source_symbol="DerivedClass",
|
||||
target_symbol="BaseClass",
|
||||
relationship_type="inherits",
|
||||
source_file=str(utils_file),
|
||||
source_line=6, # DerivedClass is defined on line 6
|
||||
target_file=str(utils_file)
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="DerivedClass.method",
|
||||
target_symbol="helper",
|
||||
relationship_type="call",
|
||||
source_file=str(utils_file),
|
||||
source_line=8,
|
||||
target_file=str(utils_file)
|
||||
),
|
||||
]
|
||||
store.add_relationships(utils_file, relationships_utils)
|
||||
|
||||
registry.close()
|
||||
|
||||
yield project_root
|
||||
|
||||
|
||||
class TestGraphCallers:
|
||||
"""Test callers query type."""
|
||||
|
||||
def test_find_callers_basic(self, temp_project):
|
||||
"""Test finding functions that call a given function."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "calculate" in result.stdout
|
||||
assert "Callers of 'add'" in result.stdout
|
||||
|
||||
def test_find_callers_json_mode(self, temp_project):
|
||||
"""Test callers query with JSON output."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--json"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "success" in result.stdout
|
||||
assert "relationships" in result.stdout
|
||||
|
||||
def test_find_callers_no_results(self, temp_project):
|
||||
"""Test callers query when no callers exist."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"nonexistent_function",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No callers found" in result.stdout or "0 found" in result.stdout
|
||||
|
||||
|
||||
class TestGraphCallees:
|
||||
"""Test callees query type."""
|
||||
|
||||
def test_find_callees_basic(self, temp_project):
|
||||
"""Test finding functions called by a given function."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callees",
|
||||
"main",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "calculate" in result.stdout
|
||||
assert "Callees of 'main'" in result.stdout
|
||||
|
||||
def test_find_callees_chain(self, temp_project):
|
||||
"""Test finding callees in a call chain."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callees",
|
||||
"calculate",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "add" in result.stdout
|
||||
|
||||
def test_find_callees_json_mode(self, temp_project):
|
||||
"""Test callees query with JSON output."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callees",
|
||||
"main",
|
||||
"--path", str(temp_project),
|
||||
"--json"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "success" in result.stdout
|
||||
|
||||
|
||||
class TestGraphInheritance:
|
||||
"""Test inheritance query type."""
|
||||
|
||||
def test_find_inheritance_basic(self, temp_project):
|
||||
"""Test finding inheritance relationships."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"inheritance",
|
||||
"BaseClass",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "DerivedClass" in result.stdout
|
||||
assert "Inheritance relationships" in result.stdout
|
||||
|
||||
def test_find_inheritance_derived(self, temp_project):
|
||||
"""Test finding inheritance from derived class perspective."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"inheritance",
|
||||
"DerivedClass",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "BaseClass" in result.stdout
|
||||
|
||||
def test_find_inheritance_json_mode(self, temp_project):
|
||||
"""Test inheritance query with JSON output."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"inheritance",
|
||||
"BaseClass",
|
||||
"--path", str(temp_project),
|
||||
"--json"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "success" in result.stdout
|
||||
|
||||
|
||||
class TestGraphValidation:
|
||||
"""Test query validation and error handling."""
|
||||
|
||||
def test_invalid_query_type(self, temp_project):
|
||||
"""Test error handling for invalid query type."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"invalid_type",
|
||||
"symbol",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Invalid query type" in result.stdout
|
||||
|
||||
def test_invalid_path(self):
|
||||
"""Test error handling for non-existent path."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"symbol",
|
||||
"--path", "/nonexistent/path"
|
||||
])
|
||||
|
||||
# Should handle gracefully (may exit with error or return empty results)
|
||||
assert result.exit_code in [0, 1]
|
||||
|
||||
|
||||
class TestGraphPerformance:
|
||||
"""Test graph query performance requirements."""
|
||||
|
||||
def test_query_response_time(self, temp_project):
|
||||
"""Verify graph queries complete in under 1 second."""
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert elapsed < 1.0, f"Query took {elapsed:.2f}s, expected <1s"
|
||||
|
||||
def test_multiple_query_types(self, temp_project):
|
||||
"""Test all three query types complete successfully."""
|
||||
import time
|
||||
|
||||
queries = [
|
||||
("callers", "add"),
|
||||
("callees", "main"),
|
||||
("inheritance", "BaseClass")
|
||||
]
|
||||
|
||||
total_start = time.time()
|
||||
|
||||
for query_type, symbol in queries:
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
query_type,
|
||||
symbol,
|
||||
"--path", str(temp_project)
|
||||
])
|
||||
assert result.exit_code == 0
|
||||
|
||||
total_elapsed = time.time() - total_start
|
||||
assert total_elapsed < 3.0, f"All queries took {total_elapsed:.2f}s, expected <3s"
|
||||
|
||||
|
||||
class TestGraphOptions:
|
||||
"""Test graph command options."""
|
||||
|
||||
def test_limit_option(self, temp_project):
|
||||
"""Test limit option works correctly."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--limit", "1"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_depth_option(self, temp_project):
|
||||
"""Test depth option works correctly."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--depth", "0"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_verbose_option(self, temp_project):
|
||||
"""Test verbose option works correctly."""
|
||||
result = runner.invoke(app, [
|
||||
"graph",
|
||||
"callers",
|
||||
"add",
|
||||
"--path", str(temp_project),
|
||||
"--verbose"
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,355 +0,0 @@
|
||||
"""Tests for code relationship storage."""
|
||||
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, Symbol
|
||||
from codexlens.storage.migration_manager import MigrationManager
|
||||
from codexlens.storage.sqlite_store import SQLiteStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create a temporary database for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
yield db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(temp_db):
|
||||
"""Create a SQLiteStore with migrations applied."""
|
||||
store = SQLiteStore(temp_db)
|
||||
store.initialize()
|
||||
|
||||
# Manually apply migration_003 (code_relationships table)
|
||||
conn = store._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT,
|
||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield store
|
||||
|
||||
# Cleanup
|
||||
store.close()
|
||||
|
||||
|
||||
def test_relationship_table_created(store):
|
||||
"""Test that the code_relationships table is created by migration."""
|
||||
conn = store._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check table exists
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='code_relationships'"
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
assert result is not None, "code_relationships table should exist"
|
||||
|
||||
# Check indexes exist
|
||||
cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='code_relationships'"
|
||||
)
|
||||
indexes = [row[0] for row in cursor.fetchall()]
|
||||
assert "idx_relationships_source" in indexes
|
||||
assert "idx_relationships_target" in indexes
|
||||
assert "idx_relationships_type" in indexes
|
||||
|
||||
|
||||
def test_add_relationships(store):
|
||||
"""Test storing code relationships."""
|
||||
# First add a file with symbols
|
||||
indexed_file = IndexedFile(
|
||||
path=str(Path(__file__).parent / "sample.py"),
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="foo", kind="function", range=(1, 5)),
|
||||
Symbol(name="bar", kind="function", range=(7, 10)),
|
||||
]
|
||||
)
|
||||
|
||||
content = """def foo():
|
||||
bar()
|
||||
baz()
|
||||
|
||||
def bar():
|
||||
print("hello")
|
||||
"""
|
||||
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
# Add relationships
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=indexed_file.path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="baz",
|
||||
relationship_type="call",
|
||||
source_file=indexed_file.path,
|
||||
target_file=None,
|
||||
source_line=3
|
||||
),
|
||||
]
|
||||
|
||||
store.add_relationships(indexed_file.path, relationships)
|
||||
|
||||
# Verify relationships were stored
|
||||
conn = store._get_connection()
|
||||
count = conn.execute("SELECT COUNT(*) FROM code_relationships").fetchone()[0]
|
||||
assert count == 2, "Should have stored 2 relationships"
|
||||
|
||||
|
||||
def test_query_relationships_by_target(store):
|
||||
"""Test querying relationships by target symbol (find callers)."""
|
||||
# Setup: Add file and relationships
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
# Content: Line 1-2: foo(), Line 4-5: bar(), Line 7-8: main()
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="foo", kind="function", range=(1, 2)),
|
||||
Symbol(name="bar", kind="function", range=(4, 5)),
|
||||
Symbol(name="main", kind="function", range=(7, 8)),
|
||||
]
|
||||
)
|
||||
|
||||
content = "def foo():\n bar()\n\ndef bar():\n pass\n\ndef main():\n bar()\n"
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2 # Call inside foo (line 2)
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="main",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=8 # Call inside main (line 8)
|
||||
),
|
||||
]
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Query: Find all callers of "bar"
|
||||
callers = store.query_relationships_by_target("bar")
|
||||
|
||||
assert len(callers) == 2, "Should find 2 callers of bar"
|
||||
assert any(r["source_symbol"] == "foo" for r in callers)
|
||||
assert any(r["source_symbol"] == "main" for r in callers)
|
||||
assert all(r["target_symbol"] == "bar" for r in callers)
|
||||
assert all(r["relationship_type"] == "call" for r in callers)
|
||||
|
||||
|
||||
def test_query_relationships_by_source(store):
|
||||
"""Test querying relationships by source symbol (find callees)."""
|
||||
# Setup
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[
|
||||
Symbol(name="foo", kind="function", range=(1, 6)),
|
||||
]
|
||||
)
|
||||
|
||||
content = "def foo():\n bar()\n baz()\n qux()\n"
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="baz",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=3
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="qux",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=4
|
||||
),
|
||||
]
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Query: Find all functions called by foo
|
||||
callees = store.query_relationships_by_source("foo", file_path)
|
||||
|
||||
assert len(callees) == 3, "Should find 3 functions called by foo"
|
||||
targets = {r["target_symbol"] for r in callees}
|
||||
assert targets == {"bar", "baz", "qux"}
|
||||
assert all(r["source_symbol"] == "foo" for r in callees)
|
||||
|
||||
|
||||
def test_query_performance(store):
|
||||
"""Test that relationship queries execute within performance threshold."""
|
||||
import time
|
||||
|
||||
# Setup: Create a file with many relationships
|
||||
file_path = str(Path(__file__).parent / "large_file.py")
|
||||
symbols = [Symbol(name=f"func_{i}", kind="function", range=(i*10+1, i*10+5)) for i in range(100)]
|
||||
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=symbols
|
||||
)
|
||||
|
||||
content = "\n".join([f"def func_{i}():\n pass\n" for i in range(100)])
|
||||
store.add_file(indexed_file, content)
|
||||
|
||||
# Create many relationships
|
||||
relationships = []
|
||||
for i in range(100):
|
||||
for j in range(10):
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=f"func_{i}",
|
||||
target_symbol=f"target_{j}",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=i*10 + 1
|
||||
)
|
||||
)
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Query and measure time
|
||||
start = time.time()
|
||||
results = store.query_relationships_by_target("target_5")
|
||||
elapsed_ms = (time.time() - start) * 1000
|
||||
|
||||
assert len(results) == 100, "Should find 100 callers"
|
||||
assert elapsed_ms < 50, f"Query took {elapsed_ms:.1f}ms, should be <50ms"
|
||||
|
||||
|
||||
def test_stats_includes_relationships(store):
|
||||
"""Test that stats() includes relationship count."""
|
||||
# Add a file with relationships
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[Symbol(name="foo", kind="function", range=(1, 5))]
|
||||
)
|
||||
|
||||
store.add_file(indexed_file, "def foo():\n bar()\n")
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
)
|
||||
]
|
||||
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Check stats
|
||||
stats = store.stats()
|
||||
|
||||
assert "relationships" in stats
|
||||
assert stats["relationships"] == 1
|
||||
assert stats["files"] == 1
|
||||
assert stats["symbols"] == 1
|
||||
|
||||
|
||||
def test_update_relationships_on_file_reindex(store):
|
||||
"""Test that relationships are updated when file is re-indexed."""
|
||||
file_path = str(Path(__file__).parent / "sample.py")
|
||||
|
||||
# Initial index
|
||||
indexed_file = IndexedFile(
|
||||
path=file_path,
|
||||
language="python",
|
||||
symbols=[Symbol(name="foo", kind="function", range=(1, 3))]
|
||||
)
|
||||
store.add_file(indexed_file, "def foo():\n bar()\n")
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="bar",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
)
|
||||
]
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
# Re-index with different relationships
|
||||
new_relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="foo",
|
||||
target_symbol="baz",
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=2
|
||||
)
|
||||
]
|
||||
store.add_relationships(file_path, new_relationships)
|
||||
|
||||
# Verify old relationships are replaced
|
||||
all_rels = store.query_relationships_by_source("foo", file_path)
|
||||
assert len(all_rels) == 1
|
||||
assert all_rels[0]["target_symbol"] == "baz"
|
||||
@@ -188,60 +188,3 @@ class TestTokenCountPerformance:
|
||||
# Precomputed should be at least 10% faster
|
||||
speedup = ((computed_time - precomputed_time) / computed_time) * 100
|
||||
assert speedup >= 10.0, f"Speedup {speedup:.2f}% < 10% (computed={computed_time:.4f}s, precomputed={precomputed_time:.4f}s)"
|
||||
|
||||
|
||||
class TestSymbolEntityTokenCount:
|
||||
"""Tests for Symbol entity token_count field."""
|
||||
|
||||
def test_symbol_with_token_count(self):
|
||||
"""Test creating Symbol with token_count."""
|
||||
symbol = Symbol(
|
||||
name="test_func",
|
||||
kind="function",
|
||||
range=(1, 10),
|
||||
token_count=42
|
||||
)
|
||||
|
||||
assert symbol.token_count == 42
|
||||
|
||||
def test_symbol_without_token_count(self):
|
||||
"""Test creating Symbol without token_count (defaults to None)."""
|
||||
symbol = Symbol(
|
||||
name="test_func",
|
||||
kind="function",
|
||||
range=(1, 10)
|
||||
)
|
||||
|
||||
assert symbol.token_count is None
|
||||
|
||||
def test_symbol_with_symbol_type(self):
|
||||
"""Test creating Symbol with symbol_type."""
|
||||
symbol = Symbol(
|
||||
name="TestClass",
|
||||
kind="class",
|
||||
range=(1, 20),
|
||||
symbol_type="class_definition"
|
||||
)
|
||||
|
||||
assert symbol.symbol_type == "class_definition"
|
||||
|
||||
def test_symbol_token_count_validation(self):
|
||||
"""Test that negative token counts are rejected."""
|
||||
with pytest.raises(ValueError, match="token_count must be >= 0"):
|
||||
Symbol(
|
||||
name="test",
|
||||
kind="function",
|
||||
range=(1, 2),
|
||||
token_count=-1
|
||||
)
|
||||
|
||||
def test_symbol_zero_token_count(self):
|
||||
"""Test that zero token count is allowed."""
|
||||
symbol = Symbol(
|
||||
name="empty",
|
||||
kind="function",
|
||||
range=(1, 1),
|
||||
token_count=0
|
||||
)
|
||||
|
||||
assert symbol.token_count == 0
|
||||
|
||||
Reference in New Issue
Block a user