mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-13 02:41:50 +08:00
refactor: 移除图索引功能,修复内存泄露,优化嵌入生成
主要更改: 1. 移除图索引功能 (graph indexing) - 删除 graph_analyzer.py 及相关迁移文件 - 移除 CLI 的 graph 命令和 --enrich 标志 - 清理 chain_search.py 中的图查询方法 (370行) - 删除相关测试文件 2. 修复嵌入生成内存问题 - 重构 generate_embeddings.py 使用流式批处理 - 改用 embedding_manager 的内存安全实现 - 文件从 548 行精简到 259 行 (52.7% 减少) 3. 修复内存泄露 - chain_search.py: quick_search 使用 with 语句管理 ChainSearchEngine - embedding_manager.py: 使用 with 语句管理 VectorStore - vector_store.py: 添加暴力搜索内存警告 4. 代码清理 - 移除 Symbol 模型的 token_count 和 symbol_type 字段 - 清理相关测试用例 测试: 760 passed, 7 skipped 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,15 +1,9 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Generate vector embeddings for existing CodexLens indexes.
|
"""Generate vector embeddings for existing CodexLens indexes.
|
||||||
|
|
||||||
This script processes all files in a CodexLens index database and generates
|
This script is a CLI wrapper around the memory-efficient streaming implementation
|
||||||
semantic vector embeddings for code chunks. The embeddings are stored in the
|
in codexlens.cli.embedding_manager. It uses batch processing to keep memory usage
|
||||||
same SQLite database in the 'semantic_chunks' table.
|
under 2GB regardless of project size.
|
||||||
|
|
||||||
Performance optimizations:
|
|
||||||
- Parallel file processing using ProcessPoolExecutor
|
|
||||||
- Batch embedding generation for efficient GPU/CPU utilization
|
|
||||||
- Batch database writes to minimize I/O overhead
|
|
||||||
- HNSW index auto-generation for fast similarity search
|
|
||||||
|
|
||||||
Requirements:
|
Requirements:
|
||||||
pip install codexlens[semantic]
|
pip install codexlens[semantic]
|
||||||
@@ -20,27 +14,21 @@ Usage:
|
|||||||
# Generate embeddings for a single index
|
# Generate embeddings for a single index
|
||||||
python generate_embeddings.py /path/to/_index.db
|
python generate_embeddings.py /path/to/_index.db
|
||||||
|
|
||||||
# Generate embeddings with parallel processing
|
# Use specific embedding model
|
||||||
python generate_embeddings.py /path/to/_index.db --workers 4
|
python generate_embeddings.py /path/to/_index.db --model code
|
||||||
|
|
||||||
# Use specific embedding model and batch size
|
|
||||||
python generate_embeddings.py /path/to/_index.db --model code --batch-size 256
|
|
||||||
|
|
||||||
# Generate embeddings for all indexes in a directory
|
# Generate embeddings for all indexes in a directory
|
||||||
python generate_embeddings.py --scan ~/.codexlens/indexes
|
python generate_embeddings.py --scan ~/.codexlens/indexes
|
||||||
|
|
||||||
|
# Force regeneration
|
||||||
|
python generate_embeddings.py /path/to/_index.db --force
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -50,100 +38,32 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Import the memory-efficient implementation
|
||||||
@dataclass
|
try:
|
||||||
class FileData:
|
from codexlens.cli.embedding_manager import (
|
||||||
"""Data for a single file to process."""
|
generate_embeddings,
|
||||||
full_path: str
|
generate_embeddings_recursive,
|
||||||
content: str
|
)
|
||||||
language: str
|
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||||
|
except ImportError as exc:
|
||||||
|
logger.error(f"Failed to import codexlens: {exc}")
|
||||||
@dataclass
|
logger.error("Make sure codexlens is installed: pip install codexlens")
|
||||||
class ChunkData:
|
SEMANTIC_AVAILABLE = False
|
||||||
"""Processed chunk data ready for embedding."""
|
|
||||||
file_path: str
|
|
||||||
content: str
|
|
||||||
metadata: dict
|
|
||||||
|
|
||||||
|
|
||||||
def check_dependencies():
|
def check_dependencies():
|
||||||
"""Check if semantic search dependencies are available."""
|
"""Check if semantic search dependencies are available."""
|
||||||
try:
|
if not SEMANTIC_AVAILABLE:
|
||||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
logger.error("Semantic search dependencies not available")
|
||||||
if not SEMANTIC_AVAILABLE:
|
logger.error("Install with: pip install codexlens[semantic]")
|
||||||
logger.error("Semantic search dependencies not available")
|
logger.error("Or: pip install fastembed numpy hnswlib")
|
||||||
logger.error("Install with: pip install codexlens[semantic]")
|
|
||||||
logger.error("Or: pip install fastembed numpy hnswlib")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
except ImportError as exc:
|
|
||||||
logger.error(f"Failed to import codexlens: {exc}")
|
|
||||||
logger.error("Make sure codexlens is installed: pip install codexlens")
|
|
||||||
return False
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def count_files(index_db_path: Path) -> int:
|
def progress_callback(message: str):
|
||||||
"""Count total files in index."""
|
"""Callback function for progress updates."""
|
||||||
try:
|
logger.info(message)
|
||||||
with sqlite3.connect(index_db_path) as conn:
|
|
||||||
cursor = conn.execute("SELECT COUNT(*) FROM files")
|
|
||||||
return cursor.fetchone()[0]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Failed to count files: {exc}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def check_existing_chunks(index_db_path: Path) -> int:
|
|
||||||
"""Check if semantic chunks already exist."""
|
|
||||||
try:
|
|
||||||
with sqlite3.connect(index_db_path) as conn:
|
|
||||||
# Check if table exists
|
|
||||||
cursor = conn.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_chunks'"
|
|
||||||
)
|
|
||||||
if not cursor.fetchone():
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Count existing chunks
|
|
||||||
cursor = conn.execute("SELECT COUNT(*) FROM semantic_chunks")
|
|
||||||
return cursor.fetchone()[0]
|
|
||||||
except Exception:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def process_file_worker(args: Tuple[str, str, str, int]) -> List[ChunkData]:
|
|
||||||
"""Worker function to process a single file (runs in separate process).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Tuple of (file_path, content, language, chunk_size)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ChunkData objects
|
|
||||||
"""
|
|
||||||
file_path, content, language, chunk_size = args
|
|
||||||
|
|
||||||
try:
|
|
||||||
from codexlens.semantic.chunker import Chunker, ChunkConfig
|
|
||||||
|
|
||||||
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
|
|
||||||
chunks = chunker.chunk_sliding_window(
|
|
||||||
content,
|
|
||||||
file_path=file_path,
|
|
||||||
language=language
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
ChunkData(
|
|
||||||
file_path=file_path,
|
|
||||||
content=chunk.content,
|
|
||||||
metadata=chunk.metadata or {}
|
|
||||||
)
|
|
||||||
for chunk in chunks
|
|
||||||
]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug(f"Error processing {file_path}: {exc}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings_for_index(
|
def generate_embeddings_for_index(
|
||||||
@@ -151,259 +71,63 @@ def generate_embeddings_for_index(
|
|||||||
model_profile: str = "code",
|
model_profile: str = "code",
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
chunk_size: int = 2000,
|
chunk_size: int = 2000,
|
||||||
workers: int = 0,
|
**kwargs # Ignore unused parameters (workers, batch_size) for backward compatibility
|
||||||
batch_size: int = 256,
|
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Generate embeddings for all files in an index.
|
"""Generate embeddings for an index using memory-efficient streaming.
|
||||||
|
|
||||||
Performance optimizations:
|
This function wraps the streaming implementation from embedding_manager
|
||||||
- Parallel file processing (chunking)
|
to maintain CLI compatibility while using the memory-optimized approach.
|
||||||
- Batch embedding generation
|
|
||||||
- Batch database writes
|
|
||||||
- HNSW index auto-generation
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_db_path: Path to _index.db file
|
index_db_path: Path to _index.db file
|
||||||
model_profile: Model profile to use (fast, code, multilingual, balanced)
|
model_profile: Model profile to use (fast, code, multilingual, balanced)
|
||||||
force: If True, regenerate even if embeddings exist
|
force: If True, regenerate even if embeddings exist
|
||||||
chunk_size: Maximum chunk size in characters
|
chunk_size: Maximum chunk size in characters
|
||||||
workers: Number of parallel workers (0 = auto-detect CPU count)
|
**kwargs: Additional parameters (ignored for compatibility)
|
||||||
batch_size: Batch size for embedding generation
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with generation statistics
|
Dictionary with generation statistics
|
||||||
"""
|
"""
|
||||||
logger.info(f"Processing index: {index_db_path}")
|
logger.info(f"Processing index: {index_db_path}")
|
||||||
|
|
||||||
# Check existing chunks
|
# Call the memory-efficient streaming implementation
|
||||||
existing_chunks = check_existing_chunks(index_db_path)
|
result = generate_embeddings(
|
||||||
if existing_chunks > 0 and not force:
|
index_path=index_db_path,
|
||||||
logger.warning(f"Index already has {existing_chunks} chunks")
|
model_profile=model_profile,
|
||||||
logger.warning("Use --force to regenerate")
|
force=force,
|
||||||
return {
|
chunk_size=chunk_size,
|
||||||
"success": False,
|
progress_callback=progress_callback,
|
||||||
"error": "Embeddings already exist",
|
)
|
||||||
"existing_chunks": existing_chunks,
|
|
||||||
}
|
|
||||||
|
|
||||||
if force and existing_chunks > 0:
|
if not result["success"]:
|
||||||
logger.info(f"Force mode: clearing {existing_chunks} existing chunks")
|
if "error" in result:
|
||||||
try:
|
logger.error(result["error"])
|
||||||
with sqlite3.connect(index_db_path) as conn:
|
return result
|
||||||
conn.execute("DELETE FROM semantic_chunks")
|
|
||||||
conn.commit()
|
|
||||||
# Also remove HNSW index file
|
|
||||||
hnsw_path = index_db_path.parent / "_vectors.hnsw"
|
|
||||||
if hnsw_path.exists():
|
|
||||||
hnsw_path.unlink()
|
|
||||||
logger.info("Removed existing HNSW index")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"Failed to clear existing data: {exc}")
|
|
||||||
|
|
||||||
# Import dependencies
|
# Extract result data and log summary
|
||||||
try:
|
data = result["result"]
|
||||||
from codexlens.semantic.embedder import Embedder
|
|
||||||
from codexlens.semantic.vector_store import VectorStore
|
|
||||||
from codexlens.entities import SemanticChunk
|
|
||||||
except ImportError as exc:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Import failed: {exc}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize components
|
|
||||||
try:
|
|
||||||
embedder = Embedder(profile=model_profile)
|
|
||||||
vector_store = VectorStore(index_db_path)
|
|
||||||
|
|
||||||
logger.info(f"Using model: {embedder.model_name}")
|
|
||||||
logger.info(f"Embedding dimension: {embedder.embedding_dim}")
|
|
||||||
except Exception as exc:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Failed to initialize components: {exc}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Read files from index
|
|
||||||
try:
|
|
||||||
with sqlite3.connect(index_db_path) as conn:
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
cursor = conn.execute("SELECT full_path, content, language FROM files")
|
|
||||||
files = [
|
|
||||||
FileData(
|
|
||||||
full_path=row["full_path"],
|
|
||||||
content=row["content"],
|
|
||||||
language=row["language"] or "python"
|
|
||||||
)
|
|
||||||
for row in cursor.fetchall()
|
|
||||||
]
|
|
||||||
except Exception as exc:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Failed to read files: {exc}",
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Found {len(files)} files to process")
|
|
||||||
if len(files) == 0:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "No files found in index",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine worker count
|
|
||||||
if workers <= 0:
|
|
||||||
workers = min(multiprocessing.cpu_count(), len(files), 8)
|
|
||||||
logger.info(f"Using {workers} worker(s) for parallel processing")
|
|
||||||
logger.info(f"Batch size for embeddings: {batch_size}")
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Phase 1: Parallel chunking
|
|
||||||
logger.info("Phase 1: Chunking files...")
|
|
||||||
chunk_start = time.time()
|
|
||||||
|
|
||||||
all_chunks: List[ChunkData] = []
|
|
||||||
failed_files = []
|
|
||||||
|
|
||||||
# Prepare work items
|
|
||||||
work_items = [
|
|
||||||
(f.full_path, f.content, f.language, chunk_size)
|
|
||||||
for f in files
|
|
||||||
]
|
|
||||||
|
|
||||||
if workers == 1:
|
|
||||||
# Single-threaded for debugging
|
|
||||||
for i, item in enumerate(work_items, 1):
|
|
||||||
try:
|
|
||||||
chunks = process_file_worker(item)
|
|
||||||
all_chunks.extend(chunks)
|
|
||||||
if i % 100 == 0:
|
|
||||||
logger.info(f"Chunked {i}/{len(files)} files ({len(all_chunks)} chunks)")
|
|
||||||
except Exception as exc:
|
|
||||||
failed_files.append((item[0], str(exc)))
|
|
||||||
else:
|
|
||||||
# Parallel processing
|
|
||||||
with ProcessPoolExecutor(max_workers=workers) as executor:
|
|
||||||
futures = {
|
|
||||||
executor.submit(process_file_worker, item): item[0]
|
|
||||||
for item in work_items
|
|
||||||
}
|
|
||||||
|
|
||||||
completed = 0
|
|
||||||
for future in as_completed(futures):
|
|
||||||
file_path = futures[future]
|
|
||||||
completed += 1
|
|
||||||
try:
|
|
||||||
chunks = future.result()
|
|
||||||
all_chunks.extend(chunks)
|
|
||||||
if completed % 100 == 0:
|
|
||||||
logger.info(
|
|
||||||
f"Chunked {completed}/{len(files)} files "
|
|
||||||
f"({len(all_chunks)} chunks)"
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
failed_files.append((file_path, str(exc)))
|
|
||||||
|
|
||||||
chunk_time = time.time() - chunk_start
|
|
||||||
logger.info(f"Chunking completed in {chunk_time:.1f}s: {len(all_chunks)} chunks")
|
|
||||||
|
|
||||||
if not all_chunks:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "No chunks created from files",
|
|
||||||
"files_processed": len(files) - len(failed_files),
|
|
||||||
"files_failed": len(failed_files),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Phase 2: Batch embedding generation
|
|
||||||
logger.info("Phase 2: Generating embeddings...")
|
|
||||||
embed_start = time.time()
|
|
||||||
|
|
||||||
# Extract all content for batch embedding
|
|
||||||
all_contents = [c.content for c in all_chunks]
|
|
||||||
|
|
||||||
# Generate embeddings in batches
|
|
||||||
all_embeddings = []
|
|
||||||
for i in range(0, len(all_contents), batch_size):
|
|
||||||
batch_contents = all_contents[i:i + batch_size]
|
|
||||||
batch_embeddings = embedder.embed(batch_contents)
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
|
|
||||||
progress = min(i + batch_size, len(all_contents))
|
|
||||||
if progress % (batch_size * 4) == 0 or progress == len(all_contents):
|
|
||||||
logger.info(f"Generated embeddings: {progress}/{len(all_contents)}")
|
|
||||||
|
|
||||||
embed_time = time.time() - embed_start
|
|
||||||
logger.info(f"Embedding completed in {embed_time:.1f}s")
|
|
||||||
|
|
||||||
# Phase 3: Batch database write
|
|
||||||
logger.info("Phase 3: Storing chunks...")
|
|
||||||
store_start = time.time()
|
|
||||||
|
|
||||||
# Create SemanticChunk objects with embeddings
|
|
||||||
semantic_chunks_with_paths = []
|
|
||||||
for chunk_data, embedding in zip(all_chunks, all_embeddings):
|
|
||||||
semantic_chunk = SemanticChunk(
|
|
||||||
content=chunk_data.content,
|
|
||||||
metadata=chunk_data.metadata,
|
|
||||||
)
|
|
||||||
semantic_chunk.embedding = embedding
|
|
||||||
semantic_chunks_with_paths.append((semantic_chunk, chunk_data.file_path))
|
|
||||||
|
|
||||||
# Batch write (handles both SQLite and HNSW)
|
|
||||||
write_batch_size = 1000
|
|
||||||
total_stored = 0
|
|
||||||
for i in range(0, len(semantic_chunks_with_paths), write_batch_size):
|
|
||||||
batch = semantic_chunks_with_paths[i:i + write_batch_size]
|
|
||||||
vector_store.add_chunks_batch(batch)
|
|
||||||
total_stored += len(batch)
|
|
||||||
if total_stored % 5000 == 0 or total_stored == len(semantic_chunks_with_paths):
|
|
||||||
logger.info(f"Stored: {total_stored}/{len(semantic_chunks_with_paths)} chunks")
|
|
||||||
|
|
||||||
store_time = time.time() - store_start
|
|
||||||
logger.info(f"Storage completed in {store_time:.1f}s")
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Generate summary
|
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info(f"Completed in {elapsed_time:.1f}s")
|
logger.info(f"Completed in {data['elapsed_time']:.1f}s")
|
||||||
logger.info(f" Chunking: {chunk_time:.1f}s")
|
logger.info(f"Total chunks created: {data['chunks_created']}")
|
||||||
logger.info(f" Embedding: {embed_time:.1f}s")
|
logger.info(f"Files processed: {data['files_processed']}")
|
||||||
logger.info(f" Storage: {store_time:.1f}s")
|
if data['files_failed'] > 0:
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.warning(f"Failed files: {data['files_failed']}")
|
||||||
logger.info(f"Files processed: {len(files) - len(failed_files)}/{len(files)}")
|
if data.get('failed_files'):
|
||||||
if vector_store.ann_available:
|
for file_path, error in data['failed_files']:
|
||||||
logger.info(f"HNSW index vectors: {vector_store.ann_count}")
|
logger.warning(f" {file_path}: {error}")
|
||||||
if failed_files:
|
|
||||||
logger.warning(f"Failed files: {len(failed_files)}")
|
|
||||||
for file_path, error in failed_files[:5]: # Show first 5 failures
|
|
||||||
logger.warning(f" {file_path}: {error}")
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"chunks_created": len(all_chunks),
|
"chunks_created": data["chunks_created"],
|
||||||
"files_processed": len(files) - len(failed_files),
|
"files_processed": data["files_processed"],
|
||||||
"files_failed": len(failed_files),
|
"files_failed": data["files_failed"],
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": data["elapsed_time"],
|
||||||
"chunk_time": chunk_time,
|
|
||||||
"embed_time": embed_time,
|
|
||||||
"store_time": store_time,
|
|
||||||
"ann_vectors": vector_store.ann_count if vector_store.ann_available else 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def find_index_databases(scan_dir: Path) -> List[Path]:
|
|
||||||
"""Find all _index.db files in directory tree."""
|
|
||||||
logger.info(f"Scanning for indexes in: {scan_dir}")
|
|
||||||
index_files = list(scan_dir.rglob("_index.db"))
|
|
||||||
logger.info(f"Found {len(index_files)} index databases")
|
|
||||||
return index_files
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate vector embeddings for CodexLens indexes",
|
description="Generate vector embeddings for CodexLens indexes (memory-efficient streaming)",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog=__doc__
|
epilog=__doc__
|
||||||
)
|
)
|
||||||
@@ -439,14 +163,14 @@ def main():
|
|||||||
"--workers",
|
"--workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Number of parallel workers for chunking (default: auto-detect CPU count)"
|
help="(Deprecated) Kept for backward compatibility, ignored"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batch-size",
|
"--batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=256,
|
default=256,
|
||||||
help="Batch size for embedding generation (default: 256)"
|
help="(Deprecated) Kept for backward compatibility, ignored"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -481,43 +205,33 @@ def main():
|
|||||||
|
|
||||||
# Determine if scanning or single file
|
# Determine if scanning or single file
|
||||||
if args.scan or index_path.is_dir():
|
if args.scan or index_path.is_dir():
|
||||||
# Scan mode
|
# Scan mode - use recursive implementation
|
||||||
if index_path.is_file():
|
if index_path.is_file():
|
||||||
logger.error("--scan requires a directory path")
|
logger.error("--scan requires a directory path")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
index_files = find_index_databases(index_path)
|
result = generate_embeddings_recursive(
|
||||||
if not index_files:
|
index_root=index_path,
|
||||||
logger.error(f"No index databases found in: {index_path}")
|
model_profile=args.model,
|
||||||
|
force=args.force,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result["success"]:
|
||||||
|
logger.error(f"Failed: {result.get('error', 'Unknown error')}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Process each index
|
# Log summary
|
||||||
total_chunks = 0
|
data = result["result"]
|
||||||
successful = 0
|
|
||||||
for idx, index_file in enumerate(index_files, 1):
|
|
||||||
logger.info(f"\n{'='*60}")
|
|
||||||
logger.info(f"Processing index {idx}/{len(index_files)}")
|
|
||||||
logger.info(f"{'='*60}")
|
|
||||||
|
|
||||||
result = generate_embeddings_for_index(
|
|
||||||
index_file,
|
|
||||||
model_profile=args.model,
|
|
||||||
force=args.force,
|
|
||||||
chunk_size=args.chunk_size,
|
|
||||||
workers=args.workers,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
total_chunks += result["chunks_created"]
|
|
||||||
successful += 1
|
|
||||||
|
|
||||||
# Final summary
|
|
||||||
logger.info(f"\n{'='*60}")
|
logger.info(f"\n{'='*60}")
|
||||||
logger.info("BATCH PROCESSING COMPLETE")
|
logger.info("BATCH PROCESSING COMPLETE")
|
||||||
logger.info(f"{'='*60}")
|
logger.info(f"{'='*60}")
|
||||||
logger.info(f"Indexes processed: {successful}/{len(index_files)}")
|
logger.info(f"Indexes processed: {data['indexes_successful']}/{data['indexes_processed']}")
|
||||||
logger.info(f"Total chunks created: {total_chunks}")
|
logger.info(f"Total chunks created: {data['total_chunks_created']}")
|
||||||
|
logger.info(f"Total files processed: {data['total_files_processed']}")
|
||||||
|
if data['total_files_failed'] > 0:
|
||||||
|
logger.warning(f"Total files failed: {data['total_files_failed']}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Single index mode
|
# Single index mode
|
||||||
@@ -530,8 +244,6 @@ def main():
|
|||||||
model_profile=args.model,
|
model_profile=args.model,
|
||||||
force=args.force,
|
force=args.force,
|
||||||
chunk_size=args.chunk_size,
|
chunk_size=args.chunk_size,
|
||||||
workers=args.workers,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not result["success"]:
|
if not result["success"]:
|
||||||
|
|||||||
@@ -268,7 +268,6 @@ def search(
|
|||||||
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
|
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
|
||||||
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
|
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
|
||||||
weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."),
|
weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."),
|
||||||
enrich: bool = typer.Option(False, "--enrich", help="Enrich results with code graph relationships (calls, imports)."),
|
|
||||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -423,30 +422,10 @@ def search(
|
|||||||
for r in result.results
|
for r in result.results
|
||||||
]
|
]
|
||||||
|
|
||||||
# Enrich results with relationship data if requested
|
|
||||||
enriched = False
|
|
||||||
if enrich:
|
|
||||||
try:
|
|
||||||
from codexlens.search.enrichment import RelationshipEnricher
|
|
||||||
|
|
||||||
# Find index path for the search path
|
|
||||||
project_record = registry.find_by_source_path(str(search_path))
|
|
||||||
if project_record:
|
|
||||||
index_path = Path(project_record["index_root"]) / "_index.db"
|
|
||||||
if index_path.exists():
|
|
||||||
with RelationshipEnricher(index_path) as enricher:
|
|
||||||
results_list = enricher.enrich(results_list, limit=limit)
|
|
||||||
enriched = True
|
|
||||||
except Exception as e:
|
|
||||||
# Enrichment failure should not break search
|
|
||||||
if verbose:
|
|
||||||
console.print(f"[yellow]Warning: Enrichment failed: {e}[/yellow]")
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"mode": actual_mode,
|
"mode": actual_mode,
|
||||||
"count": len(results_list),
|
"count": len(results_list),
|
||||||
"enriched": enriched,
|
|
||||||
"results": results_list,
|
"results": results_list,
|
||||||
"stats": {
|
"stats": {
|
||||||
"dirs_searched": result.stats.dirs_searched,
|
"dirs_searched": result.stats.dirs_searched,
|
||||||
@@ -458,8 +437,7 @@ def search(
|
|||||||
print_json(success=True, result=payload)
|
print_json(success=True, result=payload)
|
||||||
else:
|
else:
|
||||||
render_search_results(result.results, verbose=verbose)
|
render_search_results(result.results, verbose=verbose)
|
||||||
enrich_status = " | [green]Enriched[/green]" if enriched else ""
|
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms[/dim]")
|
||||||
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms{enrich_status}[/dim]")
|
|
||||||
|
|
||||||
except SearchError as exc:
|
except SearchError as exc:
|
||||||
if json_mode:
|
if json_mode:
|
||||||
@@ -1376,103 +1354,6 @@ def clean(
|
|||||||
raise typer.Exit(code=1)
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def graph(
|
|
||||||
query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"),
|
|
||||||
symbol: str = typer.Argument(..., help="Symbol name to query"),
|
|
||||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
|
|
||||||
limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."),
|
|
||||||
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."),
|
|
||||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
|
||||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
|
||||||
) -> None:
|
|
||||||
"""Query semantic graph for code relationships.
|
|
||||||
|
|
||||||
Supported query types:
|
|
||||||
- callers: Find all functions/methods that call the given symbol
|
|
||||||
- callees: Find all functions/methods called by the given symbol
|
|
||||||
- inheritance: Find inheritance relationships for the given class
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
codex-lens graph callers my_function
|
|
||||||
codex-lens graph callees MyClass.method --path src/
|
|
||||||
codex-lens graph inheritance BaseClass
|
|
||||||
"""
|
|
||||||
_configure_logging(verbose)
|
|
||||||
search_path = path.expanduser().resolve()
|
|
||||||
|
|
||||||
# Validate query type
|
|
||||||
valid_types = ["callers", "callees", "inheritance"]
|
|
||||||
if query_type not in valid_types:
|
|
||||||
if json_mode:
|
|
||||||
print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}")
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Invalid query type:[/red] {query_type}")
|
|
||||||
console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
registry: RegistryStore | None = None
|
|
||||||
try:
|
|
||||||
registry = RegistryStore()
|
|
||||||
registry.initialize()
|
|
||||||
mapper = PathMapper()
|
|
||||||
|
|
||||||
engine = ChainSearchEngine(registry, mapper)
|
|
||||||
options = SearchOptions(depth=depth, total_limit=limit)
|
|
||||||
|
|
||||||
# Execute graph query based on type
|
|
||||||
if query_type == "callers":
|
|
||||||
results = engine.search_callers(symbol, search_path, options=options)
|
|
||||||
result_type = "callers"
|
|
||||||
elif query_type == "callees":
|
|
||||||
results = engine.search_callees(symbol, search_path, options=options)
|
|
||||||
result_type = "callees"
|
|
||||||
else: # inheritance
|
|
||||||
results = engine.search_inheritance(symbol, search_path, options=options)
|
|
||||||
result_type = "inheritance"
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"query_type": query_type,
|
|
||||||
"symbol": symbol,
|
|
||||||
"count": len(results),
|
|
||||||
"relationships": results
|
|
||||||
}
|
|
||||||
|
|
||||||
if json_mode:
|
|
||||||
print_json(success=True, result=payload)
|
|
||||||
else:
|
|
||||||
from .output import render_graph_results
|
|
||||||
render_graph_results(results, query_type=query_type, symbol=symbol)
|
|
||||||
|
|
||||||
except SearchError as exc:
|
|
||||||
if json_mode:
|
|
||||||
print_json(success=False, error=f"Graph search error: {exc}")
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Graph query failed (search):[/red] {exc}")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
except StorageError as exc:
|
|
||||||
if json_mode:
|
|
||||||
print_json(success=False, error=f"Storage error: {exc}")
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Graph query failed (storage):[/red] {exc}")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
except CodexLensError as exc:
|
|
||||||
if json_mode:
|
|
||||||
print_json(success=False, error=str(exc))
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Graph query failed:[/red] {exc}")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
except Exception as exc:
|
|
||||||
if json_mode:
|
|
||||||
print_json(success=False, error=f"Unexpected error: {exc}")
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Graph query failed (unexpected):[/red] {exc}")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
finally:
|
|
||||||
if registry is not None:
|
|
||||||
registry.close()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("semantic-list")
|
@app.command("semantic-list")
|
||||||
def semantic_list(
|
def semantic_list(
|
||||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),
|
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),
|
||||||
|
|||||||
@@ -194,7 +194,6 @@ def generate_embeddings(
|
|||||||
try:
|
try:
|
||||||
# Use cached embedder (singleton) for performance
|
# Use cached embedder (singleton) for performance
|
||||||
embedder = get_embedder(profile=model_profile)
|
embedder = get_embedder(profile=model_profile)
|
||||||
vector_store = VectorStore(index_path)
|
|
||||||
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
|
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
@@ -217,85 +216,86 @@ def generate_embeddings(
|
|||||||
EMBEDDING_BATCH_SIZE = 8 # jina-embeddings-v2-base-code needs small batches
|
EMBEDDING_BATCH_SIZE = 8 # jina-embeddings-v2-base-code needs small batches
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(index_path) as conn:
|
with VectorStore(index_path) as vector_store:
|
||||||
conn.row_factory = sqlite3.Row
|
with sqlite3.connect(index_path) as conn:
|
||||||
path_column = _get_path_column(conn)
|
conn.row_factory = sqlite3.Row
|
||||||
|
path_column = _get_path_column(conn)
|
||||||
|
|
||||||
# Get total file count for progress reporting
|
# Get total file count for progress reporting
|
||||||
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
|
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
|
||||||
if total_files == 0:
|
if total_files == 0:
|
||||||
return {"success": False, "error": "No files found in index"}
|
return {"success": False, "error": "No files found in index"}
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
|
|
||||||
|
|
||||||
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
|
|
||||||
batch_number = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Fetch a batch of files (streaming, not fetchall)
|
|
||||||
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
|
|
||||||
if not file_batch:
|
|
||||||
break
|
|
||||||
|
|
||||||
batch_number += 1
|
|
||||||
batch_chunks_with_paths = []
|
|
||||||
files_in_batch_with_chunks = set()
|
|
||||||
|
|
||||||
# Step 1: Chunking for the current file batch
|
|
||||||
for file_row in file_batch:
|
|
||||||
file_path = file_row[path_column]
|
|
||||||
content = file_row["content"]
|
|
||||||
language = file_row["language"] or "python"
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunks = chunker.chunk_sliding_window(
|
|
||||||
content,
|
|
||||||
file_path=file_path,
|
|
||||||
language=language
|
|
||||||
)
|
|
||||||
if chunks:
|
|
||||||
for chunk in chunks:
|
|
||||||
batch_chunks_with_paths.append((chunk, file_path))
|
|
||||||
files_in_batch_with_chunks.add(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to chunk {file_path}: {e}")
|
|
||||||
failed_files.append((file_path, str(e)))
|
|
||||||
|
|
||||||
if not batch_chunks_with_paths:
|
|
||||||
continue
|
|
||||||
|
|
||||||
batch_chunk_count = len(batch_chunks_with_paths)
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
|
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
|
||||||
|
|
||||||
# Step 2: Generate embeddings for this batch
|
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
|
||||||
batch_embeddings = []
|
batch_number = 0
|
||||||
try:
|
|
||||||
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
|
|
||||||
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
|
|
||||||
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
|
|
||||||
embeddings = embedder.embed(batch_contents)
|
|
||||||
batch_embeddings.extend(embeddings)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
|
|
||||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Step 3: Assign embeddings to chunks
|
while True:
|
||||||
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
|
# Fetch a batch of files (streaming, not fetchall)
|
||||||
chunk.embedding = embedding
|
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
|
||||||
|
if not file_batch:
|
||||||
|
break
|
||||||
|
|
||||||
# Step 4: Store this batch to database immediately (releases memory)
|
batch_number += 1
|
||||||
try:
|
batch_chunks_with_paths = []
|
||||||
vector_store.add_chunks_batch(batch_chunks_with_paths)
|
files_in_batch_with_chunks = set()
|
||||||
total_chunks_created += batch_chunk_count
|
|
||||||
total_files_processed += len(files_in_batch_with_chunks)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
|
|
||||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
|
||||||
|
|
||||||
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
|
# Step 1: Chunking for the current file batch
|
||||||
|
for file_row in file_batch:
|
||||||
|
file_path = file_row[path_column]
|
||||||
|
content = file_row["content"]
|
||||||
|
language = file_row["language"] or "python"
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunks = chunker.chunk_sliding_window(
|
||||||
|
content,
|
||||||
|
file_path=file_path,
|
||||||
|
language=language
|
||||||
|
)
|
||||||
|
if chunks:
|
||||||
|
for chunk in chunks:
|
||||||
|
batch_chunks_with_paths.append((chunk, file_path))
|
||||||
|
files_in_batch_with_chunks.add(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to chunk {file_path}: {e}")
|
||||||
|
failed_files.append((file_path, str(e)))
|
||||||
|
|
||||||
|
if not batch_chunks_with_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
batch_chunk_count = len(batch_chunks_with_paths)
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
|
||||||
|
|
||||||
|
# Step 2: Generate embeddings for this batch
|
||||||
|
batch_embeddings = []
|
||||||
|
try:
|
||||||
|
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
|
||||||
|
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
|
||||||
|
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
|
||||||
|
embeddings = embedder.embed(batch_contents)
|
||||||
|
batch_embeddings.extend(embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
|
||||||
|
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Step 3: Assign embeddings to chunks
|
||||||
|
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
|
||||||
|
chunk.embedding = embedding
|
||||||
|
|
||||||
|
# Step 4: Store this batch to database immediately (releases memory)
|
||||||
|
try:
|
||||||
|
vector_store.add_chunks_batch(batch_chunks_with_paths)
|
||||||
|
total_chunks_created += batch_chunk_count
|
||||||
|
total_files_processed += len(files_in_batch_with_chunks)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
|
||||||
|
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||||
|
|
||||||
|
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"success": False, "error": f"Failed to read or process files: {str(e)}"}
|
return {"success": False, "error": f"Failed to read or process files: {str(e)}"}
|
||||||
|
|||||||
@@ -122,68 +122,3 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) ->
|
|||||||
console.print(header)
|
console.print(header)
|
||||||
render_symbols(list(symbols), title="Discovered Symbols")
|
render_symbols(list(symbols), title="Discovered Symbols")
|
||||||
|
|
||||||
|
|
||||||
def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None:
|
|
||||||
"""Render semantic graph query results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results: List of relationship dicts
|
|
||||||
query_type: Type of query (callers, callees, inheritance)
|
|
||||||
symbol: Symbol name that was queried
|
|
||||||
"""
|
|
||||||
if not results:
|
|
||||||
console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}")
|
|
||||||
return
|
|
||||||
|
|
||||||
title_map = {
|
|
||||||
"callers": f"Callers of '{symbol}' ({len(results)} found)",
|
|
||||||
"callees": f"Callees of '{symbol}' ({len(results)} found)",
|
|
||||||
"inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)"
|
|
||||||
}
|
|
||||||
|
|
||||||
table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})"))
|
|
||||||
|
|
||||||
if query_type == "callers":
|
|
||||||
table.add_column("Caller", style="green")
|
|
||||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
|
||||||
table.add_column("Line", justify="right", style="yellow")
|
|
||||||
table.add_column("Type", style="dim")
|
|
||||||
|
|
||||||
for rel in results:
|
|
||||||
table.add_row(
|
|
||||||
rel.get("source_symbol", "-"),
|
|
||||||
rel.get("source_file", "-"),
|
|
||||||
str(rel.get("source_line", "-")),
|
|
||||||
rel.get("relationship_type", "-")
|
|
||||||
)
|
|
||||||
|
|
||||||
elif query_type == "callees":
|
|
||||||
table.add_column("Target", style="green")
|
|
||||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
|
||||||
table.add_column("Line", justify="right", style="yellow")
|
|
||||||
table.add_column("Type", style="dim")
|
|
||||||
|
|
||||||
for rel in results:
|
|
||||||
table.add_row(
|
|
||||||
rel.get("target_symbol", "-"),
|
|
||||||
rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"),
|
|
||||||
str(rel.get("source_line", "-")),
|
|
||||||
rel.get("relationship_type", "-")
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # inheritance
|
|
||||||
table.add_column("Derived Class", style="green")
|
|
||||||
table.add_column("Base Class", style="magenta")
|
|
||||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
|
||||||
table.add_column("Line", justify="right", style="yellow")
|
|
||||||
|
|
||||||
for rel in results:
|
|
||||||
table.add_row(
|
|
||||||
rel.get("source_symbol", "-"),
|
|
||||||
rel.get("target_symbol", "-"),
|
|
||||||
rel.get("source_file", "-"),
|
|
||||||
str(rel.get("source_line", "-"))
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(table)
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ class Symbol(BaseModel):
|
|||||||
kind: str = Field(..., min_length=1)
|
kind: str = Field(..., min_length=1)
|
||||||
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
||||||
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
|
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
|
||||||
token_count: Optional[int] = Field(default=None, description="Token count for symbol content")
|
|
||||||
symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering")
|
|
||||||
|
|
||||||
@field_validator("range")
|
@field_validator("range")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -29,13 +27,6 @@ class Symbol(BaseModel):
|
|||||||
raise ValueError("end_line must be >= start_line")
|
raise ValueError("end_line must be >= start_line")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@field_validator("token_count")
|
|
||||||
@classmethod
|
|
||||||
def validate_token_count(cls, value: Optional[int]) -> Optional[int]:
|
|
||||||
if value is not None and value < 0:
|
|
||||||
raise ValueError("token_count must be >= 0")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class SemanticChunk(BaseModel):
|
class SemanticChunk(BaseModel):
|
||||||
"""A semantically meaningful chunk of content, optionally embedded."""
|
"""A semantically meaningful chunk of content, optionally embedded."""
|
||||||
|
|||||||
@@ -302,108 +302,6 @@ class ChainSearchEngine:
|
|||||||
index_paths, name, kind, options.total_limit
|
index_paths, name, kind, options.total_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
def search_callers(self, target_symbol: str,
|
|
||||||
source_path: Path,
|
|
||||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
|
||||||
"""Find all callers of a given symbol across directory hierarchy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_symbol: Name of the symbol to find callers for
|
|
||||||
source_path: Starting directory path
|
|
||||||
options: Search configuration (uses defaults if None)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of relationship dicts with caller information
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> engine = ChainSearchEngine(registry, mapper)
|
|
||||||
>>> callers = engine.search_callers("my_function", Path("D:/project"))
|
|
||||||
>>> for caller in callers:
|
|
||||||
... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}")
|
|
||||||
"""
|
|
||||||
options = options or SearchOptions()
|
|
||||||
|
|
||||||
start_index = self._find_start_index(source_path)
|
|
||||||
if not start_index:
|
|
||||||
self.logger.warning(f"No index found for {source_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
|
||||||
if not index_paths:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return self._search_callers_parallel(
|
|
||||||
index_paths, target_symbol, options.total_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
def search_callees(self, source_symbol: str,
|
|
||||||
source_path: Path,
|
|
||||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
|
||||||
"""Find all callees (what a symbol calls) across directory hierarchy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_symbol: Name of the symbol to find callees for
|
|
||||||
source_path: Starting directory path
|
|
||||||
options: Search configuration (uses defaults if None)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of relationship dicts with callee information
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> engine = ChainSearchEngine(registry, mapper)
|
|
||||||
>>> callees = engine.search_callees("MyClass.method", Path("D:/project"))
|
|
||||||
>>> for callee in callees:
|
|
||||||
... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}")
|
|
||||||
"""
|
|
||||||
options = options or SearchOptions()
|
|
||||||
|
|
||||||
start_index = self._find_start_index(source_path)
|
|
||||||
if not start_index:
|
|
||||||
self.logger.warning(f"No index found for {source_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
|
||||||
if not index_paths:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return self._search_callees_parallel(
|
|
||||||
index_paths, source_symbol, options.total_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
def search_inheritance(self, class_name: str,
|
|
||||||
source_path: Path,
|
|
||||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
|
||||||
"""Find inheritance relationships for a class across directory hierarchy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_name: Name of the class to find inheritance for
|
|
||||||
source_path: Starting directory path
|
|
||||||
options: Search configuration (uses defaults if None)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of relationship dicts with inheritance information
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> engine = ChainSearchEngine(registry, mapper)
|
|
||||||
>>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project"))
|
|
||||||
>>> for rel in inheritance:
|
|
||||||
... print(f"{rel['source_symbol']} extends {rel['target_symbol']}")
|
|
||||||
"""
|
|
||||||
options = options or SearchOptions()
|
|
||||||
|
|
||||||
start_index = self._find_start_index(source_path)
|
|
||||||
if not start_index:
|
|
||||||
self.logger.warning(f"No index found for {source_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
|
||||||
if not index_paths:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return self._search_inheritance_parallel(
|
|
||||||
index_paths, class_name, options.total_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
# === Internal Methods ===
|
# === Internal Methods ===
|
||||||
|
|
||||||
def _find_start_index(self, source_path: Path) -> Optional[Path]:
|
def _find_start_index(self, source_path: Path) -> Optional[Path]:
|
||||||
@@ -711,273 +609,6 @@ class ChainSearchEngine:
|
|||||||
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
|
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _search_callers_parallel(self, index_paths: List[Path],
|
|
||||||
target_symbol: str,
|
|
||||||
limit: int) -> List[Dict[str, Any]]:
|
|
||||||
"""Search for callers across multiple indexes in parallel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_paths: List of _index.db paths to search
|
|
||||||
target_symbol: Target symbol name
|
|
||||||
limit: Total result limit
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deduplicated list of caller relationships
|
|
||||||
"""
|
|
||||||
all_callers = []
|
|
||||||
|
|
||||||
executor = self._get_executor()
|
|
||||||
future_to_path = {
|
|
||||||
executor.submit(
|
|
||||||
self._search_callers_single,
|
|
||||||
idx_path,
|
|
||||||
target_symbol
|
|
||||||
): idx_path
|
|
||||||
for idx_path in index_paths
|
|
||||||
}
|
|
||||||
|
|
||||||
for future in as_completed(future_to_path):
|
|
||||||
try:
|
|
||||||
callers = future.result()
|
|
||||||
all_callers.extend(callers)
|
|
||||||
except Exception as exc:
|
|
||||||
self.logger.error(f"Caller search failed: {exc}")
|
|
||||||
|
|
||||||
# Deduplicate by (source_file, source_line)
|
|
||||||
seen = set()
|
|
||||||
unique_callers = []
|
|
||||||
for caller in all_callers:
|
|
||||||
key = (caller.get("source_file"), caller.get("source_line"))
|
|
||||||
if key not in seen:
|
|
||||||
seen.add(key)
|
|
||||||
unique_callers.append(caller)
|
|
||||||
|
|
||||||
# Sort by source file and line
|
|
||||||
unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0)))
|
|
||||||
|
|
||||||
return unique_callers[:limit]
|
|
||||||
|
|
||||||
def _search_callers_single(self, index_path: Path,
|
|
||||||
target_symbol: str) -> List[Dict[str, Any]]:
|
|
||||||
"""Search for callers in a single index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to _index.db file
|
|
||||||
target_symbol: Target symbol name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of caller relationship dicts (empty on error)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with SQLiteStore(index_path) as store:
|
|
||||||
return store.query_relationships_by_target(target_symbol)
|
|
||||||
except Exception as exc:
|
|
||||||
self.logger.debug(f"Caller search error in {index_path}: {exc}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _search_callees_parallel(self, index_paths: List[Path],
|
|
||||||
source_symbol: str,
|
|
||||||
limit: int) -> List[Dict[str, Any]]:
|
|
||||||
"""Search for callees across multiple indexes in parallel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_paths: List of _index.db paths to search
|
|
||||||
source_symbol: Source symbol name
|
|
||||||
limit: Total result limit
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deduplicated list of callee relationships
|
|
||||||
"""
|
|
||||||
all_callees = []
|
|
||||||
|
|
||||||
executor = self._get_executor()
|
|
||||||
future_to_path = {
|
|
||||||
executor.submit(
|
|
||||||
self._search_callees_single,
|
|
||||||
idx_path,
|
|
||||||
source_symbol
|
|
||||||
): idx_path
|
|
||||||
for idx_path in index_paths
|
|
||||||
}
|
|
||||||
|
|
||||||
for future in as_completed(future_to_path):
|
|
||||||
try:
|
|
||||||
callees = future.result()
|
|
||||||
all_callees.extend(callees)
|
|
||||||
except Exception as exc:
|
|
||||||
self.logger.error(f"Callee search failed: {exc}")
|
|
||||||
|
|
||||||
# Deduplicate by (target_symbol, source_line)
|
|
||||||
seen = set()
|
|
||||||
unique_callees = []
|
|
||||||
for callee in all_callees:
|
|
||||||
key = (callee.get("target_symbol"), callee.get("source_line"))
|
|
||||||
if key not in seen:
|
|
||||||
seen.add(key)
|
|
||||||
unique_callees.append(callee)
|
|
||||||
|
|
||||||
# Sort by source line
|
|
||||||
unique_callees.sort(key=lambda c: c.get("source_line", 0))
|
|
||||||
|
|
||||||
return unique_callees[:limit]
|
|
||||||
|
|
||||||
def _search_callees_single(self, index_path: Path,
|
|
||||||
source_symbol: str) -> List[Dict[str, Any]]:
|
|
||||||
"""Search for callees in a single index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to _index.db file
|
|
||||||
source_symbol: Source symbol name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of callee relationship dicts (empty on error)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with SQLiteStore(index_path) as store:
|
|
||||||
# Single JOIN query to get all callees (fixes N+1 query problem)
|
|
||||||
# Uses public execute_query API instead of _get_connection bypass
|
|
||||||
rows = store.execute_query(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
s.name AS source_symbol,
|
|
||||||
r.target_qualified_name AS target_symbol,
|
|
||||||
r.relationship_type,
|
|
||||||
r.source_line,
|
|
||||||
f.full_path AS source_file,
|
|
||||||
r.target_file
|
|
||||||
FROM code_relationships r
|
|
||||||
JOIN symbols s ON r.source_symbol_id = s.id
|
|
||||||
JOIN files f ON s.file_id = f.id
|
|
||||||
WHERE s.name = ? AND r.relationship_type = 'call'
|
|
||||||
ORDER BY f.full_path, r.source_line
|
|
||||||
LIMIT 100
|
|
||||||
""",
|
|
||||||
(source_symbol,)
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"source_symbol": row["source_symbol"],
|
|
||||||
"target_symbol": row["target_symbol"],
|
|
||||||
"relationship_type": row["relationship_type"],
|
|
||||||
"source_line": row["source_line"],
|
|
||||||
"source_file": row["source_file"],
|
|
||||||
"target_file": row["target_file"],
|
|
||||||
}
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
except Exception as exc:
|
|
||||||
self.logger.debug(f"Callee search error in {index_path}: {exc}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _search_inheritance_parallel(self, index_paths: List[Path],
|
|
||||||
class_name: str,
|
|
||||||
limit: int) -> List[Dict[str, Any]]:
|
|
||||||
"""Search for inheritance relationships across multiple indexes in parallel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_paths: List of _index.db paths to search
|
|
||||||
class_name: Class name to search for
|
|
||||||
limit: Total result limit
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deduplicated list of inheritance relationships
|
|
||||||
"""
|
|
||||||
all_inheritance = []
|
|
||||||
|
|
||||||
executor = self._get_executor()
|
|
||||||
future_to_path = {
|
|
||||||
executor.submit(
|
|
||||||
self._search_inheritance_single,
|
|
||||||
idx_path,
|
|
||||||
class_name
|
|
||||||
): idx_path
|
|
||||||
for idx_path in index_paths
|
|
||||||
}
|
|
||||||
|
|
||||||
for future in as_completed(future_to_path):
|
|
||||||
try:
|
|
||||||
inheritance = future.result()
|
|
||||||
all_inheritance.extend(inheritance)
|
|
||||||
except Exception as exc:
|
|
||||||
self.logger.error(f"Inheritance search failed: {exc}")
|
|
||||||
|
|
||||||
# Deduplicate by (source_symbol, target_symbol)
|
|
||||||
seen = set()
|
|
||||||
unique_inheritance = []
|
|
||||||
for rel in all_inheritance:
|
|
||||||
key = (rel.get("source_symbol"), rel.get("target_symbol"))
|
|
||||||
if key not in seen:
|
|
||||||
seen.add(key)
|
|
||||||
unique_inheritance.append(rel)
|
|
||||||
|
|
||||||
# Sort by source file
|
|
||||||
unique_inheritance.sort(key=lambda r: r.get("source_file", ""))
|
|
||||||
|
|
||||||
return unique_inheritance[:limit]
|
|
||||||
|
|
||||||
def _search_inheritance_single(self, index_path: Path,
|
|
||||||
class_name: str) -> List[Dict[str, Any]]:
|
|
||||||
"""Search for inheritance relationships in a single index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to _index.db file
|
|
||||||
class_name: Class name to search for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of inheritance relationship dicts (empty on error)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with SQLiteStore(index_path) as store:
|
|
||||||
# Use UNION to find relationships where class is either:
|
|
||||||
# 1. The base class (target) - find derived classes
|
|
||||||
# 2. The derived class (source) - find parent classes
|
|
||||||
# Uses public execute_query API instead of _get_connection bypass
|
|
||||||
rows = store.execute_query(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
s.name AS source_symbol,
|
|
||||||
r.target_qualified_name,
|
|
||||||
r.relationship_type,
|
|
||||||
r.source_line,
|
|
||||||
f.full_path AS source_file,
|
|
||||||
r.target_file
|
|
||||||
FROM code_relationships r
|
|
||||||
JOIN symbols s ON r.source_symbol_id = s.id
|
|
||||||
JOIN files f ON s.file_id = f.id
|
|
||||||
WHERE r.target_qualified_name = ? AND r.relationship_type = 'inherits'
|
|
||||||
UNION
|
|
||||||
SELECT
|
|
||||||
s.name AS source_symbol,
|
|
||||||
r.target_qualified_name,
|
|
||||||
r.relationship_type,
|
|
||||||
r.source_line,
|
|
||||||
f.full_path AS source_file,
|
|
||||||
r.target_file
|
|
||||||
FROM code_relationships r
|
|
||||||
JOIN symbols s ON r.source_symbol_id = s.id
|
|
||||||
JOIN files f ON s.file_id = f.id
|
|
||||||
WHERE s.name = ? AND r.relationship_type = 'inherits'
|
|
||||||
LIMIT 100
|
|
||||||
""",
|
|
||||||
(class_name, class_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"source_symbol": row["source_symbol"],
|
|
||||||
"target_symbol": row["target_qualified_name"],
|
|
||||||
"relationship_type": row["relationship_type"],
|
|
||||||
"source_line": row["source_line"],
|
|
||||||
"source_file": row["source_file"],
|
|
||||||
"target_file": row["target_file"],
|
|
||||||
}
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
except Exception as exc:
|
|
||||||
self.logger.debug(f"Inheritance search error in {index_path}: {exc}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
# === Convenience Functions ===
|
# === Convenience Functions ===
|
||||||
|
|
||||||
@@ -1007,10 +638,9 @@ def quick_search(query: str,
|
|||||||
|
|
||||||
mapper = PathMapper()
|
mapper = PathMapper()
|
||||||
|
|
||||||
engine = ChainSearchEngine(registry, mapper)
|
with ChainSearchEngine(registry, mapper) as engine:
|
||||||
options = SearchOptions(depth=depth)
|
options = SearchOptions(depth=depth)
|
||||||
|
result = engine.search(query, source_path, options)
|
||||||
result = engine.search(query, source_path, options)
|
|
||||||
|
|
||||||
registry.close()
|
registry.close()
|
||||||
|
|
||||||
|
|||||||
@@ -1,542 +0,0 @@
|
|||||||
"""Graph analyzer for extracting code relationships using tree-sitter.
|
|
||||||
|
|
||||||
Provides AST-based analysis to identify function calls, method invocations,
|
|
||||||
and class inheritance relationships within source files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tree_sitter import Node as TreeSitterNode
|
|
||||||
TREE_SITTER_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
TreeSitterNode = None # type: ignore[assignment]
|
|
||||||
TREE_SITTER_AVAILABLE = False
|
|
||||||
|
|
||||||
from codexlens.entities import CodeRelationship, Symbol
|
|
||||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
|
||||||
|
|
||||||
|
|
||||||
class GraphAnalyzer:
|
|
||||||
"""Analyzer for extracting semantic relationships from code using AST traversal."""
|
|
||||||
|
|
||||||
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
|
|
||||||
"""Initialize graph analyzer for a language.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
|
||||||
parser: Optional TreeSitterSymbolParser instance for dependency injection.
|
|
||||||
If None, creates a new parser instance (backward compatibility).
|
|
||||||
"""
|
|
||||||
self.language_id = language_id
|
|
||||||
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
|
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
|
||||||
"""Check if graph analyzer is available.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if tree-sitter parser is initialized and ready
|
|
||||||
"""
|
|
||||||
return self._parser.is_available()
|
|
||||||
|
|
||||||
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
|
|
||||||
"""Analyze source code and extract relationships.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Source code text
|
|
||||||
file_path: File path for relationship context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CodeRelationship objects representing intra-file relationships
|
|
||||||
"""
|
|
||||||
if not self.is_available() or self._parser._parser is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
source_bytes = text.encode("utf8")
|
|
||||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
|
||||||
root = tree.root_node
|
|
||||||
|
|
||||||
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
|
|
||||||
|
|
||||||
return relationships
|
|
||||||
except Exception:
|
|
||||||
# Gracefully handle parsing errors
|
|
||||||
return []
|
|
||||||
|
|
||||||
def analyze_with_symbols(
|
|
||||||
self, text: str, file_path: Path, symbols: List[Symbol]
|
|
||||||
) -> List[CodeRelationship]:
|
|
||||||
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Source code text
|
|
||||||
file_path: File path for relationship context
|
|
||||||
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CodeRelationship objects representing intra-file relationships
|
|
||||||
"""
|
|
||||||
if not self.is_available() or self._parser._parser is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
source_bytes = text.encode("utf8")
|
|
||||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
|
||||||
root = tree.root_node
|
|
||||||
|
|
||||||
# Convert Symbol objects to internal symbol format
|
|
||||||
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
|
|
||||||
|
|
||||||
# Extract relationships using provided symbols
|
|
||||||
relationships = self._extract_relationships_with_symbols(
|
|
||||||
source_bytes, root, str(file_path.resolve()), defined_symbols
|
|
||||||
)
|
|
||||||
|
|
||||||
return relationships
|
|
||||||
except Exception:
|
|
||||||
# Gracefully handle parsing errors
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _convert_symbols_to_dict(
|
|
||||||
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
|
|
||||||
) -> List[dict]:
|
|
||||||
"""Convert Symbol objects to internal dict format for relationship extraction.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
root: Root AST node
|
|
||||||
symbols: Pre-parsed Symbol objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of symbol info dicts with name, node, and type
|
|
||||||
"""
|
|
||||||
symbol_dicts = []
|
|
||||||
symbol_names = {s.name for s in symbols}
|
|
||||||
|
|
||||||
# Find AST nodes corresponding to symbols
|
|
||||||
for node in self._iter_nodes(root):
|
|
||||||
node_type = node.type
|
|
||||||
|
|
||||||
# Check if this node matches any of our symbols
|
|
||||||
if node_type in {"function_definition", "async_function_definition"}:
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
name = self._node_text(source_bytes, name_node)
|
|
||||||
if name in symbol_names:
|
|
||||||
symbol_dicts.append({
|
|
||||||
"name": name,
|
|
||||||
"node": node,
|
|
||||||
"type": "function"
|
|
||||||
})
|
|
||||||
elif node_type == "class_definition":
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
name = self._node_text(source_bytes, name_node)
|
|
||||||
if name in symbol_names:
|
|
||||||
symbol_dicts.append({
|
|
||||||
"name": name,
|
|
||||||
"node": node,
|
|
||||||
"type": "class"
|
|
||||||
})
|
|
||||||
elif node_type in {"function_declaration", "generator_function_declaration"}:
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
name = self._node_text(source_bytes, name_node)
|
|
||||||
if name in symbol_names:
|
|
||||||
symbol_dicts.append({
|
|
||||||
"name": name,
|
|
||||||
"node": node,
|
|
||||||
"type": "function"
|
|
||||||
})
|
|
||||||
elif node_type == "method_definition":
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
name = self._node_text(source_bytes, name_node)
|
|
||||||
if name in symbol_names:
|
|
||||||
symbol_dicts.append({
|
|
||||||
"name": name,
|
|
||||||
"node": node,
|
|
||||||
"type": "method"
|
|
||||||
})
|
|
||||||
elif node_type in {"class_declaration", "class"}:
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
name = self._node_text(source_bytes, name_node)
|
|
||||||
if name in symbol_names:
|
|
||||||
symbol_dicts.append({
|
|
||||||
"name": name,
|
|
||||||
"node": node,
|
|
||||||
"type": "class"
|
|
||||||
})
|
|
||||||
elif node_type == "variable_declarator":
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
value_node = node.child_by_field_name("value")
|
|
||||||
if name_node and value_node and value_node.type == "arrow_function":
|
|
||||||
name = self._node_text(source_bytes, name_node)
|
|
||||||
if name in symbol_names:
|
|
||||||
symbol_dicts.append({
|
|
||||||
"name": name,
|
|
||||||
"node": node,
|
|
||||||
"type": "function"
|
|
||||||
})
|
|
||||||
|
|
||||||
return symbol_dicts
|
|
||||||
|
|
||||||
def _extract_relationships_with_symbols(
|
|
||||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
|
|
||||||
) -> List[CodeRelationship]:
|
|
||||||
"""Extract relationships from AST using pre-parsed symbols.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
root: Root AST node
|
|
||||||
file_path: Absolute file path
|
|
||||||
defined_symbols: Pre-parsed symbol dicts
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of extracted relationships
|
|
||||||
"""
|
|
||||||
relationships: List[CodeRelationship] = []
|
|
||||||
|
|
||||||
# Determine call node type based on language
|
|
||||||
if self.language_id == "python":
|
|
||||||
call_node_type = "call"
|
|
||||||
extract_target = self._extract_call_target
|
|
||||||
elif self.language_id in {"javascript", "typescript"}:
|
|
||||||
call_node_type = "call_expression"
|
|
||||||
extract_target = self._extract_js_call_target
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find call expressions and match to defined symbols
|
|
||||||
for node in self._iter_nodes(root):
|
|
||||||
if node.type == call_node_type:
|
|
||||||
# Extract caller context (enclosing function/method/class)
|
|
||||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
|
||||||
if source_symbol is None:
|
|
||||||
# Call at module level, use "<module>" as source
|
|
||||||
source_symbol = "<module>"
|
|
||||||
|
|
||||||
# Extract callee (function/method being called)
|
|
||||||
target_symbol = extract_target(source_bytes, node)
|
|
||||||
if target_symbol is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create relationship
|
|
||||||
line_number = node.start_point[0] + 1
|
|
||||||
relationships.append(
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol=source_symbol,
|
|
||||||
target_symbol=target_symbol,
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None, # Intra-file only
|
|
||||||
source_line=line_number,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return relationships
|
|
||||||
|
|
||||||
def _extract_relationships(
|
|
||||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
|
||||||
) -> List[CodeRelationship]:
|
|
||||||
"""Extract relationships from AST.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
root: Root AST node
|
|
||||||
file_path: Absolute file path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of extracted relationships
|
|
||||||
"""
|
|
||||||
if self.language_id == "python":
|
|
||||||
return self._extract_python_relationships(source_bytes, root, file_path)
|
|
||||||
elif self.language_id in {"javascript", "typescript"}:
|
|
||||||
return self._extract_js_ts_relationships(source_bytes, root, file_path)
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _extract_python_relationships(
|
|
||||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
|
||||||
) -> List[CodeRelationship]:
|
|
||||||
"""Extract Python relationships from AST.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
root: Root AST node
|
|
||||||
file_path: Absolute file path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Python relationships (function/method calls)
|
|
||||||
"""
|
|
||||||
relationships: List[CodeRelationship] = []
|
|
||||||
|
|
||||||
# First pass: collect all defined symbols with their scopes
|
|
||||||
defined_symbols = self._collect_python_symbols(source_bytes, root)
|
|
||||||
|
|
||||||
# Second pass: find call expressions and match to defined symbols
|
|
||||||
for node in self._iter_nodes(root):
|
|
||||||
if node.type == "call":
|
|
||||||
# Extract caller context (enclosing function/method/class)
|
|
||||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
|
||||||
if source_symbol is None:
|
|
||||||
# Call at module level, use "<module>" as source
|
|
||||||
source_symbol = "<module>"
|
|
||||||
|
|
||||||
# Extract callee (function/method being called)
|
|
||||||
target_symbol = self._extract_call_target(source_bytes, node)
|
|
||||||
if target_symbol is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create relationship
|
|
||||||
line_number = node.start_point[0] + 1
|
|
||||||
relationships.append(
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol=source_symbol,
|
|
||||||
target_symbol=target_symbol,
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None, # Intra-file only
|
|
||||||
source_line=line_number,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return relationships
|
|
||||||
|
|
||||||
def _extract_js_ts_relationships(
|
|
||||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
|
||||||
) -> List[CodeRelationship]:
|
|
||||||
"""Extract JavaScript/TypeScript relationships from AST.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
root: Root AST node
|
|
||||||
file_path: Absolute file path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of JS/TS relationships (function/method calls)
|
|
||||||
"""
|
|
||||||
relationships: List[CodeRelationship] = []
|
|
||||||
|
|
||||||
# First pass: collect all defined symbols
|
|
||||||
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
|
|
||||||
|
|
||||||
# Second pass: find call expressions
|
|
||||||
for node in self._iter_nodes(root):
|
|
||||||
if node.type == "call_expression":
|
|
||||||
# Extract caller context
|
|
||||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
|
||||||
if source_symbol is None:
|
|
||||||
source_symbol = "<module>"
|
|
||||||
|
|
||||||
# Extract callee
|
|
||||||
target_symbol = self._extract_js_call_target(source_bytes, node)
|
|
||||||
if target_symbol is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Create relationship
|
|
||||||
line_number = node.start_point[0] + 1
|
|
||||||
relationships.append(
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol=source_symbol,
|
|
||||||
target_symbol=target_symbol,
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=line_number,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return relationships
|
|
||||||
|
|
||||||
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
|
||||||
"""Collect all Python function/method/class definitions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
root: Root AST node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of symbol info dicts with name, node, and type
|
|
||||||
"""
|
|
||||||
symbols = []
|
|
||||||
for node in self._iter_nodes(root):
|
|
||||||
if node.type in {"function_definition", "async_function_definition"}:
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
symbols.append({
|
|
||||||
"name": self._node_text(source_bytes, name_node),
|
|
||||||
"node": node,
|
|
||||||
"type": "function"
|
|
||||||
})
|
|
||||||
elif node.type == "class_definition":
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
symbols.append({
|
|
||||||
"name": self._node_text(source_bytes, name_node),
|
|
||||||
"node": node,
|
|
||||||
"type": "class"
|
|
||||||
})
|
|
||||||
return symbols
|
|
||||||
|
|
||||||
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
|
||||||
"""Collect all JS/TS function/method/class definitions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
root: Root AST node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of symbol info dicts with name, node, and type
|
|
||||||
"""
|
|
||||||
symbols = []
|
|
||||||
for node in self._iter_nodes(root):
|
|
||||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
symbols.append({
|
|
||||||
"name": self._node_text(source_bytes, name_node),
|
|
||||||
"node": node,
|
|
||||||
"type": "function"
|
|
||||||
})
|
|
||||||
elif node.type == "method_definition":
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
symbols.append({
|
|
||||||
"name": self._node_text(source_bytes, name_node),
|
|
||||||
"node": node,
|
|
||||||
"type": "method"
|
|
||||||
})
|
|
||||||
elif node.type in {"class_declaration", "class"}:
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
if name_node:
|
|
||||||
symbols.append({
|
|
||||||
"name": self._node_text(source_bytes, name_node),
|
|
||||||
"node": node,
|
|
||||||
"type": "class"
|
|
||||||
})
|
|
||||||
elif node.type == "variable_declarator":
|
|
||||||
name_node = node.child_by_field_name("name")
|
|
||||||
value_node = node.child_by_field_name("value")
|
|
||||||
if name_node and value_node and value_node.type == "arrow_function":
|
|
||||||
symbols.append({
|
|
||||||
"name": self._node_text(source_bytes, name_node),
|
|
||||||
"node": node,
|
|
||||||
"type": "function"
|
|
||||||
})
|
|
||||||
return symbols
|
|
||||||
|
|
||||||
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
|
|
||||||
"""Find the enclosing function/method/class for a node.
|
|
||||||
|
|
||||||
Returns fully qualified name (e.g., "MyClass.my_method") by traversing up
|
|
||||||
the AST tree and collecting parent class/function names.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node: AST node to find enclosure for
|
|
||||||
symbols: List of defined symbols
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Fully qualified name of enclosing symbol, or None if at module level
|
|
||||||
"""
|
|
||||||
# Walk up the tree to find all enclosing symbols
|
|
||||||
enclosing_names = []
|
|
||||||
parent = node.parent
|
|
||||||
|
|
||||||
while parent is not None:
|
|
||||||
for symbol in symbols:
|
|
||||||
if symbol["node"] == parent:
|
|
||||||
# Prepend to maintain order (innermost to outermost)
|
|
||||||
enclosing_names.insert(0, symbol["name"])
|
|
||||||
break
|
|
||||||
parent = parent.parent
|
|
||||||
|
|
||||||
# Return fully qualified name or None if at module level
|
|
||||||
if enclosing_names:
|
|
||||||
return ".".join(enclosing_names)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
|
||||||
"""Extract the target function name from a Python call expression.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
node: Call expression node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Target function name, or None if cannot be determined
|
|
||||||
"""
|
|
||||||
function_node = node.child_by_field_name("function")
|
|
||||||
if function_node is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle simple identifiers (e.g., "foo()")
|
|
||||||
if function_node.type == "identifier":
|
|
||||||
return self._node_text(source_bytes, function_node)
|
|
||||||
|
|
||||||
# Handle attribute access (e.g., "obj.method()")
|
|
||||||
if function_node.type == "attribute":
|
|
||||||
attr_node = function_node.child_by_field_name("attribute")
|
|
||||||
if attr_node:
|
|
||||||
return self._node_text(source_bytes, attr_node)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
|
||||||
"""Extract the target function name from a JS/TS call expression.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
node: Call expression node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Target function name, or None if cannot be determined
|
|
||||||
"""
|
|
||||||
function_node = node.child_by_field_name("function")
|
|
||||||
if function_node is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle simple identifiers
|
|
||||||
if function_node.type == "identifier":
|
|
||||||
return self._node_text(source_bytes, function_node)
|
|
||||||
|
|
||||||
# Handle member expressions (e.g., "obj.method()")
|
|
||||||
if function_node.type == "member_expression":
|
|
||||||
property_node = function_node.child_by_field_name("property")
|
|
||||||
if property_node:
|
|
||||||
return self._node_text(source_bytes, property_node)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _iter_nodes(self, root: TreeSitterNode):
|
|
||||||
"""Iterate over all nodes in AST.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
root: Root node to start iteration
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
AST nodes in depth-first order
|
|
||||||
"""
|
|
||||||
stack = [root]
|
|
||||||
while stack:
|
|
||||||
node = stack.pop()
|
|
||||||
yield node
|
|
||||||
for child in reversed(node.children):
|
|
||||||
stack.append(child)
|
|
||||||
|
|
||||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
|
||||||
"""Extract text for a node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_bytes: Source code as bytes
|
|
||||||
node: AST node
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Text content of node
|
|
||||||
"""
|
|
||||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
|
||||||
@@ -602,6 +602,12 @@ class VectorStore:
|
|||||||
Returns:
|
Returns:
|
||||||
List of SearchResult ordered by similarity (highest first)
|
List of SearchResult ordered by similarity (highest first)
|
||||||
"""
|
"""
|
||||||
|
logger.warning(
|
||||||
|
"Using brute-force vector search (hnswlib not available). "
|
||||||
|
"This may cause high memory usage for large indexes. "
|
||||||
|
"Install hnswlib for better performance: pip install hnswlib"
|
||||||
|
)
|
||||||
|
|
||||||
with self._cache_lock:
|
with self._cache_lock:
|
||||||
# Refresh cache if needed
|
# Refresh cache if needed
|
||||||
if self._embedding_matrix is None:
|
if self._embedding_matrix is None:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from codexlens.entities import CodeRelationship, SearchResult, Symbol
|
from codexlens.entities import SearchResult, Symbol
|
||||||
from codexlens.errors import StorageError
|
from codexlens.errors import StorageError
|
||||||
|
|
||||||
|
|
||||||
@@ -237,116 +237,6 @@ class DirIndexStore:
|
|||||||
conn.rollback()
|
conn.rollback()
|
||||||
raise StorageError(f"Failed to add file {name}: {exc}") from exc
|
raise StorageError(f"Failed to add file {name}: {exc}") from exc
|
||||||
|
|
||||||
def add_relationships(
|
|
||||||
self,
|
|
||||||
file_path: str | Path,
|
|
||||||
relationships: List[CodeRelationship],
|
|
||||||
) -> int:
|
|
||||||
"""Store code relationships for a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the source file
|
|
||||||
relationships: List of CodeRelationship objects to store
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of relationships stored
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
StorageError: If database operations fail
|
|
||||||
"""
|
|
||||||
if not relationships:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
conn = self._get_connection()
|
|
||||||
file_path_str = str(Path(file_path).resolve())
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get file_id
|
|
||||||
row = conn.execute(
|
|
||||||
"SELECT id FROM files WHERE full_path=?", (file_path_str,)
|
|
||||||
).fetchone()
|
|
||||||
if not row:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
file_id = int(row["id"])
|
|
||||||
|
|
||||||
# Delete existing relationships for symbols in this file
|
|
||||||
conn.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM code_relationships
|
|
||||||
WHERE source_symbol_id IN (
|
|
||||||
SELECT id FROM symbols WHERE file_id=?
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
(file_id,),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Insert new relationships
|
|
||||||
relationship_rows = []
|
|
||||||
skipped_relationships = []
|
|
||||||
for rel in relationships:
|
|
||||||
# Extract simple name from fully qualified name (e.g., "MyClass.my_method" -> "my_method")
|
|
||||||
# This handles cases where GraphAnalyzer generates qualified names but symbols table stores simple names
|
|
||||||
source_symbol_simple = rel.source_symbol.split(".")[-1] if "." in rel.source_symbol else rel.source_symbol
|
|
||||||
|
|
||||||
# Find symbol_id by name and file
|
|
||||||
symbol_row = conn.execute(
|
|
||||||
"""
|
|
||||||
SELECT id FROM symbols
|
|
||||||
WHERE file_id=? AND name=? AND start_line<=? AND end_line>=?
|
|
||||||
LIMIT 1
|
|
||||||
""",
|
|
||||||
(file_id, source_symbol_simple, rel.source_line, rel.source_line),
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if not symbol_row:
|
|
||||||
# Try matching by simple name only
|
|
||||||
symbol_row = conn.execute(
|
|
||||||
"SELECT id FROM symbols WHERE file_id=? AND name=? LIMIT 1",
|
|
||||||
(file_id, source_symbol_simple),
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if symbol_row:
|
|
||||||
relationship_rows.append((
|
|
||||||
int(symbol_row["id"]),
|
|
||||||
rel.target_symbol,
|
|
||||||
rel.relationship_type,
|
|
||||||
rel.source_line,
|
|
||||||
rel.target_file,
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
# Log warning when symbol lookup fails
|
|
||||||
skipped_relationships.append(rel.source_symbol)
|
|
||||||
|
|
||||||
# Log skipped relationships for debugging
|
|
||||||
if skipped_relationships:
|
|
||||||
self.logger.warning(
|
|
||||||
"Failed to find source symbol IDs for %d relationships in %s: %s",
|
|
||||||
len(skipped_relationships),
|
|
||||||
file_path_str,
|
|
||||||
", ".join(set(skipped_relationships))
|
|
||||||
)
|
|
||||||
|
|
||||||
if relationship_rows:
|
|
||||||
conn.executemany(
|
|
||||||
"""
|
|
||||||
INSERT INTO code_relationships(
|
|
||||||
source_symbol_id, target_qualified_name, relationship_type,
|
|
||||||
source_line, target_file
|
|
||||||
)
|
|
||||||
VALUES(?, ?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
relationship_rows,
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
return len(relationship_rows)
|
|
||||||
|
|
||||||
except sqlite3.DatabaseError as exc:
|
|
||||||
conn.rollback()
|
|
||||||
raise StorageError(f"Failed to add relationships: {exc}") from exc
|
|
||||||
|
|
||||||
def add_files_batch(
|
def add_files_batch(
|
||||||
self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]]
|
self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]]
|
||||||
) -> int:
|
) -> int:
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Set
|
|||||||
|
|
||||||
from codexlens.config import Config
|
from codexlens.config import Config
|
||||||
from codexlens.parsers.factory import ParserFactory
|
from codexlens.parsers.factory import ParserFactory
|
||||||
from codexlens.semantic.graph_analyzer import GraphAnalyzer
|
|
||||||
from codexlens.storage.dir_index import DirIndexStore
|
from codexlens.storage.dir_index import DirIndexStore
|
||||||
from codexlens.storage.path_mapper import PathMapper
|
from codexlens.storage.path_mapper import PathMapper
|
||||||
from codexlens.storage.registry import ProjectInfo, RegistryStore
|
from codexlens.storage.registry import ProjectInfo, RegistryStore
|
||||||
@@ -525,16 +524,6 @@ class IndexTreeBuilder:
|
|||||||
symbols=indexed_file.symbols,
|
symbols=indexed_file.symbols,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract and store code relationships for graph visualization
|
|
||||||
if language_id in {"python", "javascript", "typescript"}:
|
|
||||||
graph_analyzer = GraphAnalyzer(language_id)
|
|
||||||
if graph_analyzer.is_available():
|
|
||||||
relationships = graph_analyzer.analyze_with_symbols(
|
|
||||||
text, file_path, indexed_file.symbols
|
|
||||||
)
|
|
||||||
if relationships:
|
|
||||||
store.add_relationships(file_path, relationships)
|
|
||||||
|
|
||||||
files_count += 1
|
files_count += 1
|
||||||
symbols_count += len(indexed_file.symbols)
|
symbols_count += len(indexed_file.symbols)
|
||||||
|
|
||||||
@@ -742,16 +731,6 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
|
|||||||
symbols=indexed_file.symbols,
|
symbols=indexed_file.symbols,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract and store code relationships for graph visualization
|
|
||||||
if language_id in {"python", "javascript", "typescript"}:
|
|
||||||
graph_analyzer = GraphAnalyzer(language_id)
|
|
||||||
if graph_analyzer.is_available():
|
|
||||||
relationships = graph_analyzer.analyze_with_symbols(
|
|
||||||
text, item, indexed_file.symbols
|
|
||||||
)
|
|
||||||
if relationships:
|
|
||||||
store.add_relationships(item, relationships)
|
|
||||||
|
|
||||||
files_count += 1
|
files_count += 1
|
||||||
symbols_count += len(indexed_file.symbols)
|
symbols_count += len(indexed_file.symbols)
|
||||||
|
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
"""
|
|
||||||
Migration 003: Add code relationships storage.
|
|
||||||
|
|
||||||
This migration introduces the `code_relationships` table to store semantic
|
|
||||||
relationships between code symbols (function calls, inheritance, imports).
|
|
||||||
This enables graph-based code navigation and dependency analysis.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from sqlite3 import Connection
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade(db_conn: Connection):
|
|
||||||
"""
|
|
||||||
Applies the migration to add code relationships table.
|
|
||||||
|
|
||||||
- Creates `code_relationships` table with foreign key to symbols
|
|
||||||
- Creates indexes for efficient relationship queries
|
|
||||||
- Supports lazy expansion with target_symbol being qualified names
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_conn: The SQLite database connection.
|
|
||||||
"""
|
|
||||||
cursor = db_conn.cursor()
|
|
||||||
|
|
||||||
log.info("Creating 'code_relationships' table...")
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
source_symbol_id INTEGER NOT NULL,
|
|
||||||
target_qualified_name TEXT NOT NULL,
|
|
||||||
relationship_type TEXT NOT NULL,
|
|
||||||
source_line INTEGER NOT NULL,
|
|
||||||
target_file TEXT,
|
|
||||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("Creating indexes for code_relationships...")
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
|
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
|
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
|
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("Finished creating code_relationships table and indexes.")
|
|
||||||
@@ -10,7 +10,7 @@ from dataclasses import asdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol
|
from codexlens.entities import IndexedFile, SearchResult, Symbol
|
||||||
from codexlens.errors import StorageError
|
from codexlens.errors import StorageError
|
||||||
|
|
||||||
|
|
||||||
@@ -420,167 +420,6 @@ class SQLiteStore:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None:
|
|
||||||
"""Store code relationships for a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file containing the relationships
|
|
||||||
relationships: List of CodeRelationship objects to store
|
|
||||||
"""
|
|
||||||
if not relationships:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
conn = self._get_connection()
|
|
||||||
resolved_path = str(Path(file_path).resolve())
|
|
||||||
|
|
||||||
# Get file_id
|
|
||||||
row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone()
|
|
||||||
if not row:
|
|
||||||
raise StorageError(f"File not found in index: {file_path}")
|
|
||||||
file_id = int(row["id"])
|
|
||||||
|
|
||||||
# Delete existing relationships for symbols in this file
|
|
||||||
conn.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM code_relationships
|
|
||||||
WHERE source_symbol_id IN (
|
|
||||||
SELECT id FROM symbols WHERE file_id=?
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
(file_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Insert new relationships
|
|
||||||
relationship_rows = []
|
|
||||||
for rel in relationships:
|
|
||||||
# Find source symbol ID
|
|
||||||
symbol_row = conn.execute(
|
|
||||||
"""
|
|
||||||
SELECT id FROM symbols
|
|
||||||
WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ?
|
|
||||||
ORDER BY (end_line - start_line) ASC
|
|
||||||
LIMIT 1
|
|
||||||
""",
|
|
||||||
(file_id, rel.source_symbol, rel.source_line, rel.source_line)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if symbol_row:
|
|
||||||
source_symbol_id = int(symbol_row["id"])
|
|
||||||
relationship_rows.append((
|
|
||||||
source_symbol_id,
|
|
||||||
rel.target_symbol,
|
|
||||||
rel.relationship_type,
|
|
||||||
rel.source_line,
|
|
||||||
rel.target_file
|
|
||||||
))
|
|
||||||
|
|
||||||
if relationship_rows:
|
|
||||||
conn.executemany(
|
|
||||||
"""
|
|
||||||
INSERT INTO code_relationships(
|
|
||||||
source_symbol_id, target_qualified_name, relationship_type,
|
|
||||||
source_line, target_file
|
|
||||||
)
|
|
||||||
VALUES(?, ?, ?, ?, ?)
|
|
||||||
""",
|
|
||||||
relationship_rows
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def query_relationships_by_target(
|
|
||||||
self, target_name: str, *, limit: int = 100
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Query relationships by target symbol name (find all callers).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_name: Name of the target symbol
|
|
||||||
limit: Maximum number of results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts containing relationship info with file paths and line numbers
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
conn = self._get_connection()
|
|
||||||
rows = conn.execute(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
s.name AS source_symbol,
|
|
||||||
r.target_qualified_name,
|
|
||||||
r.relationship_type,
|
|
||||||
r.source_line,
|
|
||||||
f.full_path AS source_file,
|
|
||||||
r.target_file
|
|
||||||
FROM code_relationships r
|
|
||||||
JOIN symbols s ON r.source_symbol_id = s.id
|
|
||||||
JOIN files f ON s.file_id = f.id
|
|
||||||
WHERE r.target_qualified_name = ?
|
|
||||||
ORDER BY f.full_path, r.source_line
|
|
||||||
LIMIT ?
|
|
||||||
""",
|
|
||||||
(target_name, limit)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"source_symbol": row["source_symbol"],
|
|
||||||
"target_symbol": row["target_qualified_name"],
|
|
||||||
"relationship_type": row["relationship_type"],
|
|
||||||
"source_line": row["source_line"],
|
|
||||||
"source_file": row["source_file"],
|
|
||||||
"target_file": row["target_file"],
|
|
||||||
}
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
def query_relationships_by_source(
|
|
||||||
self, source_symbol: str, source_file: str | Path, *, limit: int = 100
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Query relationships by source symbol (find what a symbol calls).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_symbol: Name of the source symbol
|
|
||||||
source_file: File path containing the source symbol
|
|
||||||
limit: Maximum number of results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts containing relationship info
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
conn = self._get_connection()
|
|
||||||
resolved_path = str(Path(source_file).resolve())
|
|
||||||
|
|
||||||
rows = conn.execute(
|
|
||||||
"""
|
|
||||||
SELECT
|
|
||||||
s.name AS source_symbol,
|
|
||||||
r.target_qualified_name,
|
|
||||||
r.relationship_type,
|
|
||||||
r.source_line,
|
|
||||||
f.path AS source_file,
|
|
||||||
r.target_file
|
|
||||||
FROM code_relationships r
|
|
||||||
JOIN symbols s ON r.source_symbol_id = s.id
|
|
||||||
JOIN files f ON s.file_id = f.id
|
|
||||||
WHERE s.name = ? AND f.path = ?
|
|
||||||
ORDER BY r.source_line
|
|
||||||
LIMIT ?
|
|
||||||
""",
|
|
||||||
(source_symbol, resolved_path, limit)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"source_symbol": row["source_symbol"],
|
|
||||||
"target_symbol": row["target_qualified_name"],
|
|
||||||
"relationship_type": row["relationship_type"],
|
|
||||||
"source_line": row["source_line"],
|
|
||||||
"source_file": row["source_file"],
|
|
||||||
"target_file": row["target_file"],
|
|
||||||
}
|
|
||||||
for row in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
def _connect(self) -> sqlite3.Connection:
|
def _connect(self) -> sqlite3.Connection:
|
||||||
"""Legacy method for backward compatibility."""
|
"""Legacy method for backward compatibility."""
|
||||||
return self._get_connection()
|
return self._get_connection()
|
||||||
|
|||||||
@@ -1,644 +0,0 @@
|
|||||||
"""Unit tests for ChainSearchEngine.
|
|
||||||
|
|
||||||
Tests the graph query methods (search_callers, search_callees, search_inheritance)
|
|
||||||
with mocked SQLiteStore dependency to test logic in isolation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import Mock, MagicMock, patch, call
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
from codexlens.search.chain_search import (
|
|
||||||
ChainSearchEngine,
|
|
||||||
SearchOptions,
|
|
||||||
SearchStats,
|
|
||||||
ChainSearchResult,
|
|
||||||
)
|
|
||||||
from codexlens.entities import SearchResult, Symbol
|
|
||||||
from codexlens.storage.registry import RegistryStore, DirMapping
|
|
||||||
from codexlens.storage.path_mapper import PathMapper
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_registry():
|
|
||||||
"""Create a mock RegistryStore."""
|
|
||||||
registry = Mock(spec=RegistryStore)
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_mapper():
|
|
||||||
"""Create a mock PathMapper."""
|
|
||||||
mapper = Mock(spec=PathMapper)
|
|
||||||
return mapper
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def search_engine(mock_registry, mock_mapper):
|
|
||||||
"""Create a ChainSearchEngine with mocked dependencies."""
|
|
||||||
return ChainSearchEngine(mock_registry, mock_mapper, max_workers=2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_index_path():
|
|
||||||
"""Sample index database path."""
|
|
||||||
return Path("/test/project/_index.db")
|
|
||||||
|
|
||||||
|
|
||||||
class TestChainSearchEngineCallers:
|
|
||||||
"""Tests for search_callers method."""
|
|
||||||
|
|
||||||
def test_search_callers_returns_relationships(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that search_callers returns caller relationships."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
target_symbol = "my_function"
|
|
||||||
|
|
||||||
# Mock finding the start index
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock collect_index_paths to return single index
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
# Mock the parallel search to return caller data
|
|
||||||
expected_callers = [
|
|
||||||
{
|
|
||||||
"source_symbol": "caller_function",
|
|
||||||
"target_symbol": "my_function",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 42,
|
|
||||||
"source_file": "/test/project/module.py",
|
|
||||||
"target_file": "/test/project/lib.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_search_callers_parallel', return_value=expected_callers):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callers(target_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["source_symbol"] == "caller_function"
|
|
||||||
assert result[0]["target_symbol"] == "my_function"
|
|
||||||
assert result[0]["relationship_type"] == "calls"
|
|
||||||
assert result[0]["source_line"] == 42
|
|
||||||
|
|
||||||
def test_search_callers_empty_results(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that search_callers handles no results gracefully."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
target_symbol = "nonexistent_function"
|
|
||||||
|
|
||||||
# Mock finding the start index
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock collect_index_paths
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
# Mock empty results
|
|
||||||
with patch.object(search_engine, '_search_callers_parallel', return_value=[]):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callers(target_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
def test_search_callers_no_index_found(self, search_engine, mock_registry):
|
|
||||||
"""Test that search_callers returns empty list when no index found."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
target_symbol = "my_function"
|
|
||||||
|
|
||||||
# Mock no index found
|
|
||||||
mock_registry.find_nearest_index.return_value = None
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_find_start_index', return_value=None):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callers(target_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
def test_search_callers_uses_options(self, search_engine, mock_registry, mock_mapper, sample_index_path):
|
|
||||||
"""Test that search_callers respects SearchOptions."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
target_symbol = "my_function"
|
|
||||||
options = SearchOptions(depth=1, total_limit=50)
|
|
||||||
|
|
||||||
# Configure mapper to return a path that exists
|
|
||||||
mock_mapper.source_to_index_db.return_value = sample_index_path
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]) as mock_collect:
|
|
||||||
with patch.object(search_engine, '_search_callers_parallel', return_value=[]) as mock_search:
|
|
||||||
# Patch Path.exists to return True so the exact match is found
|
|
||||||
with patch.object(Path, 'exists', return_value=True):
|
|
||||||
# Execute
|
|
||||||
search_engine.search_callers(target_symbol, source_path, options)
|
|
||||||
|
|
||||||
# Assert that depth was passed to collect_index_paths
|
|
||||||
mock_collect.assert_called_once_with(sample_index_path, 1)
|
|
||||||
# Assert that total_limit was passed to parallel search
|
|
||||||
mock_search.assert_called_once_with([sample_index_path], target_symbol, 50)
|
|
||||||
|
|
||||||
|
|
||||||
class TestChainSearchEngineCallees:
|
|
||||||
"""Tests for search_callees method."""
|
|
||||||
|
|
||||||
def test_search_callees_returns_relationships(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that search_callees returns callee relationships."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
source_symbol = "caller_function"
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
expected_callees = [
|
|
||||||
{
|
|
||||||
"source_symbol": "caller_function",
|
|
||||||
"target_symbol": "callee_function",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 15,
|
|
||||||
"source_file": "/test/project/module.py",
|
|
||||||
"target_file": "/test/project/lib.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callees(source_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["source_symbol"] == "caller_function"
|
|
||||||
assert result[0]["target_symbol"] == "callee_function"
|
|
||||||
assert result[0]["source_line"] == 15
|
|
||||||
|
|
||||||
def test_search_callees_filters_by_file(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that search_callees correctly handles file-specific queries."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
source_symbol = "MyClass.method"
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
# Multiple callees from same source symbol
|
|
||||||
expected_callees = [
|
|
||||||
{
|
|
||||||
"source_symbol": "MyClass.method",
|
|
||||||
"target_symbol": "helper_a",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 10,
|
|
||||||
"source_file": "/test/project/module.py",
|
|
||||||
"target_file": "/test/project/utils.py",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"source_symbol": "MyClass.method",
|
|
||||||
"target_symbol": "helper_b",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 20,
|
|
||||||
"source_file": "/test/project/module.py",
|
|
||||||
"target_file": "/test/project/utils.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callees(source_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result[0]["target_symbol"] == "helper_a"
|
|
||||||
assert result[1]["target_symbol"] == "helper_b"
|
|
||||||
|
|
||||||
def test_search_callees_empty_results(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that search_callees handles no callees gracefully."""
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
source_symbol = "leaf_function"
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
with patch.object(search_engine, '_search_callees_parallel', return_value=[]):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callees(source_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestChainSearchEngineInheritance:
|
|
||||||
"""Tests for search_inheritance method."""
|
|
||||||
|
|
||||||
def test_search_inheritance_returns_inherits_relationships(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that search_inheritance returns inheritance relationships."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
class_name = "BaseClass"
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
expected_inheritance = [
|
|
||||||
{
|
|
||||||
"source_symbol": "DerivedClass",
|
|
||||||
"target_symbol": "BaseClass",
|
|
||||||
"relationship_type": "inherits",
|
|
||||||
"source_line": 5,
|
|
||||||
"source_file": "/test/project/derived.py",
|
|
||||||
"target_file": "/test/project/base.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_inheritance(class_name, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["source_symbol"] == "DerivedClass"
|
|
||||||
assert result[0]["target_symbol"] == "BaseClass"
|
|
||||||
assert result[0]["relationship_type"] == "inherits"
|
|
||||||
|
|
||||||
def test_search_inheritance_multiple_subclasses(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test inheritance search with multiple derived classes."""
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
class_name = "BaseClass"
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
expected_inheritance = [
|
|
||||||
{
|
|
||||||
"source_symbol": "DerivedClassA",
|
|
||||||
"target_symbol": "BaseClass",
|
|
||||||
"relationship_type": "inherits",
|
|
||||||
"source_line": 5,
|
|
||||||
"source_file": "/test/project/derived_a.py",
|
|
||||||
"target_file": "/test/project/base.py",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"source_symbol": "DerivedClassB",
|
|
||||||
"target_symbol": "BaseClass",
|
|
||||||
"relationship_type": "inherits",
|
|
||||||
"source_line": 10,
|
|
||||||
"source_file": "/test/project/derived_b.py",
|
|
||||||
"target_file": "/test/project/base.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_inheritance(class_name, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(result) == 2
|
|
||||||
assert result[0]["source_symbol"] == "DerivedClassA"
|
|
||||||
assert result[1]["source_symbol"] == "DerivedClassB"
|
|
||||||
|
|
||||||
def test_search_inheritance_empty_results(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test inheritance search with no subclasses found."""
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
class_name = "FinalClass"
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=sample_index_path,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]):
|
|
||||||
with patch.object(search_engine, '_search_inheritance_parallel', return_value=[]):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_inheritance(class_name, source_path)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestChainSearchEngineParallelSearch:
|
|
||||||
"""Tests for parallel search aggregation."""
|
|
||||||
|
|
||||||
def test_parallel_search_aggregates_results(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that parallel search aggregates results from multiple indexes."""
|
|
||||||
# Setup
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
target_symbol = "my_function"
|
|
||||||
|
|
||||||
index_path_1 = Path("/test/project/_index.db")
|
|
||||||
index_path_2 = Path("/test/project/subdir/_index.db")
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=index_path_1,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
|
|
||||||
# Mock parallel search results from multiple indexes
|
|
||||||
callers_from_multiple = [
|
|
||||||
{
|
|
||||||
"source_symbol": "caller_in_root",
|
|
||||||
"target_symbol": "my_function",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 10,
|
|
||||||
"source_file": "/test/project/root.py",
|
|
||||||
"target_file": "/test/project/lib.py",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"source_symbol": "caller_in_subdir",
|
|
||||||
"target_symbol": "my_function",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 20,
|
|
||||||
"source_file": "/test/project/subdir/module.py",
|
|
||||||
"target_file": "/test/project/lib.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_search_callers_parallel', return_value=callers_from_multiple):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callers(target_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert results from both indexes are included
|
|
||||||
assert len(result) == 2
|
|
||||||
assert any(r["source_file"] == "/test/project/root.py" for r in result)
|
|
||||||
assert any(r["source_file"] == "/test/project/subdir/module.py" for r in result)
|
|
||||||
|
|
||||||
def test_parallel_search_deduplicates_results(self, search_engine, mock_registry, sample_index_path):
|
|
||||||
"""Test that parallel search deduplicates results by (source_file, source_line)."""
|
|
||||||
# Note: This test verifies the behavior of _search_callers_parallel deduplication
|
|
||||||
source_path = Path("/test/project")
|
|
||||||
target_symbol = "my_function"
|
|
||||||
|
|
||||||
index_path_1 = Path("/test/project/_index.db")
|
|
||||||
index_path_2 = Path("/test/project/_index.db") # Same index (simulates duplicate)
|
|
||||||
|
|
||||||
mock_registry.find_nearest_index.return_value = DirMapping(
|
|
||||||
id=1,
|
|
||||||
project_id=1,
|
|
||||||
source_path=source_path,
|
|
||||||
index_path=index_path_1,
|
|
||||||
depth=0,
|
|
||||||
files_count=10,
|
|
||||||
last_updated=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]):
|
|
||||||
# Mock duplicate results from same location
|
|
||||||
duplicate_callers = [
|
|
||||||
{
|
|
||||||
"source_symbol": "caller_function",
|
|
||||||
"target_symbol": "my_function",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 42,
|
|
||||||
"source_file": "/test/project/module.py",
|
|
||||||
"target_file": "/test/project/lib.py",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"source_symbol": "caller_function",
|
|
||||||
"target_symbol": "my_function",
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 42,
|
|
||||||
"source_file": "/test/project/module.py",
|
|
||||||
"target_file": "/test/project/lib.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch.object(search_engine, '_search_callers_parallel', return_value=duplicate_callers):
|
|
||||||
# Execute
|
|
||||||
result = search_engine.search_callers(target_symbol, source_path)
|
|
||||||
|
|
||||||
# Assert: even with duplicates in input, output may contain both
|
|
||||||
# (actual deduplication happens in _search_callers_parallel)
|
|
||||||
assert len(result) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestChainSearchEngineContextManager:
|
|
||||||
"""Tests for context manager functionality."""
|
|
||||||
|
|
||||||
def test_context_manager_closes_executor(self, mock_registry, mock_mapper):
|
|
||||||
"""Test that context manager properly closes executor."""
|
|
||||||
with ChainSearchEngine(mock_registry, mock_mapper) as engine:
|
|
||||||
# Force executor creation
|
|
||||||
engine._get_executor()
|
|
||||||
assert engine._executor is not None
|
|
||||||
|
|
||||||
# Executor should be closed after exiting context
|
|
||||||
assert engine._executor is None
|
|
||||||
|
|
||||||
def test_close_method_shuts_down_executor(self, search_engine):
|
|
||||||
"""Test that close() method shuts down executor."""
|
|
||||||
# Create executor
|
|
||||||
search_engine._get_executor()
|
|
||||||
assert search_engine._executor is not None
|
|
||||||
|
|
||||||
# Close
|
|
||||||
search_engine.close()
|
|
||||||
assert search_engine._executor is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestSearchCallersSingle:
|
|
||||||
"""Tests for _search_callers_single internal method."""
|
|
||||||
|
|
||||||
def test_search_callers_single_queries_store(self, search_engine, sample_index_path):
|
|
||||||
"""Test that _search_callers_single queries SQLiteStore correctly."""
|
|
||||||
target_symbol = "my_function"
|
|
||||||
|
|
||||||
# Mock SQLiteStore
|
|
||||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
|
||||||
mock_store_instance = MockStore.return_value.__enter__.return_value
|
|
||||||
mock_store_instance.query_relationships_by_target.return_value = [
|
|
||||||
{
|
|
||||||
"source_symbol": "caller",
|
|
||||||
"target_symbol": target_symbol,
|
|
||||||
"relationship_type": "calls",
|
|
||||||
"source_line": 10,
|
|
||||||
"source_file": "/test/file.py",
|
|
||||||
"target_file": "/test/lib.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
result = search_engine._search_callers_single(sample_index_path, target_symbol)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["source_symbol"] == "caller"
|
|
||||||
mock_store_instance.query_relationships_by_target.assert_called_once_with(target_symbol)
|
|
||||||
|
|
||||||
def test_search_callers_single_handles_errors(self, search_engine, sample_index_path):
|
|
||||||
"""Test that _search_callers_single returns empty list on error."""
|
|
||||||
target_symbol = "my_function"
|
|
||||||
|
|
||||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
|
||||||
MockStore.return_value.__enter__.side_effect = Exception("Database error")
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
result = search_engine._search_callers_single(sample_index_path, target_symbol)
|
|
||||||
|
|
||||||
# Assert - should return empty list, not raise exception
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestSearchCalleesSingle:
|
|
||||||
"""Tests for _search_callees_single internal method."""
|
|
||||||
|
|
||||||
def test_search_callees_single_queries_database(self, search_engine, sample_index_path):
|
|
||||||
"""Test that _search_callees_single queries SQLiteStore correctly."""
|
|
||||||
source_symbol = "caller_function"
|
|
||||||
|
|
||||||
# Mock SQLiteStore
|
|
||||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
|
||||||
mock_store_instance = MagicMock()
|
|
||||||
MockStore.return_value.__enter__.return_value = mock_store_instance
|
|
||||||
|
|
||||||
# Mock execute_query to return relationship data (using new public API)
|
|
||||||
mock_store_instance.execute_query.return_value = [
|
|
||||||
{
|
|
||||||
"source_symbol": source_symbol,
|
|
||||||
"target_symbol": "callee_function",
|
|
||||||
"relationship_type": "call",
|
|
||||||
"source_line": 15,
|
|
||||||
"source_file": "/test/module.py",
|
|
||||||
"target_file": "/test/lib.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
result = search_engine._search_callees_single(sample_index_path, source_symbol)
|
|
||||||
|
|
||||||
# Assert - verify execute_query was called (public API)
|
|
||||||
assert mock_store_instance.execute_query.called
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["source_symbol"] == source_symbol
|
|
||||||
assert result[0]["target_symbol"] == "callee_function"
|
|
||||||
|
|
||||||
def test_search_callees_single_handles_errors(self, search_engine, sample_index_path):
|
|
||||||
"""Test that _search_callees_single returns empty list on error."""
|
|
||||||
source_symbol = "caller_function"
|
|
||||||
|
|
||||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
|
||||||
MockStore.return_value.__enter__.side_effect = Exception("DB error")
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
result = search_engine._search_callees_single(sample_index_path, source_symbol)
|
|
||||||
|
|
||||||
# Assert - should return empty list, not raise exception
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestSearchInheritanceSingle:
|
|
||||||
"""Tests for _search_inheritance_single internal method."""
|
|
||||||
|
|
||||||
def test_search_inheritance_single_queries_database(self, search_engine, sample_index_path):
|
|
||||||
"""Test that _search_inheritance_single queries SQLiteStore correctly."""
|
|
||||||
class_name = "BaseClass"
|
|
||||||
|
|
||||||
# Mock SQLiteStore
|
|
||||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
|
||||||
mock_store_instance = MagicMock()
|
|
||||||
MockStore.return_value.__enter__.return_value = mock_store_instance
|
|
||||||
|
|
||||||
# Mock execute_query to return relationship data (using new public API)
|
|
||||||
mock_store_instance.execute_query.return_value = [
|
|
||||||
{
|
|
||||||
"source_symbol": "DerivedClass",
|
|
||||||
"target_qualified_name": "BaseClass",
|
|
||||||
"relationship_type": "inherits",
|
|
||||||
"source_line": 5,
|
|
||||||
"source_file": "/test/derived.py",
|
|
||||||
"target_file": "/test/base.py",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
result = search_engine._search_inheritance_single(sample_index_path, class_name)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
assert mock_store_instance.execute_query.called
|
|
||||||
assert len(result) == 1
|
|
||||||
assert result[0]["source_symbol"] == "DerivedClass"
|
|
||||||
assert result[0]["relationship_type"] == "inherits"
|
|
||||||
|
|
||||||
# Verify execute_query was called with 'inherits' filter
|
|
||||||
call_args = mock_store_instance.execute_query.call_args
|
|
||||||
sql_query = call_args[0][0]
|
|
||||||
assert "relationship_type = 'inherits'" in sql_query
|
|
||||||
|
|
||||||
def test_search_inheritance_single_handles_errors(self, search_engine, sample_index_path):
|
|
||||||
"""Test that _search_inheritance_single returns empty list on error."""
|
|
||||||
class_name = "BaseClass"
|
|
||||||
|
|
||||||
with patch('codexlens.search.chain_search.SQLiteStore') as MockStore:
|
|
||||||
MockStore.return_value.__enter__.side_effect = Exception("DB error")
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
result = search_engine._search_inheritance_single(sample_index_path, class_name)
|
|
||||||
|
|
||||||
# Assert - should return empty list, not raise exception
|
|
||||||
assert result == []
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
"""Tests for CLI search command with --enrich flag."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
from typer.testing import CliRunner
|
|
||||||
from codexlens.cli.commands import app
|
|
||||||
|
|
||||||
|
|
||||||
class TestCLISearchEnrich:
|
|
||||||
"""Test CLI search command with --enrich flag integration."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def runner(self):
|
|
||||||
"""Create CLI test runner."""
|
|
||||||
return CliRunner()
|
|
||||||
|
|
||||||
def test_search_with_enrich_flag_help(self, runner):
|
|
||||||
"""Test --enrich flag is documented in help."""
|
|
||||||
result = runner.invoke(app, ["search", "--help"])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "--enrich" in result.output
|
|
||||||
assert "relationships" in result.output.lower() or "graph" in result.output.lower()
|
|
||||||
|
|
||||||
def test_search_with_enrich_flag_accepted(self, runner):
|
|
||||||
"""Test --enrich flag is accepted by the CLI."""
|
|
||||||
result = runner.invoke(app, ["search", "test", "--enrich"])
|
|
||||||
# Should not show 'unknown option' error
|
|
||||||
assert "No such option" not in result.output
|
|
||||||
assert "error: unrecognized" not in result.output.lower()
|
|
||||||
|
|
||||||
def test_search_without_enrich_flag(self, runner):
|
|
||||||
"""Test search without --enrich flag has no relationships."""
|
|
||||||
result = runner.invoke(app, ["search", "test", "--json"])
|
|
||||||
# Even without an index, JSON should be attempted
|
|
||||||
if result.exit_code == 0:
|
|
||||||
try:
|
|
||||||
data = json.loads(result.output)
|
|
||||||
# If we get results, they should not have enriched=true
|
|
||||||
if data.get("success") and "result" in data:
|
|
||||||
assert data["result"].get("enriched", False) is False
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # Not JSON output, that's fine for error cases
|
|
||||||
|
|
||||||
def test_search_enrich_json_output_structure(self, runner):
|
|
||||||
"""Test JSON output structure includes enriched flag."""
|
|
||||||
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
|
|
||||||
# If we get valid JSON output, check structure
|
|
||||||
if result.exit_code == 0:
|
|
||||||
try:
|
|
||||||
data = json.loads(result.output)
|
|
||||||
if data.get("success") and "result" in data:
|
|
||||||
# enriched field should exist
|
|
||||||
assert "enriched" in data["result"]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # Not JSON output
|
|
||||||
|
|
||||||
def test_search_enrich_with_mode(self, runner):
|
|
||||||
"""Test --enrich works with different search modes."""
|
|
||||||
modes = ["exact", "fuzzy", "hybrid"]
|
|
||||||
for mode in modes:
|
|
||||||
result = runner.invoke(
|
|
||||||
app, ["search", "test", "--mode", mode, "--enrich"]
|
|
||||||
)
|
|
||||||
# Should not show validation errors
|
|
||||||
assert "Invalid" not in result.output
|
|
||||||
|
|
||||||
|
|
||||||
class TestEnrichFlagBehavior:
|
|
||||||
"""Test behavioral aspects of --enrich flag."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def runner(self):
|
|
||||||
"""Create CLI test runner."""
|
|
||||||
return CliRunner()
|
|
||||||
|
|
||||||
def test_enrich_failure_does_not_break_search(self, runner):
|
|
||||||
"""Test that enrichment failure doesn't prevent search from returning results."""
|
|
||||||
# Even without proper index, search should not crash due to enrich
|
|
||||||
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
|
|
||||||
# Should not have unhandled exception
|
|
||||||
assert "Traceback" not in result.output
|
|
||||||
|
|
||||||
def test_enrich_flag_with_files_only(self, runner):
|
|
||||||
"""Test --enrich is accepted with --files-only mode."""
|
|
||||||
result = runner.invoke(app, ["search", "test", "--enrich", "--files-only"])
|
|
||||||
# Should not show option conflict error
|
|
||||||
assert "conflict" not in result.output.lower()
|
|
||||||
|
|
||||||
def test_enrich_flag_with_limit(self, runner):
|
|
||||||
"""Test --enrich works with --limit parameter."""
|
|
||||||
result = runner.invoke(app, ["search", "test", "--enrich", "--limit", "5"])
|
|
||||||
# Should not show validation error
|
|
||||||
assert "Invalid" not in result.output
|
|
||||||
|
|
||||||
|
|
||||||
class TestEnrichOutputFormat:
|
|
||||||
"""Test output format with --enrich flag."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def runner(self):
|
|
||||||
"""Create CLI test runner."""
|
|
||||||
return CliRunner()
|
|
||||||
|
|
||||||
def test_enrich_verbose_shows_status(self, runner):
|
|
||||||
"""Test verbose mode shows enrichment status."""
|
|
||||||
result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"])
|
|
||||||
# Verbose mode may show enrichment info or warnings
|
|
||||||
# Just ensure it doesn't crash
|
|
||||||
assert result.exit_code in [0, 1] # 0 = success, 1 = no index
|
|
||||||
|
|
||||||
def test_json_output_has_enriched_field(self, runner):
|
|
||||||
"""Test JSON output always has enriched field when --enrich used."""
|
|
||||||
result = runner.invoke(app, ["search", "test", "--json", "--enrich"])
|
|
||||||
if result.exit_code == 0:
|
|
||||||
try:
|
|
||||||
data = json.loads(result.output)
|
|
||||||
if data.get("success"):
|
|
||||||
result_data = data.get("result", {})
|
|
||||||
assert "enriched" in result_data
|
|
||||||
assert isinstance(result_data["enriched"], bool)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
@@ -204,8 +204,6 @@ class TestEntitySerialization:
|
|||||||
"kind": "function",
|
"kind": "function",
|
||||||
"range": (1, 10),
|
"range": (1, 10),
|
||||||
"file": None,
|
"file": None,
|
||||||
"token_count": None,
|
|
||||||
"symbol_type": None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_indexed_file_model_dump(self):
|
def test_indexed_file_model_dump(self):
|
||||||
|
|||||||
@@ -1,436 +0,0 @@
|
|||||||
"""Tests for GraphAnalyzer - code relationship extraction."""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from codexlens.semantic.graph_analyzer import GraphAnalyzer
|
|
||||||
|
|
||||||
|
|
||||||
TREE_SITTER_PYTHON_AVAILABLE = True
|
|
||||||
try:
|
|
||||||
import tree_sitter_python # type: ignore[import-not-found] # noqa: F401
|
|
||||||
except Exception:
|
|
||||||
TREE_SITTER_PYTHON_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
TREE_SITTER_JS_AVAILABLE = True
|
|
||||||
try:
|
|
||||||
import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401
|
|
||||||
except Exception:
|
|
||||||
TREE_SITTER_JS_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
|
||||||
class TestPythonGraphAnalyzer:
|
|
||||||
"""Tests for Python relationship extraction."""
|
|
||||||
|
|
||||||
def test_simple_function_call(self):
|
|
||||||
"""Test extraction of simple function call."""
|
|
||||||
code = """def helper():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def main():
|
|
||||||
helper()
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Should find main -> helper call
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_symbol == "main"
|
|
||||||
assert rel.target_symbol == "helper"
|
|
||||||
assert rel.relationship_type == "call"
|
|
||||||
assert rel.source_line == 5
|
|
||||||
|
|
||||||
def test_multiple_calls_in_function(self):
|
|
||||||
"""Test extraction of multiple calls from same function."""
|
|
||||||
code = """def foo():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def bar():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def main():
|
|
||||||
foo()
|
|
||||||
bar()
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Should find main -> foo and main -> bar
|
|
||||||
assert len(relationships) == 2
|
|
||||||
targets = {rel.target_symbol for rel in relationships}
|
|
||||||
assert targets == {"foo", "bar"}
|
|
||||||
assert all(rel.source_symbol == "main" for rel in relationships)
|
|
||||||
|
|
||||||
def test_nested_function_calls(self):
|
|
||||||
"""Test extraction of calls from nested functions."""
|
|
||||||
code = """def inner_helper():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def outer():
|
|
||||||
def inner():
|
|
||||||
inner_helper()
|
|
||||||
inner()
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Should find outer.inner -> inner_helper and outer -> inner (with fully qualified names)
|
|
||||||
assert len(relationships) == 2
|
|
||||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
|
||||||
assert ("outer.inner", "inner_helper") in call_pairs
|
|
||||||
assert ("outer", "inner") in call_pairs
|
|
||||||
|
|
||||||
def test_method_call_in_class(self):
|
|
||||||
"""Test extraction of method calls within class."""
|
|
||||||
code = """class Calculator:
|
|
||||||
def add(self, a, b):
|
|
||||||
return a + b
|
|
||||||
|
|
||||||
def compute(self, x, y):
|
|
||||||
result = self.add(x, y)
|
|
||||||
return result
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Should find Calculator.compute -> add (with fully qualified source)
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_symbol == "Calculator.compute"
|
|
||||||
assert rel.target_symbol == "add"
|
|
||||||
|
|
||||||
def test_module_level_call(self):
|
|
||||||
"""Test extraction of module-level function calls."""
|
|
||||||
code = """def setup():
|
|
||||||
pass
|
|
||||||
|
|
||||||
setup()
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Should find <module> -> setup
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_symbol == "<module>"
|
|
||||||
assert rel.target_symbol == "setup"
|
|
||||||
|
|
||||||
def test_async_function_call(self):
|
|
||||||
"""Test extraction of calls involving async functions."""
|
|
||||||
code = """async def fetch_data():
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def process():
|
|
||||||
await fetch_data()
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Should find process -> fetch_data
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_symbol == "process"
|
|
||||||
assert rel.target_symbol == "fetch_data"
|
|
||||||
|
|
||||||
def test_complex_python_file(self):
|
|
||||||
"""Test extraction from realistic Python file with multiple patterns."""
|
|
||||||
code = """class DataProcessor:
|
|
||||||
def __init__(self):
|
|
||||||
self.data = []
|
|
||||||
|
|
||||||
def load(self, filename):
|
|
||||||
self.data = read_file(filename)
|
|
||||||
|
|
||||||
def process(self):
|
|
||||||
self.validate()
|
|
||||||
self.transform()
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def transform(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def read_file(filename):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def main():
|
|
||||||
processor = DataProcessor()
|
|
||||||
processor.load("data.txt")
|
|
||||||
processor.process()
|
|
||||||
|
|
||||||
main()
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Extract call pairs
|
|
||||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
|
||||||
|
|
||||||
# Expected relationships (with fully qualified source symbols for methods)
|
|
||||||
expected = {
|
|
||||||
("DataProcessor.load", "read_file"),
|
|
||||||
("DataProcessor.process", "validate"),
|
|
||||||
("DataProcessor.process", "transform"),
|
|
||||||
("main", "DataProcessor"),
|
|
||||||
("main", "load"),
|
|
||||||
("main", "process"),
|
|
||||||
("<module>", "main"),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Should find all expected relationships
|
|
||||||
assert call_pairs >= expected
|
|
||||||
|
|
||||||
def test_empty_file(self):
|
|
||||||
"""Test handling of empty file."""
|
|
||||||
code = ""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
assert len(relationships) == 0
|
|
||||||
|
|
||||||
def test_file_with_no_calls(self):
|
|
||||||
"""Test handling of file with definitions but no calls."""
|
|
||||||
code = """def func1():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def func2():
|
|
||||||
pass
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
assert len(relationships) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed")
|
|
||||||
class TestJavaScriptGraphAnalyzer:
|
|
||||||
"""Tests for JavaScript relationship extraction."""
|
|
||||||
|
|
||||||
def test_simple_function_call(self):
|
|
||||||
"""Test extraction of simple JavaScript function call."""
|
|
||||||
code = """function helper() {}
|
|
||||||
|
|
||||||
function main() {
|
|
||||||
helper();
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("javascript")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
|
||||||
|
|
||||||
# Should find main -> helper call
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_symbol == "main"
|
|
||||||
assert rel.target_symbol == "helper"
|
|
||||||
assert rel.relationship_type == "call"
|
|
||||||
|
|
||||||
def test_arrow_function_call(self):
|
|
||||||
"""Test extraction of calls from arrow functions."""
|
|
||||||
code = """const helper = () => {};
|
|
||||||
|
|
||||||
const main = () => {
|
|
||||||
helper();
|
|
||||||
};
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("javascript")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
|
||||||
|
|
||||||
# Should find main -> helper call
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_symbol == "main"
|
|
||||||
assert rel.target_symbol == "helper"
|
|
||||||
|
|
||||||
def test_class_method_call(self):
|
|
||||||
"""Test extraction of method calls in JavaScript class."""
|
|
||||||
code = """class Calculator {
|
|
||||||
add(a, b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
compute(x, y) {
|
|
||||||
return this.add(x, y);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("javascript")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
|
||||||
|
|
||||||
# Should find Calculator.compute -> add (with fully qualified source)
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_symbol == "Calculator.compute"
|
|
||||||
assert rel.target_symbol == "add"
|
|
||||||
|
|
||||||
def test_complex_javascript_file(self):
|
|
||||||
"""Test extraction from realistic JavaScript file."""
|
|
||||||
code = """function readFile(filename) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
class DataProcessor {
|
|
||||||
constructor() {
|
|
||||||
this.data = [];
|
|
||||||
}
|
|
||||||
|
|
||||||
load(filename) {
|
|
||||||
this.data = readFile(filename);
|
|
||||||
}
|
|
||||||
|
|
||||||
process() {
|
|
||||||
this.validate();
|
|
||||||
this.transform();
|
|
||||||
}
|
|
||||||
|
|
||||||
validate() {}
|
|
||||||
|
|
||||||
transform() {}
|
|
||||||
}
|
|
||||||
|
|
||||||
function main() {
|
|
||||||
const processor = new DataProcessor();
|
|
||||||
processor.load("data.txt");
|
|
||||||
processor.process();
|
|
||||||
}
|
|
||||||
|
|
||||||
main();
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("javascript")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.js"))
|
|
||||||
|
|
||||||
# Extract call pairs
|
|
||||||
call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
|
||||||
|
|
||||||
# Expected relationships (with fully qualified source symbols for methods)
|
|
||||||
# Note: constructor calls like "new DataProcessor()" are not tracked
|
|
||||||
expected = {
|
|
||||||
("DataProcessor.load", "readFile"),
|
|
||||||
("DataProcessor.process", "validate"),
|
|
||||||
("DataProcessor.process", "transform"),
|
|
||||||
("main", "load"),
|
|
||||||
("main", "process"),
|
|
||||||
("<module>", "main"),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Should find all expected relationships
|
|
||||||
assert call_pairs >= expected
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphAnalyzerEdgeCases:
|
|
||||||
"""Edge case tests for GraphAnalyzer."""
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
|
||||||
def test_unavailable_language(self):
|
|
||||||
"""Test handling of unsupported language."""
|
|
||||||
code = "some code"
|
|
||||||
analyzer = GraphAnalyzer("rust")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.rs"))
|
|
||||||
assert len(relationships) == 0
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
|
||||||
def test_malformed_python_code(self):
|
|
||||||
"""Test handling of malformed Python code."""
|
|
||||||
code = "def broken(\n pass"
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
# Should not crash
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
assert isinstance(relationships, list)
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
|
||||||
def test_file_path_in_relationship(self):
|
|
||||||
"""Test that file path is correctly set in relationships."""
|
|
||||||
code = """def foo():
|
|
||||||
pass
|
|
||||||
|
|
||||||
def bar():
|
|
||||||
foo()
|
|
||||||
"""
|
|
||||||
test_path = Path("test.py")
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, test_path)
|
|
||||||
|
|
||||||
assert len(relationships) == 1
|
|
||||||
rel = relationships[0]
|
|
||||||
assert rel.source_file == str(test_path.resolve())
|
|
||||||
assert rel.target_file is None # Intra-file
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
|
||||||
def test_performance_large_file(self):
|
|
||||||
"""Test performance on larger file (1000 lines)."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Generate file with many functions and calls
|
|
||||||
lines = []
|
|
||||||
for i in range(100):
|
|
||||||
lines.append(f"def func_{i}():")
|
|
||||||
if i > 0:
|
|
||||||
lines.append(f" func_{i-1}()")
|
|
||||||
else:
|
|
||||||
lines.append(" pass")
|
|
||||||
|
|
||||||
code = "\n".join(lines)
|
|
||||||
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
start_time = time.time()
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
|
||||||
|
|
||||||
# Should complete in under 500ms
|
|
||||||
assert elapsed_ms < 500
|
|
||||||
|
|
||||||
# Should find 99 calls (func_1 -> func_0, func_2 -> func_1, ...)
|
|
||||||
assert len(relationships) == 99
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed")
|
|
||||||
def test_call_accuracy_rate(self):
|
|
||||||
"""Test >95% accuracy on known call graph."""
|
|
||||||
code = """def a(): pass
|
|
||||||
def b(): pass
|
|
||||||
def c(): pass
|
|
||||||
def d(): pass
|
|
||||||
def e(): pass
|
|
||||||
|
|
||||||
def test1():
|
|
||||||
a()
|
|
||||||
b()
|
|
||||||
|
|
||||||
def test2():
|
|
||||||
c()
|
|
||||||
d()
|
|
||||||
|
|
||||||
def test3():
|
|
||||||
e()
|
|
||||||
|
|
||||||
def main():
|
|
||||||
test1()
|
|
||||||
test2()
|
|
||||||
test3()
|
|
||||||
"""
|
|
||||||
analyzer = GraphAnalyzer("python")
|
|
||||||
relationships = analyzer.analyze_file(code, Path("test.py"))
|
|
||||||
|
|
||||||
# Expected calls: test1->a, test1->b, test2->c, test2->d, test3->e, main->test1, main->test2, main->test3
|
|
||||||
expected_calls = {
|
|
||||||
("test1", "a"),
|
|
||||||
("test1", "b"),
|
|
||||||
("test2", "c"),
|
|
||||||
("test2", "d"),
|
|
||||||
("test3", "e"),
|
|
||||||
("main", "test1"),
|
|
||||||
("main", "test2"),
|
|
||||||
("main", "test3"),
|
|
||||||
}
|
|
||||||
|
|
||||||
found_calls = {(rel.source_symbol, rel.target_symbol) for rel in relationships}
|
|
||||||
|
|
||||||
# Calculate accuracy
|
|
||||||
correct = len(expected_calls & found_calls)
|
|
||||||
total = len(expected_calls)
|
|
||||||
accuracy = (correct / total) * 100 if total > 0 else 0
|
|
||||||
|
|
||||||
# Should have >95% accuracy
|
|
||||||
assert accuracy >= 95.0
|
|
||||||
assert correct == total # Should be 100% for this simple case
|
|
||||||
@@ -1,392 +0,0 @@
|
|||||||
"""End-to-end tests for graph search CLI commands."""
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typer.testing import CliRunner
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from codexlens.cli.commands import app
|
|
||||||
from codexlens.storage.sqlite_store import SQLiteStore
|
|
||||||
from codexlens.storage.registry import RegistryStore
|
|
||||||
from codexlens.storage.path_mapper import PathMapper
|
|
||||||
from codexlens.entities import IndexedFile, Symbol, CodeRelationship
|
|
||||||
|
|
||||||
|
|
||||||
runner = CliRunner()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_project():
|
|
||||||
"""Create a temporary project with indexed code and relationships."""
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
project_root = Path(tmpdir) / "test_project"
|
|
||||||
project_root.mkdir()
|
|
||||||
|
|
||||||
# Create test Python files
|
|
||||||
(project_root / "main.py").write_text("""
|
|
||||||
def main():
|
|
||||||
result = calculate(5, 3)
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
def calculate(a, b):
|
|
||||||
return add(a, b)
|
|
||||||
|
|
||||||
def add(x, y):
|
|
||||||
return x + y
|
|
||||||
""")
|
|
||||||
|
|
||||||
(project_root / "utils.py").write_text("""
|
|
||||||
class BaseClass:
|
|
||||||
def method(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class DerivedClass(BaseClass):
|
|
||||||
def method(self):
|
|
||||||
super().method()
|
|
||||||
helper()
|
|
||||||
|
|
||||||
def helper():
|
|
||||||
return True
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create a custom index directory for graph testing
|
|
||||||
# Skip the standard init to avoid schema conflicts
|
|
||||||
mapper = PathMapper()
|
|
||||||
index_root = mapper.source_to_index_dir(project_root)
|
|
||||||
index_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
test_db = index_root / "_index.db"
|
|
||||||
|
|
||||||
# Register project manually
|
|
||||||
registry = RegistryStore()
|
|
||||||
registry.initialize()
|
|
||||||
project_info = registry.register_project(
|
|
||||||
source_root=project_root,
|
|
||||||
index_root=index_root
|
|
||||||
)
|
|
||||||
registry.register_dir(
|
|
||||||
project_id=project_info.id,
|
|
||||||
source_path=project_root,
|
|
||||||
index_path=test_db,
|
|
||||||
depth=0,
|
|
||||||
files_count=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the store with proper SQLiteStore schema and add files
|
|
||||||
with SQLiteStore(test_db) as store:
|
|
||||||
# Read and add files to the store
|
|
||||||
main_content = (project_root / "main.py").read_text()
|
|
||||||
utils_content = (project_root / "utils.py").read_text()
|
|
||||||
|
|
||||||
main_indexed = IndexedFile(
|
|
||||||
path=str(project_root / "main.py"),
|
|
||||||
language="python",
|
|
||||||
symbols=[
|
|
||||||
Symbol(name="main", kind="function", range=(2, 4)),
|
|
||||||
Symbol(name="calculate", kind="function", range=(6, 7)),
|
|
||||||
Symbol(name="add", kind="function", range=(9, 10))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
utils_indexed = IndexedFile(
|
|
||||||
path=str(project_root / "utils.py"),
|
|
||||||
language="python",
|
|
||||||
symbols=[
|
|
||||||
Symbol(name="BaseClass", kind="class", range=(2, 4)),
|
|
||||||
Symbol(name="DerivedClass", kind="class", range=(6, 9)),
|
|
||||||
Symbol(name="helper", kind="function", range=(11, 12))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
store.add_file(main_indexed, main_content)
|
|
||||||
store.add_file(utils_indexed, utils_content)
|
|
||||||
|
|
||||||
with SQLiteStore(test_db) as store:
|
|
||||||
# Add relationships for main.py
|
|
||||||
main_file = project_root / "main.py"
|
|
||||||
relationships_main = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="main",
|
|
||||||
target_symbol="calculate",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=str(main_file),
|
|
||||||
source_line=3,
|
|
||||||
target_file=str(main_file)
|
|
||||||
),
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="calculate",
|
|
||||||
target_symbol="add",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=str(main_file),
|
|
||||||
source_line=7,
|
|
||||||
target_file=str(main_file)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
store.add_relationships(main_file, relationships_main)
|
|
||||||
|
|
||||||
# Add relationships for utils.py
|
|
||||||
utils_file = project_root / "utils.py"
|
|
||||||
relationships_utils = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="DerivedClass",
|
|
||||||
target_symbol="BaseClass",
|
|
||||||
relationship_type="inherits",
|
|
||||||
source_file=str(utils_file),
|
|
||||||
source_line=6, # DerivedClass is defined on line 6
|
|
||||||
target_file=str(utils_file)
|
|
||||||
),
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="DerivedClass.method",
|
|
||||||
target_symbol="helper",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=str(utils_file),
|
|
||||||
source_line=8,
|
|
||||||
target_file=str(utils_file)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
store.add_relationships(utils_file, relationships_utils)
|
|
||||||
|
|
||||||
registry.close()
|
|
||||||
|
|
||||||
yield project_root
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphCallers:
|
|
||||||
"""Test callers query type."""
|
|
||||||
|
|
||||||
def test_find_callers_basic(self, temp_project):
|
|
||||||
"""Test finding functions that call a given function."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"add",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "calculate" in result.stdout
|
|
||||||
assert "Callers of 'add'" in result.stdout
|
|
||||||
|
|
||||||
def test_find_callers_json_mode(self, temp_project):
|
|
||||||
"""Test callers query with JSON output."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"add",
|
|
||||||
"--path", str(temp_project),
|
|
||||||
"--json"
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "success" in result.stdout
|
|
||||||
assert "relationships" in result.stdout
|
|
||||||
|
|
||||||
def test_find_callers_no_results(self, temp_project):
|
|
||||||
"""Test callers query when no callers exist."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"nonexistent_function",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "No callers found" in result.stdout or "0 found" in result.stdout
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphCallees:
|
|
||||||
"""Test callees query type."""
|
|
||||||
|
|
||||||
def test_find_callees_basic(self, temp_project):
|
|
||||||
"""Test finding functions called by a given function."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callees",
|
|
||||||
"main",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "calculate" in result.stdout
|
|
||||||
assert "Callees of 'main'" in result.stdout
|
|
||||||
|
|
||||||
def test_find_callees_chain(self, temp_project):
|
|
||||||
"""Test finding callees in a call chain."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callees",
|
|
||||||
"calculate",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "add" in result.stdout
|
|
||||||
|
|
||||||
def test_find_callees_json_mode(self, temp_project):
|
|
||||||
"""Test callees query with JSON output."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callees",
|
|
||||||
"main",
|
|
||||||
"--path", str(temp_project),
|
|
||||||
"--json"
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "success" in result.stdout
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphInheritance:
|
|
||||||
"""Test inheritance query type."""
|
|
||||||
|
|
||||||
def test_find_inheritance_basic(self, temp_project):
|
|
||||||
"""Test finding inheritance relationships."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"inheritance",
|
|
||||||
"BaseClass",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "DerivedClass" in result.stdout
|
|
||||||
assert "Inheritance relationships" in result.stdout
|
|
||||||
|
|
||||||
def test_find_inheritance_derived(self, temp_project):
|
|
||||||
"""Test finding inheritance from derived class perspective."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"inheritance",
|
|
||||||
"DerivedClass",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "BaseClass" in result.stdout
|
|
||||||
|
|
||||||
def test_find_inheritance_json_mode(self, temp_project):
|
|
||||||
"""Test inheritance query with JSON output."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"inheritance",
|
|
||||||
"BaseClass",
|
|
||||||
"--path", str(temp_project),
|
|
||||||
"--json"
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "success" in result.stdout
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphValidation:
|
|
||||||
"""Test query validation and error handling."""
|
|
||||||
|
|
||||||
def test_invalid_query_type(self, temp_project):
|
|
||||||
"""Test error handling for invalid query type."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"invalid_type",
|
|
||||||
"symbol",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert "Invalid query type" in result.stdout
|
|
||||||
|
|
||||||
def test_invalid_path(self):
|
|
||||||
"""Test error handling for non-existent path."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"symbol",
|
|
||||||
"--path", "/nonexistent/path"
|
|
||||||
])
|
|
||||||
|
|
||||||
# Should handle gracefully (may exit with error or return empty results)
|
|
||||||
assert result.exit_code in [0, 1]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphPerformance:
|
|
||||||
"""Test graph query performance requirements."""
|
|
||||||
|
|
||||||
def test_query_response_time(self, temp_project):
|
|
||||||
"""Verify graph queries complete in under 1 second."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"add",
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
elapsed = time.time() - start
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert elapsed < 1.0, f"Query took {elapsed:.2f}s, expected <1s"
|
|
||||||
|
|
||||||
def test_multiple_query_types(self, temp_project):
|
|
||||||
"""Test all three query types complete successfully."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
queries = [
|
|
||||||
("callers", "add"),
|
|
||||||
("callees", "main"),
|
|
||||||
("inheritance", "BaseClass")
|
|
||||||
]
|
|
||||||
|
|
||||||
total_start = time.time()
|
|
||||||
|
|
||||||
for query_type, symbol in queries:
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
query_type,
|
|
||||||
symbol,
|
|
||||||
"--path", str(temp_project)
|
|
||||||
])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
total_elapsed = time.time() - total_start
|
|
||||||
assert total_elapsed < 3.0, f"All queries took {total_elapsed:.2f}s, expected <3s"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGraphOptions:
|
|
||||||
"""Test graph command options."""
|
|
||||||
|
|
||||||
def test_limit_option(self, temp_project):
|
|
||||||
"""Test limit option works correctly."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"add",
|
|
||||||
"--path", str(temp_project),
|
|
||||||
"--limit", "1"
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
def test_depth_option(self, temp_project):
|
|
||||||
"""Test depth option works correctly."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"add",
|
|
||||||
"--path", str(temp_project),
|
|
||||||
"--depth", "0"
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
def test_verbose_option(self, temp_project):
|
|
||||||
"""Test verbose option works correctly."""
|
|
||||||
result = runner.invoke(app, [
|
|
||||||
"graph",
|
|
||||||
"callers",
|
|
||||||
"add",
|
|
||||||
"--path", str(temp_project),
|
|
||||||
"--verbose"
|
|
||||||
])
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@@ -1,355 +0,0 @@
|
|||||||
"""Tests for code relationship storage."""
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from codexlens.entities import CodeRelationship, IndexedFile, Symbol
|
|
||||||
from codexlens.storage.migration_manager import MigrationManager
|
|
||||||
from codexlens.storage.sqlite_store import SQLiteStore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_db():
|
|
||||||
"""Create a temporary database for testing."""
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
db_path = Path(tmpdir) / "test.db"
|
|
||||||
yield db_path
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def store(temp_db):
|
|
||||||
"""Create a SQLiteStore with migrations applied."""
|
|
||||||
store = SQLiteStore(temp_db)
|
|
||||||
store.initialize()
|
|
||||||
|
|
||||||
# Manually apply migration_003 (code_relationships table)
|
|
||||||
conn = store._get_connection()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
source_symbol_id INTEGER NOT NULL,
|
|
||||||
target_qualified_name TEXT NOT NULL,
|
|
||||||
relationship_type TEXT NOT NULL,
|
|
||||||
source_line INTEGER NOT NULL,
|
|
||||||
target_file TEXT,
|
|
||||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
|
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
|
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
|
|
||||||
)
|
|
||||||
cursor.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
yield store
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
store.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_relationship_table_created(store):
|
|
||||||
"""Test that the code_relationships table is created by migration."""
|
|
||||||
conn = store._get_connection()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Check table exists
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='code_relationships'"
|
|
||||||
)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
assert result is not None, "code_relationships table should exist"
|
|
||||||
|
|
||||||
# Check indexes exist
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='code_relationships'"
|
|
||||||
)
|
|
||||||
indexes = [row[0] for row in cursor.fetchall()]
|
|
||||||
assert "idx_relationships_source" in indexes
|
|
||||||
assert "idx_relationships_target" in indexes
|
|
||||||
assert "idx_relationships_type" in indexes
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_relationships(store):
|
|
||||||
"""Test storing code relationships."""
|
|
||||||
# First add a file with symbols
|
|
||||||
indexed_file = IndexedFile(
|
|
||||||
path=str(Path(__file__).parent / "sample.py"),
|
|
||||||
language="python",
|
|
||||||
symbols=[
|
|
||||||
Symbol(name="foo", kind="function", range=(1, 5)),
|
|
||||||
Symbol(name="bar", kind="function", range=(7, 10)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
content = """def foo():
|
|
||||||
bar()
|
|
||||||
baz()
|
|
||||||
|
|
||||||
def bar():
|
|
||||||
print("hello")
|
|
||||||
"""
|
|
||||||
|
|
||||||
store.add_file(indexed_file, content)
|
|
||||||
|
|
||||||
# Add relationships
|
|
||||||
relationships = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="bar",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=indexed_file.path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=2
|
|
||||||
),
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="baz",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=indexed_file.path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=3
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
store.add_relationships(indexed_file.path, relationships)
|
|
||||||
|
|
||||||
# Verify relationships were stored
|
|
||||||
conn = store._get_connection()
|
|
||||||
count = conn.execute("SELECT COUNT(*) FROM code_relationships").fetchone()[0]
|
|
||||||
assert count == 2, "Should have stored 2 relationships"
|
|
||||||
|
|
||||||
|
|
||||||
def test_query_relationships_by_target(store):
|
|
||||||
"""Test querying relationships by target symbol (find callers)."""
|
|
||||||
# Setup: Add file and relationships
|
|
||||||
file_path = str(Path(__file__).parent / "sample.py")
|
|
||||||
# Content: Line 1-2: foo(), Line 4-5: bar(), Line 7-8: main()
|
|
||||||
indexed_file = IndexedFile(
|
|
||||||
path=file_path,
|
|
||||||
language="python",
|
|
||||||
symbols=[
|
|
||||||
Symbol(name="foo", kind="function", range=(1, 2)),
|
|
||||||
Symbol(name="bar", kind="function", range=(4, 5)),
|
|
||||||
Symbol(name="main", kind="function", range=(7, 8)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
content = "def foo():\n bar()\n\ndef bar():\n pass\n\ndef main():\n bar()\n"
|
|
||||||
store.add_file(indexed_file, content)
|
|
||||||
|
|
||||||
relationships = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="bar",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=2 # Call inside foo (line 2)
|
|
||||||
),
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="main",
|
|
||||||
target_symbol="bar",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=8 # Call inside main (line 8)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
store.add_relationships(file_path, relationships)
|
|
||||||
|
|
||||||
# Query: Find all callers of "bar"
|
|
||||||
callers = store.query_relationships_by_target("bar")
|
|
||||||
|
|
||||||
assert len(callers) == 2, "Should find 2 callers of bar"
|
|
||||||
assert any(r["source_symbol"] == "foo" for r in callers)
|
|
||||||
assert any(r["source_symbol"] == "main" for r in callers)
|
|
||||||
assert all(r["target_symbol"] == "bar" for r in callers)
|
|
||||||
assert all(r["relationship_type"] == "call" for r in callers)
|
|
||||||
|
|
||||||
|
|
||||||
def test_query_relationships_by_source(store):
|
|
||||||
"""Test querying relationships by source symbol (find callees)."""
|
|
||||||
# Setup
|
|
||||||
file_path = str(Path(__file__).parent / "sample.py")
|
|
||||||
indexed_file = IndexedFile(
|
|
||||||
path=file_path,
|
|
||||||
language="python",
|
|
||||||
symbols=[
|
|
||||||
Symbol(name="foo", kind="function", range=(1, 6)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
content = "def foo():\n bar()\n baz()\n qux()\n"
|
|
||||||
store.add_file(indexed_file, content)
|
|
||||||
|
|
||||||
relationships = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="bar",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=2
|
|
||||||
),
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="baz",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=3
|
|
||||||
),
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="qux",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=4
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
store.add_relationships(file_path, relationships)
|
|
||||||
|
|
||||||
# Query: Find all functions called by foo
|
|
||||||
callees = store.query_relationships_by_source("foo", file_path)
|
|
||||||
|
|
||||||
assert len(callees) == 3, "Should find 3 functions called by foo"
|
|
||||||
targets = {r["target_symbol"] for r in callees}
|
|
||||||
assert targets == {"bar", "baz", "qux"}
|
|
||||||
assert all(r["source_symbol"] == "foo" for r in callees)
|
|
||||||
|
|
||||||
|
|
||||||
def test_query_performance(store):
|
|
||||||
"""Test that relationship queries execute within performance threshold."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Setup: Create a file with many relationships
|
|
||||||
file_path = str(Path(__file__).parent / "large_file.py")
|
|
||||||
symbols = [Symbol(name=f"func_{i}", kind="function", range=(i*10+1, i*10+5)) for i in range(100)]
|
|
||||||
|
|
||||||
indexed_file = IndexedFile(
|
|
||||||
path=file_path,
|
|
||||||
language="python",
|
|
||||||
symbols=symbols
|
|
||||||
)
|
|
||||||
|
|
||||||
content = "\n".join([f"def func_{i}():\n pass\n" for i in range(100)])
|
|
||||||
store.add_file(indexed_file, content)
|
|
||||||
|
|
||||||
# Create many relationships
|
|
||||||
relationships = []
|
|
||||||
for i in range(100):
|
|
||||||
for j in range(10):
|
|
||||||
relationships.append(
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol=f"func_{i}",
|
|
||||||
target_symbol=f"target_{j}",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=i*10 + 1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
store.add_relationships(file_path, relationships)
|
|
||||||
|
|
||||||
# Query and measure time
|
|
||||||
start = time.time()
|
|
||||||
results = store.query_relationships_by_target("target_5")
|
|
||||||
elapsed_ms = (time.time() - start) * 1000
|
|
||||||
|
|
||||||
assert len(results) == 100, "Should find 100 callers"
|
|
||||||
assert elapsed_ms < 50, f"Query took {elapsed_ms:.1f}ms, should be <50ms"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stats_includes_relationships(store):
|
|
||||||
"""Test that stats() includes relationship count."""
|
|
||||||
# Add a file with relationships
|
|
||||||
file_path = str(Path(__file__).parent / "sample.py")
|
|
||||||
indexed_file = IndexedFile(
|
|
||||||
path=file_path,
|
|
||||||
language="python",
|
|
||||||
symbols=[Symbol(name="foo", kind="function", range=(1, 5))]
|
|
||||||
)
|
|
||||||
|
|
||||||
store.add_file(indexed_file, "def foo():\n bar()\n")
|
|
||||||
|
|
||||||
relationships = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="bar",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=2
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
store.add_relationships(file_path, relationships)
|
|
||||||
|
|
||||||
# Check stats
|
|
||||||
stats = store.stats()
|
|
||||||
|
|
||||||
assert "relationships" in stats
|
|
||||||
assert stats["relationships"] == 1
|
|
||||||
assert stats["files"] == 1
|
|
||||||
assert stats["symbols"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_relationships_on_file_reindex(store):
|
|
||||||
"""Test that relationships are updated when file is re-indexed."""
|
|
||||||
file_path = str(Path(__file__).parent / "sample.py")
|
|
||||||
|
|
||||||
# Initial index
|
|
||||||
indexed_file = IndexedFile(
|
|
||||||
path=file_path,
|
|
||||||
language="python",
|
|
||||||
symbols=[Symbol(name="foo", kind="function", range=(1, 3))]
|
|
||||||
)
|
|
||||||
store.add_file(indexed_file, "def foo():\n bar()\n")
|
|
||||||
|
|
||||||
relationships = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="bar",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=2
|
|
||||||
)
|
|
||||||
]
|
|
||||||
store.add_relationships(file_path, relationships)
|
|
||||||
|
|
||||||
# Re-index with different relationships
|
|
||||||
new_relationships = [
|
|
||||||
CodeRelationship(
|
|
||||||
source_symbol="foo",
|
|
||||||
target_symbol="baz",
|
|
||||||
relationship_type="call",
|
|
||||||
source_file=file_path,
|
|
||||||
target_file=None,
|
|
||||||
source_line=2
|
|
||||||
)
|
|
||||||
]
|
|
||||||
store.add_relationships(file_path, new_relationships)
|
|
||||||
|
|
||||||
# Verify old relationships are replaced
|
|
||||||
all_rels = store.query_relationships_by_source("foo", file_path)
|
|
||||||
assert len(all_rels) == 1
|
|
||||||
assert all_rels[0]["target_symbol"] == "baz"
|
|
||||||
@@ -188,60 +188,3 @@ class TestTokenCountPerformance:
|
|||||||
# Precomputed should be at least 10% faster
|
# Precomputed should be at least 10% faster
|
||||||
speedup = ((computed_time - precomputed_time) / computed_time) * 100
|
speedup = ((computed_time - precomputed_time) / computed_time) * 100
|
||||||
assert speedup >= 10.0, f"Speedup {speedup:.2f}% < 10% (computed={computed_time:.4f}s, precomputed={precomputed_time:.4f}s)"
|
assert speedup >= 10.0, f"Speedup {speedup:.2f}% < 10% (computed={computed_time:.4f}s, precomputed={precomputed_time:.4f}s)"
|
||||||
|
|
||||||
|
|
||||||
class TestSymbolEntityTokenCount:
|
|
||||||
"""Tests for Symbol entity token_count field."""
|
|
||||||
|
|
||||||
def test_symbol_with_token_count(self):
|
|
||||||
"""Test creating Symbol with token_count."""
|
|
||||||
symbol = Symbol(
|
|
||||||
name="test_func",
|
|
||||||
kind="function",
|
|
||||||
range=(1, 10),
|
|
||||||
token_count=42
|
|
||||||
)
|
|
||||||
|
|
||||||
assert symbol.token_count == 42
|
|
||||||
|
|
||||||
def test_symbol_without_token_count(self):
|
|
||||||
"""Test creating Symbol without token_count (defaults to None)."""
|
|
||||||
symbol = Symbol(
|
|
||||||
name="test_func",
|
|
||||||
kind="function",
|
|
||||||
range=(1, 10)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert symbol.token_count is None
|
|
||||||
|
|
||||||
def test_symbol_with_symbol_type(self):
|
|
||||||
"""Test creating Symbol with symbol_type."""
|
|
||||||
symbol = Symbol(
|
|
||||||
name="TestClass",
|
|
||||||
kind="class",
|
|
||||||
range=(1, 20),
|
|
||||||
symbol_type="class_definition"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert symbol.symbol_type == "class_definition"
|
|
||||||
|
|
||||||
def test_symbol_token_count_validation(self):
|
|
||||||
"""Test that negative token counts are rejected."""
|
|
||||||
with pytest.raises(ValueError, match="token_count must be >= 0"):
|
|
||||||
Symbol(
|
|
||||||
name="test",
|
|
||||||
kind="function",
|
|
||||||
range=(1, 2),
|
|
||||||
token_count=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_symbol_zero_token_count(self):
|
|
||||||
"""Test that zero token count is allowed."""
|
|
||||||
symbol = Symbol(
|
|
||||||
name="empty",
|
|
||||||
kind="function",
|
|
||||||
range=(1, 1),
|
|
||||||
token_count=0
|
|
||||||
)
|
|
||||||
|
|
||||||
assert symbol.token_count == 0
|
|
||||||
|
|||||||
Reference in New Issue
Block a user